ResNet101 model for classification of T4 and Non-T4 laryngeal cancer

In [None]:
!pip install torch torchvision monai torchio pandas SimpleITK scikit-learn matplotlib

In [None]:
!pip install monai==1.2.0 --quiet

In [None]:
import os, numpy as np, pandas as pd, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.metrics import classification_report
from sklearn.model_selection import StratifiedKFold
from collections import Counter
import torchio as tio
import SimpleITK as sitk
from scipy.ndimage import zoom
from monai.networks.nets import ResNet
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="torchio.data.image")

# ----------------------------- CONFIG ----------------------------- #
IMAGE_DIR = "data/cropped_nrrds"
LABEL_FILE = "/data/annotations/LaryngealCT_metadata.xlsx"
TARGET_SHAPE = (32, 96, 96)
BATCH_SIZE = 4
EPOCHS = 100
PATIENCE = 10
LR = 1e-4
NUM_FOLDS = 5
OUTPUT_BASE = "/data/resnet101_cv_results"
os.makedirs(OUTPUT_BASE, exist_ok=True)

# ------------------------- FOCAL LOSS ----------------------------- #
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha.float() if alpha is not None else None
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        if self.alpha is not None:
            at = self.alpha[targets]
            loss = at * (1 - pt) ** self.gamma * ce_loss
        else:
            loss = (1 - pt) ** self.gamma * ce_loss
        return loss.mean() if self.reduction == 'mean' else loss.sum()

# ------------------------ Dataset with Aug ------------------------ #
class CTDataset(Dataset):
    def __init__(self, df, augment=False):
        self.df = df.reset_index(drop=True)
        self.augment = augment
        self.transform = tio.Compose([
            tio.RandomAffine(scales=(0.9, 1.1), degrees=10, translation=5, p=0.7),
            tio.RandomElasticDeformation(num_control_points=5, max_displacement=3.0, p=0.5),
            tio.RandomGamma(log_gamma=(-0.3, 0.3), p=0.5),
            tio.RandomNoise(p=0.3),
            tio.RandomFlip(axes=('LR',), p=0.5)
        ])

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        path = os.path.join(IMAGE_DIR, self.df.iloc[idx]['Filename'])
        img = sitk.GetArrayFromImage(sitk.ReadImage(path)).astype(np.float32)
        img = np.clip(img, -300, 300)
        img = (img - img.mean()) / (img.std() + 1e-5)
        img = zoom(img, [t / s for t, s in zip(TARGET_SHAPE, img.shape)], order=1)
        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)  # [1, D, H, W]

        if self.augment:
            subject = tio.Subject(image=tio.ScalarImage(tensor=img))
            img = self.transform(subject)['image'].data

        return img.clone().detach(), torch.tensor(int(self.df.iloc[idx]['Label']))

# ---------------------- ResNet101 Model ---------------------- #
class ResNet101_3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = ResNet(
            block="basic",
            layers=(3, 4, 23, 3),
            block_inplanes=[64, 128, 256, 512],
            spatial_dims=3,
            n_input_channels=1,
            num_classes=2
        )

    def forward(self, x):
        return self.model(x)

# -------------------------- Data Load ----------------------------- #
df = pd.read_excel(LABEL_FILE)[['Study_ID', 'Label']].dropna()
df['Filename'] = df['Study_ID'].astype(str) + "_0000.nrrd"
df['Label'] = df['Label'].map({'Non_T4': 0, 'T4': 1})
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
for fold_id, (train_idx, val_idx) in enumerate(skf.split(df, df['Label']), 1):
    print(f"\n===== Fold {fold_id}/{NUM_FOLDS} =====")
    fold_output = os.path.join(OUTPUT_BASE, f"fold_{fold_id}")
    os.makedirs(fold_output, exist_ok=True)

    train_df, val_df = df.iloc[train_idx], df.iloc[val_idx]
    class_counts = train_df['Label'].value_counts()
    weights = 1. / class_counts
    sample_weights = train_df['Label'].map(weights).values
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

    train_loader = DataLoader(CTDataset(train_df, augment=True), batch_size=BATCH_SIZE, sampler=sampler)
    val_loader = DataLoader(CTDataset(val_df, augment=False), batch_size=BATCH_SIZE)

    model = ResNet101_3D().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = FocalLoss(alpha=torch.tensor([1.0, 4.0]).to(device), gamma=2.0)

    best_f1, patience, metrics = 0, 0, []
    checkpoint_path = os.path.join(fold_output, "checkpoint.pth")

    for epoch in range(1, EPOCHS + 1):
        model.train(); total_loss = 0
        for i, (x, y) in enumerate(train_loader):
            optimizer.zero_grad()
            out = model(x.to(device))
            loss = loss_fn(out, y.to(device))
            loss.backward(); optimizer.step()
            total_loss += loss.item()
        train_loss = total_loss / len(train_loader)

        model.eval(); val_loss, preds, trues = 0, [], []
        with torch.no_grad():
            for x, y in val_loader:
                out = model(x.to(device))
                val_loss += loss_fn(out, y.to(device)).item()
                preds.extend(torch.argmax(out, dim=1).cpu().tolist())
                trues.extend(y.tolist())
        val_loss /= len(val_loader)

        report = classification_report(trues, preds, output_dict=True, zero_division=0)
        row = {
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'accuracy': report['accuracy'],
            'precision': report['1']['precision'],
            'recall': report['1']['recall'],
            'f1_score': report['1']['f1-score']
        }
        metrics.append(row)
        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Acc={row['accuracy']:.2%}, Prec={row['precision']:.2f}, Rec={row['recall']:.2f}, F1={row['f1_score']:.2f}")

        if row['f1_score'] > best_f1:
            best_f1 = row['f1_score']
            torch.save(model.state_dict(), os.path.join(fold_output, "best_model.pth"))
            patience = 0
        else:
            patience += 1

        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_f1': best_f1
        }, checkpoint_path)

        if patience >= PATIENCE:
            print("⏹️ Early stopping triggered.")
            break

    pd.DataFrame(metrics).to_csv(os.path.join(fold_output, "metrics.csv"), index=False)


In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import torchio as tio
import SimpleITK as sitk
from scipy.ndimage import zoom
from monai.networks.nets import ResNet

# CONFIG
IMAGE_DIR = "/data/cropped_nrrds"
LABEL_FILE = "/data/annotations/LaryngealCT_metadata.xlsx"
OUTPUT_BASE = "/data/resnet101_cv_results"
TARGET_SHAPE = (32, 96, 96)
BATCH_SIZE = 4
NUM_FOLDS = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset
class CTDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        path = os.path.join(IMAGE_DIR, self.df.iloc[idx]['Filename'])
        img = sitk.GetArrayFromImage(sitk.ReadImage(path)).astype(np.float32)
        img = np.clip(img, -300, 300)
        img = (img - img.mean()) / (img.std() + 1e-5)
        img = zoom(img, [t / s for t, s in zip(TARGET_SHAPE, img.shape)], order=1)
        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)
        return img, int(self.df.iloc[idx]['Label'])

# Model
class ResNet101_3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = ResNet(
            block="basic",
            layers=(3, 4, 23, 3),
            block_inplanes=[64, 128, 256, 512],
            spatial_dims=3,
            n_input_channels=1,
            num_classes=2
        )
    def forward(self, x): return self.model(x)

# Load label file
df = pd.read_excel(LABEL_FILE)[['Study_ID', 'Label']].dropna()
df['Filename'] = df['Study_ID'].astype(str) + "_0000.nrrd"
df['Label'] = df['Label'].map({'Non_T4': 0, 'T4': 1})

# K-Fold Evaluation
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
for fold_id, (_, val_idx) in enumerate(skf.split(df, df['Label']), 1):
    print(f"🔍 Evaluating Fold {fold_id}")
    val_df = df.iloc[val_idx]
    val_loader = DataLoader(CTDataset(val_df), batch_size=BATCH_SIZE)

    fold_dir = os.path.join(OUTPUT_BASE, f"fold_{fold_id}")
    model = ResNet101_3D().to(device)
    model.load_state_dict(torch.load(os.path.join(fold_dir, "best_model.pth")))
    model.eval()

    probs, preds, trues = [], [], []
    with torch.no_grad():
        for x, y in val_loader:
            out = model(x.to(device))
            soft = torch.softmax(out, dim=1)
            probs.extend(soft[:, 1].cpu().numpy())
            preds.extend(torch.argmax(soft, dim=1).cpu().numpy())
            trues.extend(y.numpy())

    # AUC
    auc = roc_auc_score(trues, probs)
    with open(os.path.join(fold_dir, "auc.txt"), "w") as f:
        f.write(f"AUC: {auc:.4f}")

    # ROC Curve
    fpr, tpr, _ = roc_curve(trues, probs)
    plt.figure()
    plt.plot(fpr, tpr, label=f"AUC={auc:.2f}")
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel("FPR"); plt.ylabel("TPR")
    plt.title(f"Fold {fold_id} ROC"); plt.legend()
    plt.savefig(os.path.join(fold_dir, "roc.png"))
    plt.close()

    # Confusion Matrix
    cm = confusion_matrix(trues, preds)
    disp = ConfusionMatrixDisplay(cm, display_labels=['Non-T4', 'T4'])
    disp.plot(cmap='Blues')
    plt.title(f"Fold {fold_id} Confusion Matrix")
    plt.savefig(os.path.join(fold_dir, "confmat.png"))
    plt.close()


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

base_dir = "/data/resnet101_cv_results"
num_folds = 5
summary_rows = []

for fold in range(1, num_folds + 1):
    metrics_path = os.path.join(base_dir, f"fold_{fold}", "metrics.csv")
    if not os.path.exists(metrics_path):
        continue
    df = pd.read_csv(metrics_path)
    best_epoch = df.loc[df["f1_score"].idxmax()]
    summary_rows.append({
        "Fold": fold,
        "Epoch": int(best_epoch["epoch"]),
        "Accuracy": best_epoch["accuracy"],
        "Precision": best_epoch["precision"],
        "Recall": best_epoch["recall"],
        "F1-score": best_epoch["f1_score"],
        "AUC": best_epoch.get("auc", None)
    })

summary_df = pd.DataFrame(summary_rows)
display(summary_df)

melted = summary_df.melt(id_vars=["Fold"], value_vars=["Accuracy", "Precision", "Recall", "F1-score", "AUC"])
plt.figure(figsize=(10, 6))
sns.barplot(data=melted, x="variable", y="value", hue="Fold")
plt.title("ResNet101 Fold-wise Performance")
plt.ylabel("Score")
plt.ylim(0, 1)
plt.legend(title="Fold")
plt.tight_layout()
plt.show()
