3D CNN (5 layered) for classification of T4 and Non-T4 laryngeal cancer

In [None]:
!pip install torch torchvision monai torchio SimpleITK pandas scikit-learn matplotlib

In [None]:
import os, numpy as np, pandas as pd, SimpleITK as sitk
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, roc_curve, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
from scipy.ndimage import zoom
from collections import Counter
import torchio as tio
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# CONFIG
IMAGE_DIR = "/data/cropped_nrrds"
LABEL_FILE = "/data/annotations/LaryngealCT_metadata.xlsx"
TARGET_SHAPE = (32, 96, 96)
BATCH_SIZE = 4
EPOCHS = 100
PATIENCE = 10
NUM_FOLDS = 5
LR = 1e-4
OUTPUT_DIR = "outputs_3dcnn_augmented"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# CNN MODEL
class Deep3DCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(1, 16, 3, padding=1), nn.BatchNorm3d(16), nn.ReLU(), nn.MaxPool3d(2),
            nn.Conv3d(16, 32, 3, padding=1), nn.BatchNorm3d(32), nn.ReLU(), nn.MaxPool3d(2),
            nn.Conv3d(32, 64, 3, padding=1), nn.BatchNorm3d(64), nn.ReLU(), nn.MaxPool3d(2),
            nn.Conv3d(64, 128, 3, padding=1), nn.BatchNorm3d(128), nn.ReLU(), nn.AdaptiveAvgPool3d(1)
        )
        self.fc = nn.Sequential(nn.Flatten(), nn.Linear(128, 2))
    def forward(self, x):
        return self.fc(self.conv(x))

# FOCAL LOSS
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha.float() if alpha is not None else None
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        if self.alpha is not None:
            at = self.alpha[targets]
            loss = at * (1 - pt) ** self.gamma * ce_loss
        else:
            loss = (1 - pt) ** self.gamma * ce_loss
        return loss.mean() if self.reduction == 'mean' else loss.sum()

# DATASET
class CTDataset(Dataset):
    def __init__(self, df, image_dir, target_shape, augment=False):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.target_shape = target_shape
        self.augment = augment
        self.transform = tio.Compose([
            tio.RandomAffine(scales=(0.9, 1.1), degrees=10, translation=5, p=0.7),
            tio.RandomElasticDeformation(num_control_points=7, max_displacement=5.0, p=0.5),
            tio.RandomGamma(log_gamma=(-0.3, 0.3), p=0.5),
            tio.RandomNoise(p=0.3),
            tio.RandomFlip(axes=('LR',), p=0.5)
        ])
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(self.image_dir, row['Filename']))).astype(np.float32)
        image = np.clip(image, -300, 300)
        image = (image - image.mean()) / (image.std() + 1e-5)
        image = zoom(image, [t / s for t, s in zip(self.target_shape, image.shape)], order=1)
        image = np.expand_dims(image, 0)
        if self.augment:
            image = self.transform(tio.Image(tensor=image, type=tio.INTENSITY)).tensor
        return torch.tensor(image, dtype=torch.float32), torch.tensor(int(row['Label']))

# MAIN TRAINING
df = pd.read_excel(LABEL_FILE)[['Study_ID', 'Label']].dropna()
df['Filename'] = df['Study_ID'].astype(str) + "_0000.nrrd"
df['Label'] = df['Label'].map({'Non_T4': 0, 'T4': 1})

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)

for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['Label'])):
    print(f"\n🔁 Fold {fold+1}/{NUM_FOLDS}")
    train_df, val_df = df.iloc[train_idx], df.iloc[val_idx]

    # Weighted Sampling
    class_counts = train_df['Label'].value_counts()
    class_weights = 1. / class_counts
    sample_weights = train_df['Label'].map(class_weights).values
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

    train_loader = DataLoader(CTDataset(train_df, IMAGE_DIR, TARGET_SHAPE, augment=True), batch_size=BATCH_SIZE, sampler=sampler)
    val_loader = DataLoader(CTDataset(val_df, IMAGE_DIR, TARGET_SHAPE, augment=False), batch_size=BATCH_SIZE)

    model = Deep3DCNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = FocalLoss(alpha=torch.tensor([1.0, 4.0]).to(device), gamma=2.0)
    best_f1, patience_counter = 0, 0
    metrics = []

    for epoch in range(1, EPOCHS + 1):
        model.train(); total_loss = 0
        for x, y in train_loader:
            optimizer.zero_grad()
            pred = model(x.to(device))
            loss = loss_fn(pred, y.to(device))
            loss.backward(); optimizer.step()
            total_loss += loss.item()
        train_loss = total_loss / len(train_loader)

        # Validation
        model.eval(); val_loss, preds, trues = 0, [], []
        with torch.no_grad():
            for x, y in val_loader:
                out = model(x.to(device))
                val_loss += loss_fn(out, y.to(device)).item()
                preds.extend(torch.argmax(out, dim=1).cpu().tolist())
                trues.extend(y.tolist())
        val_loss /= len(val_loader)

        report = classification_report(trues, preds, output_dict=True, zero_division=0)
        row = {
            'epoch': epoch,
            'train_loss': train_loss, 'val_loss': val_loss,
            'accuracy': report['accuracy'],
            'precision': report['1']['precision'],
            'recall': report['1']['recall'],
            'f1_score': report['1']['f1-score']
        }
        metrics.append(row)
        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, "
              f"Acc={row['accuracy']:.2%}, Prec={row['precision']:.2f}, "
              f"Rec={row['recall']:.2f}, F1={row['f1_score']:.2f}")

        # Early stopping
        if row['f1_score'] > best_f1:
            best_f1 = row['f1_score']
            patience_counter = 0
            torch.save(model.state_dict(), f"{OUTPUT_DIR}/fold_{fold+1}_best.pth")
        else:
            patience_counter += 1
        if patience_counter >= PATIENCE:
            print("⏹️ Early stopping triggered.")
            break

    # Save metrics
    pd.DataFrame(metrics).to_csv(f"{OUTPUT_DIR}/fold_{fold+1}_metrics.csv", index=False)

    # ROC + Confusion Matrix
    model.load_state_dict(torch.load(f"{OUTPUT_DIR}/fold_{fold+1}_best.pth"))
    model.eval(); probs, y_true = [], []
    with torch.no_grad():
        for x, y in val_loader:
            out = model(x.to(device))
            probs.extend(torch.softmax(out, dim=1)[:, 1].cpu().tolist())
            y_true.extend(y.tolist())
    fpr, tpr, _ = roc_curve(y_true, probs)
    auc = roc_auc_score(y_true, probs)
    plt.figure(); plt.plot(fpr, tpr, label=f"AUC={auc:.2f}")
    plt.plot([0,1], [0,1], 'k--'); plt.legend()
    plt.title(f"Fold {fold+1} ROC"); plt.xlabel("FPR"); plt.ylabel("TPR"); plt.grid()
    plt.savefig(f"{OUTPUT_DIR}/fold_{fold+1}_roc.png"); plt.close()

    y_pred = [1 if p > 0.5 else 0 for p in probs]
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(cm, display_labels=['Non-T4', 'T4'])
    disp.plot(cmap='Blues'); plt.title(f"Confusion Matrix Fold {fold+1}")
    plt.savefig(f"{OUTPUT_DIR}/fold_{fold+1}_confmat.png"); plt.close()
