Libraries

In [None]:
import os
import glob
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report, precision_recall_fscore_support
from transformers import ViTFeatureExtractor, ViTForImageClassification
from torchvision import models
import kagglehub
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Dataset Path

In [None]:
dataset_root = kagglehub.dataset_download("orvile/brain-cancer-mri-dataset")
data_dir = os.path.join(dataset_root, "Brain_Cancer raw MRI data")
data_dir = os.path.join(data_dir, "Brain_Cancer")

Transformations and Dataset Loaders

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = ImageFolder(root=data_dir, transform=transform)
class_names = dataset.classes


Split into train/test

In [None]:
train_size = int(0.7 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)


In [None]:
# from collections import defaultdict
# import random
# random.seed(42)

# class_to_indices = defaultdict(list)
# for idx, (_, label) in enumerate(dataset.samples):
#     class_to_indices[label].append(idx)

# # Sample 100 train and 100 test indices per class
# train_indices = []
# test_indices = []

# for class_id, indices in class_to_indices.items():
#     random.shuffle(indices)
#     train_indices.extend(indices[:100])
#     test_indices.extend(indices[100:200])

# # Create subsets using the sampled indices
# from torch.utils.data import Subset

# train_dataset = Subset(dataset, train_indices)
# test_dataset = Subset(dataset, test_indices)

# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=32)

CNN Model

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 54 * 54, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

Train CNN

In [None]:
cnn = SimpleCNN(num_classes=len(class_names)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-4)

def train_model(model, loader):
    model.train()
    for epoch in range(3):
        running_loss = 0.0
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1} - Loss: {running_loss:.4f}")

train_model(cnn, train_loader)

def evaluate_model(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.numpy())
    print("CNN Evaluation:")
    print(classification_report(all_labels, all_preds, target_names=class_names))

evaluate_model(cnn, test_loader)

ViT Fine-Tuning

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
vit_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=len(class_names)).to(device)

from transformers import Trainer, TrainingArguments
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, subset):
        self.data = subset
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img, label = self.data[idx]
        return {"pixel_values": img, "label": label}

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

vit_train_dataset = CustomDataset(train_dataset)
vit_test_dataset = CustomDataset(test_dataset)

training_args = TrainingArguments(
    output_dir="./vit_output",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="no",
    remove_unused_columns=False
)

trainer = Trainer(
    model=vit_model,
    args=training_args,
    train_dataset=vit_train_dataset,
    eval_dataset=vit_test_dataset,
    compute_metrics=compute_metrics
)

trainer.train()
print("ViT Evaluation:")
trainer.evaluate()


SVM

In [None]:
def extract_features(dataloader):
    features = []
    labels = []
    with torch.no_grad():
        for imgs, lbls in dataloader:
            imgs = imgs.to(device)
            output = resnet(imgs).cpu().numpy()
            features.extend(output)
            labels.extend(lbls.numpy())
    return np.array(features), np.array(labels)

# Load ResNet once
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Identity()
resnet = resnet.to(device)
resnet.eval()

# Extract features from train and test sets
svm_train_features, svm_train_labels = extract_features(train_loader)
svm_test_features, svm_test_labels = extract_features(test_loader)

# Train and evaluate SVM
svm = SVC(kernel='linear')
svm.fit(svm_train_features, svm_train_labels)
svm_preds = svm.predict(svm_test_features)

# Print metrics
print("SVM Evaluation:")
print(classification_report(svm_test_labels, svm_preds, target_names=class_names))