In [None]:
import os
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [None]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader,Dataset

class BrainTumorDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths, self.labels = self.process_images(root_dir)
        self.classes = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.classes.index(self.labels[idx])
        
        if self.transform:
            image = self.transform(image)
        return image, label

    def process_images(self, path):
        images = []
        labels = []
        for category in os.listdir(path):
            category_path = os.path.join(path, category)
            if os.path.isdir(category_path):
                for img_name in os.listdir(category_path):
                    img_path = os.path.join(category_path, img_name)
                    if self.is_image_file(img_path):
                        images.append(img_path)
                        labels.append(category)
        return images, labels

    def is_image_file(self, filename):
        valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff"]
        return any(filename.lower().endswith(ext) for ext in valid_image_extensions)

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Less aggressive cropping
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),  # Reduce rotation angle
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),  # Less aggressive color jitter
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
}

In [None]:
train_dataset = BrainTumorDataset('/kaggle/input/brain-tumor-data-18k/tumordata/Training', transform=data_transforms['train'])
val_dataset = BrainTumorDataset('/kaggle/input/brain-tumor-data-18k/tumordata/Testing', transform=data_transforms['val'])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
import torch
from torch import nn, optim
import torchvision.models as models
from sklearn.metrics import precision_score, recall_score, f1_score
import os
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import normalize

class SimCLR(nn.Module):
    def __init__(self, base_model, projection_size=128, projection_hidden_size=2048):
        super(SimCLR, self).__init__()
        self.encoder = base_model

        self.projector = nn.Sequential(
            nn.Linear(1000, projection_hidden_size),
            nn.ReLU(),
            nn.Linear(projection_hidden_size, projection_size)
        )

    def forward(self, x_i, x_j):
        h_i = self.encoder(x_i)
        h_j = self.encoder(x_j)

        z_i = normalize(self.projector(h_i), dim=1)
        z_j = normalize(self.projector(h_j), dim=1)

        return z_i, z_j

def nt_xent_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.size(0)
    z = torch.cat([z_i, z_j], dim=0)
    similarity_matrix = torch.matmul(z, z.T)
    mask = torch.eye(2 * batch_size, device=z.device).bool()
    positives = torch.cat([torch.diag(similarity_matrix, batch_size), torch.diag(similarity_matrix, -batch_size)])
    negatives = similarity_matrix[~mask].view(2 * batch_size, -1)
    logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
    logits /= temperature
    labels = torch.zeros(2 * batch_size, dtype=torch.long, device=z.device)
    return F.cross_entropy(logits, labels)


resnet_encoder = models.resnet50(pretrained=True)
resnet_encoder = nn.Sequential(*list(resnet_encoder.children())[:-1])

simclr_model = SimCLR(resnet_encoder)

optimizer = optim.Adam(simclr_model.parameters(), lr=1e-3)

def train_and_evaluate_simclr(model, train_loader, val_loader, optimizer, device, num_epochs=10):
    criterion = nn.CrossEntropyLoss()

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    train_precisions, val_precisions = [], []
    train_recalls, val_recalls = [], []
    train_f1s, val_f1s = [], []

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0.0
        correct_train = 0
        total_train = 0
        y_true_train = []
        y_pred_train = []

        for x, labels in train_loader:
            x, labels = x.to(device), labels.to(device)
            optimizer.zero_grad()
            projections = model(x)
            
            loss = criterion(projections, labels)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            _, predicted = torch.max(projections, 1)
            correct_train += (predicted == labels).sum().item()
            total_train += labels.size(0)

            y_true_train.extend(labels.cpu().numpy())
            y_pred_train.extend(predicted.cpu().numpy())

        train_loss = total_train_loss / len(train_loader)
        train_accuracy = correct_train / total_train

        train_precision = precision_score(y_true_train, y_pred_train, average='macro')
        train_recall = recall_score(y_true_train, y_pred_train, average='macro')
        train_f1 = f1_score(y_true_train, y_pred_train, average='macro')

        model.eval()
        total_val_loss = 0.0
        correct_val = 0
        total_val = 0
        y_true_val = []
        y_pred_val = []

        with torch.no_grad():
            for x, labels in val_loader:
                x, labels = x.to(device), labels.to(device)
                projections = model(x)
                
                loss = criterion(projections, labels)
                total_val_loss += loss.item()
                
                _, predicted = torch.max(projections, 1)
                correct_val += (predicted == labels).sum().item()
                total_val += labels.size(0)

                y_true_val.extend(labels.cpu().numpy())
                y_pred_val.extend(predicted.cpu().numpy())

        val_loss = total_val_loss / len(val_loader)
        val_accuracy = correct_val / total_val

        # Calculate precision, recall, f1 for validation
        val_precision = precision_score(y_true_val, y_pred_val, average='macro')
        val_recall = recall_score(y_true_val, y_pred_val, average='macro')
        val_f1 = f1_score(y_true_val, y_pred_val, average='macro')

        # Store metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)
        train_precisions.append(train_precision)
        val_precisions.append(val_precision)
        train_recalls.append(train_recall)
        val_recalls.append(val_recall)
        train_f1s.append(train_f1)
        val_f1s.append(val_f1)

        # Print epoch metrics
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}')
        print(f'Train Precision: {train_precision:.4f}, Train Recall: {train_recall:.4f}, Train F1: {train_f1:.4f}')
        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
        print(f'Validation Precision: {val_precision:.4f}, Validation Recall: {val_recall:.4f}, Validation F1: {val_f1:.4f}')
        
        if (epoch + 1) % 5 == 0 or epoch == num_epochs - 1:
            plot_loss_and_accuracy(train_losses, val_losses, train_accuracies, val_accuracies, epoch + 1)
            plot_precision_recall_f1(train_precisions, val_precisions, train_recalls, val_recalls, train_f1s, val_f1s, epoch + 1)

def plot_loss_and_accuracy(train_losses, val_losses, train_accuracies, val_accuracies, epoch, save_dir='plots'):
    epochs = range(1, epoch + 1)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    plt.figure(figsize=(10, 5))

    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.title('Loss Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label='Train Accuracy')
    plt.plot(epochs, val_accuracies, label='Validation Accuracy')
    plt.title('Accuracy Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()

    save_path = os.path.join(save_dir, f'epoch_{epoch}.png')
    plt.savefig(save_path)

    plt.show()

def plot_precision_recall_f1(train_precisions, val_precisions, train_recalls, val_recalls, train_f1s, val_f1s, epoch, save_dir='plots'):
    epochs = range(1, epoch + 1)

    plt.figure(figsize=(15, 5))

    # Precision plot
    plt.subplot(1, 3, 1)
    plt.plot(epochs, train_precisions, label='Train Precision')
    plt.plot(epochs, val_precisions, label='Validation Precision')
    plt.title('Precision Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Precision')
    plt.legend()

    # Recall plot
    plt.subplot(1, 3, 2)
    plt.plot(epochs, train_recalls, label='Train Recall')
    plt.plot(epochs, val_recalls, label='Validation Recall')
    plt.title('Recall Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Recall')
    plt.legend()

    # F1 plot
    plt.subplot(1, 3, 3)
    plt.plot(epochs, train_f1s, label='Train F1')
    plt.plot(epochs, val_f1s, label='Validation F1')
    plt.title('F1 Score Curve')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()

    plt.tight_layout()

    save_path = os.path.join(save_dir, f'epoch_{epoch}_prf1.png')
    plt.savefig(save_path)

    plt.show()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
simclr_model.to(device)

In [None]:
train_and_evaluate_simclr(simclr_model, train_loader, val_loader, optimizer, device, num_epochs=10)