# MNIST Digit Classification with PyTorch

This notebook trains a neural network using PyTorch to classify handwritten digits.

**Dataset**: MNIST Handwritten Digits
- 60,000 training samples, 10,000 test samples
- Target: Digit class (0-9)

**Features**:
- 28x28 grayscale images (784 pixels)
- Normalized pixel values (0-1)


In [None]:
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
%pip install pandas numpy scikit-learn mlflow

In [None]:
import os

# Force CPU-only execution (disable CUDA even if available)
os.environ['CUDA_VISIBLE_DEVICES'] = ''

# ============================================================================
# Standard imports (after environment configuration)
# ============================================================================
import argparse
import json
import tempfile
import numpy as np
import pandas as pd
from datetime import datetime

# PyTorch imports (CPU-only)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

# Set default tensor type to CPU
torch.set_default_dtype(torch.float32)
device = torch.device('cpu')

# Verify PyTorch configuration
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Using device: {device}")

# MLflow imports
import mlflow
import mlflow.pytorch
from mlflow import set_tracking_uri, set_experiment
from mlflow.client import MlflowClient
from mlflow.models.signature import infer_signature

# Scikit-learn imports
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

In [None]:
def setup_mlflow(mlflow_uri: str, username: str, password: str) -> MlflowClient:
    """Configure MLflow tracking and return client."""
    os.environ["MLFLOW_TRACKING_USERNAME"] = username
    os.environ["MLFLOW_TRACKING_PASSWORD"] = password
    
    set_tracking_uri(mlflow_uri)
    client = MlflowClient(mlflow_uri)
    
    print(f"MLflow tracking URI: {mlflow_uri}")
    return client


def load_and_prepare_data():
    """Load MNIST dataset using torchvision and prepare train/test splits."""
    print("\n" + "=" * 80)
    print("LOADING DATASET (torchvision MNIST)")
    print("=" * 80)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(-1))  # flatten 28x28 â†’ 784
    ])

    train_dataset = datasets.MNIST(
        root="./data",
        train=True,
        download=True,
        transform=transform
    )

    test_dataset = datasets.MNIST(
        root="./data",
        train=False,
        download=True,
        transform=transform
    )

    print("MNIST download / load completed")

    # Convert to numpy arrays
    X_train = train_dataset.data.numpy().reshape(-1, 784).astype('float32') / 255.0
    y_train = train_dataset.targets.numpy()

    X_test = test_dataset.data.numpy().reshape(-1, 784).astype('float32') / 255.0
    y_test = test_dataset.targets.numpy()

    # Use subset for faster training
    X_train, y_train = X_train[:8000], y_train[:8000]
    X_test, y_test = X_test[:2000], y_test[:2000]

    print(f"Dataset: MNIST")
    print(f"Features: 784 (28x28 pixels)")
    print(f"Classes: 10 (digits 0-9)")
    print(f"\nTrain samples: {len(X_train):,}")
    print(f"Test samples: {len(X_test):,}")

    return X_train, X_test, y_train, y_test


class MNISTNet(nn.Module):
    """Simple neural network for MNIST classification."""
    def __init__(self):
        super(MNISTNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        x = x.view(-1, 784)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x


def train_model(X_train, y_train, X_test, y_test, hyperparams: dict):
    """Train PyTorch model and return predictions."""
    print("\n" + "=" * 80)
    print("TRAINING MODEL")
    print("=" * 80)
    
    print("Hyperparameters:")
    for key, value in hyperparams.items():
        print(f"  {key}: {value}")
    
    # Convert to tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long)
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long)
    
    # Create data loaders
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    train_loader = DataLoader(train_dataset, batch_size=hyperparams['batch_size'], shuffle=True)
    
    # Initialize model
    model = MNISTNet()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=hyperparams['learning_rate'])
    
    # Training loop
    model.train()
    for epoch in range(hyperparams['epochs']):
        total_loss = 0
        for batch_X, batch_y in train_loader:
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch [{epoch+1}/{hyperparams['epochs']}], Loss: {total_loss/len(train_loader):.4f}")
    
    print("Training completed!")
    
    # Predictions
    model.eval()
    with torch.no_grad():
        y_train_pred = torch.argmax(model(X_train_tensor), dim=1).numpy()
        y_test_pred = torch.argmax(model(X_test_tensor), dim=1).numpy()
    
    return model, y_train_pred, y_test_pred


def calculate_metrics(y_true, y_pred, dataset_name="Test"):
    """Calculate and return evaluation metrics."""
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    
    return {
        f"{dataset_name.lower()}_accuracy": accuracy,
        f"{dataset_name.lower()}_precision": precision,
        f"{dataset_name.lower()}_recall": recall,
        f"{dataset_name.lower()}_f1": f1
    }


def log_to_mlflow(model, X_train, y_train, X_test, y_test, 
                  y_train_pred, y_test_pred, hyperparams):
    """Log model, parameters, and metrics to MLflow."""
    print("\n" + "=" * 80)
    print("LOGGING TO MLFLOW")
    print("=" * 80)
    
    # Log hyperparameters
    for key, value in hyperparams.items():
        mlflow.log_param(key, value)
    
    # Calculate and log metrics
    train_metrics = calculate_metrics(y_train, y_train_pred, "Train")
    test_metrics = calculate_metrics(y_test, y_test_pred, "Test")
    all_metrics = {**train_metrics, **test_metrics}
    
    for metric_name, metric_value in all_metrics.items():
        mlflow.log_metric(metric_name, metric_value)
    
    print("\nModel Performance:")
    print(f"  Training Accuracy: {train_metrics['train_accuracy']:.4f}")
    print(f"  Training F1: {train_metrics['train_f1']:.4f}")
    print(f"  Test Accuracy: {test_metrics['test_accuracy']:.4f}")
    print(f"  Test Precision: {test_metrics['test_precision']:.4f}")
    print(f"  Test Recall: {test_metrics['test_recall']:.4f}")
    print(f"  Test F1: {test_metrics['test_f1']:.4f}")
    
    # Print confusion matrix
    cm = confusion_matrix(y_test, y_test_pred)
    print(f"\n  Confusion Matrix (first 5x5):")
    print(f"  {cm[:5, :5]}")
    
    # Create input DataFrame with column names (for signature)
    X_train_df = pd.DataFrame(X_train)
    X_train_df.columns = [f"pixel_{i}" for i in range(X_train_df.shape[1])]
    
    # Create output DataFrame matching model's actual output (raw logits)
    # MLflow's pytorch flavor returns raw model output, not argmax
    model.eval()
    with torch.no_grad():
        X_sample = torch.tensor(X_train[:1], dtype=torch.float32)
        logits = model(X_sample).numpy()  # Shape: (1, 10) - raw scores for each class
    
    # Create DataFrame with one column per class (digit 0-9)
    y_pred_df = pd.DataFrame(logits, columns=[f"digit_{i}_score" for i in range(10)])
    
    # Create signature and input example
    signature = infer_signature(X_train_df, y_pred_df)
    input_example = X_train_df.head(1)
    
    # Save and log model (like XGBoost example)
    with tempfile.TemporaryDirectory() as tmpdir:
        local_model_path = os.path.join(tmpdir, "model")
        
        mlflow.pytorch.save_model(
            model,
            local_model_path,
            signature=signature,
            input_example=input_example
        )
        
        mlflow.log_artifacts(local_model_path, artifact_path="model")
        print("Model artifacts logged successfully!")
    
    return all_metrics


def create_sample_payload(X_test, y_test, model):
    """Create realistic sample prediction payload."""
    # Get a sample
    sample_idx = 0
    sample = X_test[sample_idx]
    actual_digit = y_test[sample_idx]
    
    # Get model prediction (raw logits - what the deployed model will return)
    model.eval()
    with torch.no_grad():
        sample_tensor = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
        logits = model(sample_tensor).numpy()[0]  # Shape: (10,) - raw scores
        predicted_digit = int(logits.argmax())  # Convert to predicted class
    
    # Create feature dict
    features = {f"pixel_{i}": float(sample[i]) for i in range(len(sample))}
    
    return {
        "features": features,
        "actual_digit": int(actual_digit),
        "predicted_digit": predicted_digit,
        "raw_logits": logits.tolist()  # Include raw scores for reference
    }


def register_model(client: MlflowClient, model_name: str, run_id: str, experiment_id: str):
    """Register model in MLflow Model Registry."""
    print("\n" + "=" * 80)
    print("REGISTERING MODEL")
    print("=" * 80)
    
    model_uri = f"runs:/{run_id}/model"
    
    # Create registered model if it doesn't exist
    try:
        client.get_registered_model(model_name)
        print(f"Model '{model_name}' already exists in registry")
    except Exception:
        try:
            client.create_registered_model(model_name)
            print(f"Created registered model: {model_name}")
        except Exception as e:
            print(f"Could not create registered model: {e}")
    
    # Create model version
    try:
        result = client.create_model_version(
            name=model_name,
            source=model_uri,
            run_id=run_id
        )
        print(f"Model version registered successfully!")
        print(f"   Model Name: {model_name}")
        print(f"   Version: {result.version}")
        print(f"   Run ID: {run_id}")
        return result.version
    except Exception as e:
        print(f"Model registration failed (model still usable via run URI): {e}")
        print(f"   You can deploy using: mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model")
        return None


def print_deployment_info(run_id: str, experiment_id: str, sample_payload: dict):
    """Print deployment instructions and sample payloads."""
    print("\n" + "=" * 80)
    print("TRAINING COMPLETE!")
    print("=" * 80)
    
    print(f"\nRun Information:")
    print(f"  Run ID: {run_id}")
    print(f"  Experiment ID: {experiment_id}")
    print(f"  Model URI: mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model")
    
    print("\n" + "=" * 80)
    print("DEPLOYMENT PAYLOAD (deploy-model API)")
    print("=" * 80)
    
    deploy_payload = {
        "serve_name": "mnist-pytorch-classifier",
        "model_uri": f"mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model",
        "env": "local",
        "cores": 2,
        "memory": 4,
        "node_capacity": "spot",
        "min_replicas": 1,
        "max_replicas": 3
    }
    
    print(json.dumps(deploy_payload, indent=2))
    
    print("\n" + "=" * 80)
    print("SAMPLE PREDICTION PAYLOAD (first 10 pixels shown)")
    print("=" * 80)
    
    # Show only first 10 pixels for brevity
    sample_features = {k: v for i, (k, v) in enumerate(sample_payload["features"].items()) if i < 10}
    sample_features["..."] = "... (784 pixels total)"
    
    predict_payload = {
        "features": sample_features
    }
    
    print(json.dumps(predict_payload, indent=2))
    
    print(f"\nExpected Output (Raw Logits):")
    print(f"  Actual Digit: {sample_payload['actual_digit']}")
    print(f"  Predicted Digit: {sample_payload['predicted_digit']} (argmax of logits)")
    print(f"\n  Model returns raw logits (10 scores, one per digit):")
    print(f"  To get predicted digit: argmax(scores)")
    print(f"  To get probabilities: softmax(scores)")

In [None]:
def main():
    parser = argparse.ArgumentParser(description="Train PyTorch MNIST Classification Model")
    parser.add_argument(
        "--mlflow-uri",
        default="http://darwin-mlflow-lib.darwin.svc.cluster.local:8080",
        help="MLflow tracking URI"
    )
    parser.add_argument(
        "--username",
        default="abc@gmail.com",
        help="MLflow username"
    )
    parser.add_argument(
        "--password",
        default="password",
        help="MLflow password"
    )
    parser.add_argument(
        "--experiment-name",
        default="mnist_pytorch_classification",
        help="MLflow experiment name"
    )
    parser.add_argument(
        "--model-name",
        default="MNISTPyTorchClassifier",
        help="Registered model name"
    )
    
    args, _ = parser.parse_known_args()
    
    print("\n" + "=" * 80)
    print("MNIST DIGIT CLASSIFICATION WITH PYTORCH")
    print("=" * 80)
    print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Setup MLflow
    client = setup_mlflow(args.mlflow_uri, args.username, args.password)
    set_experiment(experiment_name=args.experiment_name)
    print(f"Experiment: {args.experiment_name}")
    
    # Load data
    X_train, X_test, y_train, y_test = load_and_prepare_data()
    
    # Define hyperparameters
    hyperparams = {
        "epochs": 20,
        "batch_size": 128,
        "learning_rate": 0.001,
        "dropout": 0.2
    }
    
    # Start MLflow run
    with mlflow.start_run(run_name=f"pytorch_mnist_{datetime.now().strftime('%Y%m%d_%H%M%S')}"):
        # Train model
        model, y_train_pred, y_test_pred = train_model(
            X_train, y_train, X_test, y_test, hyperparams
        )
        
        # Log to MLflow
        metrics = log_to_mlflow(
            model, X_train, y_train, X_test, y_test,
            y_train_pred, y_test_pred, hyperparams
        )
        
        # Get run information
        run_id = mlflow.active_run().info.run_id
        experiment_id = mlflow.active_run().info.experiment_id
        
        # Create sample payload
        sample_payload = create_sample_payload(X_test, y_test, model)
    
    # Register model (outside of run context)
    version = register_model(client, args.model_name, run_id, experiment_id)
    
    # Print deployment information
    print_deployment_info(run_id, experiment_id, sample_payload)
    
    print("\nScript completed successfully!")
    print("=" * 80 + "\n")


if __name__ == "__main__":
    main()
