# Download and Extract Datasets

In [None]:
# https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset

!curl -L -o brain-tumor-mri-dataset.zip\
  https://www.kaggle.com/api/v1/datasets/download/masoudnickparvar/brain-tumor-mri-dataset

!mkdir -p mri_datasets
!unzip brain-tumor-mri-dataset.zip -d mri_datasets

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100  148M  100  148M    0     0   135M      0  0:00:01  0:00:01 --:--:--  197M


# Import Libraries

In [None]:
%load_ext tensorboard

In [None]:
# Core
import os
import glob
import shutil
import random
import math
from pathlib import Path

# Data handling
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
)
import seaborn as sns

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter

# Vision
import torchvision
from torchvision import datasets, transforms, models
from torchvision.models import resnet50, ResNet50_Weights

# Transformers and schedulers
from transformers import get_linear_schedule_with_warmup

# Visualization
import matplotlib.pyplot as plt
from matplotlib import colormaps
import cv2
from tqdm import tqdm


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

batch_size = 32
patience = 3
experiment_path = "runs/experiment_1"
num_epochs = 10
class_names = ['glioma', 'meningioma', 'notumor','pituitary']

# Generate metadata

In [None]:
train_dir = "mri_datasets/Training/"
test_dir = "mri_datasets/Testing/"

train_rows = []
train_files = glob.glob(train_dir + "**/*.jpg", recursive=True)
for file in train_files:
    label = file.split("/")[2]
    label = class_names.index(label)
    train_rows.append([file, label])

df_train = pd.DataFrame(train_rows, columns=["filepath", "label"])
df_train, df_valid = train_test_split(df_train,
                                      test_size=0.2,
                                      stratify=df_train['label'],
                                      random_state=42)

df_train.to_csv ("train_metadata.csv", index=False)
df_valid.to_csv ("valid_metadata.csv", index=False)

test_rows = []
test_files = glob.glob(test_dir + "**/*.jpg", recursive=True)
for file in test_files:
    label = file.split("/")[2]
    label = class_names.index(label)
    test_rows.append([file, label])

df_test = pd.DataFrame(test_rows, columns=["filepath", "label"])
df_test.to_csv ("test_metadata.csv", index=False)

In [None]:
example_images = df_test.groupby("label").sample(n=1)
plt.figure(figsize=(18, 5))
for idx, (i, row) in enumerate(example_images.iterrows()):
    img = Image.open(row["filepath"])
    ax = plt.subplot(1, len(example_images), idx + 1)
    plt.imshow(img)
    plt.title(class_names[row["label"]])
    plt.axis("off")

# Prepare Dataloader

In [None]:
# https://docs.pytorch.org/vision/main/models.html#classification

In [None]:
class BrainMRIDatasets(Dataset):
    def __init__(self, dataset, transforms, classes):
      # Constructor
      self.dummy = classes

    def __getitem__(self, idx):
      # Method 1
      return self.dummy

    def __len__(self):
      # Method 2
      return self.dummy

In [None]:
train_transforms = transforms.Compose([
    transforms.Resize(232),
    transforms.CenterCrop(224),

    transforms.RandomApply([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.RandomRotation(degrees=15),
        transforms.RandomAffine(
            degrees=0,
            translate=(0.1, 0.1),
            scale=(0.9, 1.1),
            shear=5
        ),
    ], p=0.8),

    transforms.RandomApply([
        transforms.ColorJitter(
            brightness=0.3,
            contrast=0.3,
            saturation=0.2,
            hue=0.05
        )
    ], p=0.6),

    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.5)),
        transforms.RandomGrayscale(p=0.1)
    ], p=0.3),

    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize(232),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define Model

In [None]:
# https://docs.pytorch.org/vision/main/models.html#classification

In [None]:
class MRIResnet50Model(nn.Module):
    def __init__(self, num_classes):
        super(MRIResnet50Model, self).__init__()

        self.resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.features = nn.Sequential(*list(self.resnet.children())[:-2])
        self.avgpool = self.resnet.avgpool
        num_features = self.resnet.fc.in_features
        self.fc = nn.Linear(num_features, num_classes)

    def forward(self, x, return_features=False):
        feats = self.features(x)
        pooled = self.avgpool(feats)
        pooled = torch.flatten(pooled, 1)
        logits = self.fc(pooled)

        if return_features:
            return logits, feats
        else:
            return logits

# Training Setup

In [None]:
num_training_steps = len(train_dataloader) * num_epochs

# criterion, optimizer_p, scheduler_p

In [None]:
# Tensorboard

In [None]:
%tensorboard --logdir

# Training Loop

In [None]:
best_vit_loss = np.inf
patience_val = []

# Testing

In [None]:
cm = confusion_matrix(test_labels, test_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Label')
plt.ylabel('Truth Label')
plt.title('Confusion Matrix')
plt.show()

# Gradcam

In [None]:
original_transforms = transforms.Compose([
    transforms.Resize(232),
    transforms.CenterCrop(224),
])

In [None]:
activations, gradients = [], []

for module in resnet_model.modules():
    module._backward_hooks.clear()
    module._forward_hooks.clear()

def forward_hook(module, input, output):
    activations.append(output)

def backward_hook(module, grad_input, grad_output):
    gradients.append(grad_output[0])

In [None]:
sample_data = df_train.sample(n=1).values[0]
image_original = Image.open(sample_data[0]).convert("RGB")

In [None]:
cmap = colormaps["jet"]
colors = cmap(np.arange(256))[:, :3]
rgb_heatmap = plt.cm.jet(heatmap)[:, :, :3]

In [None]:
images = [image_original, rgb_heatmap, combined_image]

titles = [f"GT: {class_names[label]}",
            f"Heatmap (Pred: {class_names[pred_class]})",
            "Overlayed Heatmap"]

plt.figure(figsize=(21, 18))
for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.imshow(images[i])
    plt.axis('off')
    plt.title(titles[i], fontdict={"size": 18})
    plt.tight_layout()
plt.show()