In [None]:
import os
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [None]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader,Dataset

class BrainTumorDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths, self.labels = self.process_images(root_dir)
        self.classes = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.classes.index(self.labels[idx])
        
        if self.transform:
            image = self.transform(image)
        return image, label

    def process_images(self, path):
        images = []
        labels = []
        for category in os.listdir(path):
            category_path = os.path.join(path, category)
            if os.path.isdir(category_path):
                for img_name in os.listdir(category_path):
                    img_path = os.path.join(category_path, img_name)
                    if self.is_image_file(img_path):
                        images.append(img_path)
                        labels.append(category)
        return images, labels

    def is_image_file(self, filename):
        valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff"]
        return any(filename.lower().endswith(ext) for ext in valid_image_extensions)

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Less aggressive cropping
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),  # Reduce rotation angle
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),  # Less aggressive color jitter
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
}

In [None]:
train_dataset = BrainTumorDataset('/kaggle/input/brain-tumor-data-18k/tumordata/Training', transform=data_transforms['train'])
val_dataset = BrainTumorDataset('/kaggle/input/brain-tumor-data-18k/tumordata/Testing', transform=data_transforms['val'])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models
from timm import create_model
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
import matplotlib.pyplot as plt
import os

# SimCLR Model with EfficientNet_B0 as backbone
class SimCLR(nn.Module):
    def __init__(self, base_model):
        super(SimCLR, self).__init__()
        self.backbone = base_model
        self.projection_head = nn.Sequential(
            nn.Linear(1280, 512),  # EfficientNet_B0 has 1280-dim output
            nn.ReLU(),
            nn.Linear(512, 128)
        )

    def forward(self, x):
        x = self.backbone(x).flatten(1)
        x = self.projection_head(x)
        return x

# Tumor Classifier Head
class TumorClassifier(nn.Module):
    def __init__(self, base_model, num_classes):
        super(TumorClassifier, self).__init__()
        self.base_model = base_model
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        features = self.base_model(x)
        logits = self.classifier(features)
        return logits

In [None]:
efficientnet_b0 = create_model('efficientnet_b0', pretrained=True, num_classes=0)  # Remove final layer

In [None]:
simclr_model = SimCLR(efficientnet_b0)
tumor_classifier = TumorClassifier(simclr_model, num_classes=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tumor_classifier = tumor_classifier.to(device)

# Optimizer and loss function
optimizer = optim.Adam(tumor_classifier.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
def calculate_metrics(preds, labels, num_classes):
    preds = torch.argmax(preds, dim=1)
    preds = preds.cpu().numpy()
    labels = labels.cpu().numpy()
    
    precision = precision_score(labels, preds, average='macro', zero_division=1)
    recall = recall_score(labels, preds, average='macro', zero_division=1)
    f1 = f1_score(labels, preds, average='macro', zero_division=1)
    
    # Calculate AUC
    preds_one_hot = F.one_hot(torch.tensor(preds), num_classes=num_classes).float()
    labels_one_hot = F.one_hot(torch.tensor(labels), num_classes=num_classes).float()
    
    auc = roc_auc_score(labels_one_hot, preds_one_hot, average='macro', multi_class='ovr')
    
    return precision, recall, f1, auc

In [None]:
def fine_tune(model, train_loader, val_loader, optimizer, criterion, num_classes, epochs=10):
    os.makedirs('plots', exist_ok=True)
    
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        train_preds = []
        train_labels = []

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item() * images.size(0)

            train_preds.append(outputs)
            train_labels.append(labels)

        train_loss = running_loss / total
        train_acc = correct / total

        train_preds = torch.cat(train_preds)
        train_labels = torch.cat(train_labels)
        train_precision, train_recall, train_f1, train_auc = calculate_metrics(train_preds, train_labels, num_classes)

        print(f'Epoch [{epoch+1}/{epochs}], Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}, Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1: {train_f1:.4f}, AUC: {train_auc:.4f}')

        train_losses.append(train_loss)
        train_accuracies.append(train_acc)

        model.eval()
        correct = 0
        total = 0
        val_loss = 0.0
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_loss += loss.item() * images.size(0)

                val_preds.append(outputs)
                val_labels.append(labels)

        val_loss /= total
        val_acc = correct / total

        val_preds = torch.cat(val_preds)
        val_labels = torch.cat(val_labels)
        val_precision, val_recall, val_f1, val_auc = calculate_metrics(val_preds, val_labels, num_classes)

        print(f'Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}, AUC: {val_auc:.4f}')

        val_losses.append(val_loss)
        val_accuracies.append(val_acc)

        if (epoch + 1) % 5 == 0:
            plt.figure(figsize=(12, 5))

            # Plot Loss Curve
            plt.subplot(1, 2, 1)
            plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
            plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            plt.title('Loss Curve')

            # Plot Accuracy Curve
            plt.subplot(1, 2, 2)
            plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
            plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy')
            plt.legend()
            plt.title('Accuracy Curve')

            plt.tight_layout()

            # Save the plot
            plot_filename = f'plots/epoch_{epoch+1}_metrics.png'
            plt.savefig(plot_filename)
            plt.close()

In [None]:
fine_tune(tumor_classifier, train_loader, val_loader, optimizer, criterion, num_classes=4, epochs=30)