<a href="https://colab.research.google.com/github/edwardleetenafly/LA-net/blob/main/lanet_week5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [19]:
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

root = Path("/content/drive/MyDrive/la_net")
prep_dir = root / "preprocessed"
results_dir = root / "results"
results_dir.mkdir(exist_ok=True)

print("root:", root)
print("prep exists:", prep_dir.exists())
print("npy count:", len(list(prep_dir.glob("*_image.npy"))))


root: /content/drive/MyDrive/la_net
prep exists: True
npy count: 20


In [20]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [21]:
splits = pd.read_csv(root / "splits.csv")

train_df = splits[splits.split == "train"].reset_index(drop=True)
val_df   = splits[splits.split == "val"].reset_index(drop=True)
test_df  = splits[splits.split == "test"].reset_index(drop=True)

print("Train/Val/Test:", len(train_df), len(val_df), len(test_df))
train_df.head()


Train/Val/Test: 14 3 3


Unnamed: 0,case_id,split
0,la_020,train
1,la_030,train
2,la_019,train
3,la_011,train
4,la_014,train


In [34]:
import numpy as np
from scipy.spatial import ConvexHull, distance_matrix
from scipy.ndimage import binary_erosion

def slice_diameter_pixels(binary2d: np.ndarray) -> float:
    """
    Computes max in-slice diameter (pixels) using boundary + convex hull.
    Deterministic and efficient.
    """
    b = (binary2d > 0)

    pts = np.argwhere(b)
    if pts.shape[0] < 2:
        return 0.0

    # boundary pixels only
    boundary = b & ~binary_erosion(b)
    bpts = np.argwhere(boundary)

    # fallback for tiny objects
    if bpts.shape[0] < 2:
        bpts = pts

    # deterministic cap (no randomness)
    max_pts = 1000
    if bpts.shape[0] > max_pts:
        step = int(np.ceil(bpts.shape[0] / max_pts))
        bpts = bpts[::step]

    # convex hull
    if bpts.shape[0] >= 3:
        try:
            hull = ConvexHull(bpts)
            hpts = bpts[hull.vertices]
        except:
            hpts = bpts
    else:
        hpts = bpts

    if hpts.shape[0] < 2:
        return 0.0

    D = distance_matrix(hpts, hpts)
    return float(D.max())


def compute_diameter_mm(
    mask3d: np.ndarray,
    spacing_xyz=(1.0, 1.0, 1.0),
    slice_axis=2,
) -> float:
    """
    Computes max diameter across all slices (mm).
    Assumes isotropic 1mm spacing after resampling.
    """
    m = (mask3d > 0).astype(np.uint8)
    m = np.moveaxis(m, slice_axis, 0)  # (S, H, W)

    best_px = 0.0
    for s in range(m.shape[0]):
        dpx = slice_diameter_pixels(m[s])
        if dpx > best_px:
            best_px = dpx

    # isotropic spacing → px == mm
    return float(best_px)

In [37]:
diam_rows = []
all_case_ids = splits["case_id"].values

for cid in tqdm(all_case_ids):
    msk = np.load(prep_dir / f"{cid}_mask.npy")  # (128,128,128)
    d_mm = compute_diameter_mm(msk, slice_axis=2)
    diam_rows.append({"case_id": cid, "d_mm": d_mm})

diam_df = pd.DataFrame(diam_rows)

diam_csv = results_dir / "diameters_mm.csv"
diam_df.to_csv(diam_csv, index=False)

print("Saved:", diam_csv)
diam_df["d_mm"].describe()


100%|██████████| 20/20 [00:01<00:00, 13.56it/s]

Saved: /content/drive/MyDrive/la_net/results/diameters_mm.csv





Unnamed: 0,d_mm
count,20.0
mean,71.052865
std,8.732301
min,47.88528
25%,66.775334
50%,71.568149
75%,76.795188
max,85.375641


In [38]:
diam_train = diam_df.merge(train_df[["case_id"]], on="case_id", how="inner")

d_mean = float(diam_train["d_mm"].mean())
d_std  = float(diam_train["d_mm"].std() + 1e-8)

diam_df["d_z"] = (diam_df["d_mm"] - d_mean) / d_std

diam_z_csv = results_dir / "diameters_z.csv"
diam_df.to_csv(diam_z_csv, index=False)

print("Train mean/std (mm):", d_mean, d_std)
print("Saved:", diam_z_csv)
diam_df.head()


Train mean/std (mm): 68.75822612963175 8.867784197244093
Saved: /content/drive/MyDrive/la_net/results/diameters_z.csv


Unnamed: 0,case_id,d_mm,d_z
0,la_020,47.88528,-2.353795
1,la_030,68.410526,-0.039209
2,la_019,74.886581,0.691081
3,la_011,82.097503,1.50424
4,la_014,78.294317,1.075363


In [40]:
class LADatasetMulti(Dataset):
    def __init__(self, df_split: pd.DataFrame, prep_dir: Path, diam_table: pd.DataFrame):
        self.df = df_split.reset_index(drop=True)
        self.prep_dir = prep_dir
        self.diam = diam_table.set_index("case_id")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        cid = self.df.loc[idx, "case_id"]

        img = np.load(self.prep_dir / f"{cid}_image.npy").astype(np.float32)   # (128,128,128)
        msk = np.load(self.prep_dir / f"{cid}_mask.npy").astype(np.float32)    # (128,128,128) in {0,1}

        img = torch.from_numpy(img[None, ...])   # (1,128,128,128)
        msk = torch.from_numpy(msk[None, ...])   # (1,128,128,128)

        d_z = torch.tensor(float(self.diam.loc[cid, "d_z"]), dtype=torch.float32)  # scalar

        return img, msk, d_z, cid


In [42]:
diam_df_loaded = pd.read_csv(results_dir / "diameters_z.csv")

train_loader = DataLoader(
    LADatasetMulti(train_df, prep_dir, diam_df_loaded),
    batch_size=1, shuffle=True, num_workers=0, pin_memory=True
)
val_loader = DataLoader(
    LADatasetMulti(val_df, prep_dir, diam_df_loaded),
    batch_size=1, shuffle=False, num_workers=0, pin_memory=True
)
test_loader = DataLoader(
    LADatasetMulti(test_df, prep_dir, diam_df_loaded),
    batch_size=1, shuffle=False, num_workers=0, pin_memory=True
)

# sanity batch
img, msk, d_z, cid = next(iter(train_loader))
print("img:", img.shape, img.dtype)
print("msk:", msk.shape, msk.dtype)
print("d_z:", d_z, "cid:", cid[0])


img: torch.Size([1, 1, 128, 128, 128]) torch.float32
msk: torch.Size([1, 1, 128, 128, 128]) torch.float32
d_z: tensor([-0.2084]) cid: la_018


In [43]:
class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_c, out_c, 3, padding=1, bias=False),
            nn.InstanceNorm3d(out_c),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_c, out_c, 3, padding=1, bias=False),
            nn.InstanceNorm3d(out_c),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool3d(1)
        self.mx  = nn.AdaptiveMaxPool3d(1)
        hidden = max(channels // reduction, 1)
        self.fc = nn.Sequential(
            nn.Conv3d(channels, hidden, 1),
            nn.ReLU(inplace=True),
            nn.Conv3d(hidden, channels, 1),
        )
    def forward(self, x):
        return torch.sigmoid(self.fc(self.avg(x)) + self.fc(self.mx(x)))

class SpatialAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv3d(2, 1, kernel_size=7, padding=3, bias=False)
    def forward(self, x):
        avg = torch.mean(x, dim=1, keepdim=True)
        mx, _ = torch.max(x, dim=1, keepdim=True)
        s = torch.cat([avg, mx], dim=1)
        return torch.sigmoid(self.conv(s))

class CBAM(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.ca = ChannelAttention(channels)
        self.sa = SpatialAttention()
    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x


In [45]:
class LANetMultiTask(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base=16):
        super().__init__()
        self.enc1 = nn.Sequential(DoubleConv(in_channels, base), CBAM(base))
        self.enc2 = nn.Sequential(DoubleConv(base, base*2), CBAM(base*2))
        self.enc3 = nn.Sequential(DoubleConv(base*2, base*4), CBAM(base*4))

        self.pool = nn.MaxPool3d(2)

        self.bottleneck_conv = DoubleConv(base*4, base*8)
        self.bottleneck_attn = CBAM(base*8)

        # decoder
        self.up3 = nn.ConvTranspose3d(base*8, base*4, 2, 2)
        self.dec3 = DoubleConv(base*8, base*4)

        self.up2 = nn.ConvTranspose3d(base*4, base*2, 2, 2)
        self.dec2 = DoubleConv(base*4, base*2)

        self.up1 = nn.ConvTranspose3d(base*2, base, 2, 2)
        self.dec1 = DoubleConv(base*2, base)

        self.seg_head = nn.Conv3d(base, out_channels, 1)

        # regression head (predict d_z)
        self.gap = nn.AdaptiveAvgPool3d(1)
        self.reg_mlp = nn.Sequential(
            nn.Linear(base*8, base*4),
            nn.ReLU(inplace=True),
            nn.Linear(base*4, 1)
        )

    def forward(self, x):
        e1 = self.enc1(x); p1 = self.pool(e1)
        e2 = self.enc2(p1); p2 = self.pool(e2)
        e3 = self.enc3(p2); p3 = self.pool(e3)

        b = self.bottleneck_attn(self.bottleneck_conv(p3))

        # regression from bottleneck
        v = self.gap(b).flatten(1)                 # (B, base*8)
        d_z_hat = self.reg_mlp(v).squeeze(1)       # (B,)

        # segmentation decoder
        u3 = torch.cat([self.up3(b), e3], dim=1)
        d3 = self.dec3(u3)

        u2 = torch.cat([self.up2(d3), e2], dim=1)
        d2 = self.dec2(u2)

        u1 = torch.cat([self.up1(d2), e1], dim=1)
        d1 = self.dec1(u1)

        seg_logits = self.seg_head(d1)             # (B,1,128,128,128)

        return seg_logits, d_z_hat

In [46]:
model = LANetMultiTask().to(device)
x = torch.randn(1,1,128,128,128).to(device)
seg_logits, d_z_hat = model(x)
print("seg logits:", seg_logits.shape)
print("d_z_hat:", d_z_hat.shape)


seg logits: torch.Size([1, 1, 128, 128, 128])
d_z_hat: torch.Size([1])


In [102]:
@torch.no_grad()
def dice_score_from_logits(logits, target, eps=1e-6):
    prob = torch.sigmoid(logits)
    pred = (prob > 0.5).float()
    pred = pred.view(pred.size(0), -1)
    target = target.view(target.size(0), -1)
    inter = (pred * target).sum(dim=1)
    union = pred.sum(dim=1) + target.sum(dim=1)
    dice = (2*inter + eps) / (union + eps)
    return dice.mean().item()

@torch.no_grad()
def soft_dice_from_logits(logits, target, eps=1e-6):
    prob = torch.sigmoid(logits)
    prob = prob.view(prob.size(0), -1)
    target = target.view(target.size(0), -1)
    inter = (prob * target).sum(dim=1)
    union = prob.sum(dim=1) + target.sum(dim=1)
    dice = (2*inter + eps) / (union + eps)
    return dice.mean().item()

class SoftDiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        probs = probs.view(probs.size(0), -1)
        targets = targets.view(targets.size(0), -1)
        inter = (probs * targets).sum(dim=1)
        union = probs.sum(dim=1) + targets.sum(dim=1)
        dice = (2*inter + self.eps) / (union + self.eps)
        return 1 - dice.mean()

bce = nn.BCEWithLogitsLoss()
dice_loss = SoftDiceLoss()

def seg_loss_fn(seg_logits, mask):
    return 0.6*bce(seg_logits, mask) + 0.4*dice_loss(seg_logits, mask)


In [184]:
img, msk, d_z, cid = next(iter(train_loader))
print("pos frac:", msk.mean().item(), "sum:", msk.sum().item())

pos frac: 0.030888080596923828 sum: 64777.0


In [185]:
with torch.no_grad():
    img, msk, d_z, cid = next(iter(val_loader))
    logits, _ = model(img.to(device))
    prob = torch.sigmoid(logits)[0,0]
    print("pred frac:", (prob > 0.5).float().mean().item())


pred frac: 0.06282758712768555


In [186]:
# from Cell 7
print("d_mean, d_std:", d_mean, d_std)

def z_to_mm(d_z_tensor):
    return d_z_tensor * d_std + d_mean


d_mean, d_std: 68.75822612963175 8.867784197244093


In [187]:
lambda_reg = 0.2
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

num_epochs = 30
best_val_dice = -1.0
best_path = root / "la_net_multitask_best.ckpt"

print("Ready. best_path:", best_path)


Ready. best_path: /content/drive/MyDrive/la_net/la_net_multitask_best.ckpt


In [188]:
with torch.no_grad():
  img, msk, d_true_z, cid = next(iter(val_loader))
  img, msk = img.to(device), msk.to(device)
  seg_logits, d_hat_z = model (img)

  prob = torch.sigmoid(seg_logits)
  pred = (prob > 0.5).float()
  print("logits min/max", seg_logits.min().item(), seg_logits.max().item())
  print("prob min/max", prob.min().item(), prob.max().item())
  print("pred unique", torch.unique(pred))
  print("mask unique", torch.unique(msk))
  print("pred sum:", pred.sum().item(), "mask sum:", msk.sum().item())

logits min/max -0.8474089503288269 9.086922645568848
prob min/max 0.2999766767024994 0.9998868703842163
pred unique tensor([0., 1.], device='cuda:0')
mask unique tensor([0., 1.], device='cuda:0')
pred sum: 131759.0 mask sum: 104355.0


In [189]:
train_history = []
val_history = []

for epoch in range(1, num_epochs+1):
    # ---- train ----
    model.train()
    tr_seg_loss = 0.0
    tr_reg_loss = 0.0

    for img, msk, d_z_true, cid in train_loader:
        img = img.to(device)
        msk = msk.to(device)
        d_z_true = d_z_true.to(device)  # (B,1)

        optimizer.zero_grad()
        seg_logits, d_z_hat = model(img)

        Lseg = seg_loss_fn(seg_logits, msk)
        Lreg = l1(d_z_hat, d_z_true)
        loss = Lseg + lambda_reg * Lreg

        loss.backward()
        optimizer.step()

        tr_seg_loss += float(Lseg.item())
        tr_reg_loss += float(Lreg.item())

    tr_seg_loss /= max(len(train_loader), 1)
    tr_reg_loss /= max(len(train_loader), 1)

    # ---- val ----
    model.eval()
    val_dice_list = []
    val_mae_mm_list = []

    with torch.no_grad():
        for img, msk, d_z_true, cid in val_loader:
            img = img.to(device)
            msk = msk.to(device)
            d_z_true = d_z_true.to(device)

            seg_logits, d_z_hat = model(img)

            # dice
            val_dice_list.append(dice_score_from_logits(seg_logits, msk))

            # MAE in mm (convert both)
            d_hat_mm = z_to_mm(d_z_hat).cpu().numpy().reshape(-1)
            d_true_mm = z_to_mm(d_z_true).cpu().numpy().reshape(-1)
            val_mae_mm_list.append(float(np.mean(np.abs(d_hat_mm - d_true_mm))))

    val_dice = float(np.mean(val_dice_list)) if len(val_dice_list) else 0.0
    val_mae_mm = float(np.mean(val_mae_mm_list)) if len(val_mae_mm_list) else 0.0

    train_history.append({"epoch": epoch, "seg_loss": tr_seg_loss, "reg_loss_z": tr_reg_loss})
    val_history.append({"epoch": epoch, "val_dice": val_dice, "val_mae_mm": val_mae_mm})

    # save best by val dice
    if val_dice > best_val_dice:
        best_val_dice = val_dice
        torch.save(model.state_dict(), best_path)

    print(f"Epoch {epoch:02d} | seg_loss {tr_seg_loss:.4f} | reg_L1(z) {tr_reg_loss:.4f} | val_dice {val_dice:.4f} | val_MAE(mm) {val_mae_mm:.2f} | best {best_val_dice:.4f}")


Epoch 01 | seg_loss 0.5921 | reg_L1(z) 0.3491 | val_dice 0.8147 | val_MAE(mm) 3.97 | best 0.8147
Epoch 02 | seg_loss 0.5882 | reg_L1(z) 0.3672 | val_dice 0.8524 | val_MAE(mm) 3.03 | best 0.8524
Epoch 03 | seg_loss 0.5859 | reg_L1(z) 0.3250 | val_dice 0.8698 | val_MAE(mm) 5.15 | best 0.8698
Epoch 04 | seg_loss 0.5838 | reg_L1(z) 0.3369 | val_dice 0.8720 | val_MAE(mm) 3.93 | best 0.8720
Epoch 05 | seg_loss 0.5819 | reg_L1(z) 0.3099 | val_dice 0.8446 | val_MAE(mm) 3.14 | best 0.8720
Epoch 06 | seg_loss 0.5800 | reg_L1(z) 0.3193 | val_dice 0.8484 | val_MAE(mm) 3.21 | best 0.8720
Epoch 07 | seg_loss 0.5774 | reg_L1(z) 0.2872 | val_dice 0.8680 | val_MAE(mm) 2.31 | best 0.8720
Epoch 08 | seg_loss 0.5756 | reg_L1(z) 0.3084 | val_dice 0.8850 | val_MAE(mm) 3.27 | best 0.8850
Epoch 09 | seg_loss 0.5739 | reg_L1(z) 0.2917 | val_dice 0.8818 | val_MAE(mm) 3.71 | best 0.8850
Epoch 10 | seg_loss 0.5720 | reg_L1(z) 0.2810 | val_dice 0.8832 | val_MAE(mm) 3.69 | best 0.8850
Epoch 11 | seg_loss 0.5701 | r

In [191]:
train_log = pd.DataFrame(train_history)
val_log = pd.DataFrame(val_history)

train_log_path = results_dir / "la_net_multitask_train_log.csv"
val_log_path   = results_dir / "la_net_multitask_val_log.csv"

train_log.to_csv(train_log_path, index=False)
val_log.to_csv(val_log_path, index=False)

print("Saved:", train_log_path)
print("Saved:", val_log_path)
val_log.tail()


Saved: /content/drive/MyDrive/la_net/results/la_net_multitask_train_log.csv
Saved: /content/drive/MyDrive/la_net/results/la_net_multitask_val_log.csv


Unnamed: 0,epoch,val_dice,val_mae_mm
25,26,0.859043,2.078662
26,27,0.876279,4.314654
27,28,0.881313,3.289714
28,29,0.87772,4.882609
29,30,0.861505,2.086782


In [192]:
model = LANetMultiTask().to(device)
model.load_state_dict(torch.load(best_path, map_location=device))
model.eval()
print("Loaded best:", best_path, "best_val_dice:", best_val_dice)


Loaded best: /content/drive/MyDrive/la_net/la_net_multitask_best.ckpt best_val_dice: 0.8904640674591064


In [194]:
import numpy as np
import scipy.ndimage as ndimage
from scipy.ndimage import distance_transform_edt

def hd95_np(pred, gt, spacing=(1.0, 1.0, 1.0)):
    pred = pred.astype(bool)
    gt   = gt.astype(bool)

    if pred.sum() == 0 and gt.sum() == 0:
        return 0.0
    if pred.sum() == 0 or gt.sum() == 0:
        return float("inf")

    struct = ndimage.generate_binary_structure(3, 1)

    pred_er = ndimage.binary_erosion(pred, structure=struct)
    gt_er   = ndimage.binary_erosion(gt,   structure=struct)

    pred_surface = pred ^ pred_er
    gt_surface   = gt   ^ gt_er

    # fallback for tiny objects where erosion removes everything
    if pred_surface.sum() == 0:
        pred_surface = pred
    if gt_surface.sum() == 0:
        gt_surface = gt

    dt_pred = distance_transform_edt(~pred, sampling=spacing)
    dt_gt   = distance_transform_edt(~gt,   sampling=spacing)

    d1 = dt_gt[pred_surface]  # pred surface to nearest gt surface
    d2 = dt_pred[gt_surface]  # gt surface to nearest pred surface

    all_d = np.concatenate([d1, d2])
    return float(np.percentile(all_d, 95))



In [197]:
import numpy as np
import scipy.ndimage as ndimage

def largest_cc(binary3d: np.ndarray) -> np.ndarray:
    """
    Keeps only the largest connected component in a 3D binary mask.
    """
    lab, n = ndimage.label(binary3d)
    if n == 0:
        return binary3d

    counts = np.bincount(lab.ravel())
    counts[0] = 0  # ignore background
    keep = counts.argmax()

    return (lab == keep)


In [198]:
rows = []
model.eval()

with torch.no_grad():
    for img, msk, d_z_true, cid in tqdm(val_loader):
        cid = cid[0]  # ok because batch_size=1

        img = img.to(device)
        msk = msk.to(device)
        d_z_true = d_z_true.to(device)

        seg_logits, d_z_hat = model(img)

        # hard pred for metrics
        prob = torch.sigmoid(seg_logits)
        pred = (prob > 0.5).float()

        # dice (torch)
        inter = (pred * msk).sum(dim=(1,2,3,4))
        union = pred.sum(dim=(1,2,3,4)) + msk.sum(dim=(1,2,3,4))
        dice = ((2*inter + 1e-6) / (union + 1e-6)).item()

        # numpy for hd95
        pred_np = pred.detach().cpu().numpy()[0,0]  # (D,H,W)
        msk_np  = msk.detach().cpu().numpy()[0,0]

        # binarize + largest CC (recommended)
        pred_np = largest_cc(pred_np > 0.5).astype(np.uint8)
        msk_np  = (msk_np > 0.5).astype(np.uint8)

        h = hd95_np(pred_np, msk_np, spacing=(1.0,1.0,1.0))

        # diameter MAE (mm)
        d_hat_mm  = float(z_to_mm(d_z_hat).detach().cpu().numpy().reshape(-1)[0])
        d_true_mm = float(z_to_mm(d_z_true).detach().cpu().numpy().reshape(-1)[0])
        mae_mm = abs(d_hat_mm - d_true_mm)

        rows.append({
            "case_id": cid,
            "dice": float(dice),
            "hd95": float(h),
            "d_hat_mm": d_hat_mm,
            "d_true_mm": d_true_mm,
            "abs_err_mm": mae_mm
        })

val_metrics = pd.DataFrame(rows)
val_metrics_path = results_dir / "la_net_val_metrics.csv"
val_metrics.to_csv(val_metrics_path, index=False)

print("Saved:", val_metrics_path)
print("Mean Dice:", val_metrics["dice"].mean())
print("Mean HD95:", val_metrics["hd95"].replace(np.inf, np.nan).mean())
print("Mean MAE (mm):", val_metrics["abs_err_mm"].mean())
val_metrics.head()




100%|██████████| 3/3 [00:02<00:00,  1.18it/s]

Saved: /content/drive/MyDrive/la_net/results/la_net_val_metrics.csv
Mean Dice: 0.8904640674591064
Mean HD95: 3.6138730843212596
Mean MAE (mm): 4.031974792480469





Unnamed: 0,case_id,dice,hd95,d_hat_mm,d_true_mm,abs_err_mm
0,la_016,0.929599,2.236068,67.496132,69.375786,1.879654
1,la_009,0.839608,5.0,68.49295,76.295479,7.802528
2,la_023,0.902185,3.605551,69.154404,71.568146,2.413742


In [200]:
import scipy.ndimage as ndimage

def largest_connected_component(binary3d: np.ndarray):
    lab, n = ndimage.label(binary3d.astype(np.uint8))
    if n == 0:
        return binary3d.astype(np.uint8)
    counts = np.bincount(lab.ravel())
    counts[0] = 0
    keep = counts.argmax()
    return (lab == keep).astype(np.uint8)


In [201]:
rows_pp = []
model.eval()

with torch.no_grad():
    for img, msk, d_z_true, cid in tqdm(val_loader):
        cid = cid[0]
        img = img.to(device)
        msk = msk.to(device)

        seg_logits, d_z_hat = model(img)
        prob = torch.sigmoid(seg_logits)

        pred_np = (prob.cpu().numpy()[0,0] > 0.5).astype(np.uint8)
        pred_pp = largest_connected_component(pred_np)

        msk_np = (msk.cpu().numpy()[0,0] > 0.5).astype(np.uint8)

        inter = (pred_pp * msk_np).sum()
        union = pred_pp.sum() + msk_np.sum()
        dice = (2*inter + 1e-6) / (union + 1e-6)

        h = hd95_np(pred_pp, msk_np, spacing=(1.0,1.0,1.0))

        rows_pp.append({"case_id": cid, "dice_pp": float(dice), "hd95_pp": float(h)})

val_pp = pd.DataFrame(rows_pp)
print("Mean Dice (PP):", val_pp["dice_pp"].mean())
print("Mean HD95 (PP):", val_pp["hd95_pp"].replace(np.inf, np.nan).mean())
val_pp.head()



100%|██████████| 3/3 [00:02<00:00,  1.16it/s]

Mean Dice (PP): 0.8924827268869805
Mean HD95 (PP): 3.6138730843212596





Unnamed: 0,case_id,dice_pp,hd95_pp
0,la_016,0.930316,2.236068
1,la_009,0.840511,5.0
2,la_023,0.906621,3.605551


In [203]:
val_pp_path = results_dir / "la_net_val_metrics_postprocessed.csv"
val_pp.to_csv(val_pp_path, index=False)
print("Saved:", val_pp_path)


Saved: /content/drive/MyDrive/la_net/results/la_net_val_metrics_postprocessed.csv


In [204]:
def count_components(binary3d: np.ndarray) -> int:
    _, n = ndimage.label(binary3d.astype(np.uint8))
    return int(n)

def hole_voxels(binary3d: np.ndarray) -> int:
    # Fill holes then compare
    filled = ndimage.binary_fill_holes(binary3d.astype(bool))
    holes = filled & (~binary3d.astype(bool))
    return int(holes.sum())


In [205]:
rows_shape = []
with torch.no_grad():
    for img, msk, d_z_true, cid in tqdm(val_loader):
        cid = cid[0]
        img = img.to(device)

        seg_logits, d_z_hat = model(img)
        prob = torch.sigmoid(seg_logits)
        pred = (prob > 0.5).float()

        pred_np = pred.cpu().numpy()[0,0].astype(np.uint8)

        n_comp = count_components(pred_np)
        holes = hole_voxels(pred_np)

        rows_shape.append({"case_id": cid, "n_components": n_comp, "hole_voxels": holes})

shape_df = pd.DataFrame(rows_shape)
shape_path = results_dir / "la_net_val_shape_metrics.csv"
shape_df.to_csv(shape_path, index=False)

print("Saved:", shape_path)
print(shape_df.describe())


100%|██████████| 3/3 [00:01<00:00,  2.95it/s]

Saved: /content/drive/MyDrive/la_net/results/la_net_val_shape_metrics.csv
       n_components  hole_voxels
count           3.0          3.0
mean            8.0          0.0
std             5.0          0.0
min             3.0          0.0
25%             5.5          0.0
50%             8.0          0.0
75%            10.5          0.0
max            13.0          0.0



