# Preprocessing

## 1. CONFIG

In [1]:
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from skimage.transform import resize
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings("ignore")

BASE_PATH = 'archive'
SEG_FOLDER = os.path.join(BASE_PATH, 'segmentations')
IMAGE_SIZE = (224, 224)
MAX_HEALTHY_PER_PATIENT = 200
SEED = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.manual_seed(SEED)
np.random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Using device: cuda


## 2. UTILS

In [2]:
def load_nifti(path):
    return nib.load(path).get_fdata()

def normalize(volume):
    volume = np.clip(volume, -1000, 400)
    volume = (volume - volume.min()) / (volume.max() - volume.min() + 1e-8)
    return volume.astype(np.float32)

def preprocess_volume_and_mask(vol_path, seg_path):
    vol = normalize(load_nifti(vol_path))
    mask = load_nifti(seg_path)
    slices = []
    for z in range(vol.shape[2]):
        if np.any(mask[:, :, z]):  # tumor present
            img = resize(vol[:, :, z], IMAGE_SIZE, preserve_range=True)
            msk = resize(mask[:, :, z], IMAGE_SIZE, preserve_range=True)
            msk = (msk > 0.5).astype(np.uint8)
            slices.append((img, msk))
    return slices

## 3. LOAD DATA WITH PATIENT TRACKING


In [3]:
print("Loading tumor slices with patient ID tracking...")
volume_folders = [f for f in os.listdir(BASE_PATH) if f.startswith("volume_pt")]
all_slices_with_patient = []  # (img, mask, patient_id)

for patient_folder in volume_folders:
    patient_id = patient_folder
    patient_path = os.path.join(BASE_PATH, patient_folder)
    for file in os.listdir(patient_path):
        if file.startswith("volume") and file.endswith(".nii"):
            vol_path = os.path.join(patient_path, file)
            seg_path = os.path.join(SEG_FOLDER, file.replace("volume", "segmentation"))
            if os.path.exists(seg_path):
                slices = preprocess_volume_and_mask(vol_path, seg_path)
                for img, mask in slices:
                    all_slices_with_patient.append((img, mask, patient_id))
                print(f"  {patient_id}/{file}: {len(slices)} tumor slices")

print(f"\nTotal tumor slices: {len(all_slices_with_patient)} from {len(volume_folders)} patients")


Loading tumor slices with patient ID tracking...
  volume_pt1/volume-0.nii: 29 tumor slices
  volume_pt1/volume-1.nii: 29 tumor slices
  volume_pt1/volume-10.nii: 181 tumor slices
  volume_pt1/volume-2.nii: 139 tumor slices
  volume_pt1/volume-3.nii: 169 tumor slices
  volume_pt1/volume-4.nii: 250 tumor slices
  volume_pt1/volume-5.nii: 176 tumor slices
  volume_pt1/volume-6.nii: 186 tumor slices
  volume_pt1/volume-7.nii: 177 tumor slices
  volume_pt1/volume-8.nii: 179 tumor slices
  volume_pt1/volume-9.nii: 173 tumor slices
  volume_pt2/volume-11.nii: 167 tumor slices
  volume_pt2/volume-12.nii: 189 tumor slices
  volume_pt2/volume-13.nii: 142 tumor slices
  volume_pt2/volume-14.nii: 139 tumor slices
  volume_pt2/volume-15.nii: 133 tumor slices
  volume_pt2/volume-16.nii: 187 tumor slices
  volume_pt2/volume-17.nii: 198 tumor slices
  volume_pt2/volume-18.nii: 189 tumor slices
  volume_pt2/volume-19.nii: 188 tumor slices
  volume_pt2/volume-20.nii: 194 tumor slices
  volume_pt3/volum

## 4. BUILD CLASSIFICATION DATASET


In [4]:
labeled_data = []
# Tumor slices → label 1
for img, mask, pid in all_slices_with_patient:
    labeled_data.append((img, 1, pid))

# Healthy slices → label 0
print("Collecting healthy slices (max 200 per patient)...")
healthy_per_patient = {}
for patient_folder in volume_folders:
    patient_id = patient_folder
    patient_path = os.path.join(BASE_PATH, patient_folder)
    for file in os.listdir(patient_path):
        if file.startswith("volume") and file.endswith(".nii"):
            vol_path = os.path.join(patient_path, file)
            seg_path = os.path.join(SEG_FOLDER, file.replace("volume", "segmentation"))
            if not os.path.exists(seg_path):
                continue
            vol = normalize(load_nifti(vol_path))
            mask = load_nifti(seg_path)
            count = 0
            for z in range(vol.shape[2]):
                if count >= MAX_HEALTHY_PER_PATIENT:
                    break
                if np.sum(mask[:, :, z]) == 0:
                    img = resize(vol[:, :, z], IMAGE_SIZE, preserve_range=True)
                    healthy_per_patient.setdefault(patient_id, []).append(img)
                    count += 1

for pid, imgs in healthy_per_patient.items():
    for img in imgs:
        labeled_data.append((img, 0, pid))

print(f"Healthy slices added: {len(labeled_data) - len(all_slices_with_patient)}")
print(f"Total classification samples: {len(labeled_data)}")

Collecting healthy slices (max 200 per patient)...
Healthy slices added: 19920
Total classification samples: 39083


## 5. PATIENT-LEVEL SPLIT (ZERO LEAKAGE)


In [5]:
all_patients = list({pid for _, _, pid in labeled_data})
train_patients, val_patients = train_test_split(all_patients, test_size=0.2, random_state=SEED)

# Classification split
train_cls = [(img, label) for img, label, pid in labeled_data if pid in train_patients]
val_cls   = [(img, label) for img, label, pid in labeled_data if pid in val_patients]

# Segmentation split (only tumor slices)
train_seg = [(img, mask) for img, mask, pid in all_slices_with_patient if pid in train_patients]
val_seg   = [(img, mask) for img, mask, pid in all_slices_with_patient if pid in val_patients]

print(f"\nSplit Summary:")
print(f"  Train patients: {len(train_patients)} → Cls: {len(train_cls)} | Seg: {len(train_seg)}")
print(f"  Val   patients: {len(val_patients)} → Cls: {len(val_cls)} | Seg: {len(val_seg)}")



Split Summary:
  Train patients: 5 → Cls: 25805 | Seg: 12583
  Val   patients: 2 → Cls: 13278 | Seg: 6580


## 6. DATASETS


In [6]:
class TumorClassificationDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        img, label = self.data[idx]
        return torch.tensor(img, dtype=torch.float32).unsqueeze(0), torch.tensor(label, dtype=torch.float32)

class SegmentationDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        img, mask = self.data[idx]
        return torch.tensor(img, dtype=torch.float32).unsqueeze(0), torch.tensor(mask, dtype=torch.float32).unsqueeze(0)

# Loaders
cls_train_loader = DataLoader(TumorClassificationDataset(train_cls), batch_size=32, shuffle=True,  num_workers=0, pin_memory=True)
cls_val_loader   = DataLoader(TumorClassificationDataset(val_cls),   batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
seg_train_loader = DataLoader(SegmentationDataset(train_seg), batch_size=8,  shuffle=True,  num_workers=0, pin_memory=True)
seg_val_loader   = DataLoader(SegmentationDataset(val_seg),   batch_size=8,  shuffle=False, num_workers=0, pin_memory=True)


# Modeling

## 1. CLASSIFICATION AUGMENTATION (Fixed: Add Normalize!)


In [7]:
cls_train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
    A.Rotate(limit=20, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.GaussNoise(var_limit=(5, 30), p=0.3),
    A.ShiftScaleRotate(shift_limit=0.06, scale_limit=0.1, rotate_limit=25, p=0.6),
    A.CoarseDropout(max_holes=8, max_height=16, max_width=16, p=0.4),
    A.Normalize(mean=0.5, std=0.5),        # ← CRITICAL: was missing!
    ToTensorV2(),
])

cls_val_transform = A.Compose([
    A.Normalize(mean=0.5, std=0.5),        # ← Also needed here!
    ToTensorV2()
])


## 2. CLASSIFICATION MODEL (Fixed: logits output + BN)


In [8]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv3 = nn.Sequential(                    # ← Added extra layer
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 28 * 28, 256),
            nn.ReLU(),
            nn.Dropout(0.6),
            nn.Linear(256, 1)                           # ← Raw logits (no sigmoid)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        return self.fc(x).squeeze(1)                    # ← Return logits

# Use BCEWithLogitsLoss → more stable
cls_model = SimpleCNN().to(device)
cls_criterion = nn.BCEWithLogitsLoss()              # ← Changed!
cls_optimizer = torch.optim.AdamW(cls_model.parameters(), lr=1e-4, weight_decay=1e-4)

## 3. SEGMENTATION MODEL (Fixed: pretrained + no activation)


In [9]:
seg_model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",        # ← Use pretrained! Huge boost
    in_channels=1,
    classes=1,
    activation=None,                   # ← Critical: output logits
    decoder_attention_type="scse"      # ← Bonus: better attention
).to(device)

dice_loss = smp.losses.DiceLoss(mode='binary', from_logits=True)
bce_loss = nn.BCEWithLogitsLoss()

def seg_criterion(pred, target):
    return 0.7 * dice_loss(pred, target) + 0.3 * bce_loss(pred, target)  # ← Weighted

seg_optimizer = torch.optim.AdamW(seg_model.parameters(), lr=1e-4, weight_decay=1e-5)

# small helper: Dice score implementation (fix NameError: dice not defined)
def dice(pred, target, eps=1e-6, average='micro'):
    """
    Compute Dice coefficient between pred and target.
    Supports per-sample batches or single 2D masks.
    pred: Tensor, binary (0/1) or probabilities (will be cast to float)
    target: Tensor, binary (0/1)
    average: 'micro' returns a single scalar; 'none' returns per-sample scores.
    """
    pred = pred.float()
    target = target.float()

    # single mask [H, W]
    if pred.dim() == 2 and target.dim() == 2:
        inter = (pred * target).sum()
        denom = pred.sum() + target.sum()
        return (2.0 * inter + eps) / (denom + eps)

    # assume batch dimension first: [B, ...]
    pred_flat = pred.view(pred.size(0), -1)
    target_flat = target.view(target.size(0), -1)
    inter = (pred_flat * target_flat).sum(dim=1)
    denom = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
    scores = (2.0 * inter + eps) / (denom + eps)

    if average == 'micro':
        return scores.mean()
    elif average == 'none':
        return scores
    else:
        return scores.mean()

## 4. TRAINING FUNCTIONS (Fixed accuracy calculation)


In [10]:
def train_cls_epoch():
    cls_model.train()
    total_loss = correct = total = 0
    for x, y in cls_train_loader:
        x, y = x.to(device), y.to(device)
        logits = cls_model(x)
        loss = cls_criterion(logits, y)
        
        cls_optimizer.zero_grad()
        loss.backward()
        cls_optimizer.step()
        
        total_loss += loss.item()
        pred = torch.sigmoid(logits) > 0.5
        correct += (pred == y).sum().item()
        total += y.size(0)
    
    return total_loss / len(cls_train_loader), correct / total

def val_cls_epoch():
    cls_model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in cls_val_loader:
            x, y = x.to(device), y.to(device)
            logits = cls_model(x)
            pred = torch.sigmoid(logits) > 0.5
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

def train_seg_epoch():
    seg_model.train()
    total_loss = 0
    for x, y in seg_train_loader:
        x, y = x.to(device), y.to(device)
        pred = seg_model(x)
        loss = seg_criterion(pred, y)
        
        seg_optimizer.zero_grad()
        loss.backward()
        seg_optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(seg_train_loader)

def val_seg_epoch():
    seg_model.eval()
    total_loss = 0
    dice_scores = []
    with torch.no_grad():
        for x, y in seg_val_loader:
            x, y = x.to(device), y.to(device)
            pred = seg_model(x)
            loss = seg_criterion(pred, y)
            total_loss += loss.item()
            
            pred_bin = (torch.sigmoid(pred) > 0.5).float()
            # ensure target is float/binary for dice()
            dice_scores.append(dice(pred_bin, y.float(), average='micro').item())
    
    return total_loss / len(seg_val_loader), np.mean(dice_scores)

## 5. TRAIN LOOP + EARLY STOPPING + REAL DICE


In [11]:
best_val_acc = 0.0        # best classification accuracy
best_seg_loss = float('inf')  # best segmentation loss
patience = 15
counter = 0

print("\n" + "="*70)
print("FINAL TRAINING – Patient-Independent, No Leakage, Real Metrics")
print("="*70)

for epoch in range(1, 101):
    cls_train_loss, cls_train_acc = train_cls_epoch()
    cls_val_acc = val_cls_epoch()
    seg_train_loss = train_seg_epoch()
    seg_val_loss, seg_val_dice = val_seg_epoch()
    
    print(f"Epoch {epoch:2d} | "
          f"Cls: {cls_train_acc:.4f} → {cls_val_acc:.4f} | "
          f"Seg: {seg_train_loss:.4f} → {seg_val_loss:.4f} | "
          f"Dice: {seg_val_dice:.4f}")
    
    improved_cls = False
    improved_seg = False
    
    # Save best classification model
    if cls_val_acc > best_val_acc:
        best_val_acc = cls_val_acc
        torch.save({
            'cls_model': cls_model.state_dict(),
            'cls_acc': cls_val_acc,
            'epoch': epoch
        }, "best_cls_model.pth")
        print(f"  New best classification! Val Acc: {best_val_acc:.4f}")
        improved_cls = True
    
    # Save best segmentation model
    if seg_val_loss < best_seg_loss:
        best_seg_loss = seg_val_loss
        torch.save({
            'seg_model': seg_model.state_dict(),
            'seg_loss': seg_val_loss,
            'seg_dice': seg_val_dice,
            'epoch': epoch
        }, "best_seg_model.pth")
        print(f"  New best segmentation! Val Loss: {best_seg_loss:.4f} | Dice: {seg_val_dice:.4f}")
        improved_seg = True
    
    # Patience logic: increment only if neither improved
    if not improved_cls and not improved_seg:
        counter += 1
        print(f"  No improvement. Counter = {counter}/{patience}")
    else:
        counter = 0
    
    if counter >= patience:
        print("Early stopping!")
        break

print(f"\nTraining complete!")
print(f"Best classification accuracy: {best_val_acc:.4f}")
print(f"Best segmentation loss: {best_seg_loss:.4f}")
print("Models saved as 'best_cls_model.pth' and 'best_seg_model.pth'")



FINAL TRAINING – Patient-Independent, No Leakage, Real Metrics
Epoch  1 | Cls: 0.9470 → 0.9099 | Seg: 0.1671 → 0.0886 | Dice: 0.8979
  New best classification! Val Acc: 0.9099
  New best segmentation! Val Loss: 0.0886 | Dice: 0.8979
Epoch  2 | Cls: 0.9781 → 0.9243 | Seg: 0.0314 → 0.0674 | Dice: 0.9092
  New best classification! Val Acc: 0.9243
  New best segmentation! Val Loss: 0.0674 | Dice: 0.9092
Epoch  3 | Cls: 0.9826 → 0.9354 | Seg: 0.0234 → 0.0659 | Dice: 0.9064
  New best classification! Val Acc: 0.9354
  New best segmentation! Val Loss: 0.0659 | Dice: 0.9064
Epoch  4 | Cls: 0.9859 → 0.9350 | Seg: 0.0198 → 0.0783 | Dice: 0.8917
  No improvement. Counter = 1/15
Epoch  5 | Cls: 0.9875 → 0.9302 | Seg: 0.0178 → 0.0695 | Dice: 0.9048
  No improvement. Counter = 2/15
Epoch  6 | Cls: 0.9892 → 0.9363 | Seg: 0.0181 → 0.0684 | Dice: 0.9055
  New best classification! Val Acc: 0.9363
Epoch  7 | Cls: 0.9906 → 0.9382 | Seg: 0.0151 → 0.0711 | Dice: 0.9031
  New best classification! Val Acc: 0