In [None]:
import os
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt

In [None]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader,Dataset

class BrainTumorDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths, self.labels = self.process_images(root_dir)
        self.classes = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.classes.index(self.labels[idx])
        
        if self.transform:
            image = self.transform(image)
        return image, label

    def process_images(self, path):
        images = []
        labels = []
        for category in os.listdir(path):
            category_path = os.path.join(path, category)
            if os.path.isdir(category_path):
                for img_name in os.listdir(category_path):
                    img_path = os.path.join(category_path, img_name)
                    if self.is_image_file(img_path):
                        images.append(img_path)
                        labels.append(category)
        return images, labels

    def is_image_file(self, filename):
        valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff"]
        return any(filename.lower().endswith(ext) for ext in valid_image_extensions)

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Less aggressive cropping
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),  # Reduce rotation angle
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),  # Less aggressive color jitter
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
}

In [None]:
train_dataset = BrainTumorDataset('/kaggle/input/brain-tumor-data-18k/tumordata/Training', transform=data_transforms['train'])
val_dataset = BrainTumorDataset('/kaggle/input/brain-tumor-data-18k/tumordata/Testing', transform=data_transforms['val'])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
class BYOL(nn.Module):
    def __init__(self, base_model, projection_size=256, projection_hidden_size=4096, momentum=0.996):
        super(BYOL, self).__init__()
        self.online_encoder = base_model
        self.target_encoder = base_model
        self.momentum = momentum

        self.projector = nn.Sequential(
            nn.Linear(1000, projection_hidden_size),
            nn.BatchNorm1d(projection_hidden_size),
            nn.ReLU(),
            nn.Linear(projection_hidden_size, projection_size)
        )
        self.online_predictor = nn.Sequential(
            nn.Linear(projection_size, projection_hidden_size),
            nn.BatchNorm1d(projection_hidden_size),
            nn.ReLU(),
            nn.Linear(projection_hidden_size, projection_size)
        )

        self._initialize_target_encoder()

    def _initialize_target_encoder(self):
        for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False 

    @torch.no_grad()
    def _momentum_update_target_encoder(self):
        for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_k.data = param_k.data * self.momentum + param_q.data * (1 - self.momentum)

    def forward(self, x1, x2):
        online_proj_1 = self.projector(self.online_encoder(x1))
        online_proj_2 = self.projector(self.online_encoder(x2))

        online_pred_1 = self.online_predictor(online_proj_1)
        online_pred_2 = self.online_predictor(online_proj_2)

        with torch.no_grad():
            self._momentum_update_target_encoder()

            target_proj_1 = self.projector(self.target_encoder(x1))
            target_proj_2 = self.projector(self.target_encoder(x2))

        return online_pred_1, online_pred_2, target_proj_1, target_proj_2

def byol_loss_fn(pred, target):
    pred = nn.functional.normalize(pred, dim=-1)
    target = nn.functional.normalize(target, dim=-1)
    return 2 - 2 * (pred * target).sum(dim=-1)

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    train_loss = 0.0
    true_labels, pred_labels = [], []

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        true_labels.extend(labels.cpu().numpy())
        pred_labels.extend(preds.cpu().numpy())
    
    avg_loss = train_loss / len(train_loader)
    accuracy = accuracy_score(true_labels, pred_labels)
    precision = precision_score(true_labels, pred_labels, average='macro')
    recall = recall_score(true_labels, pred_labels, average='macro')
    f1 = f1_score(true_labels, pred_labels, average='macro')
    
    return avg_loss, accuracy, precision, recall, f1

# Validation Loop
def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    true_labels, pred_labels = [], []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            true_labels.extend(labels.cpu().numpy())
            pred_labels.extend(preds.cpu().numpy())
    
    avg_loss = val_loss / len(val_loader)
    accuracy = accuracy_score(true_labels, pred_labels)
    precision = precision_score(true_labels, pred_labels, average='macro')
    recall = recall_score(true_labels, pred_labels, average='macro')
    f1 = f1_score(true_labels, pred_labels, average='macro')

    return avg_loss, accuracy, precision, recall, f1

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(1, num_epochs + 1):
        print(f"Epoch {epoch}/{num_epochs}:")
        train_loss, train_acc, train_prec, train_rec, train_f1 = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, val_prec, val_rec, val_f1 = validate_epoch(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Prec: {train_prec:.4f}, Recall: {train_rec:.4f}, F1: {train_f1:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Prec: {val_prec:.4f}, Recall: {val_rec:.4f}, F1: {val_f1:.4f}")

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)

        if epoch % 5 == 0:
            plot_curves(epoch, train_losses, val_losses, train_accuracies, val_accuracies)

def plot_curves(epoch, train_losses, val_losses, train_accuracies, val_accuracies):
    plt.figure(figsize=(12, 5))

    # Loss curves
    plt.subplot(1, 2, 1)
    plt.plot(range(1, epoch+1), train_losses, label='Train Loss')
    plt.plot(range(1, epoch+1), val_losses, label='Validation Loss')
    plt.title('Loss over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Accuracy curves
    plt.subplot(1, 2, 2)
    plt.plot(range(1, epoch+1), train_accuracies, label='Train Accuracy')
    plt.plot(range(1, epoch+1), val_accuracies, label='Validation Accuracy')
    plt.title('Accuracy over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.savefig(f'accuracy_loss_epoch_{epoch}.png')
    plt.show()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)

In [None]:
byol_model = BYOL(base_model).to(device)
optimizer = optim.Adam(byol_model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [None]:
classifier_model = Classifier(base_model).to(device)
optimizer = optim.Adam(classifier_model.parameters(), lr=1e-3)

In [None]:
train_model(classifier_model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device=device)