In [2]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

from datetime import datetime
from sklearn.metrics import confusion_matrix, f1_score
from sklearn.model_selection import train_test_split
from PIL import Image
from torch.utils.data import DataLoader, Dataset


from torchvision import datasets, transforms

In [3]:
def get_device():
    """Get the device to use for training."""
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")
    
print(f"Using device: {get_device()}")

Using device: mps


In [4]:
class MRIDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [5]:
def load_oasis_data(data_dir="../OasisImages"):
    image_paths = []
    labels = []
    label_names = []

    class_dirs = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])

    print(f"Found {len(class_dirs)} classes: {class_dirs}")

    for class_idx, class_name in enumerate(class_dirs):
        class_path = os.path.join(data_dir, class_name)
        label_names.append(class_name)

        class_images = []
        for img_file in os.listdir(class_path):
            if img_file.lower().endswith((".jpg")):
                image_paths.append(os.path.join(class_path, img_file))
                labels.append(class_idx)
                class_images.append(img_file)

        print(f"Class '{class_name}' (label {class_idx}): {len(class_images)} images")

    print(f"Total images loaded: {len(image_paths)}")
    return image_paths, labels, label_names



# Model

In [6]:
class OasisModel(nn.Module):
    def __init__(self, num_classes=4):
        super(OasisModel, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        
        # Adjust based on input size (224x224 -> 14x14 after 4 pooling layers)
        self.fc1 = nn.Linear(256 * 14 * 14, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        
        x = x.view(x.size(0), -1)
        
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x
        

In [7]:
def get_transforms():
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

# Training & Validation

In [25]:
def train(model, train_loader, optimizer, criterion, device, epoch, num_epochs):
    model.train()
    running_loss = 0.0
    total_correct = 0
    total_samples = 0 + 1e-8

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        
        if batch_idx == 0:
            print(f'Epoch: {epoch}/{num_epochs}, Batch: {batch_idx}, ' f'Loss: {loss.item():.6f}, Accuracy: {100.*total_correct/total_samples:.2f}')

    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = 100. * total_correct / total_samples
    return epoch_loss, epoch_accuracy


In [26]:
def validate(model, val_loader, criterion, deivce):
    model.eval()
    validation_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(deivce), target.to(deivce)
            output = model(data)
            validation_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            total_correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += target.size(0)

    validation_loss /= len(val_loader)
    validation_accuracy = 100. * total_correct / total_samples
    return validation_loss, validation_accuracy

In [None]:
def train_oasis_model():
    batch_size = 32
    num_epochs = 10
    learning_rate = 0.001
    device = get_device()

    print(f"Using device: {device}")

    # Load data
    image_paths, labels, label_names = load_oasis_data()

    # Split data into training and validation sets
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=0.2, random_state=42, stratify=labels
    )

    print(f"Training samples: {len(train_paths)}")
    print(f"Validation samples: {len(val_paths)}")

    train_transform, val_transform = get_transforms()

    train_dataset = MRIDataset(train_paths, train_labels, transform=train_transform)
    validation_dataset = MRIDataset(val_paths, val_labels, transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

    model = OasisModel().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    best_val_accuracy = 0.0
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(1, num_epochs + 1):
        train_loss, train_accuracy = train(model, train_loader, optimizer, criterion, device, epoch, num_epochs)
        validation_loss, validation_accuracy = validate(model, val_loader, criterion, device)

        scheduler.step()
        train_losses.append(train_loss)
        val_losses.append(validation_loss)
        train_accuracies.append(train_accuracy)
        val_accuracies.append(validation_accuracy)

        print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, '
              f'Val Loss: {validation_loss:.4f}, Val Acc: {validation_accuracy:.2f}%')
        
        if validation_accuracy > best_val_accuracy:
            best_val_accuracy = validation_accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_accuracy': best_val_accuracy,
                'label_names': label_names
            }, f'best_oasis_model.pth')
            print(f'Best Validation Accuracy: {best_val_accuracy:.2f}%')

    torch.save({
        'model_state_dict': model.state_dict(),
        'label_names': label_names
    }, f"final_oasis_model.pth")

    return model, label_names



model, label_names = train_oasis_model()


# Testing

In [None]:
# Test Model on Random Samples using saved model
import random

def load_saved_model(model_path='best_oasis_model.pth'):
    """Load the saved model from file"""
    device = get_device()
    
    # Load the checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Get label names and number of classes
    label_names = checkpoint['label_names']
    num_classes = len(label_names)
    
    # Initialize model
    model = OasisModel(num_classes=num_classes)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"Model loaded from {model_path}")
    print(f"Best validation accuracy: {checkpoint['best_val_accuracy']:.2f}%")
    print(f"Classes: {label_names}")
    
    return model, label_names

def test_model_random_samples(model, label_names, num_samples=10):
    """Test model on random samples and print confidence scores"""
    device = get_device()
    model.eval()
    
    # Load data and get validation transform
    image_paths, labels, _ = load_oasis_data()
    _, val_transform = get_transforms()
    
    # Randomly select samples
    random_indices = random.sample(range(len(image_paths)), num_samples)
    
    print(f"\nTesting model on {num_samples} random samples:")
    print("=" * 80)
    
    correct_predictions = 0
    
    with torch.no_grad():
        for i, idx in enumerate(random_indices):
            # Load and preprocess image
            image_path = image_paths[idx]
            true_label = labels[idx]
            
            image = Image.open(image_path).convert('RGB')
            image_tensor = val_transform(image).unsqueeze(0).to(device)
            
            # Get model prediction
            output = model(image_tensor)
            probabilities = F.softmax(output, dim=1)
            confidence_scores = probabilities.cpu().numpy()[0]
            
            predicted_class = torch.argmax(output, dim=1).item()
            max_confidence = confidence_scores[predicted_class]
            
            # Check if prediction is correct
            is_correct = predicted_class == true_label
            if is_correct:
                correct_predictions += 1
            
            print(f"Sample {i+1}:")
            print(f"  Image: {os.path.basename(image_path)}")
            print(f"  True Label: {label_names[true_label]} (index {true_label})")
            print(f"  Predicted: {label_names[predicted_class]} (index {predicted_class})")
            print(f"  Confidence: {max_confidence:.4f} ({max_confidence*100:.2f}%)")
            print(f"  Correct: {'✓' if is_correct else '✗'}")
            
            # Print all class probabilities
            print(f"  All class probabilities:")
            for j, class_name in enumerate(label_names):
                print(f"    {class_name}: {confidence_scores[j]:.4f} ({confidence_scores[j]*100:.2f}%)")
            print("-" * 60)
    
    accuracy = correct_predictions / num_samples
    print(f"\nTest Results:")
    print(f"Accuracy: {correct_predictions}/{num_samples} = {accuracy:.2f} ({accuracy*100:.1f}%)")

# Load the saved model and run tests
model, label_names = load_saved_model('best_oasis_model.pth')
test_model_random_samples(model, label_names, num_samples=86000)