In [None]:
## Imports and Environment Setup

# Display system information (hostname and kernel) - useful for verifying the environment
!hostnamectl

# Change working directory to the project source folder
#%cd /home/ir739wb/ilyarekun/bc_project/centralized-learning/src/

import sys
import os

# Add '../src' to the Python path so modules in that directory can be imported
sys.path.append('../src')

# Import custom data preprocessing function
from src.data_preprocessing import data_preprocessing_tumor_stratified

# Import model definition and early stopping utility
from src.model import BrainCNN, EarlyStopping

# Utility for counting occurrences per class
from collections import defaultdict

# Standard libraries for numeric operations, randomness, and file handling
import numpy as np
import random
import torch
import pickle

# Libraries for plotting and visualization
import seaborn as sns
import matplotlib.pyplot as plt

# PyTorch building blocks: neural network modules and optimizers
from torch import nn
from torch import optim

# Scikit-learn metrics for evaluation
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix


## Reproducibility: Set Random Seeds

seed = 42
torch.manual_seed(seed)                        # Seed CPU RNG
torch.cuda.manual_seed(seed)                   # Seed current GPU
torch.cuda.manual_seed_all(seed)               # Seed all GPUs (if using multi-GPU)
random.seed(seed)                              # Seed Python built-in RNG
np.random.seed(seed)                           # Seed NumPy RNG

# Ensure deterministic behavior in cuDNN (at the cost of performance)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


## Load and Inspect Data

# Perform stratified preprocessing for tumor classification, returning PyTorch DataLoaders
train_loader, valid_loader, test_loader = data_preprocessing_tumor_stratified()

print("Data was successfully loaded")
print(f"Train dataset size: {len(train_loader.dataset)}")
print(f"Validation dataset size: {len(valid_loader.dataset)}")
print(f"Test dataset size: {len(test_loader.dataset)}")


def count_images_per_class(loader):
    """
    Count how many images belong to each class in a DataLoader.
    Returns a dictionary mapping class_label -> count.
    """
    class_counts = defaultdict(int)

    # Iterate through batches of (data, labels)
    for _, labels in loader:
        for label in labels:
            class_counts[label.item()] += 1

    return class_counts


# Compute class distributions for train/validation/test sets
train_class_counts = count_images_per_class(train_loader)
valid_class_counts = count_images_per_class(valid_loader)
test_class_counts = count_images_per_class(test_loader)

# Print counts per class for the training set
print("Train loader class counts:")
for class_label, count in train_class_counts.items():
    print(f"Class {class_label}: {count} images")

# Print counts per class for the validation set
print("\nValidation loader class counts:")
for class_label, count in valid_class_counts.items():
    print(f"Class {class_label}: {count} images")

# Print counts per class for the test set
print("\nTest loader class counts:")
for class_label, count in test_class_counts.items():
    print(f"Class {class_label}: {count} images")


## Instantiate and Train Model

# Create a fresh instance of the BrainCNN model
model = BrainCNN()

# Train the model.
# - train_loader: iterator over training data
# - valid_loader: iterator over validation data
# - num_epochs: maximum number of epochs to train
# - patience: number of epochs to wait for improvement before early stopping
# - delta: minimum change in validation loss to qualify as improvement
# - learning_rate, momentum, weight_decay: optimizer hyperparameters
# - save_path: file path to save the best model weights
train_loss_metr, val_loss_metr, train_acc_metr, val_acc_metr, early_stopping = \
    model.train_model(
        train_loader,
        valid_loader,
        num_epochs=50,
        patience=6,
        delta=0.004,
        learning_rate=0.002,
        momentum=0.85,
        weight_decay=0.07,
        save_path="./braincnn_prototype.weights"
    )


## Plot Training and Validation Curves (Loss and Accuracy)

# Define the x-axis as epoch numbers
epochs = range(1, len(train_loss_metr) + 1)

plt.figure(figsize=(12, 5))

# Plot Loss Curves
plt.subplot(1, 2, 1)
plt.plot(epochs, train_loss_metr, label="Train Loss", marker="o")
plt.plot(epochs, val_loss_metr, label="Validation Loss", marker="o")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()

# Plot Accuracy Curves
plt.subplot(1, 2, 2)
plt.plot(epochs, train_acc_metr, label="Train Accuracy", marker="o")
plt.plot(epochs, val_acc_metr, label="Validation Accuracy", marker="o")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Training and Validation Accuracy")
plt.legend()

plt.tight_layout()
plt.show()


## Load the Best Model Weights for Final Evaluation

save_path = "./braincnn_prototype.weights"
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Instantiate a fresh model and move it to the chosen device
model = BrainCNN()
model.to(device)

# Load the saved state dict, mapping tensors to the target device
state_dict = torch.load(save_path, map_location=device)

# If the model was saved using nn.DataParallel (prefix "module."), strip it
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)

# If multiple GPUs are available, wrap the model in DataParallel
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)


## Evaluate Model on Test Set

# Set model to evaluation mode (disable dropout, batchnorm updates, etc.)
model.eval()

correct = 0
total = 0
test_targets = []
test_preds = []

# Disable gradient computation for inference
with torch.no_grad():
    for data, target in test_loader:
        # Move data and labels to the selected device
        data, target = data.to(device), target.to(device)

        # Obtain raw outputs (logits) from the model
        outputs = model(data)

        # Predicted class is the index with the max logit
        _, predicted = torch.max(outputs.data, 1)

        total += target.size(0)
        correct += (predicted == target).sum().item()

        # Store targets and predicted labels for metric computation
        test_targets.extend(target.cpu().numpy())
        test_preds.extend(predicted.cpu().numpy())

# Compute overall accuracy on the test set
test_accuracy = correct / total

# Compute weighted precision, recall, and F1-score
precision = precision_score(test_targets, test_preds, average='weighted')
recall = recall_score(test_targets, test_preds, average='weighted')
f1 = f1_score(test_targets, test_preds, average='weighted')

# Print the evaluation metrics
print('Metrics of the model on the test images:')
print(f'Accuracy: {test_accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')


## Save Training and Evaluation Metrics to Disk

# Save all relevant metric lists and final test metrics in a pickle file
with open("training_metrics.pkl", "wb") as f:
    pickle.dump({
        "train_loss": train_loss_metr,
        "val_loss": val_loss_metr,
        "train_acc": train_acc_metr,
        "val_acc": val_acc_metr,
        "accuracy": test_accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }, f)


## Reload Metrics and Plot Curves Again (e.g., from a Notebook Path)

# Load previously saved metrics (note: adjust path as needed)
with open("/home/ir739wb/ilyarekun/bc_project/centralized-learning/notebooks/training_metrics.pkl", "rb") as f:
    metrics = pickle.load(f)

train_loss_metr = metrics["train_loss"]
val_loss_metr = metrics["val_loss"]
train_acc_metr = metrics["train_acc"]
val_acc_metr = metrics["val_acc"]

plt.figure(figsize=(12, 5))

# Plot Loss Curve (reloaded from pickle)
plt.subplot(1, 2, 1)
plt.plot(train_loss_metr, label='Train Loss')
plt.plot(val_loss_metr, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')

# Plot Accuracy Curve (reloaded from pickle)
plt.subplot(1, 2, 2)
plt.plot(train_acc_metr, label='Train Accuracy')
plt.plot(val_acc_metr, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')

# Save the combined plot to a PNG file with high resolution
plt.savefig("training_plots.png", dpi=300, bbox_inches="tight")
plt.show()


## Compute and Display Confusion Matrix

# Compute confusion matrix given true labels and model predictions
cm = confusion_matrix(test_targets, test_preds)

plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=['Glioma', 'Meningioma', 'notumor', 'Putuitary'],
    yticklabels=['Glioma', 'Meningioma', 'notumor', 'Putuitary']
)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')

# Save confusion matrix figure to file
plt.savefig("confusion_matrix.png", dpi=300, bbox_inches="tight")
plt.show()
