In [None]:
# install + mount
!pip -q install pydicom
from google.colab import drive
drive.mount("/content/drive")

# paths
import os
base = "/content/drive/MyDrive/STAT362 Final Project_RSNA"
zip_path = f"{base}/images.zip"
out_dir = "/content/train_images"
img_dir = f"{out_dir}/stage_2_train_images"

# unzip
!mkdir -p /content/train_images
!unzip -n -q "{zip_path}" -d "{out_dir}"

# imports
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pydicom
from PIL import Image
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, balanced_accuracy_score
from tqdm import tqdm
import copy
from torchvision.models import densenet121, DenseNet121_Weights

# apply_voi_lut import (pydicom v3 vs newer)
try:
    from pydicom.pixels import apply_voi_lut
except Exception:
    from pydicom.pixel_data_handlers.util import apply_voi_lut

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
print("device:", device)
if use_cuda:
    print("gpu:", torch.cuda.get_device_name(0))
torch.backends.cudnn.benchmark = True

# load csv
detailed_class = pd.read_csv(f"{base}/stage_2_detailed_class_info.csv")
labels = pd.read_csv(f"{base}/stage_2_train_labels.csv")

# process labels
detailed_class = detailed_class.drop_duplicates(subset=["patientId"])
class_mapping = {"Normal": 0, "Lung Opacity": 1, "No Lung Opacity / Not Normal": 2}
detailed_class["target"] = detailed_class["class"].map(class_mapping)
detailed_class["path"] = detailed_class["patientId"].apply(lambda x: f"{img_dir}/{x}.dcm")
print("total:", len(detailed_class))
print(detailed_class["target"].value_counts())

# dataset
class RSNADataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, "path"]
        label = int(self.df.loc[idx, "target"])

        ds = pydicom.dcmread(img_path)
        img = ds.pixel_array.astype(np.float32)

        try:
            img = apply_voi_lut(img, ds).astype(np.float32)
        except Exception:
            pass

        if getattr(ds, "PhotometricInterpretation", "") == "MONOCHROME1":
            img = img.max() - img

        lo, hi = np.percentile(img, (1, 99))
        img = np.clip(img, lo, hi)
        img = (img - lo) / (hi - lo + 1e-6)

        img = (img * 255.0).astype(np.uint8)
        image = Image.fromarray(img).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

# transforms
IMG_SIZE = 320

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.85, 1.0), ratio=(0.9, 1.1)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

eval_transforms = transforms.Compose([
    transforms.Resize(int(IMG_SIZE * 1.1)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# split
train_df, temp_df = train_test_split(
    detailed_class, test_size=0.3, stratify=detailed_class["target"], random_state=42
)
val_df, test_df = train_test_split(
    temp_df, test_size=0.5, stratify=temp_df["target"], random_state=42
)
print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

# loaders
BATCH_SIZE = 32
NUM_WORKERS = 2  # set to 0 if DataLoader workers crash

train_dataset = RSNADataset(train_df, transform=train_transforms)
val_dataset   = RSNADataset(val_df,   transform=eval_transforms)
test_dataset  = RSNADataset(test_df,  transform=eval_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=use_cuda)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=use_cuda)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=use_cuda)

x0, y0 = next(iter(train_loader))
print("batch:", x0.shape, y0.shape)

# model
model = densenet121(weights=DenseNet121_Weights.DEFAULT)
nf = model.classifier.in_features
model.classifier = nn.Sequential(
    nn.Linear(nf, 512),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(512, 3),
)
model = model.to(device)

# loss (class weights + label smoothing)
class_counts = train_df["target"].value_counts().sort_index().values
class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float32)
class_weights = (class_weights / class_weights.sum()) * len(class_counts)
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device), label_smoothing=0.05)

# optimizer + scheduler
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=1)

# amp
scaler = torch.amp.GradScaler("cuda") if use_cuda else None

# mixup
def mixup(x, y, alpha=0.2):
    if alpha <= 0:
        return x, y, y, 1.0
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0), device=x.device)
    x2 = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return x2, y_a, y_b, lam

# train
num_epochs = 15
patience = 4
best_val_bal = -1.0
best_state = None
no_improve = 0

for epoch in range(num_epochs):
    model.train()
    tr_loss = 0.0
    tr_total = 0
    tr_correct = 0

    loop = tqdm(train_loader, desc=f"train {epoch+1}/{num_epochs}")
    for x, y in loop:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        x, y_a, y_b, lam = mixup(x, y, alpha=0.2)

        optimizer.zero_grad(set_to_none=True)

        if use_cuda:
            with torch.amp.autocast("cuda"):
                out = model(x)
                loss = lam * criterion(out, y_a) + (1 - lam) * criterion(out, y_b)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(x)
            loss = lam * criterion(out, y_a) + (1 - lam) * criterion(out, y_b)
            loss.backward()
            optimizer.step()

        tr_loss += loss.item() * x.size(0)
        tr_total += y.size(0)
        tr_correct += (out.argmax(1) == y).sum().item()
        loop.set_postfix(loss=float(loss.item()))

    tr_loss /= tr_total
    tr_acc = tr_correct / tr_total

    model.eval()
    va_loss = 0.0
    va_total = 0
    va_preds = []
    va_labels = []

    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            if use_cuda:
                with torch.amp.autocast("cuda"):
                    out = model(x)
                    loss = criterion(out, y)
            else:
                out = model(x)
                loss = criterion(out, y)

            va_loss += loss.item() * x.size(0)
            va_total += y.size(0)

            va_preds.append(out.argmax(1).cpu().numpy())
            va_labels.append(y.cpu().numpy())

    va_loss /= va_total
    va_preds = np.concatenate(va_preds)
    va_labels = np.concatenate(va_labels)

    va_acc = (va_preds == va_labels).mean()
    va_bal = balanced_accuracy_score(va_labels, va_preds)

    scheduler.step(va_bal)

    print(f"epoch {epoch+1} train_loss {tr_loss:.4f} train_acc {tr_acc:.4f} val_loss {va_loss:.4f} val_acc {va_acc:.4f} val_bal {va_bal:.4f}")

    if va_bal > best_val_bal:
        best_val_bal = va_bal
        best_state = copy.deepcopy(model.state_dict())
        no_improve = 0
    else:
        no_improve += 1

    if no_improve >= patience:
        print("early stop")
        break

if best_state is not None:
    model.load_state_dict(best_state)
print("best val_bal:", best_val_bal)

# test
model.eval()
te_loss = 0.0
te_total = 0
te_preds = []
te_labels = []

with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        if use_cuda:
            with torch.amp.autocast("cuda"):
                out = model(x)
                loss = criterion(out, y)
        else:
            out = model(x)
            loss = criterion(out, y)

        te_loss += loss.item() * x.size(0)
        te_total += y.size(0)

        te_preds.append(out.argmax(1).cpu().numpy())
        te_labels.append(y.cpu().numpy())

te_loss /= te_total
te_preds = np.concatenate(te_preds)
te_labels = np.concatenate(te_labels)

te_acc = (te_preds == te_labels).mean()
te_bal = balanced_accuracy_score(te_labels, te_preds)
cm = confusion_matrix(te_labels, te_preds)

print("test_loss:", round(float(te_loss), 4))
print("test_acc:", round(float(te_acc), 4))
print("test_balanced_acc:", round(float(te_bal), 4))
print("confusion_matrix:\n", cm)
print("\nclassification_report:\n", classification_report(te_labels, te_preds, digits=4))


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.7/2.4 MB[0m [31m21.4 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.4/2.4 MB[0m [31m34.8 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[?25hMounted at /content/drive
device: cuda
gpu: Tesla T4
total: 26684
target
2    11821
0     8851
1     6012
Name: count, dtype: int64
Train: 18678, Val: 4003, Test: 4003
batch: torch.Size([32, 3, 320, 320]) torch.Size([32])
Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


100%|██████████| 30.8M/30.8M [00:00<00:00, 165MB/s]
train 1/15: 100%|██████████| 584/584 [10:09<00:00,  1.04s/it, loss=0.966]


epoch 1 train_loss 0.8282 train_acc 0.4942 val_loss 0.7073 val_acc 0.6605 val_bal 0.7000


train 2/15: 100%|██████████| 584/584 [09:17<00:00,  1.05it/s, loss=0.695]


epoch 2 train_loss 0.7791 train_acc 0.5271 val_loss 0.6782 val_acc 0.7110 val_bal 0.7275


train 3/15: 100%|██████████| 584/584 [09:19<00:00,  1.04it/s, loss=0.63]


epoch 3 train_loss 0.7587 train_acc 0.5259 val_loss 0.6950 val_acc 0.6920 val_bal 0.7218


train 4/15: 100%|██████████| 584/584 [09:14<00:00,  1.05it/s, loss=0.615]


epoch 4 train_loss 0.7461 train_acc 0.5314 val_loss 0.6966 val_acc 0.6787 val_bal 0.7161


train 5/15: 100%|██████████| 584/584 [09:17<00:00,  1.05it/s, loss=0.548]


epoch 5 train_loss 0.7131 train_acc 0.5377 val_loss 0.6757 val_acc 0.7292 val_bal 0.7310


train 6/15: 100%|██████████| 584/584 [09:14<00:00,  1.05it/s, loss=1.05]


epoch 6 train_loss 0.7029 train_acc 0.5645 val_loss 0.6668 val_acc 0.7332 val_bal 0.7324


train 7/15: 100%|██████████| 584/584 [09:16<00:00,  1.05it/s, loss=0.62]


epoch 7 train_loss 0.6888 train_acc 0.5661 val_loss 0.6710 val_acc 0.6970 val_bal 0.7222


train 8/15: 100%|██████████| 584/584 [09:14<00:00,  1.05it/s, loss=0.611]


epoch 8 train_loss 0.6874 train_acc 0.5529 val_loss 0.6778 val_acc 0.7280 val_bal 0.7337


train 9/15: 100%|██████████| 584/584 [09:13<00:00,  1.05it/s, loss=0.406]


epoch 9 train_loss 0.6719 train_acc 0.5520 val_loss 0.6753 val_acc 0.7277 val_bal 0.7349


train 10/15:  63%|██████▎   | 369/584 [05:52<03:10,  1.13it/s, loss=0.487]