In [None]:
import timm
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
from io import StringIO

# Initialize wandb
wandb.init(project="mnist-baseline-final", 
           name="1337",
           config={
    "epochs": 10,
    "batch_size": 64,
    "lr": 1e-3,
    "model_name": "resnet18",
    "seed": 1337
})
config = wandb.config

# Set random seed
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config.seed)

# Data transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Use seeded generator for reproducibility
generator = torch.Generator().manual_seed(config.seed)


# Load MNIST training set (60,000 images)
train_val_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Split the training set into 75% train and 25% val 
train_size = int(0.75 * len(train_val_dataset))  # 45,000
val_size = len(train_val_dataset) - train_size   # 15,000
train_set, val_set = random_split(train_val_dataset, [train_size, val_size], generator=generator)

# Check reproducibility: Print the indices
train_indices = train_set.indices if hasattr(train_set, "indices") else None
val_indices = val_set.indices if hasattr(val_set, "indices") else None

print("Train indices (first 10):", train_indices[:10])
print("Val indices (first 10):", val_indices[:10])

# Load the official MNIST test set (10,000 images)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True, generator=generator)
val_loader = DataLoader(val_set, batch_size=config.batch_size, shuffle=False, generator=generator)
test_loader = DataLoader(test_set, batch_size=config.batch_size, shuffle=False, generator=generator)
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = timm.create_model(config.model_name, pretrained=False, num_classes=10, in_chans=1)
model = model.to(device)

# Optimizer and Loss
optimizer = optim.Adam(model.parameters(), lr=config.lr)
criterion = nn.CrossEntropyLoss()
class_names = [str(i) for i in range(10)]

# define evaluation function for val and test

def evaluate_model(
    model,
    data_loader,
    criterion,
    class_names,
    mode="val",        # "val" or "test"
    epoch=None,
    device="cuda" if torch.cuda.is_available() else "cpu"
):
    model.eval()
    total_loss, total_correct = 0.0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            loss = criterion(outputs, labels)
            total_loss += loss.item()
            total_correct += (outputs.argmax(1) == labels).sum().item()

            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = total_correct / len(data_loader.dataset)
    avg_loss = total_loss / len(data_loader)

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    conf_mat = confusion_matrix(all_labels, all_preds)
    per_class_acc = conf_mat.diagonal() / conf_mat.sum(axis=1)

    # === Classification report ===
    report_text = classification_report(all_labels, all_preds, target_names=class_names)
    print(f"\n {mode.upper()} Classification Report:\n")
    #print(report_text)

    if epoch is not None:
        report_path = f"classification_report_{mode}_epoch{epoch+1}.txt"
    else:
        report_path = f"classification_report_{mode}.txt"

    with open(report_path, "w") as f:
        f.write(report_text)

    # === Confusion matrix plot ===
    fig, ax = plt.subplots()
    sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names, ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    ax.set_title(f"{mode.capitalize()} Confusion Matrix")
    plt.tight_layout()

    # === Per-class accuracy table ===
    acc_table = wandb.Table(columns=["Class", "Accuracy"])
    for cls, acc in zip(class_names, per_class_acc):
        acc_table.add_data(cls, float(acc))

    # === Classification report as W&B Table with accuracy column ===
    report_df = pd.read_fwf(StringIO(report_text), index_col=0)
    report_table = wandb.Table(columns=["Class", "Precision", "Recall", "F1-Score", "Support", "Accuracy"])

    rows_to_log = []
    accuracy_row = None

    for idx, row in report_df.iterrows():
        precision = row.get("precision", None)
        recall = row.get("recall", None)
        f1_ = row.get("f1-score", None)
        support = row.get("support", None)

        try:
            precision = float(precision) if precision != "-" else None
            recall = float(recall) if recall != "-" else None
            f1_ = float(f1_) if f1_ != "-" else None
            support = int(support) if pd.notna(support) else None
        except:
            continue

        if idx.isdigit():  # class label
            acc = float(per_class_acc[int(idx)])
            rows_to_log.append([idx, precision, recall, f1_, support, acc])
        elif idx.lower() == "accuracy":
            # put final accuracy value in "Accuracy" column
            accuracy_row = [idx, None, None, None, support, accuracy]
        else:
            rows_to_log.append([idx, precision, recall, f1_, support, None])

    for row in rows_to_log:
        report_table.add_data(*row)
    if accuracy_row:
        report_table.add_data(*accuracy_row)

    # preprint updated classification report table with accuracy
    columns = ["Class", "Precision", "Recall", "F1-Score", "Support", "Accuracy"]
    df = pd.DataFrame(rows_to_log + ([accuracy_row] if accuracy_row else []), columns=columns)

    print("\nClassification Report Table (with Accuracy):")
    print(df.to_string(index=False))


    # === W&B logging ===
    log_data = {
        f"{mode}_loss": avg_loss,
        f"{mode}_accuracy": accuracy,
        f"{mode}_precision_macro": precision,
        f"{mode}_recall_macro": recall,
        f"{mode}_f1_score_macro": f1,
        f"{mode}_classification_report_path": report_path
    }

    if epoch is not None:
        log_data["epoch"] = epoch + 1
        log_data[f"{mode}_confusion_matrix_image_epoch_{epoch+1}"] = wandb.Image(fig)
        log_data[f"{mode}_per_class_accuracy_table_epoch_{epoch+1}"] = acc_table
        log_data[f"{mode}_classification_report_table_epoch_{epoch+1}"] = report_table
    else:
        log_data[f"{mode}_confusion_matrix_image"] = wandb.Image(fig)
        log_data[f"{mode}_per_class_accuracy_table"] = acc_table
        log_data[f"{mode}_classification_report_table"] = report_table

    wandb.log(log_data)
    plt.close(fig)

    return avg_loss, accuracy


# Training + Evaluation loop
best_val_accuracy = 0.0
for epoch in range(config.epochs):
    model.train()
    total_loss, correct = 0.0, 0
    train_preds, train_labels = [], []

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

        preds = outputs.argmax(dim=1)
        train_preds.extend(preds.cpu().numpy())
        train_labels.extend(labels.cpu().numpy())

    train_accuracy = correct / len(train_loader.dataset)
    avg_train_loss = total_loss / len(train_loader)
    train_precision = precision_score(train_labels, train_preds, average='macro')
    train_recall = recall_score(train_labels, train_preds, average='macro')
    train_f1 = f1_score(train_labels, train_preds, average='macro')

    print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}, "
        f"Train Precision: {train_precision:.4f}, Train Recall: {train_recall:.4f}, Train F1: {train_f1:.4f}")
    
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": avg_train_loss,
        "train_accuracy": train_accuracy,
        "train_precision_macro": train_precision,
        "train_recall_macro":train_recall,
        "train_f1_score_macro": train_f1
    })

    # validation
    val_loss, val_accuracy = evaluate_model(model, val_loader, criterion, class_names, mode="val", epoch=epoch)
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), "best_model.pth")
        print(f"Best model saved with val accuracy: {val_accuracy:.4f}")

# final test
model.load_state_dict(torch.load("best_model.pth"))
model.to(device)
evaluate_model(model, test_loader, criterion, class_names, mode="test", epoch=None)


0,1
epoch,▁▁▂▂▃▃▅▅▆▆▇▇██
train_accuracy,▁▇▇████
train_f1_score_macro,▁▇█████
train_loss,█▂▁▁▁▁▁
train_precision_macro,▁▇▇████
train_recall_macro,▁▇█████
val_accuracy,▁▅▆▅█▆▇
val_f1_score_macro,▁▅▆▅█▆▇
val_loss,█▃▃▄▁▃▂
val_precision_macro,▁▅▅▅█▅█

0,1
epoch,7
train_accuracy,0.99324
train_f1_score_macro,0.99323
train_loss,0.02068
train_precision_macro,0.99324
train_recall_macro,0.99321
val_accuracy,0.98867
val_classification_report_path,classification_repor...
val_f1_score_macro,0.98865
val_loss,0.03658


Train indices (first 10): [27415, 57082, 25212, 22657, 23939, 28148, 28313, 3041, 10690, 18767]
Val indices (first 10): [6991, 39897, 3982, 52646, 14149, 38675, 18106, 42669, 33438, 48256]
[Epoch 1] Train Loss: 0.2998, Train Acc: 0.9148, Train Precision: 0.9152, Train Recall: 0.9138, Train F1: 0.9142

 VAL Classification Report:


Classification Report Table (with Accuracy):
       Class  Precision  Recall  F1-Score  Support  Accuracy
           0       1.00    0.89      0.94     1442  0.893897
           1       0.98    0.98      0.98     1672  0.982656
           2       0.97    0.98      0.97     1504  0.978059
           3       0.98    0.98      0.98     1500  0.983333
           4       0.98    0.98      0.98     1460  0.976027
           5       0.98    0.99      0.98     1326  0.987179
           6       0.98    0.99      0.98     1486  0.986541
           7       0.99    0.95      0.97     1624  0.946429
           8       0.89    0.98      0.93     1473  0.983707
           9

KeyboardInterrupt: 