In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
mkdir callback

In [None]:
train_dataset = BrainTumorDataset(root_dir='/content/drive/MyDrive/tumour_data_archive/Training', transform=transform)
test_dataset = BrainTumorDataset(root_dir='/content/drive/MyDrive/tumour_data_archive/Testing', transform=transform)

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import KFold

# check if GPU is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# custom dataset class
class BrainTumorDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.image_paths = []
        self.labels = []
        for label, class_dir in enumerate(self.classes):
            class_dir_path = os.path.join(root_dir, class_dir)
            for image_name in os.listdir(class_dir_path):
                self.image_paths.append(os.path.join(class_dir_path, image_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("L")
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Transformations for the images
transform = Compose([
    Resize((224, 224)),
    RandomHorizontalFlip(),
    RandomRotation(10),
    ToTensor(),
    Normalize((0.5,), (0.5,))
])

# load datasets
root_dir = '/content/drive/MyDrive/tumour_data_archive/Training'
dataset = BrainTumorDataset(root_dir=root_dir, transform=transform)

# define Vision Transformer model
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, embed_dim):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, grid_size, grid_size)
        x = x.flatten(2)  # (B, embed_dim, num_patches)
        x = x.transpose(1, 2)  # (B, num_patches, embed_dim)
        return x

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        B, N, C = x.shape
        q = self.query(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.key(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.value(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, num_classes=4, embed_dim=128, depth=4, num_heads=4, mlp_dim=256, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans=1, embed_dim=embed_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_dim, dropout)
        for _ in range(depth)])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)
        return x

# Cross-validation setup
k_folds = 5
kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
    print(f'FOLD {fold+1}')
    print('--------------------------------')

    # Sample elements randomly from a given list of ids, no replacement.
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)

    train_loader = DataLoader(dataset, batch_size=32, sampler=train_subsampler)
    val_loader = DataLoader(dataset, batch_size=32, sampler=val_subsampler)

    # Initialize model, loss function, and optimizer
    model = VisionTransformer(num_classes=4).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0005)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

    # Early stopping
    early_stopping = EarlyStopping(patience=5, min_delta=0.01)

    # Training loop
    epochs = 50
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        progress_bar = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{epochs}]')
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            progress_bar.set_postfix(loss=loss.item(), accuracy=100 * correct / total)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = 100 * correct / total
        scheduler.step()
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')

        val_loss = 0.0
        val_correct = 0
        val_total = 0
        model.eval()
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_acc = 100 * val_correct / val_total
        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%')

        if early_stopping(val_loss):
            print("Early stopping triggered")
            break

    print('--------------------------------')

# Final evaluation on test dataset
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
model.eval()
correct = 0
total = 0
progress_bar = tqdm(test_loader, desc='Testing')
with torch.no_grad():
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_acc = 100 * correct / total
print(f'Test Accuracy: {test_acc:.2f}%')

# Save the model
torch.save(model.state_dict(), 'callback/vision_transformer_model.pth')


Using device: cuda
FOLD 1
--------------------------------


Epoch [1/50]: 100%|██████████| 143/143 [00:39<00:00,  3.64it/s, accuracy=51.5, loss=0.834]


Epoch [1/50], Loss: 0.8790, Accuracy: 51.54%
Validation Loss: 0.1573, Validation Accuracy: 69.47%


Epoch [2/50]: 100%|██████████| 143/143 [00:31<00:00,  4.60it/s, accuracy=71.7, loss=0.974]


Epoch [2/50], Loss: 0.5826, Accuracy: 71.72%
Validation Loss: 0.1321, Validation Accuracy: 72.70%


Epoch [3/50]: 100%|██████████| 143/143 [00:33<00:00,  4.26it/s, accuracy=76.5, loss=0.697]


Epoch [3/50], Loss: 0.4850, Accuracy: 76.47%
Validation Loss: 0.1080, Validation Accuracy: 79.44%


Epoch [4/50]: 100%|██████████| 143/143 [00:31<00:00,  4.50it/s, accuracy=78.5, loss=0.365]


Epoch [4/50], Loss: 0.4376, Accuracy: 78.51%
Validation Loss: 0.1018, Validation Accuracy: 80.93%


Epoch [5/50]: 100%|██████████| 143/143 [00:32<00:00,  4.43it/s, accuracy=81.4, loss=0.45]


Epoch [5/50], Loss: 0.3925, Accuracy: 81.40%
Validation Loss: 0.0932, Validation Accuracy: 82.76%


Epoch [6/50]: 100%|██████████| 143/143 [00:31<00:00,  4.49it/s, accuracy=82.1, loss=0.627]


Epoch [6/50], Loss: 0.3666, Accuracy: 82.14%
Validation Loss: 0.0857, Validation Accuracy: 83.11%


Epoch [7/50]: 100%|██████████| 143/143 [00:31<00:00,  4.50it/s, accuracy=84.3, loss=0.579]


Epoch [7/50], Loss: 0.3292, Accuracy: 84.29%
Validation Loss: 0.0789, Validation Accuracy: 85.30%


Epoch [8/50]: 100%|██████████| 143/143 [00:31<00:00,  4.52it/s, accuracy=84.9, loss=0.351]


Epoch [8/50], Loss: 0.3197, Accuracy: 84.90%
Validation Loss: 0.0796, Validation Accuracy: 85.65%


Epoch [9/50]: 100%|██████████| 143/143 [00:31<00:00,  4.53it/s, accuracy=85.4, loss=0.81]


Epoch [9/50], Loss: 0.3067, Accuracy: 85.45%
Validation Loss: 0.0766, Validation Accuracy: 85.48%


Epoch [10/50]: 100%|██████████| 143/143 [00:31<00:00,  4.52it/s, accuracy=85.8, loss=0.348]


Epoch [10/50], Loss: 0.2926, Accuracy: 85.75%
Validation Loss: 0.0770, Validation Accuracy: 85.39%


Epoch [11/50]: 100%|██████████| 143/143 [00:31<00:00,  4.57it/s, accuracy=85.5, loss=0.25]


Epoch [11/50], Loss: 0.2914, Accuracy: 85.51%
Validation Loss: 0.0774, Validation Accuracy: 84.78%


Epoch [12/50]: 100%|██████████| 143/143 [00:31<00:00,  4.53it/s, accuracy=86.6, loss=0.371]


Epoch [12/50], Loss: 0.2868, Accuracy: 86.56%
Validation Loss: 0.0794, Validation Accuracy: 85.83%
Early stopping triggered
--------------------------------
FOLD 2
--------------------------------


Epoch [1/50]: 100%|██████████| 143/143 [00:30<00:00,  4.62it/s, accuracy=56.7, loss=0.911]


Epoch [1/50], Loss: 0.8121, Accuracy: 56.73%
Validation Loss: 0.1668, Validation Accuracy: 67.54%


Epoch [2/50]: 100%|██████████| 143/143 [00:31<00:00,  4.48it/s, accuracy=73.8, loss=0.436]


Epoch [2/50], Loss: 0.5516, Accuracy: 73.80%
Validation Loss: 0.1408, Validation Accuracy: 71.65%


Epoch [3/50]: 100%|██████████| 143/143 [00:31<00:00,  4.56it/s, accuracy=76.8, loss=0.598]


Epoch [3/50], Loss: 0.4776, Accuracy: 76.80%
Validation Loss: 0.1258, Validation Accuracy: 74.02%


Epoch [4/50]: 100%|██████████| 143/143 [00:31<00:00,  4.56it/s, accuracy=78.8, loss=0.532]


Epoch [4/50], Loss: 0.4351, Accuracy: 78.84%
Validation Loss: 0.1202, Validation Accuracy: 76.64%


Epoch [5/50]: 100%|██████████| 143/143 [00:31<00:00,  4.60it/s, accuracy=80.5, loss=0.349]


Epoch [5/50], Loss: 0.4013, Accuracy: 80.54%
Validation Loss: 0.1133, Validation Accuracy: 77.78%


Epoch [6/50]: 100%|██████████| 143/143 [00:31<00:00,  4.53it/s, accuracy=82.5, loss=0.552]


Epoch [6/50], Loss: 0.3626, Accuracy: 82.47%
Validation Loss: 0.0987, Validation Accuracy: 81.28%


Epoch [7/50]: 100%|██████████| 143/143 [00:32<00:00,  4.36it/s, accuracy=83.4, loss=0.425]


Epoch [7/50], Loss: 0.3431, Accuracy: 83.39%
Validation Loss: 0.0921, Validation Accuracy: 82.68%


Epoch [8/50]: 100%|██████████| 143/143 [00:30<00:00,  4.65it/s, accuracy=84, loss=0.443]


Epoch [8/50], Loss: 0.3333, Accuracy: 83.96%
Validation Loss: 0.0851, Validation Accuracy: 83.03%


Epoch [9/50]: 100%|██████████| 143/143 [00:30<00:00,  4.62it/s, accuracy=85.3, loss=0.261]


Epoch [9/50], Loss: 0.3015, Accuracy: 85.27%
Validation Loss: 0.0913, Validation Accuracy: 83.11%


Epoch [10/50]: 100%|██████████| 143/143 [00:31<00:00,  4.50it/s, accuracy=85.3, loss=0.355]


Epoch [10/50], Loss: 0.3049, Accuracy: 85.29%
Validation Loss: 0.0869, Validation Accuracy: 83.38%


Epoch [11/50]: 100%|██████████| 143/143 [00:31<00:00,  4.57it/s, accuracy=85.7, loss=0.343]


Epoch [11/50], Loss: 0.2963, Accuracy: 85.71%
Validation Loss: 0.0862, Validation Accuracy: 83.29%


Epoch [12/50]: 100%|██████████| 143/143 [00:32<00:00,  4.45it/s, accuracy=85.5, loss=0.151]


Epoch [12/50], Loss: 0.2967, Accuracy: 85.47%
Validation Loss: 0.0846, Validation Accuracy: 83.11%


Epoch [13/50]: 100%|██████████| 143/143 [00:31<00:00,  4.57it/s, accuracy=85.7, loss=0.463]


Epoch [13/50], Loss: 0.2982, Accuracy: 85.66%
Validation Loss: 0.0842, Validation Accuracy: 83.29%
Early stopping triggered
--------------------------------
FOLD 3
--------------------------------


Epoch [1/50]: 100%|██████████| 143/143 [00:32<00:00,  4.45it/s, accuracy=54.6, loss=0.645]


Epoch [1/50], Loss: 0.8433, Accuracy: 54.60%
Validation Loss: 0.1584, Validation Accuracy: 68.56%


Epoch [2/50]: 100%|██████████| 143/143 [00:31<00:00,  4.60it/s, accuracy=71.1, loss=0.699]


Epoch [2/50], Loss: 0.5808, Accuracy: 71.05%
Validation Loss: 0.1287, Validation Accuracy: 75.57%


Epoch [3/50]: 100%|██████████| 143/143 [00:31<00:00,  4.61it/s, accuracy=74.9, loss=0.891]


Epoch [3/50], Loss: 0.5040, Accuracy: 74.92%
Validation Loss: 0.1288, Validation Accuracy: 76.97%


Epoch [4/50]: 100%|██████████| 143/143 [00:32<00:00,  4.38it/s, accuracy=77.4, loss=0.653]


Epoch [4/50], Loss: 0.4550, Accuracy: 77.37%
Validation Loss: 0.1047, Validation Accuracy: 81.96%


Epoch [5/50]: 100%|██████████| 143/143 [00:31<00:00,  4.57it/s, accuracy=80, loss=0.864]


Epoch [5/50], Loss: 0.4113, Accuracy: 80.02%
Validation Loss: 0.1042, Validation Accuracy: 81.52%


Epoch [6/50]: 100%|██████████| 143/143 [00:31<00:00,  4.61it/s, accuracy=81.9, loss=0.322]


Epoch [6/50], Loss: 0.3825, Accuracy: 81.95%
Validation Loss: 0.1065, Validation Accuracy: 79.68%


Epoch [7/50]: 100%|██████████| 143/143 [00:31<00:00,  4.55it/s, accuracy=83.1, loss=0.383]


Epoch [7/50], Loss: 0.3483, Accuracy: 83.09%
Validation Loss: 0.0872, Validation Accuracy: 83.71%


Epoch [8/50]: 100%|██████████| 143/143 [00:32<00:00,  4.34it/s, accuracy=83.8, loss=0.287]


Epoch [8/50], Loss: 0.3349, Accuracy: 83.76%
Validation Loss: 0.0806, Validation Accuracy: 85.99%


Epoch [9/50]: 100%|██████████| 143/143 [00:32<00:00,  4.37it/s, accuracy=84.4, loss=0.331]


Epoch [9/50], Loss: 0.3216, Accuracy: 84.42%
Validation Loss: 0.0846, Validation Accuracy: 84.24%


Epoch [10/50]: 100%|██████████| 143/143 [00:31<00:00,  4.59it/s, accuracy=85.1, loss=0.34]


Epoch [10/50], Loss: 0.3052, Accuracy: 85.05%
Validation Loss: 0.0809, Validation Accuracy: 85.55%


Epoch [11/50]: 100%|██████████| 143/143 [00:32<00:00,  4.45it/s, accuracy=85.6, loss=0.306]


Epoch [11/50], Loss: 0.3048, Accuracy: 85.60%
Validation Loss: 0.0811, Validation Accuracy: 84.76%


Epoch [12/50]: 100%|██████████| 143/143 [00:31<00:00,  4.48it/s, accuracy=85.5, loss=0.492]


Epoch [12/50], Loss: 0.2994, Accuracy: 85.54%
Validation Loss: 0.0806, Validation Accuracy: 85.99%
Early stopping triggered
--------------------------------
FOLD 4
--------------------------------


Epoch [1/50]: 100%|██████████| 143/143 [00:32<00:00,  4.39it/s, accuracy=56.3, loss=0.904]


Epoch [1/50], Loss: 0.8240, Accuracy: 56.26%
Validation Loss: 0.1531, Validation Accuracy: 70.67%


Epoch [2/50]: 100%|██████████| 143/143 [00:33<00:00,  4.27it/s, accuracy=71.1, loss=0.86]


Epoch [2/50], Loss: 0.5709, Accuracy: 71.07%
Validation Loss: 0.1467, Validation Accuracy: 72.24%


Epoch [3/50]: 100%|██████████| 143/143 [00:32<00:00,  4.34it/s, accuracy=75.8, loss=0.471]


Epoch [3/50], Loss: 0.4930, Accuracy: 75.75%
Validation Loss: 0.1161, Validation Accuracy: 78.11%


Epoch [4/50]: 100%|██████████| 143/143 [00:32<00:00,  4.45it/s, accuracy=78.9, loss=0.658]


Epoch [4/50], Loss: 0.4376, Accuracy: 78.86%
Validation Loss: 0.1092, Validation Accuracy: 78.11%


Epoch [5/50]: 100%|██████████| 143/143 [00:32<00:00,  4.46it/s, accuracy=80.5, loss=0.455]


Epoch [5/50], Loss: 0.3980, Accuracy: 80.48%
Validation Loss: 0.1008, Validation Accuracy: 79.86%


Epoch [6/50]: 100%|██████████| 143/143 [00:32<00:00,  4.43it/s, accuracy=82.6, loss=0.361]


Epoch [6/50], Loss: 0.3585, Accuracy: 82.58%
Validation Loss: 0.0923, Validation Accuracy: 81.96%


Epoch [7/50]: 100%|██████████| 143/143 [00:33<00:00,  4.27it/s, accuracy=83.5, loss=0.464]


Epoch [7/50], Loss: 0.3393, Accuracy: 83.46%
Validation Loss: 0.0876, Validation Accuracy: 84.06%


Epoch [8/50]: 100%|██████████| 143/143 [00:32<00:00,  4.37it/s, accuracy=84.5, loss=0.504]


Epoch [8/50], Loss: 0.3231, Accuracy: 84.51%
Validation Loss: 0.0885, Validation Accuracy: 84.24%


Epoch [9/50]: 100%|██████████| 143/143 [00:33<00:00,  4.33it/s, accuracy=85.8, loss=0.535]


Epoch [9/50], Loss: 0.3000, Accuracy: 85.80%
Validation Loss: 0.0886, Validation Accuracy: 83.01%


Epoch [10/50]: 100%|██████████| 143/143 [00:31<00:00,  4.52it/s, accuracy=85.5, loss=0.701]


Epoch [10/50], Loss: 0.2987, Accuracy: 85.54%
Validation Loss: 0.0800, Validation Accuracy: 83.80%


Epoch [11/50]: 100%|██████████| 143/143 [00:32<00:00,  4.42it/s, accuracy=85.6, loss=0.238]


Epoch [11/50], Loss: 0.3027, Accuracy: 85.56%
Validation Loss: 0.0813, Validation Accuracy: 84.41%


Epoch [12/50]: 100%|██████████| 143/143 [00:34<00:00,  4.16it/s, accuracy=85, loss=0.358]


Epoch [12/50], Loss: 0.3014, Accuracy: 85.01%
Validation Loss: 0.0823, Validation Accuracy: 83.71%
Early stopping triggered
--------------------------------
FOLD 5
--------------------------------


Epoch [1/50]: 100%|██████████| 143/143 [00:32<00:00,  4.41it/s, accuracy=57.9, loss=0.802]


Epoch [1/50], Loss: 0.8129, Accuracy: 57.88%
Validation Loss: 0.1424, Validation Accuracy: 71.89%


Epoch [2/50]: 100%|██████████| 143/143 [00:31<00:00,  4.55it/s, accuracy=72.7, loss=0.697]


Epoch [2/50], Loss: 0.5593, Accuracy: 72.74%
Validation Loss: 0.1255, Validation Accuracy: 74.69%


Epoch [3/50]: 100%|██████████| 143/143 [00:31<00:00,  4.58it/s, accuracy=76.7, loss=0.444]


Epoch [3/50], Loss: 0.4763, Accuracy: 76.67%
Validation Loss: 0.1021, Validation Accuracy: 78.81%


Epoch [4/50]: 100%|██████████| 143/143 [00:32<00:00,  4.38it/s, accuracy=79.2, loss=0.659]


Epoch [4/50], Loss: 0.4381, Accuracy: 79.21%
Validation Loss: 0.0981, Validation Accuracy: 80.47%


Epoch [5/50]: 100%|██████████| 143/143 [00:31<00:00,  4.50it/s, accuracy=81.3, loss=0.481]


Epoch [5/50], Loss: 0.3905, Accuracy: 81.29%
Validation Loss: 0.0891, Validation Accuracy: 82.66%


Epoch [6/50]: 100%|██████████| 143/143 [00:30<00:00,  4.62it/s, accuracy=82.1, loss=0.635]


Epoch [6/50], Loss: 0.3711, Accuracy: 82.06%
Validation Loss: 0.0910, Validation Accuracy: 81.35%


Epoch [7/50]: 100%|██████████| 143/143 [00:31<00:00,  4.50it/s, accuracy=83.2, loss=0.216]


Epoch [7/50], Loss: 0.3477, Accuracy: 83.22%
Validation Loss: 0.0778, Validation Accuracy: 84.76%


Epoch [8/50]: 100%|██████████| 143/143 [00:31<00:00,  4.57it/s, accuracy=84.5, loss=0.333]


Epoch [8/50], Loss: 0.3231, Accuracy: 84.51%
Validation Loss: 0.0757, Validation Accuracy: 85.38%


Epoch [9/50]: 100%|██████████| 143/143 [00:32<00:00,  4.34it/s, accuracy=85.3, loss=0.38]


Epoch [9/50], Loss: 0.3116, Accuracy: 85.32%
Validation Loss: 0.0764, Validation Accuracy: 85.73%


Epoch [10/50]: 100%|██████████| 143/143 [00:32<00:00,  4.42it/s, accuracy=84.8, loss=0.274]


Epoch [10/50], Loss: 0.3149, Accuracy: 84.79%
Validation Loss: 0.0740, Validation Accuracy: 86.43%


Epoch [11/50]: 100%|██████████| 143/143 [00:31<00:00,  4.55it/s, accuracy=86.5, loss=0.383]


Epoch [11/50], Loss: 0.2974, Accuracy: 86.48%
Validation Loss: 0.0735, Validation Accuracy: 85.46%


Epoch [12/50]: 100%|██████████| 143/143 [00:32<00:00,  4.44it/s, accuracy=85.7, loss=0.279]


Epoch [12/50], Loss: 0.3020, Accuracy: 85.71%
Validation Loss: 0.0712, Validation Accuracy: 85.55%
Early stopping triggered
--------------------------------


Testing: 100%|██████████| 41/41 [00:07<00:00,  5.44it/s]

Test Accuracy: 84.06%





# Vision Transformer for Brain Tumor classification (experiments)

## Introduction

Brain tumors are abnormal growths of cells in the brain, which can be either benign (non-cancerous) or malignant (cancerous). Early detection and classification of brain tumors are crucial for effective treatment and improving patient outcomes.

The implementation of a Vision Transformer (ViT) model for the classification of brain tumors using MRI images. The model is trained and evaluated using a dataset of brain MRI images, and the performance is analyzed using K-fold cross-validation.

## Dataset

The dataset consists of MRI images of brain tumors classified into four categories:
- Pituitary
- No tumor
- Meningioma
- Glioma


### Training
- pituitary/
- notumor/
- meningioma/
- glioma/

### Testing
- pituitary/
- notumor/
- meningioma/
- glioma/


Each directory contains images of brain MRIs corresponding to the respective class.

## Data preprocessing

### Transformations

The following transformations are applied to the images:
- Resize the images to 224 x 224 pixels.
- Random horizontal flip.
- Random rotation of 10 degrees.
- Convert the images to tensors.
- Normalize the images with a mean and standard deviation of 0.5.

### Custom Dataset class

A custom dataset class is used to load the images and their corresponding labels from the directory structure. This class handles the loading of image paths, converting images to grayscale, and applying the transformations.

## Vision Transformer model

### Patch Embedding

The Patch Embedding layer splits the input image into patches and projects each patch into a high-dimensional space using a convolutional layer. The image of size $$( H \times W \times C)$$ is divided into patches of size $$( P \times P )$$. Each patch is then linearly transformed into a vector of size \( D \).

The number of patches N is given by:

$$
[ N = \left(\frac{H}{P}\right) \times \left(\frac{W}{P}\right)]
$$

The patches are then flattened and transposed to form a sequence of patches.

### Multi-Head Self-Attention

The Multi-Head Self-Attention layer allows the model to focus on different parts of the input sequence (patches) simultaneously. It computes the attention weights and applies them to the value vectors. The attention mechanism is defined as:

$$
[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V ]
$$

where Q, K, and V are the query, key, and value matrices, respectively, and \( d_k \) is the dimension of the key vectors. The attention weights are computed as the dot product of the query and key matrices, scaled by sqrt(𝑑𝑘), and then passed through a softmax function.

### Transformer Block

The Transformer Block consists of a multi-head self-attention layer followed by a feed-forward neural network. Both layers are followed by layer normalization and residual connections. The feed-forward network is defined as:

$$
[ \text{FFN}(x) = \text{GELU}(xW_1 + b_1)W_2 + b_2]
$$

where 𝑊1, 𝑊2, 𝑏1, and 𝑏2 are learnable parameters, and GELU is the Gaussian Error Linear Unit activation function. The residual connections are added to stabilize training:

$$
[ x' = x + \text{SelfAttention}(x)]
$$


$$
[ x'' = x' + \text{FFN}(x')]
$$

### Vision Transformer

The Vision Transformer model consists of a patch embedding layer, multiple transformer blocks, and a final classification head. The input image is first divided into patches and embedded into a higher-dimensional space. A class token is added to the sequence of patches, and positional embeddings are added to retain spatial information. The sequence is then passed through multiple transformer blocks, and the final class token is used for classification.

## Training and Evaluation

### Cross-Validation Setup

K-fold cross-validation is used to evaluate the model. The dataset is split into \( k \) folds, and the model is trained and validated on different subsets of the data. In this setup, 5-fold cross-validation is used, where the dataset is divided into 5 subsets. In each iteration, one subset is used for validation, and the remaining four are used for training.

### Training Loop

The training loop includes forward and backward passes, loss computation, and optimization steps. The learning rate is adjusted using a cosine annealing scheduler, and early stopping is implemented to prevent overfitting. The loss function used is cross-entropy loss, which is suitable for multi-class classification tasks.

The optimization process involves minimizing the cross-entropy loss:

$$
[ \text{CrossEntropyLoss} = -\sum_{c=1}^{C} y_c \log(\hat{y}_c)]
$$

where C is the number of classes, 𝑦𝑐 is the true label, and 𝑦̂𝑐 is the predicted probability for class c.

### Early Stopping

Early stopping is used to prevent overfitting by monitoring the validation loss. If the validation loss does not improve for a certain number of epochs (patience), training is stopped.

### Results

The training process produces the following results:

- **FOLD 1**
    - Epoch 1: Loss: 0.8790, Accuracy: 51.54%, Validation Loss: 0.1573, Validation Accuracy: 69.47%
    - Epoch 2: Loss: 0.5826, Accuracy: 71.72%, Validation Loss: 0.1321, Validation Accuracy: 72.70%
    - ...
    - Epoch 12: Loss: 0.2868, Accuracy: 86.56%, Validation Loss: 0.0794, Validation Accuracy: 85.83%
    - Early stopping triggered

- **FOLD 2**
    - Epoch 1: Loss: 0.8121, Accuracy: 56.73%, Validation Loss: 0.1668, Validation Accuracy: 67.54%
    - Epoch 2: Loss: 0.5516, Accuracy: 73.80%, Validation Loss: 0.1408, Validation Accuracy: 71.65%
    - ...
    - Epoch 13: Loss: 0.2982, Accuracy: 85.66%, Validation Loss: 0.0842, Validation Accuracy: 83.29%
    - Early stopping triggered

- **FOLD 3**
    - Epoch 1: Loss: 0.8433, Accuracy: 54.60%, Validation Loss: 0.1584, Validation Accuracy: 68.56%
    - Epoch 2: Loss: 0.5808, Accuracy: 71.05%, Validation Loss: 0.1287, Validation Accuracy: 75.57%
    - ...
    - Epoch 12: Loss: 0.2994, Accuracy: 85.54%, Validation Loss: 0.0806, Validation Accuracy: 85.99%
    - Early stopping triggered

- **FOLD 4**
    - Epoch 1: Loss: 0.8240, Accuracy: 56.26%, Validation Loss: 0.1531, Validation Accuracy: 70.67%
    - Epoch 2: Loss: 0.5709, Accuracy: 71.07%, Validation Loss: 0.1467, Validation Accuracy: 72.24%
    - ...
    - Epoch 12: Loss: 0.3014, Accuracy: 85.01%, Validation Loss: 0.0823, Validation Accuracy: 83.71%
    - Early stopping triggered

- **FOLD 5**
    - Epoch 1: Loss: 0.8129, Accuracy: 57.88%, Validation Loss: 0.1424, Validation Accuracy: 71.89%
    - Epoch 2: Loss: 0.5593, Accuracy: 72.74%, Validation Loss: 0.1255, Validation Accuracy: 74.69%
    - ...
    - Epoch 12: Loss: 0.3020, Accuracy: 85.71%, Validation Loss: 0.0712, Validation Accuracy: 85.55%
    - Early stopping triggered

- **Testing**: Test Accuracy: 84.06%

## Conclusion

The Vision Transformer model for brain tumor classification using MRI images demonstrates promising performance with an overall test accuracy of 84.06%. The use of data augmentation, cross-validation, and early stopping contributes to the model's effectiveness in classifying brain tumors. Further improvements can be made by fine-tuning hyperparameters, increasing the dataset size, and exploring more advanced architectures.
