In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from dataset import BrainTumorDataset

In [2]:
# class BrainTumorDataset(Dataset):
#     def __init__(self, root_dir, transform=None, img_size=224):
#         self.root_dir = root_dir
#         self.transform = transform
#         self.img_size = img_size
#         self.image_paths, self.labels = self.process_images(root_dir)
#         self.classes = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']

#     def __len__(self):
#         return len(self.image_paths)
    
#     def __getitem__(self, idx):
#         img_path = self.image_paths[idx]
#         image = Image.open(img_path).convert('RGB')
#         label = self.classes.index(self.labels[idx])
        
#         if self.transform:
#             image = self.transform(image)
#         else:
#             image = image.resize((self.img_size, self.img_size))
#             image = transforms.ToTensor()(image)
        
#         return image, label

#     def process_images(self, path):
#         images = []
#         labels = []
#         for category in os.listdir(path):
#             category_path = os.path.join(path, category)
#             if os.path.isdir(category_path):  # Ensure it's a directory
#                 for img_name in os.listdir(category_path):
#                     img_path = os.path.join(category_path, img_name)
#                     if self.is_image_file(img_path):
#                         images.append(img_path)
#                         labels.append(category)
#                     else:
#                         print(f"Warning: Skipping file {img_path} as it is not a valid image.")
#         return images, labels

#     def is_image_file(self, filename):
#         valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff"]
#         if not any(filename.lower().endswith(ext) for ext in valid_image_extensions):
#             return False
#         try:
#             Image.open(filename).verify()
#             return True
#         except (IOError, SyntaxError):
#             return False

In [2]:
# Data augmentation and normalization for training
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [3]:
train_dataset = BrainTumorDataset(root_dir='/Users/nasifsafwan/Downloads/ML/BrainTumorResearch/tumordata/Training/',
                                 transform=data_transforms['train'])
val_dataset = BrainTumorDataset(root_dir='/Users/nasifsafwan/Downloads/ML/BrainTumorResearch/tumordata/Testing/'
                                 ,transform=data_transforms['val'])



In [4]:
train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import timm
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR

# Load a pre-trained ResNet model
model = timm.create_model('resnet18', pretrained=True)

# Modify the final classification layer
num_classes = 4  # Number of classes in your dataset
model.fc = nn.Linear(model.fc.in_features, num_classes)


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

#Learning Rate Scheduler
scheduler = StepLR(optimizer, step_size=7, gamma=0.1)

### Fine Tuning

In [6]:
from scipy.spatial.distance import directed_hausdorff

# Dice Coefficient
def dice_coefficient(pred, target):
    smooth = 1e-6
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

# Hausdorff Distance
def hausdorff_distance(pred, target):
    pred_points = np.argwhere(pred.cpu().numpy() == 1)
    target_points = np.argwhere(target.cpu().numpy() == 1)
    return max(directed_hausdorff(pred_points, target_points)[0], directed_hausdorff(target_points, pred_points)[0])

# Mean Absolute Error
def mean_absolute_error(pred, target):
    return torch.mean(torch.abs(pred.float() - target.float()))

# Mean Squared Error
def mean_squared_error(pred, target):
    return torch.mean((pred.float() - target.float()) ** 2)

In [8]:
# #from dataset import BrainTumorDataset
# class FirstSSLModel(nn.Module):
#     def __init__(self, base_model, num_classes):
#         super(FirstSSLModel, self).__init__()
#         self.encoder = nn.Sequential(*list(base_model.children())[:-1])
#         self.fc = nn.Linear(base_model.fc.in_features, num_classes)
        
#     def forward(self, x):
#         x = self.encoder(x)
#         x = torch.flatten(x, 1)
#         x = self.fc(x)
#         return x
# num_classes = 4
# classification_model = FirstSSLModel(base_model, num_classes)

# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(classification_model.parameters(), lr = 0.001)

In [11]:
from tqdm import tqdm

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", ncols=100)
    
    for inputs, labels in train_bar:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        train_bar.set_postfix(loss=running_loss / ((train_bar.n + 1) * train_loader.batch_size))

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
    
    # Validation step
    model.eval()
    running_corrects = 0
    
    val_bar = tqdm(val_loader, desc="Validation", ncols=100)
    with torch.no_grad():
        for inputs, labels in val_bar:
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)
    
    epoch_acc = running_corrects.double() / len(val_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Accuracy: {epoch_acc:.4f}')
    
    # Optional: Step the scheduler
    scheduler.step()

Epoch 1/10: 100%|████████████████████████████████████████| 90/90 [02:53<00:00,  1.93s/it, loss=1.23]


Epoch 1/10, Loss: 1.2318


Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.34s/it]


Epoch 1/10, Accuracy: 0.4264


Epoch 2/10: 100%|███████████████████████████████████████| 90/90 [02:52<00:00,  1.91s/it, loss=0.872]


Epoch 2/10, Loss: 0.8755


Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.36s/it]


Epoch 2/10, Accuracy: 0.4695


Epoch 3/10: 100%|███████████████████████████████████████| 90/90 [02:55<00:00,  1.95s/it, loss=0.627]


Epoch 3/10, Loss: 0.6299


Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.35s/it]


Epoch 3/10, Accuracy: 0.5152


Epoch 4/10: 100%|███████████████████████████████████████| 90/90 [02:57<00:00,  1.97s/it, loss=0.504]


Epoch 4/10, Loss: 0.5057


Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.36s/it]


Epoch 4/10, Accuracy: 0.5482


Epoch 5/10: 100%|███████████████████████████████████████| 90/90 [02:56<00:00,  1.96s/it, loss=0.439]


Epoch 5/10, Loss: 0.4410


Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.35s/it]


Epoch 5/10, Accuracy: 0.5838


Epoch 6/10: 100%|███████████████████████████████████████| 90/90 [02:56<00:00,  1.96s/it, loss=0.372]


Epoch 6/10, Loss: 0.3738


Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.36s/it]


Epoch 6/10, Accuracy: 0.6168


Epoch 7/10: 100%|███████████████████████████████████████| 90/90 [02:58<00:00,  1.99s/it, loss=0.331]


Epoch 7/10, Loss: 0.3324


Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.35s/it]


Epoch 7/10, Accuracy: 0.6497


Epoch 8/10: 100%|████████████████████████████████████████| 90/90 [02:54<00:00,  1.93s/it, loss=0.32]


Epoch 8/10, Loss: 0.3209


Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.34s/it]


Epoch 8/10, Accuracy: 0.6371


Epoch 9/10: 100%|███████████████████████████████████████| 90/90 [02:57<00:00,  1.97s/it, loss=0.304]


Epoch 9/10, Loss: 0.3055


Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.33s/it]


Epoch 9/10, Accuracy: 0.6447


Epoch 10/10: 100%|██████████████████████████████████████| 90/90 [02:54<00:00,  1.94s/it, loss=0.299]


Epoch 10/10, Loss: 0.3000


Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.33s/it]

Epoch 10/10, Accuracy: 0.6599





In [12]:
import numpy as np
# Validation step
model.eval()
running_corrects = 0
dice_scores = []
hausdorff_distances = []
mae_scores = []
mse_scores = []
    
val_bar = tqdm(val_loader, desc="Validation", ncols=100)  # Add tqdm progress bar for validation
with torch.no_grad():
    for inputs, labels in val_bar:
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels.data)
            
        # Calculate metrics
        dice_scores.append(dice_coefficient(preds, labels).item())
        hausdorff_distances.append(hausdorff_distance(preds, labels))
        mae_scores.append(mean_absolute_error(preds, labels).item())
        mse_scores.append(mean_squared_error(preds, labels).item())
    
epoch_acc = running_corrects.double() / len(val_loader.dataset)
epoch_dice = np.mean(dice_scores)
epoch_hausdorff = np.mean(hausdorff_distances)
epoch_mae = np.mean(mae_scores)
epoch_mse = np.mean(mse_scores)
    
print(f'Epoch {epoch+1}/{num_epochs}, Accuracy: {epoch_acc:.4f}, Dice: {epoch_dice:.4f}, Hausdorff: {epoch_hausdorff:.4f}, MAE: {epoch_mae:.4f}, MSE: {epoch_mse:.4f}')

Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.34s/it]

Epoch 10/10, Accuracy: 0.6599, Dice: 1.3938, Hausdorff: inf, MAE: 0.5154, MSE: 0.8740





In [13]:
torch.save(model.state_dict(), 'fine_tuned_resnet18.pth')