ResNet50 model for classification of T4 and Non-T4 laryngeal cancer

In [None]:
!pip install torch torchvision monai torchio pandas SimpleITK scikit-learn matplotlib

In [None]:
!pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0

In [None]:
import os, numpy as np, pandas as pd, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.metrics import classification_report, roc_auc_score, roc_curve
from sklearn.model_selection import StratifiedKFold
import torchio as tio
import SimpleITK as sitk
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
import torchvision



# CONFIG
IMAGE_DIR = "/data/cropped_nrrds"
LABEL_FILE = "/data/annotations/LaryngealCT_metadata.xlsx"
TARGET_SHAPE = (32, 96, 96)
BATCH_SIZE = 4
EPOCHS = 100
PATIENCE = 10
LR = 1e-4
NUM_FOLDS = 5
OUTPUT_DIR = "/data/resnet50_cv_augmented"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# FOCAL LOSS
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha.float() if alpha is not None else None
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        if self.alpha is not None:
            at = self.alpha[targets]
            loss = at * (1 - pt) ** self.gamma * ce_loss
        else:
            loss = (1 - pt) ** self.gamma * ce_loss
        return loss.mean() if self.reduction == 'mean' else loss.sum()

# DATASET
class CTSubjectDataset(tio.SubjectsDataset):
    def __init__(self, df, augment=False):
        subjects = []
        self.df = df.reset_index(drop=True)
        for _, row in self.df.iterrows():
            img_path = os.path.join(IMAGE_DIR, row['Filename'])
            img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)).astype(np.float32)
            img = np.clip(img, -300, 300)
            img = (img - img.mean()) / (img.std() + 1e-5)
            img = zoom(img, [t / s for t, s in zip(TARGET_SHAPE, img.shape)], order=1)
            img = np.expand_dims(img, 0)  # Shape: [1, D, H, W]

            subject = tio.Subject(
                image=tio.ScalarImage(tensor=torch.tensor(img), type=tio.INTENSITY),
                label=int(row['Label'])
            )
            subjects.append(subject)

        if augment:
            transform = tio.Compose([
                tio.RandomAffine(scales=(0.9, 1.1), degrees=10, translation=5, p=0.7),
                tio.RandomElasticDeformation(num_control_points=5, max_displacement=3.0, p=0.5),
                tio.RandomGamma(log_gamma=(-0.3, 0.3), p=0.5),
                tio.RandomNoise(p=0.3),
                tio.RandomFlip(axes=('LR',), p=0.5),
            ])
        else:
            transform = None

        super().__init__(subjects, transform=transform)

    def __getitem__(self, index):
        subject = super().__getitem__(index)
        image = subject['image'].data.float()  # shape: [1, D, H, W]
        label = torch.tensor(subject['label'])
        return image, label

from monai.networks.nets import resnet50

class ResNet50_3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = resnet50(
            spatial_dims=3,
            n_input_channels=1,
            num_classes=2,
            pretrained=False
        )

    def forward(self, x):
        return self.model(x)



# DATASET
df = pd.read_excel(LABEL_FILE)[['Study_ID', 'Label']].dropna()
df['Filename'] = df['Study_ID'].astype(str) + "_0000.nrrd"
df['Label'] = df['Label'].map({'Non_T4': 0, 'T4': 1})
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# CV
skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['Label']), 1):
    print(f"\n===== Fold {fold}/{NUM_FOLDS} =====")
    fold_dir = os.path.join(OUTPUT_DIR, f"fold_{fold}")
    os.makedirs(fold_dir, exist_ok=True)

    train_df, val_df = df.iloc[train_idx], df.iloc[val_idx]
    class_weights = 1. / train_df['Label'].value_counts()
    sample_weights = train_df['Label'].map(class_weights).values
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

    train_dataset = CTSubjectDataset(train_df, augment=True)
    val_dataset = CTSubjectDataset(val_df, augment=False)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)


    model = ResNet50_3D().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = FocalLoss(alpha=torch.tensor([1.0, 4.0]).to(device), gamma=2.0)

    best_f1, patience_counter = 0, 0
    metrics = []

    for epoch in range(1, EPOCHS + 1):
        model.train(); total_loss = 0
        for x, y in train_loader:
            optimizer.zero_grad()
            out = model(x.to(device))
            loss = loss_fn(out, y.to(device))
            loss.backward(); optimizer.step()
            total_loss += loss.item()
        train_loss = total_loss / len(train_loader)

        model.eval(); val_loss, preds, trues, probs = 0, [], [], []
        with torch.no_grad():
            for x, y in val_loader:
                out = model(x.to(device))
                val_loss += loss_fn(out, y.to(device)).item()
                probs.extend(torch.softmax(out, dim=1)[:, 1].cpu().tolist())
                preds.extend(torch.argmax(out, dim=1).cpu().tolist())
                trues.extend(y.tolist())
        val_loss /= len(val_loader)

        report = classification_report(trues, preds, output_dict=True, zero_division=0)
        auc = roc_auc_score(trues, probs)
        row = {
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'accuracy': report['accuracy'],
            'precision': report['1']['precision'],
            'recall': report['1']['recall'],
            'f1_score': report['1']['f1-score'],
            'auc': auc
        }
        metrics.append(row)
        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Acc={row['accuracy']:.2%}, Prec={row['precision']:.2f}, Rec={row['recall']:.2f}, F1={row['f1_score']:.2f}, AUC={auc:.2f}")

        # Checkpoints
        if row['f1_score'] > best_f1:
            torch.save(model.state_dict(), os.path.join(fold_dir, "best_f1.pth"))
            best_f1 = row['f1_score']
            patience_counter = 0
        else:
            patience_counter += 1
        if patience_counter >= PATIENCE:
            print("⏹️ Early stopping triggered."); break

    pd.DataFrame(metrics).to_csv(os.path.join(fold_dir, "metrics.csv"), index=False)

In [None]:
import os, torch, pandas as pd, numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, roc_auc_score, ConfusionMatrixDisplay
import torchio as tio
from torch.utils.data import DataLoader
import SimpleITK as sitk
from scipy.ndimage import zoom
from monai.networks.nets import resnet50
import torch.nn as nn

# Config
IMAGE_DIR = "/data/cropped_nrrds"
LABEL_FILE = "/data/annotations/LaryngealCT_metadata.xlsx"
TARGET_SHAPE = (32, 96, 96)
BATCH_SIZE = 2
FOLD = 1 #change the fold number for each fold
OUTPUT_DIR = f"/data/resnet50_cv_augmented/fold_{FOLD}"

# Dataset
class CTDataset(tio.SubjectsDataset):
    def __init__(self, df):
        subjects = []
        for _, row in df.iterrows():
            path = os.path.join(IMAGE_DIR, row['Filename'])
            img = sitk.GetArrayFromImage(sitk.ReadImage(path)).astype(np.float32)
            img = np.clip(img, -300, 300)
            img = (img - img.mean()) / (img.std() + 1e-5)
            img = zoom(img, [t / s for t, s in zip(TARGET_SHAPE, img.shape)], order=1)
            img = np.expand_dims(img, 0)
            subject = tio.Subject(image=tio.ScalarImage(tensor=torch.tensor(img)), label=int(row['Label']))
            subjects.append(subject)
        super().__init__(subjects)

    def __getitem__(self, index):
        s = super().__getitem__(index)
        return s['image'].data.float(), torch.tensor(s['label'])

# Model
class ResNet50_3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = resnet50(spatial_dims=3, n_input_channels=1, num_classes=2, pretrained=False)
    def forward(self, x): return self.model(x)

# Prepare
df = pd.read_excel(LABEL_FILE)[['Study_ID', 'Label']].dropna()
df['Filename'] = df['Study_ID'].astype(str) + "_0000.nrrd"
df['Label'] = df['Label'].map({'Non_T4': 0, 'T4': 1})

metrics_df = pd.read_csv(os.path.join(OUTPUT_DIR, "metrics.csv"))
val_idx = metrics_df.index.tolist()
val_df = df.iloc[val_idx]

val_loader = DataLoader(CTDataset(val_df), batch_size=BATCH_SIZE)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet50_3D().to(device)
model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, "best_f1.pth")))
model.eval()

# Evaluation
probs, preds, trues = [], [], []
with torch.no_grad():
    for x, y in val_loader:
        out = model(x.to(device))
        prob = torch.softmax(out, dim=1)[:, 1]
        preds += torch.argmax(out, dim=1).cpu().tolist()
        probs += prob.cpu().tolist()
        trues += y.tolist()

# Save plots
fpr, tpr, _ = roc_curve(trues, probs)
auc = roc_auc_score(trues, probs)
cm = confusion_matrix(trues, preds)
report = classification_report(trues, preds, output_dict=True)

# ROC Curve
plt.figure()
plt.plot(fpr, tpr, label=f"AUC = {auc:.2f}")
plt.plot([0, 1], [0, 1], 'k--')
plt.title("ROC - ResNet50")
plt.xlabel("FPR"); plt.ylabel("TPR"); plt.grid(True); plt.legend()
plt.savefig(os.path.join(OUTPUT_DIR, "roc_test.png")); plt.close()

# Confusion Matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Non-T4", "T4"])
disp.plot(cmap='Blues')
plt.title("Confusion Matrix - ResNet50")
plt.savefig(os.path.join(OUTPUT_DIR, "confmat_test.png")); plt.close()

print("✅ Evaluation completed and plots saved.")

#repeat the code for all five folds