# Breast Cancer Classification with CatBoost

This example demonstrates training a **CatBoostClassifier** to predict whether a breast tumor is malignant or benign.

## Dataset

- **Name**: Breast Cancer Wisconsin (Diagnostic)
- **Samples**: 569
- **Features**: 30 numeric features
- **Target**: Binary classification (0=malignant, 1=benign)

## What this notebook does

- Trains a CatBoost model
- Logs params + metrics to MLflow
- Saves model artifacts + signature
- Registers the model (best-effort)
- Prints sample payloads for deployment and prediction


In [None]:
%pip install catboost pandas numpy scikit-learn mlflow


In [None]:
import os
import argparse
import json
import tempfile
import numpy as np
import pandas as pd
from datetime import datetime

# CatBoost imports
from catboost import CatBoostClassifier

# MLflow imports
import mlflow
import mlflow.catboost
from mlflow import set_tracking_uri, set_experiment
from mlflow.client import MlflowClient
from mlflow.models import infer_signature

# Scikit-learn imports
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix


def setup_mlflow(mlflow_uri: str, username: str, password: str) -> MlflowClient:
    """Configure MLflow tracking and return client."""
    os.environ["MLFLOW_TRACKING_USERNAME"] = username
    os.environ["MLFLOW_TRACKING_PASSWORD"] = password

    set_tracking_uri(mlflow_uri)
    client = MlflowClient(mlflow_uri)

    print(f"MLflow tracking URI: {mlflow_uri}")
    return client


def load_and_prepare_data():
    """Load Breast Cancer dataset and prepare train/test splits."""
    print("\n" + "=" * 80)
    print("LOADING DATASET")
    print("=" * 80)

    data = load_breast_cancer(as_frame=True)
    X = data.data
    y = pd.Series(data.target, name="target")

    print("Dataset: Breast Cancer Wisconsin (Diagnostic)")
    print(f"Samples: {X.shape[0]:,}")
    print(f"Features: {X.shape[1]}")
    print("\nTarget distribution:")
    print(y.value_counts().sort_index().rename(index={0: "malignant(0)", 1: "benign(1)"}))

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, shuffle=True, stratify=y
    )

    print(f"\nTrain samples: {X_train.shape[0]:,}")
    print(f"Test samples: {X_test.shape[0]:,}")

    return X_train, X_test, y_train, y_test, X.columns.tolist()


def train_model(X_train, y_train, X_test, y_test, hyperparams: dict):
    """Train CatBoost model and return predictions."""
    print("\n" + "=" * 80)
    print("TRAINING MODEL")
    print("=" * 80)

    print("Hyperparameters:")
    for key, value in hyperparams.items():
        print(f"  {key}: {value}")

    model = CatBoostClassifier(**hyperparams)

    model.fit(
        X_train,
        y_train,
        eval_set=(X_test, y_test),
        verbose=hyperparams.get("verbose", 100),
    )

    print("Training completed!")

    y_train_pred = model.predict(X_train)
    y_test_pred = model.predict(X_test)
    y_train_proba = model.predict_proba(X_train)
    y_test_proba = model.predict_proba(X_test)

    return model, y_train_pred, y_test_pred, y_train_proba, y_test_proba


def calculate_metrics(y_true, y_pred, dataset_name="Test"):
    """Calculate and return evaluation metrics."""
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average="binary")
    recall = recall_score(y_true, y_pred, average="binary")
    f1 = f1_score(y_true, y_pred, average="binary")

    return {
        f"{dataset_name.lower()}_accuracy": accuracy,
        f"{dataset_name.lower()}_precision": precision,
        f"{dataset_name.lower()}_recall": recall,
        f"{dataset_name.lower()}_f1": f1,
    }


def log_to_mlflow(
    model,
    X_train,
    y_train,
    X_test,
    y_test,
    y_train_pred,
    y_test_pred,
    y_train_proba,
    y_test_proba,
    hyperparams,
    feature_names,
):
    """Log model, parameters, and metrics to MLflow."""
    print("\n" + "=" * 80)
    print("LOGGING TO MLFLOW")
    print("=" * 80)

    for key, value in hyperparams.items():
        mlflow.log_param(key, value)

    train_metrics = calculate_metrics(y_train, y_train_pred, "Train")
    test_metrics = calculate_metrics(y_test, y_test_pred, "Test")
    all_metrics = {**train_metrics, **test_metrics}

    for metric_name, metric_value in all_metrics.items():
        mlflow.log_metric(metric_name, metric_value)

    print("\nModel Performance:")
    print(f"  Training Accuracy: {train_metrics['train_accuracy']:.4f}")
    print(f"  Training F1:       {train_metrics['train_f1']:.4f}")
    print(f"  Test Accuracy:     {test_metrics['test_accuracy']:.4f}")
    print(f"  Test Precision:    {test_metrics['test_precision']:.4f}")
    print(f"  Test Recall:       {test_metrics['test_recall']:.4f}")
    print(f"  Test F1:           {test_metrics['test_f1']:.4f}")

    cm = confusion_matrix(y_test, y_test_pred)
    print("\nConfusion Matrix:")
    print(cm)

    # Explicit outputs for signature
    train_proba_df = pd.DataFrame(y_train_proba[:1], columns=["prob_malignant_0", "prob_benign_1"])
    signature = infer_signature(X_train, train_proba_df)
    input_example = X_test.head(1)

    with tempfile.TemporaryDirectory() as tmpdir:
        local_model_path = os.path.join(tmpdir, "model")
        mlflow.catboost.save_model(
            cb_model=model,
            path=local_model_path,
            signature=signature,
            input_example=input_example,
        )
        mlflow.log_artifacts(local_model_path, artifact_path="model")
        print("Model artifacts logged successfully!")

    return all_metrics


def create_sample_payload(X_test, y_test, model, feature_names):
    """Create realistic sample prediction payload."""
    sample_idx = 0
    sample = X_test.iloc[sample_idx]
    actual_class = int(y_test.iloc[sample_idx])

    predicted_class = int(model.predict(sample.values.reshape(1, -1))[0])

    return {
        "features": sample.to_dict(),
        "actual_class": actual_class,
        "predicted_class": predicted_class,
    }


def register_model(client: MlflowClient, model_name: str, run_id: str, experiment_id: str):
    """Register model in MLflow Model Registry."""
    print("\n" + "=" * 80)
    print("REGISTERING MODEL")
    print("=" * 80)

    model_uri = f"runs:/{run_id}/model"

    try:
        client.get_registered_model(model_name)
        print(f"Model '{model_name}' already exists in registry")
    except Exception:
        try:
            client.create_registered_model(model_name)
            print(f"Created registered model: {model_name}")
        except Exception as e:
            print(f"Could not create registered model: {e}")

    try:
        result = client.create_model_version(
            name=model_name,
            source=model_uri,
            run_id=run_id,
        )
        print("Model version registered successfully!")
        print(f"   Model Name: {model_name}")
        print(f"   Version: {result.version}")
        print(f"   Run ID: {run_id}")
        return result.version
    except Exception as e:
        print(f"Model registration failed (model still usable via run URI): {e}")
        print(f"   You can deploy using: mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model")
        return None


def print_deployment_info(run_id: str, experiment_id: str, sample_payload: dict):
    """Print deployment instructions and sample payloads."""
    print("\n" + "=" * 80)
    print("TRAINING COMPLETE!")
    print("=" * 80)

    print("\nRun Information:")
    print(f"  Run ID: {run_id}")
    print(f"  Experiment ID: {experiment_id}")
    print(f"  Model URI: mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model")

    print("\n" + "=" * 80)
    print("DEPLOYMENT PAYLOAD (deploy-model API)")
    print("=" * 80)

    deploy_payload = {
        "serve_name": "breast-cancer-catboost-classifier",
        "model_uri": f"mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model",
        "env": "local",
        "cores": 2,
        "memory": 4,
        "node_capacity": "spot",
        "min_replicas": 1,
        "max_replicas": 3,
    }
    print(json.dumps(deploy_payload, indent=2))

    print("\n" + "=" * 80)
    print("SAMPLE PREDICTION PAYLOAD")
    print("=" * 80)

    predict_payload = {"features": sample_payload["features"]}
    print(json.dumps(predict_payload, indent=2))

    print("\nExpected Output:")
    print(f"  Actual Class:      {sample_payload['actual_class']}  (0=malignant, 1=benign)")
    print(f"  Model Prediction:  {sample_payload['predicted_class']}")


In [None]:
def main():
    parser = argparse.ArgumentParser(description="Train CatBoost Breast Cancer Classification Model")
    parser.add_argument(
        "--mlflow-uri",
        default="http://darwin-mlflow-lib.darwin.svc.cluster.local:8080",
        help="MLflow tracking URI",
    )
    parser.add_argument(
        "--username",
        default="abc@gmail.com",
        help="MLflow username",
    )
    parser.add_argument(
        "--password",
        default="password",
        help="MLflow password",
    )
    parser.add_argument(
        "--experiment-name",
        default="breast_cancer_catboost_classification",
        help="MLflow experiment name",
    )
    parser.add_argument(
        "--model-name",
        default="BreastCancerCatBoostClassifier",
        help="Registered model name",
    )

    args, _ = parser.parse_known_args()

    print("\n" + "=" * 80)
    print("BREAST CANCER CLASSIFICATION WITH CATBOOST")
    print("=" * 80)
    print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    client = setup_mlflow(args.mlflow_uri, args.username, args.password)
    set_experiment(experiment_name=args.experiment_name)
    print(f"Experiment: {args.experiment_name}")

    X_train, X_test, y_train, y_test, feature_names = load_and_prepare_data()

    hyperparams = {
        "loss_function": "Logloss",
        "eval_metric": "AUC",
        "iterations": 300,
        "learning_rate": 0.05,
        "depth": 6,
        "l2_leaf_reg": 3.0,
        "random_seed": 42,
        "verbose": 50,
    }

    with mlflow.start_run(run_name=f"catboost_breast_cancer_{datetime.now().strftime('%Y%m%d_%H%M%S')}"):
        model, y_train_pred, y_test_pred, y_train_proba, y_test_proba = train_model(
            X_train, y_train, X_test, y_test, hyperparams.copy()
        )

        _ = log_to_mlflow(
            model,
            X_train,
            y_train,
            X_test,
            y_test,
            y_train_pred,
            y_test_pred,
            y_train_proba,
            y_test_proba,
            hyperparams,
            feature_names,
        )

        run_id = mlflow.active_run().info.run_id
        experiment_id = mlflow.active_run().info.experiment_id
        sample_payload = create_sample_payload(X_test, y_test, model, feature_names)

    _ = register_model(client, args.model_name, run_id, experiment_id)
    print_deployment_info(run_id, experiment_id, sample_payload)

    print("\nScript completed successfully!")
    print("=" * 80 + "\n")


if __name__ == "__main__":
    main()
