## Introduction
Here we train our tumor segmentation network. <br />

## Imports:

* Pathlib for easy path handling
* torch for tensor handling
* pytorch lightning for efficient and easy training implementation
* ModelCheckpoint and TensorboardLogger for checkpoint saving and logging
* imgaug for Data Augmentation
* numpy for file loading and array ops
* matplotlib for visualizing some images
* tqdm for progress par when validating the model
* celluloid for easy video generation
* Our dataset and model

In [None]:
#!pip install pytorch_lightning celluloid numpy==1.26.4 scipy==1.11.4 imgaug==0.4.0
#!pip install lightning

In [None]:
from pathlib import Path

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import imgaug.augmenters as iaa
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from celluloid import Camera

In [None]:
from pathlib import Path
import os
from google.colab import drive

if not os.path.ismount('/content/drive'):
    drive.mount('/content/drive')

notebook_dir = Path("/content/drive/MyDrive/Colab Notebooks")  # <-- set to the folder containing your .ipynb
notebook_dir.mkdir(parents=True, exist_ok=True)

In [None]:
from pathlib import Path
import importlib.util, os

# Point to the folder that contains BOTH dataset.py and model.py
notebook_dir = Path("/content/drive/MyDrive/Colab Notebooks")
os.chdir(notebook_dir)

def import_from_file(fname: str, as_name: str):
    p = notebook_dir / fname
    assert p.exists(), f"Not found: {p}"
    spec = importlib.util.spec_from_file_location(as_name, str(p))
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    return mod

# Import your local modules explicitly by path
dataset_local = import_from_file("dataset.py", "dataset_local")
model_local   = import_from_file("model.py",   "model_local")

# Pull the classes you need
LungDataset = dataset_local.LungDataset
UNet        = model_local.UNet

print("Loaded:", LungDataset, UNet)

## Dataset Creation
Here we create the train and validation dataset. <br />
Additionally we define our data augmentation pipeline.
Subsequently the two dataloaders are created

In [None]:
seq = iaa.Sequential([
    iaa.Affine(translate_percent=(0.15),
               scale=(0.85, 1.15), # zoom in or out
               rotate=(-45, 45)#
               ),  # rotate up to 45 degrees
    iaa.ElasticTransformation()  # Elastic Transformations
                ])

In [None]:
# Create the dataset objects
BASE = Path("/content/drive/MyDrive/CIS 5810/LA-NSCLC dataset/selected_150_split")
train_path = BASE / "Preprocessed_for_2D_Unet/train"
val_path = BASE / "Preprocessed_for_2D_Unet/val"
test_path = BASE / "Preprocessed_for_2D_Unet/test"

train_dataset = LungDataset(train_path, seq)
val_dataset = LungDataset(val_path, None)
test_dataset = LungDataset(test_path, None)

print(f"There are {len(train_dataset)} train images, {len(val_dataset)} val images and {len(test_dataset)} test images")

## Oversampling to tackle strong class imbalance
Lung tumors are often very small, thus we need to make sure that our model does not learn a trivial solution which simply outputs 0 for all voxels.<br />
In this notebook we use oversampling to sample slices which contain a tumor more often.

To do so we can use the WeightedRandomSampler provided by pytorch which needs a weight for each sample in the dataset.
Typically we have one weight for each class, which means that we need to calculate two weights, one for slices without tumors and one for slices with a tumor and create list that assigns each sample from the dataset the corresponding weight

To do so, we at first need to create a list containing only the class labels:

In [None]:
target_list = []
for _, label in tqdm(train_dataset):
    # Check if mask contains a tumorous pixel:
    if np.any(label):
        target_list.append(1)
    else:
        target_list.append(0)

Then we can calculate the weight for each class: To do so, we can simply compute the fraction between the classes and then create the weight list

In [None]:
# !pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless thinc
# !pip install -U --force-reinstall --no-cache-dir numpy==1.26.4 scipy==1.11.4
# !pip install -U imgaug==0.4.0
# !pip install -U "opencv-python==4.8.1.78" "opencv-contrib-python==4.8.1.78" "opencv-python-headless==4.8.1.78"
# # (Only if you truly need thinc/spaCy here)
# # pip install "thinc<8.3"


In [None]:
uniques = np.unique(target_list, return_counts=True)
uniques

In [None]:
fraction = uniques[1][0] / uniques[1][1]
fraction

Subsequently we assign the weight 1 to each slice without a tumor and ~9 to each slice with a tumor

In [None]:
weight_list = []
for target in target_list:
    if target == 0:
        weight_list.append(1)
    else:
        weight_list.append(fraction)

In [None]:
weight_list[:50]

Finally we create the sampler which we can pass to the DataLoader. We only use a sampler for the train loader. We dont't want to change the validation data to get a real validation.

In [None]:
sampler = torch.utils.data.sampler.WeightedRandomSampler(weight_list, len(weight_list))

In [None]:
batch_size = 8
num_workers = 4

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, sampler=sampler)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

We can verify that our sampler works by taking a batch from the train loader and count how many labels are larger than zero

In [None]:
verify_sampler = next(iter(train_loader))

In [None]:
(verify_sampler[1][:,0]).sum([1, 2]) > 0  # ~ half the batch size

## Loss

We use the Binary Cross Entropy

## Full Segmentation Model

We now combine everything into the full pytorch lightning model

In [None]:
import torch
import torch.nn as nn
import lightning.pytorch as pl
import numpy as np
import matplotlib.pyplot as plt

class TumorSegmentation(pl.LightningModule):
    def __init__(self, lr: float = 1e-4):
        super().__init__()
        self.model = UNet().float()                 # ensure model params are float32
        self.lr = lr
        self.loss_fn = nn.BCEWithLogitsLoss()       # expects float targets, logits input

    def forward(self, x):
        return self.model(x)                        # logits

    def training_step(self, batch, batch_idx):
        ct, mask = batch
        ct   = ct.float()                           # <— enforce float32
        mask = mask.float()                         # <— enforce float32

        logits = self(ct)                           # (B,1,H,W) logits (float32)
        loss = self.loss_fn(logits, mask)

        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)

        if (batch_idx % 50 == 0) and self.logger is not None:
            self.log_images(ct, logits, mask, "Train")
        return loss

    def validation_step(self, batch, batch_idx):
        ct, mask = batch
        ct   = ct.float()                           # <— enforce float32
        mask = mask.float()                         # <— enforce float32

        logits = self(ct)
        val_loss = self.loss_fn(logits, mask)

        self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True)

        if (batch_idx % 50 == 0) and self.logger is not None:
            self.log_images(ct, logits, mask, "Val")
        return val_loss

    def log_images(self, ct, logits, mask, split_name: str):
        probs = torch.sigmoid(logits)
        pred_bin = (probs > 0.5)

        img = ct[0, 0].detach().cpu().numpy()
        gt  = mask[0, 0].detach().cpu().numpy()
        pr  = pred_bin[0, 0].detach().cpu().numpy()

        fig, axis = plt.subplots(1, 2, figsize=(8, 4))
        axis[0].imshow(img, cmap="bone")
        axis[0].imshow(np.ma.masked_where(gt == 0, gt), alpha=0.6)
        axis[0].set_title("Ground Truth"); axis[0].axis("off")

        axis[1].imshow(img, cmap="bone")
        axis[1].imshow(np.ma.masked_where(pr == 0, pr), alpha=0.6, cmap="autumn")
        axis[1].set_title("Prediction"); axis[1].axis("off")

        self.logger.experiment.add_figure(f"{split_name} Prediction vs Label", fig, self.global_step)
        plt.close(fig)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [None]:
# Instanciate the model
model = TumorSegmentation()

In [None]:
# Create the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='Val Dice',
    save_top_k=30,
    mode='min')

In [None]:
# Create the trainer
# Change the gpus parameter to the number of available gpus in your computer. Use 0 for CPU training

import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath="./ckpts", monitor="val_loss", mode="min", save_top_k=1
)

trainer = pl.Trainer(
    accelerator="gpu", devices=1,            # or accelerator="auto", devices="auto"
    logger=TensorBoardLogger(save_dir="./logs"),
    log_every_n_steps=1,
    callbacks=[checkpoint_callback],
    max_epochs=30,
)

In [None]:
trainer.fit(model, train_loader, val_loader)

## Evaluation:
Evaluate the results

In [None]:
class DiceScore(torch.nn.Module):
    """
    class to compute the Dice Loss
    """
    def __init__(self):
        super().__init__()

    def forward(self, pred, mask):

        #flatten label and prediction tensors
        pred = torch.flatten(pred)
        mask = torch.flatten(mask)

        counter = (pred * mask).sum()  # Counter
        denum = pred.sum() + mask.sum()  # denominator
        dice = (2*counter)/denum

        return dice

In [None]:
from pathlib import Path
import torch

def resolve_ckpt_path(ckpt_cb) -> Path:
    # 1) Prefer the "best" checkpoint
    p = getattr(ckpt_cb, "best_model_path", "") or ""
    if p and Path(p).is_file():
        return Path(p)

    # 2) Fall back to "last"
    p = getattr(ckpt_cb, "last_model_path", "") or ""
    if p and Path(p).is_file():
        return Path(p)

    # 3) Search the callback dirpath for any .ckpt (pick most recent)
    root = Path(getattr(ckpt_cb, "dirpath", "."))  # where ModelCheckpoint saves
    cands = list(root.rglob("*.ckpt")) if root.exists() else []
    if not cands:
        # also search Lightning default logs/checkpoints tree
        cands = list(Path("lightning_logs").rglob("*.ckpt"))
    if not cands:
        raise FileNotFoundError(
            "No checkpoint found. Ensure ModelCheckpoint has dirpath/monitor or use save_last=True."
        )
    return max(cands, key=lambda x: x.stat().st_mtime)

# --- use it ---
ckpt_path = resolve_ckpt_path(checkpoint_callback)
print(f"Loading checkpoint: {ckpt_path}")

model = TumorSegmentation.load_from_checkpoint(
    ckpt_path,
    map_location="cuda" if torch.cuda.is_available() else "cpu",
    # strict=True  # set to False if your model signature changed since saving
)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
preds = []
labels = []

for slice, label in tqdm(val_dataset):
    slice = torch.tensor(slice).float().to(device).unsqueeze(0)
    with torch.no_grad():
        pred = torch.sigmoid(model(slice))
    preds.append(pred.cpu().numpy())
    labels.append(label)

preds = np.array(preds)
labels = np.array(labels)

Compute overall Dice Score on the Validation Set

In [None]:
dice_score = DiceScore()(torch.from_numpy(preds), torch.from_numpy(labels).unsqueeze(0).float())
print(f"The Val Dice Score is: {dice_score}")

In [None]:
preds = []
labels = []

for slice, label in tqdm(test_dataset):
    slice = torch.tensor(slice).float().to(device).unsqueeze(0)
    with torch.no_grad():
        pred = torch.sigmoid(model(slice))
    preds.append(pred.cpu().numpy())
    labels.append(label)

preds = np.array(preds)
labels = np.array(labels)

Compute overall Dice Score on the Test Set

In [None]:
dice_score = DiceScore()(torch.from_numpy(preds), torch.from_numpy(labels).unsqueeze(0).float())
print(f"The Val Dice Score is: {dice_score}")

## Visualization

In [None]:
THRESHOLD = 0.5

In [None]:
import nibabel as nib
import cv2

In [None]:
BASE = Path("/content/drive/MyDrive/CIS 5810/LA-NSCLC dataset/selected_150_split")

subject = Path(BASE / "test" / "image" / "Lung_054_0000.nii.gz")
ct = nib.load(subject).get_fdata() / 3071  # standardize
ct = ct[:,:,30:]  # crop

In [None]:
segmentation = []
label = []
scan = []

for i in range(ct.shape[-1]):
    slice = ct[:,:,i]
    slice = cv2.resize(slice, (256, 256))
    slice = torch.tensor(slice)
    scan.append(slice)
    slice = slice.unsqueeze(0).unsqueeze(0).float().to(device)

    with torch.no_grad():
        pred = model(slice)[0][0].cpu()
    pred = pred > THRESHOLD
    segmentation.append(pred)
    label.append(segmentation)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML, display

def reorient_2d(a: np.ndarray) -> np.ndarray:
    # flip vertically, then rotate -90° (clockwise)
    return np.rot90(np.flipud(a), k=3)

# ensure arrays
scan = np.asarray(scan)
seg  = np.asarray(segmentation)
N = min(len(scan), len(seg))

# apply the same reorientation to every slice
scan_r = np.array([reorient_2d(s) for s in scan])
seg_r  = np.array([reorient_2d(s) for s in seg])

fig, ax = plt.subplots(figsize=(5,5))
im = ax.imshow(scan_r[0], cmap="bone", interpolation="nearest")
ov = ax.imshow(np.ma.masked_where(seg_r[0]==0, seg_r[0]),
               cmap="autumn", alpha=0.5, interpolation="nearest")
ax.axis("off")

def update(i):
    im.set_data(scan_r[i])
    ov.set_data(np.ma.masked_where(seg_r[i]==0, seg_r[i]))
    return im, ov

anim = animation.FuncAnimation(fig, update, frames=range(0, N, 2), interval=80, blit=False)
display(HTML(anim.to_jshtml()))
plt.close(fig)