In [None]:
import os, re, glob, random, time
from dataclasses import dataclass
from typing import List

import numpy as np
import cv2
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import nibabel as nib

In [3]:
import kagglehub

# Download dataset
path = kagglehub.dataset_download("javariatahir/litstrain-val")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/litstrain-val


In [4]:
# -----------------------
# 1) Config
# -----------------------
@dataclass
class CFG:
    DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"

    BASE_PATH: str = "/kaggle/input/litstrain-val/LiTS(train_test)"
    TRAIN_CT_DIR: str = "train_CT"
    TRAIN_MSK_DIR: str = "train_mask"

    # LiverTumor segmentation cache
    CACHE_DIR: str = "/kaggle/working/lits_cache_LiverTumor256"

    IMG_SIZE: int = 256

    VAL_SPLIT: float = 0.2
    SEED: int = 42
    HU_MIN: int = -100
    HU_MAX: int = 400

    VAL_SPLIT: float = 0.2
    SEED: int = 42

    KEEP_NEGATIVE_PROB: float = 0.05
    ONLY_TUMOR_SLICES: bool = True

    BATCH_SIZE: int = 16
    NUM_WORKERS: int = 2
    PIN_MEMORY: bool = True

cfg = CFG()
random.seed(cfg.SEED)
np.random.seed(cfg.SEED)
torch.manual_seed(cfg.SEED)

os.makedirs(cfg.CACHE_DIR, exist_ok=True)

In [5]:
# -----------------------
# 2) Pairing helpers (same approach)
# -----------------------
def extract_id(path: str) -> int:
    m = re.search(r"-(\d+)\.nii", os.path.basename(path))
    if not m:
        raise ValueError(f"Cannot parse id from {path}")
    return int(m.group(1))

train_ct_paths = sorted(glob.glob(os.path.join(cfg.BASE_PATH, cfg.TRAIN_CT_DIR, "volume-*.nii")))
train_msk_paths = sorted(glob.glob(os.path.join(cfg.BASE_PATH, cfg.TRAIN_MSK_DIR, "segmentation-*.nii")))

ct_map = {extract_id(p): p for p in train_ct_paths}
msk_map = {extract_id(p): p for p in train_msk_paths}
ids = sorted(list(set(ct_map.keys()) & set(msk_map.keys())))
assert len(ids) > 0, "No matched CT/mask pairs found."

print(f"Matched volumes: {len(ids)}")

train_ids, val_ids = train_test_split(ids, test_size=cfg.VAL_SPLIT, random_state=cfg.SEED)
train_ids = sorted(train_ids)
val_ids = sorted(val_ids)
print(f"Train volumes: {len(train_ids)} | Val volumes: {len(val_ids)}")

Matched volumes: 111
Train volumes: 88 | Val volumes: 23


In [6]:
# -----------------------
# 3) Preprocess helpers
# -----------------------
def hu_window_and_scale(img2d: np.ndarray, hu_min: int, hu_max: int) -> np.ndarray:
    x = np.clip(img2d, hu_min, hu_max).astype(np.float32)
    x = (x - hu_min) / float(hu_max - hu_min)
    return x

def resize2d(img2d: np.ndarray, size: int, is_mask: bool) -> np.ndarray:
    interp = cv2.INTER_NEAREST if is_mask else cv2.INTER_AREA
    return cv2.resize(img2d, (size, size), interpolation=interp)

def cache_paths(split: str):
    img_dir = os.path.join(cfg.CACHE_DIR, split, "images")
    msk_dir = os.path.join(cfg.CACHE_DIR, split, "masks")
    os.makedirs(img_dir, exist_ok=True)
    os.makedirs(msk_dir, exist_ok=True)
    return img_dir, msk_dir

In [7]:
def create_multi_channel_image(ct_slice: np.ndarray) -> np.ndarray:
    # ensure float32
    ct_slice = ct_slice.astype(np.float32, copy=False)

    # Channel 1: liver window (40..400)
    ch1 = np.clip(ct_slice, 40, 400)
    ch1 = (ch1 - 40) / (400 - 40)

    # Channel 2: soft tissue window (-100..200)
    ch2 = np.clip(ct_slice, -100, 200)
    ch2 = (ch2 + 100) / (200 + 100)

    # Robust normalize for gradient channel
    p1, p99 = np.percentile(ct_slice, (1, 99))
    base = np.clip(ct_slice, p1, p99).astype(np.float32, copy=False)
    base = (base - base.min()) / (base.max() - base.min() + 1e-8)

    # IMPORTANT: make contiguous float32 for OpenCV
    base = np.ascontiguousarray(base, dtype=np.float32)

    sobelx = cv2.Sobel(base, ddepth=cv2.CV_32F, dx=1, dy=0, ksize=3)
    sobely = cv2.Sobel(base, ddepth=cv2.CV_32F, dx=0, dy=1, ksize=3)
    grad = cv2.magnitude(sobelx, sobely)
    grad = grad / (grad.max() + 1e-8)

    x = np.stack([ch1, ch2, grad], axis=-1).astype(np.float32)  # [H,W,3]
    x = np.clip(x, 0, 1)
    return x


In [8]:
# -----------------------
# 5) Fast dataset
# -----------------------
class NPYSliceDataset(Dataset):
    def __init__(self, cache_root: str, split: str, augment: bool = False, show_scan_progress: bool = True):
        self.img_dir = os.path.join(cache_root, split, "images")
        self.msk_dir = os.path.join(cache_root, split, "masks")

        # tqdm during scan so you see something happening even on slow FS
        img_glob = os.path.join(self.img_dir, "*.npy")
        img_paths = glob.glob(img_glob)

        if show_scan_progress:
            # just to show progress, we iterate once (cost is small)
            self.img_paths = []
            for p in tqdm(sorted(img_paths), desc=f"Indexing {split} .npy slices"):
                self.img_paths.append(p)
        else:
            self.img_paths = sorted(img_paths)

        assert len(self.img_paths) > 0, f"No cached images found in {self.img_dir}"
        self.augment = augment

    def __len__(self):
        return len(self.img_paths)

    def _augment_pair(self, img, msk):
        if random.random() < 0.5:
            img = np.fliplr(img).copy()
            msk = np.fliplr(msk).copy()
        if random.random() < 0.5:
            img = np.flipud(img).copy()
            msk = np.flipud(msk).copy()
        return img, msk

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        stem = os.path.basename(img_path)
        msk_path = os.path.join(self.msk_dir, stem)
        img = np.load(img_path).astype(np.float32)  # [H,W,3]
        msk = np.load(msk_path).astype(np.float32)  # [H,W]
        
        if self.augment:
            img, msk = self._augment_pair(img, msk)
        
        img_t = torch.from_numpy(img).permute(2,0,1)  # [3,H,W]
        msk_t = torch.from_numpy(msk).unsqueeze(0)    # [1,H,W]
        return img_t, msk_t

In [9]:
# ============================================================
# 6) Base U-Net
# ============================================================
class DoubleConvNoBN(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=True),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNetBase(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base=64):
        super().__init__()
        self.enc1 = DoubleConvNoBN(in_channels, base)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = DoubleConvNoBN(base, base*2)
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = DoubleConvNoBN(base*2, base*4)
        self.pool3 = nn.MaxPool2d(2)

        self.enc4 = DoubleConvNoBN(base*4, base*8)
        self.pool4 = nn.MaxPool2d(2)

        self.bot = DoubleConvNoBN(base*8, base*16)

        self.up4 = nn.ConvTranspose2d(base*16, base*8, 2, stride=2)
        self.dec4 = DoubleConvNoBN(base*16, base*8)

        self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3 = DoubleConvNoBN(base*8, base*4)

        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = DoubleConvNoBN(base*4, base*2)

        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec1 = DoubleConvNoBN(base*2, base)

        self.outc = nn.Conv2d(base, out_channels, 1)

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool1(x1))
        x3 = self.enc3(self.pool2(x2))
        x4 = self.enc4(self.pool3(x3))
        xb = self.bot(self.pool4(x4))

        u4 = self.up4(xb)
        d4 = self.dec4(torch.cat([u4, x4], dim=1))

        u3 = self.up3(d4)
        d3 = self.dec3(torch.cat([u3, x3], dim=1))

        u2 = self.up2(d3)
        d2 = self.dec2(torch.cat([u2, x2], dim=1))

        u1 = self.up1(d2)
        d1 = self.dec1(torch.cat([u1, x1], dim=1))

        return self.outc(d1)

In [10]:
# ============================================================
# 7) Loss + Metrics
# ============================================================
def dice_loss_from_logits(logits, y, eps=1e-6):
    p = torch.sigmoid(logits)
    num = 2*(p*y).sum(dim=(2,3)) + eps
    den = (p+y).sum(dim=(2,3)) + eps
    return 1 - (num/den).mean()

class BCEDice(nn.Module):
    def __init__(self, pos_weight=1.0, bce_weight=0.5):
        super().__init__()
        self.register_buffer("pw", torch.tensor([pos_weight], dtype=torch.float32))
        self.bce_weight = bce_weight
    def forward(self, logits, y):
        bce = F.binary_cross_entropy_with_logits(logits, y, pos_weight=self.pw)
        d   = dice_loss_from_logits(logits, y)
        return self.bce_weight*bce + (1-self.bce_weight)*d

@torch.no_grad()
def estimate_pos_weight(loader, device, max_batches=40):
    pos = 0.0
    neg = 0.0
    for i, (x, y) in enumerate(loader):
        if i >= max_batches: break
        y = y.to(device)
        pos += y.sum().item()
        neg += (1 - y).sum().item()
    return 1.0 if pos < 1 else float(neg / pos)

@torch.no_grad()
def metrics_from_logits(logits, y, thr=0.5, eps=1e-8):
    p = (torch.sigmoid(logits) > thr).float()
    tp = (p*y).sum().item()
    tn = ((1-p)*(1-y)).sum().item()
    fp = (p*(1-y)).sum().item()
    fn = ((1-p)*y).sum().item()
    dice = (2*tp)/(2*tp+fp+fn+eps)
    iou  = tp/(tp+fp+fn+eps)
    acc  = (tp+tn)/(tp+tn+fp+fn+eps)
    sens = tp/(tp+fn+eps)
    spec = tn/(tn+fp+eps)
    return {"dice":dice,"iou":iou,"acc":acc,"sens":sens,"spec":spec}

@torch.no_grad()
def evaluate(model, loader, loss_fn, device):
    model.eval()
    agg = {"loss":0.0,"dice":0.0,"iou":0.0,"acc":0.0,"sens":0.0,"spec":0.0}
    n = 0
    loop = tqdm(loader, desc="Validation", leave=False)
    for x, y in loop:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        loss = loss_fn(logits, y).item()
        m = metrics_from_logits(logits, y)
        bs = x.size(0)
        agg["loss"] += loss*bs
        for k in m: agg[k] += m[k]*bs
        n += bs
        loop.set_postfix(loss=f"{loss:.4f}", dice=f"{m['dice']:.3f}", acc=f"{m['acc']:.3f}")
    for k in agg: agg[k] /= max(n,1)
    return agg

def train_one_epoch(model, loader, opt, loss_fn, device):
    model.train()
    running = 0.0
    n = 0
    loop = tqdm(loader, desc="Training", leave=False)
    for x, y in loop:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        opt.zero_grad(set_to_none=True)
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        opt.step()

        bs = x.size(0)
        running += loss.item()*bs
        n += bs

        # show live loss + accuracy/dice for the current batch
        m = metrics_from_logits(logits.detach(), y.detach())
        loop.set_postfix(loss=f"{loss.item():.4f}", dice=f"{m['dice']:.3f}", acc=f"{m['acc']:.3f}")

    return running/max(n,1)

# Liver Tumor Segmentation

## Visualization Helpers

In [11]:
import pandas as pd
import matplotlib.pyplot as plt

class HistoryLogger:
    def __init__(self, tag: str, out_dir="/kaggle/working"):
        self.tag = tag
        self.out_dir = out_dir
        self.rows = []

    def add(self, epoch: int, train_loss: float, val_dict: dict, lr: float, sec: float):
        row = {
            "epoch": epoch,
            "train_loss": float(train_loss),
            "val_loss": float(val_dict["loss"]),
            "dice": float(val_dict.get("dice", np.nan)),
            "iou": float(val_dict.get("iou", np.nan)),
            "acc": float(val_dict.get("acc", np.nan)),
            "sens": float(val_dict.get("sens", np.nan)),
            "spec": float(val_dict.get("spec", np.nan)),
            "lr": float(lr),
            "sec": float(sec),
        }
        self.rows.append(row)

    def to_df(self):
        return pd.DataFrame(self.rows)

    def save_csv(self):
        df = self.to_df()
        path = f"{self.out_dir}/history_{self.tag}.csv"
        df.to_csv(path, index=False)
        print(f"saved {path}")
        return path

    def plot(self):
        df = self.to_df()
        if len(df) == 0:
            print("No history to plot.")
            return

        # Loss
        plt.figure()
        plt.plot(df["epoch"], df["train_loss"], label="train_loss")
        plt.plot(df["epoch"], df["val_loss"], label="val_loss")
        plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title(f"{self.tag} Loss"); plt.legend()
        plt.grid(True, alpha=0.3)
        loss_path = f"{self.out_dir}/plot_{self.tag}_loss.png"
        plt.savefig(loss_path, dpi=150, bbox_inches="tight")
        plt.show()
        print(f"saved {loss_path}")

        # Dice/IoU
        if "dice" in df.columns:
            plt.figure()
            plt.plot(df["epoch"], df["dice"], label="dice")
            if "iou" in df.columns:
                plt.plot(df["epoch"], df["iou"], label="iou")
            plt.xlabel("Epoch"); plt.ylabel("Score"); plt.title(f"{self.tag} Dice/IoU"); plt.legend()
            plt.grid(True, alpha=0.3)
            met_path = f"{self.out_dir}/plot_{self.tag}_metrics.png"
            plt.savefig(met_path, dpi=150, bbox_inches="tight")
            plt.show()
            print(f"saved {met_path}")

        # Accuracy/Sens/Spec
        if "acc" in df.columns:
            plt.figure()
            plt.plot(df["epoch"], df["acc"], label="acc")
            if "sens" in df.columns:
                plt.plot(df["epoch"], df["sens"], label="sens")
            if "spec" in df.columns:
                plt.plot(df["epoch"], df["spec"], label="spec")
            plt.xlabel("Epoch"); plt.ylabel("Score"); plt.title(f"{self.tag} Acc/Sens/Spec"); plt.legend()
            plt.grid(True, alpha=0.3)
            cls_path = f"{self.out_dir}/plot_{self.tag}_clsmetrics.png"
            plt.savefig(cls_path, dpi=150, bbox_inches="tight")
            plt.show()
            print(f"saved {cls_path}")

In [12]:
import random

@torch.no_grad()
def visualize_predictions(model, loader, device, n=6, thr=0.5, title="preds"):
    model.eval()
    batch = next(iter(loader))
    x, y = batch
    x = x.to(device)
    y = y.to(device)

    logits = model(x)
    prob = torch.sigmoid(logits)
    pred = (prob > thr).float()

    n = min(n, x.shape[0])
    idxs = random.sample(range(x.shape[0]), n)

    plt.figure(figsize=(12, 2*n))
    for i, idx in enumerate(idxs):
        # show first channel for visualization (liver window)
        img = x[idx, 0].detach().cpu().numpy()
        gt  = y[idx, 0].detach().cpu().numpy()
        pr  = pred[idx, 0].detach().cpu().numpy()

        # overlay prediction (red) + GT (green)
        overlay = np.stack([img, img, img], axis=-1)
        overlay[..., 0] = np.clip(overlay[..., 0] + 0.6*pr, 0, 1)
        overlay[..., 1] = np.clip(overlay[..., 1] + 0.6*gt, 0, 1)

        ax = plt.subplot(n, 3, 3*i + 1)
        ax.imshow(img, cmap="gray"); ax.set_title("image"); ax.axis("off")
        ax = plt.subplot(n, 3, 3*i + 2)
        ax.imshow(gt, cmap="gray"); ax.set_title("gt"); ax.axis("off")
        ax = plt.subplot(n, 3, 3*i + 3)
        ax.imshow(overlay); ax.set_title("overlay (pred red, gt green)"); ax.axis("off")

    plt.suptitle(title)
    plt.tight_layout()
    out_path = f"/kaggle/working/{title}.png"
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.show()
    print(f"saved {out_path}")

## Making Cache

In [13]:
def build_cache_LiverTumor_for_ids(volume_ids, split: str):
    img_dir, msk_dir = cache_paths(split)
    os.makedirs(img_dir, exist_ok=True)
    os.makedirs(msk_dir, exist_ok=True)

    print(f"[{split}] Building LiverTumor cache ...")
    kept_pos = 0
    kept_neg = 0

    # Use neg probability only for training
    keep_neg_prob = cfg.KEEP_NEGATIVE_PROB if split == "train" else 0.0

    for vid in tqdm(volume_ids, desc=f"Building {split} LiverTumor cache (volumes)"):
        ct = nib.load(ct_map[vid]).get_fdata(dtype=np.float32)     # (H,W,Z)
        msk = nib.load(msk_map[vid]).get_fdata(dtype=np.float32)   # (H,W,Z)
        Z = ct.shape[2]

        for z in range(Z):
            m = msk[:, :, z]

            # LiverTumor foreground: any label > 0
            fg = (m > 0).astype(np.float32)

            is_positive = fg.sum() >= 20  # same threshold idea as you used

            # Decide keep/skip
            if not is_positive:
                # negative slice (no liver/tumor)
                if random.random() > keep_neg_prob:
                    continue

            stem = f"id{vid:03d}_z{z:03d}"

            img = ct[:, :, z]
            img3 = create_multi_channel_image(img)  # [H,W,3] in [0,1]
            img3 = cv2.resize(img3, (cfg.IMG_SIZE, cfg.IMG_SIZE), interpolation=cv2.INTER_LINEAR)

            fg_r = cv2.resize(fg, (cfg.IMG_SIZE, cfg.IMG_SIZE), interpolation=cv2.INTER_NEAREST)

            np.save(os.path.join(img_dir, f"{stem}.npy"), img3.astype(np.float32))
            np.save(os.path.join(msk_dir, f"{stem}.npy"), fg_r.astype(np.float32))

            if is_positive:
                kept_pos += 1
            else:
                kept_neg += 1

    print(f"[{split}] Done. Cached positives: {kept_pos} | Cached negatives: {kept_neg}")

## Training flow

In [15]:
# build liver cache
build_cache_LiverTumor_for_ids(train_ids, "train")
build_cache_LiverTumor_for_ids(val_ids, "val")

[train] Building LiverTumor cache ...


Building train LiverTumor cache (volumes): 100%|██████████| 88/88 [10:59<00:00,  7.49s/it]


[train] Done. Cached positives: 12493 | Cached negatives: 1252
[val] Building LiverTumor cache ...


Building val LiverTumor cache (volumes): 100%|██████████| 23/23 [03:00<00:00,  7.86s/it]

[val] Done. Cached positives: 3069 | Cached negatives: 0





In [46]:
train_ds = NPYSliceDataset(cfg.CACHE_DIR, "train", augment=True)
val_ds   = NPYSliceDataset(cfg.CACHE_DIR, "val", augment=False)

train_loader = DataLoader(
    train_ds,
    batch_size=cfg.BATCH_SIZE,
    shuffle=True,
    num_workers=cfg.NUM_WORKERS,
    pin_memory=cfg.PIN_MEMORY,
    persistent_workers=(cfg.NUM_WORKERS > 0),
)
val_loader = DataLoader(
    val_ds,
    batch_size=cfg.BATCH_SIZE,
    shuffle=False,
    num_workers=cfg.NUM_WORKERS,
    pin_memory=cfg.PIN_MEMORY,
    persistent_workers=(cfg.NUM_WORKERS > 0),
)

print(f"[LIVER Tumor] Cached train slices: {len(train_ds)} | Cached val slices: {len(val_ds)}")
# speed check
t0 = time.time()
for i, (x, y) in enumerate(train_loader):
    if i == 50:
        break
print("50 batches load time:", time.time() - t0, "sec")

Indexing train .npy slices: 100%|██████████| 13745/13745 [00:00<00:00, 4817071.23it/s]
Indexing val .npy slices: 100%|██████████| 3069/3069 [00:00<00:00, 3897159.85it/s]

[LIVER Tumor] Cached train slices: 13745 | Cached val slices: 3069





50 batches load time: 3.3808228969573975 sec


In [None]:
hist1 = HistoryLogger("LiverTumor")

device = torch.device(cfg.DEVICE)
model = UNetBase(in_channels=3, out_channels=1, base=64).to(device)

pos_w = estimate_pos_weight(train_loader, device)
loss_fn = BCEDice(pos_weight=pos_w, bce_weight=0.3).to(device)

opt = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5, patience=5)

best_dice = -1.0
EPOCHS = 40

for ep in range(1, EPOCHS + 1):
    t0 = time.time()  # start timer

    tr_loss = train_one_epoch(model, train_loader, opt, loss_fn, device)
    val = evaluate(model, val_loader, loss_fn, device)

    scheduler.step(val["dice"])
    lr = opt.param_groups[0]["lr"]   
    dt = time.time() - t0            

    if val["dice"] > best_dice:
        best_dice = val["dice"]
        torch.save({"model": model.state_dict(), "epoch": ep, "val": val}, "best_unet_LiverTumor.pt")
        print(f"saved best_unet_LiverTumor.pt (best_dice={best_dice:.4f})")

    print(
        f"Epoch {ep:02d}/{EPOCHS} | train_loss={tr_loss:.4f} | val_loss={val['loss']:.4f} | "
        f"dice={val['dice']:.4f} iou={val['iou']:.4f} acc={val['acc']:.4f} "
        f"sens={val['sens']:.4f} spec={val['spec']:.4f} | lr={lr:.2e} | {dt:.1f}s"
    )

    hist1.add(ep, tr_loss, val, lr, dt)

hist1.save_csv()
hist1.plot()

ck = torch.load("best_unet_LiverTumor.pt", map_location=device)
model.load_state_dict(ck["model"])
model.eval()

visualize_predictions(model, val_loader, device, n=6, thr=0.5, title="LiverTumor_preds")

# Inference

In [39]:
# ============================================================
# Download best_unet_MSLesion.pt from Google Drive
# ============================================================

import os
import re
import subprocess

GDRIVE_URL = "https://drive.google.com/file/d/1kCiERrbDukSMnBhFfVqTF05GFQswv3dA/view?usp=sharing"

# Where to save
OUT_PATH = "/kaggle/working/best_unet_LiverTumor.pt"

def extract_file_id(url: str):
    # Works for /file/d/<id>/view links
    m = re.search(r"/file/d/([^/]+)/", url)
    if m:
        return m.group(1)
    # Works for id=<id> links
    m = re.search(r"[?&]id=([^&]+)", url)
    if m:
        return m.group(1)
    raise ValueError("Could not extract file ID from URL")

file_id = extract_file_id(GDRIVE_URL)
print("Google Drive File ID:", file_id)

# Install gdown (if not installed)
subprocess.check_call(["pip", "-q", "install", "gdown"])

import gdown

print("Downloading model...")
gdown.download(id=file_id, output=OUT_PATH, quiet=False)

print("Download complete:", OUT_PATH)

# Optional: verify file exists
if os.path.exists(OUT_PATH):
    print("File size (MB):", os.path.getsize(OUT_PATH)/1024/1024)
else:
    print("Download failed.")

Google Drive File ID: 1kCiERrbDukSMnBhFfVqTF05GFQswv3dA
Downloading model...


Downloading...
From (original): https://drive.google.com/uc?id=1kCiERrbDukSMnBhFfVqTF05GFQswv3dA
From (redirected): https://drive.google.com/uc?id=1kCiERrbDukSMnBhFfVqTF05GFQswv3dA&confirm=t&uuid=4614341f-26c9-467b-9abf-852d39f9a6af
To: /kaggle/working/best_unet_LiverTumor.pt
100%|██████████| 124M/124M [00:02<00:00, 58.3MB/s] 

Download complete: /kaggle/working/best_unet_LiverTumor.pt
File size (MB): 118.3936128616333





In [40]:
# ============================================================
# Inference on a single NIfTI slice + Visualization
# ============================================================

import numpy as np
import cv2
import nibabel as nib
import torch
import matplotlib.pyplot as plt

@torch.no_grad()
def run_inference(volume_path: str,
                  ckpt_path: str,
                  slice_idx: int,
                  mask_path: str = None,
                  img_size: int = None,
                  thr: float = 0.5,
                  device: torch.device = None,
                  base: int = 64):
    """
    Load a trained UNetBase checkpoint and run inference on one CT slice.

    Notes
    -----
    - This assumes you trained the model with:
        * UNetBase(in_channels=3, out_channels=1, base=64)
        * create_multi_channel_image(...) preprocessing
        * resizing to cfg.IMG_SIZE (default 256)
    - If you change any of those settings during training, update them here too.

    Parameters
    ----------
    volume_path : str
        Path to CT volume (e.g., volume-XXX.nii).
    ckpt_path : str
        Path to saved checkpoint (e.g., 'best_unet_LiverTumor.pt').
    slice_idx : int
        Axial slice index (0..Z-1). Negative values work like Python indexing.
    mask_path : str, optional
        Path to GT mask (e.g., segmentation-XXX.nii) to visualize ground truth (mask > 0).
        If None, only the prediction will be shown.
    img_size : int, optional
        Spatial size for the network input. Defaults to cfg.IMG_SIZE if present, else 256.
    thr : float
        Threshold on sigmoid probability to produce binary mask.
    device : torch.device, optional
        If None, uses CUDA if available.
    base : int
        UNetBase base feature size (must match training).

    Returns
    -------
    vis_img : np.ndarray
        2D image used for display (channel 0 = liver window), shape [img_size, img_size].
    pred_mask : np.ndarray
        Binary predicted mask (0/1), shape [img_size, img_size].
    bg_mask : np.ndarray
        Background mask (1 - pred_mask), shape [img_size, img_size].
    """
    # --- basic config ---
    if img_size is None:
        img_size = cfg.IMG_SIZE if "cfg" in globals() else 256

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Running inference on: {device}")

    # --- Build model (MUST match training settings) ---
    model = UNetBase(in_channels=3, out_channels=1, base=base).to(device)

    # --- Load checkpoint robustly (supports several common formats) ---
    ckpt = torch.load(ckpt_path, map_location=device)

    state_dict = None
    if isinstance(ckpt, dict):
        # Your training loop saves {"model": state_dict, ...}
        state_dict = ckpt.get("model") or ckpt.get("model_state_dict") or ckpt.get("state_dict")

    if state_dict is None:
        state_dict = ckpt  # raw state_dict case

    # Handle DataParallel "module." prefix if present
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

    model.load_state_dict(state_dict, strict=True)
    model.eval()

    # --- Load CT slice ---
    ct = nib.load(volume_path).get_fdata(dtype=np.float32)  # LiTS: (H,W,Z)
    Z = ct.shape[2]

    if slice_idx < 0:
        slice_idx = Z + slice_idx
    if not (0 <= slice_idx < Z):
        raise ValueError(f"slice_idx={slice_idx} out of range for Z={Z}")

    ct_slice = ct[:, :, slice_idx]

    # --- Preprocess exactly like training ---
    # create_multi_channel_image must be defined earlier in the notebook.
    x3 = create_multi_channel_image(ct_slice)  # [H,W,3] in [0,1]
    x3 = cv2.resize(x3, (img_size, img_size), interpolation=cv2.INTER_LINEAR)

    # tensor: (B,C,H,W)
    x = torch.from_numpy(x3).permute(2, 0, 1).unsqueeze(0).to(device)

    # --- Predict ---
    logits = model(x)
    prob = torch.sigmoid(logits).squeeze(0).squeeze(0).cpu().numpy()
    pred_mask = (prob >= thr).astype(np.uint8)

    # --- Optional ground truth mask (mask > 0 = liver + tumor) ---
    true_mask = None
    if mask_path is not None:
        m = nib.load(mask_path).get_fdata(dtype=np.float32)[:, :, slice_idx]
        true_mask = (m > 0).astype(np.uint8)
        true_mask = cv2.resize(true_mask, (img_size, img_size), interpolation=cv2.INTER_NEAREST)

    # For visualization, show the "liver window" channel (channel 0)
    vis_img = x3[..., 0]

    # --- Plot ---
    ncols = 3 if true_mask is not None else 2
    fig, axes = plt.subplots(1, ncols, figsize=(5 * ncols, 5))

    axes[0].imshow(vis_img, cmap="gray")
    axes[0].set_title(f"CT (slice {slice_idx})")
    axes[0].axis("off")

    if true_mask is not None:
        axes[1].imshow(vis_img, cmap="gray")
        axes[1].imshow(true_mask, cmap="Greens", alpha=0.5)
        axes[1].set_title("Ground truth (green)")
        axes[1].axis("off")
        ax_pred = axes[2]
    else:
        ax_pred = axes[1]

    ax_pred.imshow(vis_img, cmap="gray")
    ax_pred.imshow(pred_mask, cmap="Reds", alpha=0.5)
    ax_pred.set_title(f"Prediction (red) | thr={thr}")
    ax_pred.axis("off")

    plt.tight_layout()
    plt.show()

    bg_mask = 1 - pred_mask
    return vis_img, pred_mask, bg_mask

In [41]:
# ============================================================
# Validation Inference + Average Dice and Accuracy
# ============================================================

import numpy as np
import torch
from tqdm import tqdm

CKPT_PATH = "best_unet_LiverTumor.pt"   # <-- change if your checkpoint name/path differs
THR = 0.5                               # sigmoid threshold

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# --- Load model ---
model = UNetBase(in_channels=3, out_channels=1, base=64).to(device)

ckpt = torch.load(CKPT_PATH, map_location=device)

# handle different checkpoint formats
if isinstance(ckpt, dict) and ("model" in ckpt or "model_state_dict" in ckpt or "state_dict" in ckpt):
    state_dict = ckpt.get("model") or ckpt.get("model_state_dict") or ckpt.get("state_dict")
else:
    state_dict = ckpt  # raw state_dict

# handle DataParallel prefix just in case
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

model.load_state_dict(state_dict, strict=True)
model.eval()


@torch.no_grad()
def eval_val_loader(model, loader, device, thr=0.5, eps=1e-8):
    """
    Returns:
      dice_macro: mean Dice over samples (per-slice average)
      acc_macro:  mean pixel accuracy over samples
      dice_micro: Dice computed from total TP/FP/FN over whole val set
      acc_micro:  pixel accuracy computed from total TP/TN/FP/FN over whole val set
    """
    dice_sum = 0.0
    acc_sum  = 0.0
    n_samples = 0

    TP = 0.0
    TN = 0.0
    FP = 0.0
    FN = 0.0

    for x, y in tqdm(loader, desc="Running val inference"):
        x = x.to(device, non_blocking=True)  # [B,3,H,W]
        y = y.to(device, non_blocking=True)  # [B,1,H,W]

        logits = model(x)                    # [B,1,H,W]
        prob = torch.sigmoid(logits)
        pred = (prob >= thr).float()

        # per-sample confusion terms (sum over C,H,W)
        tp = (pred * y).sum(dim=(1,2,3))
        tn = ((1 - pred) * (1 - y)).sum(dim=(1,2,3))
        fp = (pred * (1 - y)).sum(dim=(1,2,3))
        fn = ((1 - pred) * y).sum(dim=(1,2,3))

        # per-sample Dice + Acc (macro averaging)
        den = (2*tp + fp + fn)
        dice = torch.where(den > 0, (2*tp + eps) / (den + eps), torch.ones_like(den))
        acc  = (tp + tn + eps) / (tp + tn + fp + fn + eps)

        bs = x.size(0)
        dice_sum += dice.sum().item()
        acc_sum  += acc.sum().item()
        n_samples += bs

        # accumulate totals for micro/global metrics
        TP += tp.sum().item()
        TN += tn.sum().item()
        FP += fp.sum().item()
        FN += fn.sum().item()

    dice_macro = dice_sum / max(n_samples, 1)
    acc_macro  = acc_sum  / max(n_samples, 1)

    dice_micro = (2*TP) / (2*TP + FP + FN + eps)
    acc_micro  = (TP + TN) / (TP + TN + FP + FN + eps)

    return {
        "dice_macro": dice_macro,
        "acc_macro": acc_macro,
        "dice_micro": dice_micro,
        "acc_micro": acc_micro,
        "TP": TP, "TN": TN, "FP": FP, "FN": FN,
        "n_samples": n_samples
    }


# --- Run evaluation ---
metrics = eval_val_loader(model, val_loader, device, thr=THR)

print("\n===== Validation Metrics =====")
print(f"Samples (slices): {metrics['n_samples']}")
print(f"Dice (macro avg):  {metrics['dice_macro']:.4f}")
print(f"Acc  (macro avg):  {metrics['acc_macro']:.4f}")
print(f"Dice (micro/global): {metrics['dice_micro']:.4f}")
print(f"Acc  (micro/global): {metrics['acc_micro']:.4f}")
print(f"TP={metrics['TP']:.0f} TN={metrics['TN']:.0f} FP={metrics['FP']:.0f} FN={metrics['FN']:.0f}")

Device: cuda


Running val inference: 100%|██████████| 192/192 [00:23<00:00,  8.09it/s]


===== Validation Metrics =====
Samples (slices): 3069
Dice (macro avg):  0.9129
Acc  (macro avg):  0.9938
Dice (micro/global): 0.9554
Acc  (micro/global): 0.9938
TP=13400452 TN=186479875 FP=893579 FN=356078



