In [1]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from dataset import BrainTumorDataset

In [2]:
# Data Transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [3]:
train_dataset = BrainTumorDataset(root_dir='/Users/nasifsafwan/Downloads/ML/BrainTumorResearch/tumordata/Training/',
                                 transform=data_transforms['train'])
val_dataset = BrainTumorDataset(root_dir='/Users/nasifsafwan/Downloads/ML/BrainTumorResearch/tumordata/Testing/'
                                 ,transform=data_transforms['val'])



In [4]:
train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [5]:
model = timm.create_model('vgg16.tv_in1k', pretrained=True, num_classes=4)

model.safetensors:   0%|          | 0.00/553M [00:00<?, ?B/s]

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [9]:
from tqdm import tqdm

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", ncols=100)  # Add tqdm progress bar
    
    for inputs, labels in train_bar:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        train_bar.set_postfix(loss=running_loss / ((train_bar.n + 1) * train_loader.batch_size))

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
    
    # Validation step
    model.eval()
    running_corrects = 0
    
    val_bar = tqdm(val_loader, desc="Validation", ncols=100)
    with torch.no_grad():
        for inputs, labels in val_bar:
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)
    
    epoch_acc = running_corrects.double() / len(val_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Accuracy: {epoch_acc:.4f}')

Epoch 1/10: 100%|████████████████████████████████████████| 90/90 [24:10<00:00, 16.12s/it, loss=1.38]


Epoch 1/10, Loss: 1.3860


Validation: 100%|███████████████████████████████████████████████████| 13/13 [01:30<00:00,  6.96s/it]


Epoch 1/10, Accuracy: 0.2919


Epoch 2/10: 100%|████████████████████████████████████████| 90/90 [22:42<00:00, 15.14s/it, loss=1.35]


Epoch 2/10, Loss: 1.3516


Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:56<00:00,  4.32s/it]


Epoch 2/10, Accuracy: 0.2919


Epoch 3/10:  21%|████████▍                               | 19/90 [03:39<13:41, 11.57s/it, loss=1.36]


KeyboardInterrupt: 

In [None]:
from scipy.spatial.distance import directed_hausdorff

# Dice Coefficient
def dice_coefficient(pred, target):
    smooth = 1e-6
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

# Hausdorff Distance
def hausdorff_distance(pred, target):
    pred_points = np.argwhere(pred.cpu().numpy() == 1)
    target_points = np.argwhere(target.cpu().numpy() == 1)
    return max(directed_hausdorff(pred_points, target_points)[0], directed_hausdorff(target_points, pred_points)[0])

# Mean Absolute Error
def mean_absolute_error(pred, target):
    return torch.mean(torch.abs(pred.float() - target.float()))

# Mean Squared Error
def mean_squared_error(pred, target):
    return torch.mean((pred.float() - target.float()) ** 2)

In [None]:
import numpy as np
# Validation step
model.eval()
running_corrects = 0
dice_scores = []
hausdorff_distances = []
mae_scores = []
mse_scores = []
    
val_bar = tqdm(val_loader, desc="Validation", ncols=100)  # Add tqdm progress bar for validation
with torch.no_grad():
    for inputs, labels in val_bar:
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels.data)
            
        # Calculate metrics
        dice_scores.append(dice_coefficient(preds, labels).item())
        hausdorff_distances.append(hausdorff_distance(preds, labels))
        mae_scores.append(mean_absolute_error(preds, labels).item())
        mse_scores.append(mean_squared_error(preds, labels).item())
    
epoch_acc = running_corrects.double() / len(val_loader.dataset)
epoch_dice = np.mean(dice_scores)
epoch_hausdorff = np.mean(hausdorff_distances)
epoch_mae = np.mean(mae_scores)
epoch_mse = np.mean(mse_scores)
    
print(f'Epoch {epoch+1}/{num_epochs}, Accuracy: {epoch_acc:.4f}, Dice: {epoch_dice:.4f}, Hausdorff: {epoch_hausdorff:.4f}, MAE: {epoch_mae:.4f}, MSE: {epoch_mse:.4f}')

In [None]:
torch.save(model.state_dict(), 'vgg16_model.pth')