<h1>Alexnet Majority Vote</h1>
We start by recreating the model used in the orgininal MRNet paper to get a baseline for performance
This model uses three Alexnet backbones with a dense classification layer trained on axial, coronal and sagittal MRIs respectively and then uses a majority vote system to determine the final output 

<h3>Model class</h3>
We start by definining the model class 

In [19]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import AlexNet_Weights

class MRNet3(nn.Module):
    
    def __init__(self,use_batchnorm=True):
        super().__init__()
        self.model1 = models.alexnet(weights=AlexNet_Weights.DEFAULT)
        self.model2 = models.alexnet(weights=AlexNet_Weights.DEFAULT)
        self.model3 = models.alexnet(weights=AlexNet_Weights.DEFAULT)
        self.gap = nn.AdaptiveMaxPool2d(1)
        # self.gap = nn.AdaptiveAvgPool2d(1)
        self.use_batchnorm = use_batchnorm
        n = 0.15
        # Dropout for each view's features
        self.dropout_view1 = nn.Dropout(p=n)
        self.dropout_view2 = nn.Dropout(p=n)
        self.dropout_view3 = nn.Dropout(p=n)
        
        print(f"Dropout of {n}")


        classifier_layers_axial = [nn.Linear(256, 256)]
        if self.use_batchnorm:
            classifier_layers_axial.append(nn.BatchNorm1d(256))
        self.classifier1_axial = nn.Sequential(*classifier_layers_axial)

        classifier_layers_coronal = [nn.Linear(256, 256)]
        if self.use_batchnorm:
            classifier_layers_coronal.append(nn.BatchNorm1d(256))
        self.classifier1_coronal = nn.Sequential(*classifier_layers_coronal)

        classifier_layers_sagittal = [nn.Linear(256, 256)]
        if self.use_batchnorm:
            classifier_layers_sagittal.append(nn.BatchNorm1d(256))
        self.classifier1_sagittal = nn.Sequential(*classifier_layers_sagittal)


        # Separate classifier2 for each view
        self.classifier2_axial = nn.Linear(256, 1)
        self.classifier2_coronal = nn.Linear(256, 1)
        self.classifier2_sagittal = nn.Linear(256, 1)


    #New forward pass to deal with batch normalisation

    def forward(self, x): 

        # Separate by view
        axial_views    = [sample[0] for sample in x]
        coronal_views  = [sample[1] for sample in x]
        sagittal_views = [sample[2] for sample in x]

        def process_view(view_list, model, dropout, classifier1, classifier2):
            features = []
            for view in view_list:
                slices, c, h, w = view.size()  # [num_slices, 3, 224, 224]
                view = view.view(slices, c, h, w).to(next(model.parameters()).device)
                feat = model.features(view)                     # [slices, 256, 6, 6]
                feat = self.gap(feat).view(slices, 256)         # [slices, 256]
                feat = torch.max(feat, dim=0)[0]                # [256]
                feat = dropout(feat)
                features.append(feat)
            features = torch.stack(features)                    # [batch_size, 256]
            features = classifier1(features)                    # [batch_size, 256]
            logits = classifier2(features)                      # [batch_size, 1]
            return logits

        logit_axial    = process_view(axial_views,    self.model1, self.dropout_view1, self.classifier1_axial, self.classifier2_axial)
        logit_coronal  = process_view(coronal_views,  self.model2, self.dropout_view2, self.classifier1_coronal, self.classifier2_coronal)
        logit_sagittal = process_view(sagittal_views, self.model3, self.dropout_view3, self.classifier1_sagittal, self.classifier2_sagittal)

        logits = torch.stack([logit_axial, logit_coronal, logit_sagittal], dim=0)  # [3, batch_size, 1]
        probs = torch.sigmoid(logits)                                              # [3, batch_size, 1]
        majority_prob = torch.mean(probs, dim=0)                                   # [batch_size, 1]
    
        return majority_prob

<h3>Loader</h3>
Then we have the code to correctly load and preprocess the data 

In [None]:
import numpy as np
import os
import torch
import torch.nn.functional as F
import torch.utils.data as data
import pandas as pd
from sklearn.model_selection import train_test_split
import kornia.augmentation as K
import random

INPUT_DIM = 224
MAX_PIXEL_VAL = 255
MEAN = 58.09
STDDEV = 49.73

class Dataset3(data.Dataset):
    def __init__(self, data_dir, file_list, labels_dict, device, train=False, augment=False):
        super().__init__()
        self.device = device
        self.data_dir_axial = f"{data_dir}/axial"
        self.data_dir_coronal = f"{data_dir}/coronal"
        self.data_dir_sagittal = f"{data_dir}/sagittal"

        self.paths_axial = [os.path.join(self.data_dir_axial, file) for file in file_list]
        self.paths_coronal = [os.path.join(self.data_dir_coronal, file) for file in file_list]
        self.paths_sagittal = [os.path.join(self.data_dir_sagittal, file) for file in file_list]
        
        self.paths = [self.paths_axial, self.paths_coronal, self.paths_sagittal]
        
        self.labels = [labels_dict[file] for file in file_list]

        neg_weight = np.mean(self.labels)
        self.weights = [neg_weight, 1 - neg_weight]

        self.train = train  #this ensures even when augment = True we never perform data augmentation on the validation/test set
        self.augment = augment              

    def weighted_loss(self, prediction, target, eps: float = 0.0):
        # Ensure target is [batch_size, 1]
        target = target.view(-1, 1)

        # Compute weights for each sample
        weights_npy = np.array([self.weights[int(t.item())] for t in target.flatten()])

        # Reshape weights to [batch_size, 1] to match prediction and target
        weights_tensor = torch.FloatTensor(weights_npy).view(-1, 1).to(target.device)

        smoothed = target * (1 - eps) + (1 - target) * eps # new

        # 3) compute BCE with logits against the *smoothed* targets
        loss = F.binary_cross_entropy_with_logits(prediction, smoothed, weight=weights_tensor) #new
        return loss
        # # Compute loss with weights reshaped to [batch_size, 1]
        # loss = F.binary_cross_entropy_with_logits(prediction, target, weight=weights_tensor)

        # return loss

    def __getitem__(self, index):
        vol_list = []
        for i in range(3):           
            path = self.paths[i][index]
            vol = np.load(path).astype(np.int32)
            pad = int((vol.shape[2] - INPUT_DIM) / 2)
            vol = vol[:, pad:-pad, pad:-pad]
            vol = (vol - np.min(vol)) / (np.max(vol) - np.min(vol)) * MAX_PIXEL_VAL
            vol = (vol - MEAN) / STDDEV
            vol = np.stack((vol,) * 3, axis=1)
            vol_tensor = torch.FloatTensor(vol)  # Keep on CPU
            vol_list.append(vol_tensor)

            # Apply augmentations if train and augment flags are True
            if self.train and self.augment:
                vol_tensor = self.apply_augmentations(vol_tensor)
        
            vol_list.append(vol_tensor)

        label_tensor = torch.FloatTensor([self.labels[index]])  # Shape: [1]
        return vol_list, label_tensor
    
    def apply_augmentations(self, vol_tensor):
        # Apply same augmentations slice-wise
        vol_tensor = K.RandomRotation(degrees=25)(vol_tensor)
        vol_tensor = K.RandomAffine(degrees=0, translate=(25/224, 25/224))(vol_tensor)
        if random.random() > 0.5:
            vol_tensor = K.RandomHorizontalFlip(p=1.0)(vol_tensor)
        return vol_tensor
    

    def __len__(self):
        return len(self.labels)

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable slice counts.
    Returns a list of view tensors and a stacked label tensor.
    """
    vol_lists = [item[0] for item in batch]  # List of [axial, coronal, sagittal] for each sample
    labels = torch.stack([item[1] for item in batch], dim=0)  # Stack labels: [batch_size, 1]
    return vol_lists, labels

def load_data3(device, data_dir, labels_csv, batch_size=1, augment=False):
    labels_df = pd.read_csv(labels_csv, header=None, names=['filename', 'label'])
    labels_df['filename'] = labels_df['filename'].apply(lambda x: f"{int(x):04d}.npy")
    
    # Filter files that exist in all 3 views
    valid_files = []
    valid_labels = []
    for _, row in labels_df.iterrows():
        fname = row['filename']
        exists_all_views = all(os.path.exists(os.path.join(data_dir, view, fname)) for view in ['axial', 'coronal', 'sagittal'])
        if exists_all_views:
            valid_files.append(fname)
            valid_labels.append(row['label'])
    
    labels_dict = dict(zip(valid_files, valid_labels))

    # Stratify split
    train_files, valid_files = train_test_split(
        valid_files,
        test_size=0.2,
        random_state=42,
        stratify=valid_labels
    )

    train_dataset = Dataset3(data_dir, train_files, labels_dict, device, train=True, augment=augment)
    valid_dataset = Dataset3(data_dir, valid_files, labels_dict, device, train=False, augment=False)

    train_loader = data.DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        num_workers=0, 
        shuffle=True, 
        pin_memory=device.type == 'cuda',
        collate_fn=custom_collate_fn
    )

    valid_loader = data.DataLoader(
        valid_dataset, 
        batch_size=batch_size, 
        num_workers=0, 
        shuffle=False, 
        pin_memory=device.type == 'cuda',
        collate_fn=custom_collate_fn
    )

    print(f"Training samples: {len(train_dataset)}, Validation samples: {len(valid_dataset)}")
    return train_loader, valid_loader

def load_data_test(device, data_dir, labels_csv, batch_size=1, label_smoothing=0):
    labels_df = pd.read_csv(labels_csv, header=None, names=['filename', 'label'])
    labels_df['filename'] = labels_df['filename'].apply(lambda x: f"{int(x):04d}.npy")
    labels_dict = dict(zip(labels_df['filename'], labels_df['label']))

    test_files = [f for f in os.listdir(f"{data_dir}/axial") if f.endswith(".npy")]
    test_files = [f for f in test_files if f in labels_dict]
    test_files.sort()

    test_dataset = MRDataset(data_dir, test_files, labels_dict, device, train=False, label_smoothing=label_smoothing, augment=False)

    test_loader = data.DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        num_workers=0, 
        shuffle=False, 
        pin_memory=device.type == 'cuda',
        collate_fn=collate_fn
    )

    return test_loader


<h3>Training and evaluation functions</h3>
Lastly we define the functions to run training and evaluation 

In [None]:
import argparse
import matplotlib.pyplot as plt
import os
import numpy as np
import torch
from sklearn import metrics
from torch.autograd import Variable
from tqdm import tqdm

def get_device(use_gpu, use_mps):
    
    if use_gpu and torch.cuda.is_available():
        return torch.device("cuda")
    
    elif use_mps and torch.backends.mps.is_available():
        return torch.device("mps")
    
    else:
        return torch.device("cpu")

def run_model(model, loader, train=False, optimizer=None, eps: float = 0.0):
    """
    model    : your MRNet3 instance
    loader   : DataLoader returning (vol_lists, label)
    train    : whether to do optimizer.step()
    optimizer: your Adam optimizer (only used if train=True)
    eps      : label-smoothing factor (0.0 = no smoothing)
    """
    preds = []
    labels = []
    total_loss = 0.0
    num_batches = 0

    if train:
        model.train()
    else:
        model.eval()

    device = loader.dataset.device

    for vol_lists, label in tqdm(loader, desc="Processing batches", total=len(loader)):
        # Move data to device
        label = label.to(device)                       # [batch_size,1]
        vol_lists = [[view.to(device) for view in views] for views in vol_lists]

        # Forward
        logits = model(vol_lists)                      # [batch_size,1]
        probs  = torch.sigmoid(logits)                 # [batch_size,1]

        # Loss
        if train:
            # use smoothing in training
            loss = loader.dataset.weighted_loss(logits, label, eps=eps)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            # no smoothing in val/test
            loss = loader.dataset.weighted_loss(logits, label, eps=0.0)

        # Accumulate
        total_loss += loss.item()
        preds.extend(probs.detach().cpu().view(-1).tolist())
        labels.extend(label.detach().cpu().view(-1).tolist())
        num_batches += 1

    avg_loss = total_loss / num_batches
    fpr, tpr, _ = metrics.roc_curve(labels, preds)
    auc = metrics.auc(fpr, tpr)

    return avg_loss, auc, preds, labels

def evaluate(split, model_path, use_gpu, mps, data_dir, labels_csv):
    device = get_device(use_gpu, mps)
    print(f"Using device: {device}")
    
    
    if split == 'test' or split == 'valid':
        train_loader, valid_loader = load_data3(device, data_dir, labels_csv)
    elif split == 'train':
        train_loader = load_data_train(device, data_dir, labels_csv, augment=True)
    else:
        raise ValueError("split must be 'train', 'valid', or 'test'")
    
    model = MRNet3()
    
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    model = model.to(device)
    
    if split == 'train':
        loader = train_loader
    elif split == 'valid':
        loader = valid_loader
    elif split == 'test':
        loader = test_loader

    loss, auc, preds, labels = run_model(model, loader, train=False)
    print(f'{split} loss: {loss:.4f}')
    print(f'{split} AUC: {auc:.4f}')
    return preds, labels


In [22]:
import argparse
import json
import numpy as np
import os
import torch
from datetime import datetime
from pathlib import Path
from sklearn import metrics

def get_device(use_gpu, use_mps):
    if use_gpu and torch.cuda.is_available():
        return torch.device("cuda")
    elif use_mps and torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

def train3(rundir, epochs, learning_rate, use_gpu, use_mps, data_dir, labels_csv, weight_decay, max_patience, batch_size, augment, epsilon):
    device = get_device(use_gpu, use_mps)
    print(f"Using device: {device}")
    train_loader, valid_loader = load_data3(device, data_dir, labels_csv, batch_size=batch_size, augment=augment)
    

    #This now deals with the case that batch size is 1
    use_batchnorm = batch_size > 1
    model = MRNet3(use_batchnorm=use_batchnorm)
    model = model.to(device)

    print(f"Using BatchNorm: {use_batchnorm}")

    optimizer = torch.optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=max_patience, factor=.3, threshold=1e-4)

    best_val_loss = float('inf')

    start_time = datetime.now()

    print(f"Value of eps:{epsilon}")
    for epoch in range(epochs):
        change = datetime.now() - start_time
        print('starting epoch {}. time passed: {}'.format(epoch+1, str(change)))
        
        train_loss, train_auc, _, _ = run_model(model, train_loader, train=True, optimizer=optimizer, eps=epsilon)
        print(f'train loss: {train_loss:0.4f}')
        print(f'train AUC: {train_auc:0.4f}')

        val_loss, val_auc, _, _ = run_model(model, valid_loader, eps=0.0)
        print(f'valid loss: {val_loss:0.4f}')
        print(f'valid AUC: {val_auc:0.4f}')

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            file_name = f'val{val_loss:0.4f}_train{train_loss:0.4f}_epoch{epoch+1}'
            save_path = Path(rundir) / file_name 
            # -torch.save(model.state_dict(), save_path)

        # Log metrics to file
        with open(os.path.join(rundir, 'metrics.txt'), 'a') as f:
            f.write(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, train_auc={train_auc:.4f}, val_auc={val_auc:.4f}\n")


<h3>Training</h3>
Now we train the model we found the best parameters to be (parameters). As in the original paper, each sample was randomly rotated by an angle between -25 adn 25 degrees, randomly translated by up to 25 pixels and flipped horizontaly with probability 50%

In [23]:
rundir =  "/Users/matteobruno/Desktop/runs"  #"directory/to/store/runs"
data_dir = "/Users/matteobruno/Desktop/MRNet-v1.0/train" #"Directory/containing/.npy_files'"
labels_csv =  "/Users/matteobruno/Desktop/MRNet-v1.0/train/train-acl.csv" #"Path/to/labels/CSV/file"
seed = 42
gpu = False #If true runs on Nvidia GPU
mps = True #If true runs on Apple MPS
learning_rate = 1e-05
weight_decay = 0.01
epochs = 50
max_patience = 5
factor = 0.3 
batch_size = 1 
eps = 0.0 #Label smoothing factor (0.0 = no smoothing)'
augment = True  #Apply data augmentation during training


np.random.seed(seed)
torch.manual_seed(seed)
if gpu and torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
elif mps and torch.backends.mps.is_available():
    pass

os.makedirs(rundir, exist_ok=True)

# Save parameters to args.json
params = {
    "rundir": rundir,
    "data_dir": data_dir,
    "labels_csv": labels_csv,
    "seed": seed,
    "gpu": gpu,
    "mps": mps,
    "learning_rate": learning_rate,
    "weight_decay": weight_decay,
    "epochs": epochs,
    "max_patience": max_patience,
    "batch_size": batch_size,
    "label_smoothing": eps,
    "augment": augment
}
with open(Path(rundir) / 'args.json', 'w') as out:
    json.dump(params, out, indent=4)

    train3(rundir, epochs, learning_rate, 
        gpu, mps, data_dir, labels_csv, weight_decay, max_patience, batch_size, augment, eps)

Using device: mps
Training samples: 904, Validation samples: 226
Dropout of 0.15
Using BatchNorm: False
Value of eps:0.0
starting epoch 1. time passed: 0:00:00.000013


Processing batches:   3%|▎         | 30/904 [00:07<03:26,  4.24it/s]


KeyboardInterrupt: 

In [None]:
model_path = "path/to/your/model.pth"  # Path to the saved model
split = "test"  # or "train", "valid"
data_dir = "/Users/matteobruno/Desktop/MRNet-v1.0/test" #"Directory/containing/.npy_files'"
labels_csv =  "/Users/matteobruno/Desktop/MRNet-v1.0/train/train-acl.csv" #"Path/to/labels/CSV/file"
gpu = False #If true runs on Nvidia GPU
mps = True #If true runs on Apple MPS


# Save parameters to args.json
params = {
    "model_path": model_path,
    "split": split,
    "data_dir": data_dir,
    "labels_csv": labels_csv,
    "gpu": gpu,
    "mps": mps
}

evaluate(split, model_path, gpu, mps, data_dir, labels_csv)

usage: ipykernel_launcher.py [-h] --model_path MODEL_PATH --split SPLIT
                             --diagnosis DIAGNOSIS --data_dir DATA_DIR
                             --labels_csv LABELS_CSV [--gpu] [--mps]
ipykernel_launcher.py: error: the following arguments are required: --model_path, --split, --diagnosis, --data_dir, --labels_csv


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
