In [None]:
"""
Filename: resnet50_model.ipynb
Description: MRI brain images classification analysis with ResNet50 model

Author: Ng, Wee Ding
Date Created: 2024-11-30
Last Modified: 2024-12-06
Version: 1.0

License: MIT
"""

In [None]:

import os
import random
import numpy as np
from datetime import datetime 
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool, knn_graph
from sklearn.metrics import precision_recall_fscore_support, classification_report, confusion_matrix, accuracy_score, precision_recall_fscore_support, roc_auc_score

matplotlib.style.use('ggplot')
np.__version__

In [None]:
# helper functions

def save_model(epochs, model, optimizer, criterion, modelname="default"):
    """
    Function to save the trained model to disk.
    """
    torch.save({
                'epoch': epochs,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
                }, f"./output/model-{modelname}-{datetime.now().strftime("%Y%m%d-%H%M%S")}.pth"
               )


# Dataset Class
class BrainTumorDataset:
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


def train_model(model, dataloaders, device, num_epochs=10, modelname="default"):
    criterion = nn.CrossEntropyLoss()  # Multi-class classification loss
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Metrics storage
    train_metrics = {'loss': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': []}
    val_metrics = {'precision': [], 'recall': [], 'f1': [], 'accuracy': []}

    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        train_labels, train_preds = [], []
        epoch_loss = 0

        for images, labels in dataloaders['train']:
            images, labels = images.to(device), labels.to(device).long()

            optimizer.zero_grad()
            outputs = model(images)  # Outputs logits
            loss = criterion(outputs, labels)  # Labels must be integer-encoded
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

            # Predictions
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            train_preds.extend(preds)
            train_labels.extend(labels.cpu().numpy())

        # Compute training metrics
        precision, recall, f1, _ = precision_recall_fscore_support(
            train_labels, train_preds, average='macro', zero_division=1
        )
        accuracy = accuracy_score(train_labels, train_preds)
        train_metrics['loss'].append(epoch_loss / len(dataloaders['train']))
        train_metrics['precision'].append(precision)
        train_metrics['recall'].append(recall)
        train_metrics['f1'].append(f1)
        train_metrics['accuracy'].append(accuracy)

        # Validation Phase
        model.eval()
        val_labels, val_preds = [], []
        val_loss = 0

        with torch.no_grad():
            for images, labels in dataloaders['val']:
                images, labels = images.to(device), labels.to(device).long()
                outputs = model(images)
                val_loss += criterion(outputs, labels).item()
                preds = torch.argmax(outputs, dim=1).cpu().numpy()
                val_preds.extend(preds)
                val_labels.extend(labels.cpu().numpy())

        # Compute validation metrics
        precision, recall, f1, _ = precision_recall_fscore_support(
            val_labels, val_preds, average='macro', zero_division=1
        )
        accuracy = accuracy_score(val_labels, val_preds)
        val_metrics['precision'].append(precision)
        val_metrics['recall'].append(recall)
        val_metrics['f1'].append(f1)
        val_metrics['accuracy'].append(accuracy)

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, "
              f"Train Acc: {train_metrics['accuracy'][-1]:.4f}, Val Acc: {val_metrics['accuracy'][-1]:.4f}")

    save_model(num_epochs, model, optimizer, criterion, modelname)
    return model, train_metrics, val_metrics, num_epochs


def show_metrics(model, train_metrics, val_metrics, num_epochs=10, modelname="default"):
    # Print final metrics
    print("\nFinal Metrics After All Epochs:")
    print("Training Metrics:")
    for metric, values in train_metrics.items():
        print(f"{metric.capitalize()}: {values[-1]:.4f}")
    print("\nValidation Metrics:")
    for metric, values in val_metrics.items():
        print(f"{metric.capitalize()}: {values[-1]:.4f}")

    # Define different line styles for each metric
    line_styles = {
        'loss':'-',
        'precision': '-',
        'recall': '--',
        'f1': '-.',
        'accuracy': ':'
    }
    # Combined Plot of Metrics
    epochs = range(1, num_epochs + 1)
    plt.figure(figsize=(8, 6))
    for metric in ['loss']:
        plt.plot(epochs, train_metrics[metric], label=f'Train {metric.capitalize()}', color='black', linestyle=line_styles[metric])
    for metric in ['precision', 'recall', 'f1', 'accuracy']:
        plt.plot(epochs, train_metrics[metric], label=f'Train {metric.capitalize()}', color='blue', linestyle=line_styles[metric])
    for metric in ['precision', 'recall', 'f1', 'accuracy']:
        plt.plot(epochs, val_metrics[metric], label=f'Val {metric.capitalize()}', color='red', linestyle=line_styles[metric])

    plt.title(f'Metrics over Epochs - CNN ResNet50')
    plt.xlabel('epochs')
    plt.ylabel('score')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"./output/metric-{modelname}-{datetime.now().strftime("%Y%m%d-%H%M%S")}.png")
    plt.show()
    print(f"Validation Accuracy: {val_metrics['accuracy'][-1]:.4f}")
    print(f"Training Accuracy: {train_metrics['accuracy'][-1]:.4f}")

def evaluate_model(model, dataloader, device, return_images=False):
    model.eval()  # Set the model to evaluation mode
    all_labels = []
    all_predictions = []
    all_probs = []
    all_images = []

    with torch.no_grad():
        for D in dataloader:
            images, labels = D
            images = images.to(device)  # Move images to the device
            labels = labels.to(device)  # Move labels to the device
            #images = images.permute(0, 3, 1, 2)  # Change channel order

            y_hat = model(images)  # Get model predictions
            y_hat_probs = torch.softmax(y_hat, dim=1)  # Apply softmax for probabilities
            
            all_labels.append(labels.cpu().numpy())
            all_predictions.append(torch.argmax(y_hat_probs, dim=1).cpu().numpy())  # Get predicted classes
            all_probs.append(y_hat_probs.cpu().numpy())  # Store probabilities
            if return_images:
                all_images.append(images.cpu())  # Store images for visualization

    all_labels = np.concatenate(all_labels)
    all_predictions = np.concatenate(all_predictions)
    all_probs = np.concatenate(all_probs)

    if return_images:
        all_images = torch.cat(all_images)  # Combine images into a single tensor
        return all_labels, all_predictions, all_probs, all_images
    return all_labels, all_predictions, all_probs

def plot_results(true_labels, predicted_labels, predicted_probs, class_names):
    """
    Plots the confusion matrix and calculates accuracy, ROC-AUC, and other metrics.
    Displays class names instead of digits in the confusion matrix.
    
    Args:
        true_labels: Ground truth labels, as integers.
        predicted_labels: Model-predicted labels, as integers.
        predicted_probs: Predicted probabilities for each class.
        class_names: List of class names corresponding to the class indices.
    """
    # Ensure true_labels are scalar integers
    true_labels_scalar = np.array(true_labels, dtype=int)  # Assume true_labels are already integer-encoded

    # Calculate and print accuracy
    accuracy = accuracy_score(true_labels_scalar, predicted_labels)
    print(f'Validation: {accuracy:.4f}')

    # Compute confusion matrix and print classification report
    cm = confusion_matrix(true_labels_scalar, predicted_labels)
    print("\nClassification Report:")
    print(classification_report(true_labels_scalar, predicted_labels, target_names=class_names))


    # Plot confusion matrix with class names
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix - CNN ResNet50')
    plt.show()
    print(f'Validation Accuracy: {accuracy:.4f}')
    
    # Convert true_labels to one-hot format for ROC-AUC calculations
    n_classes = len(class_names)
    true_labels_one_hot = np.zeros((len(true_labels_scalar), n_classes))
    for idx, label in enumerate(true_labels_scalar):
        true_labels_one_hot[idx, label] = 1

    # Calculate ROC-AUC for each class
    roc_auc_scores = []
    for i in range(n_classes):
        roc_auc = roc_auc_score(true_labels_one_hot[:, i], predicted_probs[:, i])
        roc_auc_scores.append(roc_auc)



    # Plot ROC-AUC curves
    from sklearn.metrics import roc_curve, auc
    plt.figure(figsize=(8, 6))
    for i in range(n_classes):
        fpr, tpr, _ = roc_curve(true_labels_one_hot[:, i], predicted_probs[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')

    plt.title("ROC Curves - CNN ResNet50")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend(loc="lower right")
    plt.grid()
    plt.show()
    print(f'Validation Accuracy: {accuracy:.4f}')
    print("ROC-AUC Scores for each class:")
    output = ", ".join([f"{class_name}: {roc_auc_scores[i]:.4f}" for i, class_name in enumerate(class_names)])
    print(output)
    for i, class_name in enumerate(class_names):
        print(f"{class_name}: {roc_auc_scores[i]:.4f}")
    


In [None]:
def plot_misclassified_images(images, true_labels, predicted_labels, n_rows=2, n_cols=5):
    """
    Plots up to 10 randomly selected misclassified images in a grid (2 rows, 5 columns).

    Args:
        images: Tensor of images.
        true_labels: Ground truth labels (scalar or one-hot encoded).
        predicted_labels: Predicted labels (scalar).
        n_rows: Number of rows in the grid.
        n_cols: Number of columns in the grid.
    """
    # Ensure true_labels are scalar
    if len(true_labels.shape) > 1:  # If one-hot encoded
        true_labels_scalar = np.argmax(true_labels, axis=1)
    else:  # Already scalar
        true_labels_scalar = true_labels

    # Find misclassified indices
    misclassified_indices = np.where(true_labels_scalar != predicted_labels)[0]

    # Select up to 10 random misclassified indices
    selected_indices = random.sample(list(misclassified_indices), min(len(misclassified_indices), n_rows * n_cols))

    plt.figure(figsize=(10, n_rows * 2))
    plt.suptitle("Misclassified Images - CNN ResNet50")
    for i, index in enumerate(selected_indices):
        plt.subplot(n_rows, n_cols, i + 1)
        plt.imshow(images[index].permute(1, 2, 0).cpu().numpy())  # Convert image tensor for visualization
        plt.title(f'True: {true_labels_scalar[index]}\nPred: {predicted_labels[index]}', fontsize=12)
        plt.axis('off')

    plt.tight_layout()
    plt.show()

def plot_correctly_predicted_images(images, true_labels, predicted_labels, n_rows=2, n_cols=5):
    """
    Plots up to 10 randomly selected correctly predicted images in a grid (2 rows, 5 columns).

    Args:
        images: Tensor of images.
        true_labels: Ground truth labels (scalar or one-hot encoded).
        predicted_labels: Predicted labels (scalar).
        n_rows: Number of rows in the grid.
        n_cols: Number of columns in the grid.
    """
    # Ensure true_labels are scalar
    if len(true_labels.shape) > 1:  # If one-hot encoded
        true_labels_scalar = np.argmax(true_labels, axis=1)
    else:  # Already scalar
        true_labels_scalar = true_labels

    # Find correctly predicted indices
    correct_indices = np.where(true_labels_scalar == predicted_labels)[0]

    # Select up to 10 random correctly predicted indices
    selected_indices = random.sample(list(correct_indices), min(len(correct_indices), n_rows * n_cols))

    plt.figure(figsize=(10, n_rows * 2))
    plt.suptitle("Correctly Classified Images - CNN ResNet50")
    for i, index in enumerate(selected_indices):
        plt.subplot(n_rows, n_cols, i + 1)
        plt.imshow(images[index].permute(1, 2, 0).cpu().numpy())  # Convert image tensor for visualization
        plt.title(f'True: {true_labels_scalar[index]}\nPred: {predicted_labels[index]}', fontsize=12)
        plt.axis('off')

    plt.tight_layout()
    plt.show()



print("Loaded utilities libraries")

In [None]:
# ResNet Model
class ResNetMultiLabel(nn.Module):
    def __init__(self, num_classes=4):
        super(ResNetMultiLabel, self).__init__()
        # Load pre-trained ResNet
        self.base_model = models.resnet50(pretrained=True)
        # Replace the fully connected layer
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Sequential(
            nn.Linear(num_ftrs, num_classes),
            nn.Sigmoid()  # For multi-label output
        )
        
    def forward(self, x):
        return self.base_model(x)

class ResNetMultiClass(nn.Module):
    def __init__(self, num_classes=4):
        super(ResNetMultiClass, self).__init__()
        # Load pre-trained ResNet
        self.base_model = models.resnet50(pretrained=True)
        # Replace the fully connected layer
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(num_ftrs, num_classes)  # No sigmoid here

    def forward(self, x):
        return self.base_model(x)  # Outputs logits


In [None]:
if __name__ == "__main__":
    data_dir = "test_data"
    train_dir = os.path.join(data_dir, "Training")
    val_dir = os.path.join(data_dir, "Testing")
    
    # Define class labels
    classes = ['glioma', 'meningioma', 'notumor', 'pituitary']
    class_names = ['glioma', 'meningioma', 'notumor', 'pituitary']
    train_paths = []
    train_labels = []
    val_paths = []
    val_labels = []

    # Prepare training data
    for idx, cls in enumerate(classes):
        class_dir = os.path.join(train_dir, cls)
        for img_path in os.listdir(class_dir):
            train_paths.append(os.path.join(class_dir, img_path))
            train_labels.append(idx)  # Integer label (multi-class)

    # Prepare validation data
    for idx, cls in enumerate(classes):
        class_dir = os.path.join(val_dir, cls)
        for img_path in os.listdir(class_dir):
            val_paths.append(os.path.join(class_dir, img_path))
            val_labels.append(idx)  # Integer label (multi-class)

    # Data transforms
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    # Create datasets and loaders
    train_dataset = BrainTumorDataset(train_paths, train_labels, transform=transform)
    val_dataset = BrainTumorDataset(val_paths, val_labels, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

    dataloaders = {'train': train_loader, 'val': val_loader}

    # Define model and device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_resnet50 = ResNetMultiClass(num_classes=len(classes)).to(device)

    # Train the model
    model, train_metrics, val_metrics, num_epochs = train_model(model_resnet50, dataloaders, device, 10, "resnet50")

    # Print final metrics
    print("\nTraining Completed!")
    print("Train Metrics (Final Epoch):")
    for metric, values in train_metrics.items():
        print(f"{metric.capitalize()}: {values[-1]:.4f}")

    print("\nValidation Metrics (Final Epoch):")
    for metric, values in val_metrics.items():
        print(f"{metric.capitalize()}: {values[-1]:.4f}")

    # Evaluate the model
    true_labels, predicted_labels, all_probs = evaluate_model(model, val_loader, device)
    print("\nClassification Report:")
    print(classification_report(true_labels, predicted_labels, target_names=classes))


In [None]:

show_metrics(model, train_metrics, val_metrics, num_epochs,"resnet50")


In [None]:
# https://github.com/FarahElshenawi/brain-tumor-classification/blob/main/Notebook/Brain_tumor_mri_Pytorch.ipynb
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, roc_auc_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
true_labels, predicted_labels, predicted_probs, misclassified_images = evaluate_model(model, val_loader, device, return_images=True)
n_classes = predicted_probs.shape[1]  # Number of classes


In [None]:
import seaborn as sns
plot_results(true_labels, predicted_labels, predicted_probs, class_names)  # Evaluate and plot results

In [None]:
plot_misclassified_images(misclassified_images, true_labels, predicted_labels, 2, 5)

In [None]:
# Assuming the following:
# images: Tensor of shape [num_samples, channels, height, width]
# true_labels: One-hot encoded or scalar labels
# predicted_labels: Scalar predictions
# n_classes: Number of classes in the dataset

plot_correctly_predicted_images(misclassified_images, true_labels, predicted_labels, 2,5)
