In [1]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from dataset import BrainTumorDataset

In [6]:
# import os
# from PIL import Image
# import torch
# from torch.utils.data import Dataset
# from torchvision import transforms

# class BrainTumorDataset(Dataset):
#     def __init__(self, root_dir, transform=None, img_size=224):
#         self.root_dir = root_dir
#         self.transform = transform
#         self.img_size = img_size
#         self.classes = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
#         self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
#         self.image_paths, self.labels = self.process_images(root_dir)

#     def __len__(self):
#         return len(self.image_paths)
    
#     def __getitem__(self, idx):
#         img_path = self.image_paths[idx]
#         image = Image.open(img_path).convert('RGB')
#         label = self.labels[idx]
        
#         if self.transform:
#             image = self.transform(image)
#         else:
#             image = image.resize((self.img_size, self.img_size))
#             image = transforms.ToTensor()(image)
        
#         return image, label

#     def process_images(self, path):
#         images = []
#         labels = []
#         for category in os.listdir(path):
#             category_path = os.path.join(path, category)
#             if os.path.isdir(category_path):  # Ensure it's a directory
#                 for img_name in os.listdir(category_path):
#                     img_path = os.path.join(category_path, img_name)
#                     if self.is_image_file(img_path):
#                         images.append(img_path)
#                         labels.append(self.class_to_idx[category])
#                     else:
#                         print(f"Warning: Skipping file {img_path} as it is not a valid image.")
#         return images, labels

#     def is_image_file(self, filename):
#         valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff"]
#         if not any(filename.lower().endswith(ext) for ext in valid_image_extensions):
#             return False
#         try:
#             Image.open(filename).verify()
#             return True
#         except (IOError, SyntaxError):
#             return False

In [2]:
# Data Transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [3]:
train_dataset = BrainTumorDataset(root_dir='/Users/nasifsafwan/Downloads/ML/BrainTumorResearch/tumordata/Training/',
                                 transform=data_transforms['train'])
val_dataset = BrainTumorDataset(root_dir='/Users/nasifsafwan/Downloads/ML/BrainTumorResearch/tumordata/Testing/'
                                 ,transform=data_transforms['val'])



In [4]:
train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [5]:
model = timm.create_model('resnet18', pretrained=True, num_classes=4)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [7]:
from tqdm import tqdm

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", ncols=100)  # Add tqdm progress bar
    
    for inputs, labels in train_bar:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        train_bar.set_postfix(loss=running_loss / ((train_bar.n + 1) * train_loader.batch_size))

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')

Epoch 1/10: 100%|███████████████████████████████████████| 90/90 [02:49<00:00,  1.89s/it, loss=0.716]


Epoch 1/10, Loss: 0.7185


Epoch 2/10: 100%|███████████████████████████████████████| 90/90 [02:49<00:00,  1.88s/it, loss=0.366]


Epoch 2/10, Loss: 0.3672


Epoch 3/10: 100%|███████████████████████████████████████| 90/90 [02:49<00:00,  1.88s/it, loss=0.294]


Epoch 3/10, Loss: 0.2952


Epoch 4/10: 100%|███████████████████████████████████████| 90/90 [02:53<00:00,  1.92s/it, loss=0.254]


Epoch 4/10, Loss: 0.2550


Epoch 5/10: 100%|███████████████████████████████████████| 90/90 [07:30<00:00,  5.01s/it, loss=0.203]


Epoch 5/10, Loss: 0.2042


Epoch 6/10: 100%|███████████████████████████████████████| 90/90 [02:49<00:00,  1.88s/it, loss=0.206]


Epoch 6/10, Loss: 0.2063


Epoch 7/10: 100%|███████████████████████████████████████| 90/90 [02:49<00:00,  1.89s/it, loss=0.188]


Epoch 7/10, Loss: 0.1886


Epoch 8/10: 100%|███████████████████████████████████████| 90/90 [02:52<00:00,  1.92s/it, loss=0.207]


Epoch 8/10, Loss: 0.2080


Epoch 9/10: 100%|███████████████████████████████████████| 90/90 [02:53<00:00,  1.93s/it, loss=0.203]


Epoch 9/10, Loss: 0.2041


Epoch 10/10: 100%|██████████████████████████████████████| 90/90 [02:52<00:00,  1.92s/it, loss=0.157]

Epoch 10/10, Loss: 0.1571





In [14]:
from scipy.spatial.distance import directed_hausdorff

# Dice Coefficient
def dice_coefficient(pred, target):
    smooth = 1e-6
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

# Hausdorff Distance
def hausdorff_distance(pred, target):
    pred_points = np.argwhere(pred.cpu().numpy() == 1)
    target_points = np.argwhere(target.cpu().numpy() == 1)
    return max(directed_hausdorff(pred_points, target_points)[0], directed_hausdorff(target_points, pred_points)[0])

# Mean Absolute Error
def mean_absolute_error(pred, target):
    return torch.mean(torch.abs(pred.float() - target.float()))

# Mean Squared Error
def mean_squared_error(pred, target):
    return torch.mean((pred.float() - target.float()) ** 2)

In [16]:
import numpy as np
# Validation step
model.eval()
running_corrects = 0
dice_scores = []
hausdorff_distances = []
mae_scores = []
mse_scores = []
    
val_bar = tqdm(val_loader, desc="Validation", ncols=100)  # Add tqdm progress bar for validation
with torch.no_grad():
    for inputs, labels in val_bar:
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels.data)
            
        # Calculate metrics
        dice_scores.append(dice_coefficient(preds, labels).item())
        hausdorff_distances.append(hausdorff_distance(preds, labels))
        mae_scores.append(mean_absolute_error(preds, labels).item())
        mse_scores.append(mean_squared_error(preds, labels).item())
    
epoch_acc = running_corrects.double() / len(val_loader.dataset)
epoch_dice = np.mean(dice_scores)
epoch_hausdorff = np.mean(hausdorff_distances)
epoch_mae = np.mean(mae_scores)
epoch_mse = np.mean(mse_scores)
    
print(f'Epoch {epoch+1}/{num_epochs}, Accuracy: {epoch_acc:.4f}, Dice: {epoch_dice:.4f}, Hausdorff: {epoch_hausdorff:.4f}, MAE: {epoch_mae:.4f}, MSE: {epoch_mse:.4f}')

Validation: 100%|███████████████████████████████████████████████████| 13/13 [00:29<00:00,  2.30s/it]

Epoch 10/10, Accuracy: 0.7462, Dice: 1.4284, Hausdorff: inf, MAE: 0.3870, MSE: 0.6582





In [17]:
torch.save(model.state_dict(), 'resnet_model.pth')