ResNet18 model for classification of T4 and Non-T4 laryngeal cancer

In [None]:
# Install required libraries (Run once)
!pip install monai torch torchvision pandas openpyxl scikit-learn torchio --quiet

In [None]:
!pip install monai==1.2.0 --quiet

In [None]:
# Imports
import os, numpy as np, pandas as pd, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix, ConfusionMatrixDisplay
import torchio as tio
import SimpleITK as sitk
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from monai.networks.nets import resnet18

# Config
IMAGE_DIR = "/data/cropped_nrrds"
LABEL_FILE = "/data/annotations/LaryngealCT_metadata.xlsx"
OUTPUT_DIR = "/data/resnet18_cv_results"
TARGET_SHAPE = (32, 96, 96)
EPOCHS, PATIENCE, LR, BATCH_SIZE, FOLDS = 100, 10, 1e-4, 4, 5
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha.float() if alpha is not None else None
        self.gamma = gamma
    def forward(self, inputs, targets):
        ce = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce)
        if self.alpha is not None:
            at = self.alpha[targets]
            loss = at * (1 - pt) ** self.gamma * ce
        else:
            loss = (1 - pt) ** self.gamma * ce
        return loss.mean()

# Dataset
class CTSubjectDataset(tio.SubjectsDataset):
    def __init__(self, df, augment=False):
        subjects = []
        for _, row in df.iterrows():
            path = os.path.join(IMAGE_DIR, row['Filename'])
            img = sitk.GetArrayFromImage(sitk.ReadImage(path)).astype(np.float32)
            img = np.clip(img, -300, 300)
            img = (img - img.mean()) / (img.std() + 1e-5)
            img = zoom(img, [t / s for t, s in zip(TARGET_SHAPE, img.shape)], order=1)
            img = np.expand_dims(img, 0)
            subject = tio.Subject(image=tio.ScalarImage(tensor=torch.tensor(img)), label=int(row['Label']))
            subjects.append(subject)
        transform = tio.Compose([
            tio.RandomAffine(scales=(0.9, 1.1), degrees=10, translation=5, p=0.7),
            tio.RandomElasticDeformation(num_control_points=5, max_displacement=3, p=0.5),
            tio.RandomGamma(log_gamma=(-0.3, 0.3), p=0.5),
            tio.RandomNoise(p=0.3),
            tio.RandomFlip(axes=('LR',), p=0.5)
        ]) if augment else None
        super().__init__(subjects, transform=transform)
    def __getitem__(self, i):
        subj = super().__getitem__(i)
        return subj['image'].data.float(), torch.tensor(subj['label'])

# Model
class ResNet18_3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = resnet18(spatial_dims=3, n_input_channels=1, num_classes=2, pretrained=False)
    def forward(self, x): return self.model(x)

# Data
df = pd.read_excel(LABEL_FILE)[['Study_ID', 'Label']].dropna()
df['Filename'] = df['Study_ID'].astype(str) + "_0000.nrrd"
df['Label'] = df['Label'].map({'Non_T4': 0, 'T4': 1})

# Cross-Validation Loop
skf = StratifiedKFold(n_splits=FOLDS, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['Label']), 1):
    print(f"\n📂 Fold {fold}")
    fold_dir = os.path.join(OUTPUT_DIR, f"fold_{fold}")
    os.makedirs(fold_dir, exist_ok=True)
    train_df, val_df = df.iloc[train_idx], df.iloc[val_idx]

    weights = 1. / train_df['Label'].value_counts()
    sampler = WeightedRandomSampler(train_df['Label'].map(weights).values, len(train_df), replacement=True)

    train_loader = DataLoader(CTSubjectDataset(train_df, augment=True), batch_size=BATCH_SIZE, sampler=sampler)
    val_loader = DataLoader(CTSubjectDataset(val_df), batch_size=BATCH_SIZE)

    model = ResNet18_3D().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = FocalLoss(alpha=torch.tensor([1.0, 4.0]).to(device))
    metrics, best_f1, patience_counter = [], 0, 0

    for epoch in range(1, EPOCHS + 1):
        model.train(); total_loss = 0
        for x, y in train_loader:
            optimizer.zero_grad()
            out = model(x.to(device))
            loss = loss_fn(out, y.to(device))
            loss.backward(); optimizer.step()
            total_loss += loss.item()
        train_loss = total_loss / len(train_loader)

        model.eval(); val_loss, preds, probs, trues = 0, [], [], []
        with torch.no_grad():
            for x, y in val_loader:
                out = model(x.to(device))
                val_loss += loss_fn(out, y.to(device)).item()
                probs += torch.softmax(out, 1)[:, 1].cpu().tolist()
                preds += torch.argmax(out, 1).cpu().tolist()
                trues += y.tolist()
        val_loss /= len(val_loader)

        report = classification_report(trues, preds, output_dict=True, zero_division=0)
        auc = roc_auc_score(trues, probs)
        row = {
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'accuracy': report['accuracy'],
            'precision': report['1']['precision'],
            'recall': report['1']['recall'],
            'f1_score': report['1']['f1-score'],
            'auc': auc
        }
        metrics.append(row)
        print(f"📊 Epoch {epoch}: F1={row['f1_score']:.2f}, AUC={auc:.2f}")

        if row['f1_score'] > best_f1:
            best_f1 = row['f1_score']
            torch.save(model.state_dict(), os.path.join(fold_dir, "best_model.pth"))
            patience_counter = 0
        else:
            patience_counter += 1
        if patience_counter >= PATIENCE:
            print("⏹️ Early stopping"); break

    pd.DataFrame(metrics).to_csv(os.path.join(fold_dir, "metrics.csv"), index=False)

    # 🔍 ROC + Confusion Matrix
    fpr, tpr, _ = roc_curve(trues, probs)
    plt.figure(); plt.plot(fpr, tpr, label=f"AUC={auc:.2f}"); plt.plot([0,1],[0,1],'k--')
    plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title(f"Fold {fold} ROC"); plt.legend()
    plt.savefig(os.path.join(fold_dir, "roc.png")); plt.close()

    cm = confusion_matrix(trues, preds)
    disp = ConfusionMatrixDisplay(cm, display_labels=['Non-T4', 'T4'])
    disp.plot(cmap='Blues'); plt.title(f"Fold {fold} Confusion Matrix")
    plt.savefig(os.path.join(fold_dir, "confmat.png")); plt.close()
