<a href="https://colab.research.google.com/github/edwardleetenafly/LA-net/blob/main/lanet_week4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from scipy.spatial.distance import cdist

print("Week 4 notebook loaded!")

root = Path("/content/drive/MyDrive/la_net")
prep_dir = root / "preprocessed"

splits = pd.read_csv(root / "splits.csv")
test_df = splits[splits["split"] == "test"]

device = "cuda" if torch.cuda.is_available() else "cpu"
device


Week 4 notebook loaded!


'cuda'

In [3]:
class LADataset(Dataset):
    def __init__(self, df, prep_dir):
        self.df = df
        self.prep_dir = prep_dir

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        cid = row["case_id"]

        img = np.load(self.prep_dir / f"{cid}_image.npy")
        msk = np.load(self.prep_dir / f"{cid}_mask.npy")

        img = img[np.newaxis, ...]   # (1,128,128,128)
        msk = msk[np.newaxis, ...]

        return torch.tensor(img, dtype=torch.float32), torch.tensor(msk, dtype=torch.float32)


In [4]:
import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_c, out_c, 3, padding=1),
            nn.InstanceNorm3d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_c, out_c, 3, padding=1),
            nn.InstanceNorm3d(out_c),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base_channels=16):
        super().__init__()

        self.enc1 = DoubleConv(in_channels, base_channels)
        self.enc2 = DoubleConv(base_channels, base_channels*2)
        self.enc3 = DoubleConv(base_channels*2, base_channels*4)

        self.bottleneck = DoubleConv(base_channels*4, base_channels*8)

        self.up3 = nn.ConvTranspose3d(base_channels*8, base_channels*4, 2, 2)
        self.dec3 = DoubleConv(base_channels*8, base_channels*4)

        self.up2 = nn.ConvTranspose3d(base_channels*4, base_channels*2, 2, 2)
        self.dec2 = DoubleConv(base_channels*4, base_channels*2)

        self.up1 = nn.ConvTranspose3d(base_channels*2, base_channels, 2, 2)
        self.dec1 = DoubleConv(base_channels*2, base_channels)

        self.out_conv = nn.Conv3d(base_channels, out_channels, 1)

        self.pool = nn.MaxPool3d(2)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool(e1)

        e2 = self.enc2(p1)
        p2 = self.pool(e2)

        e3 = self.enc3(p2)
        p3 = self.pool(e3)

        b = self.bottleneck(p3)

        u3 = self.up3(b)
        u3 = torch.cat([u3, e3], dim=1)
        d3 = self.dec3(u3)

        u2 = self.up2(d3)
        u2 = torch.cat([u2, e2], dim=1)
        d2 = self.dec2(u2)

        u1 = self.up1(d2)
        u1 = torch.cat([u1, e1], dim=1)
        d1 = self.dec1(u1)

        return self.out_conv(d1)


In [5]:
model = UNet3D().to(device)

ckpt_path = root / "baseline_unet_best.ckpt"
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

print("Loaded baseline model from:", ckpt_path)


Loaded baseline model from: /content/drive/MyDrive/la_net/baseline_unet_best.ckpt


In [6]:
def dice_coef(pred, target, eps=1e-6):
    inter = (pred * target).sum()
    union = pred.sum() + target.sum()
    return (2*inter + eps) / (union + eps)

import scipy.ndimage as ndimage
from scipy.ndimage import distance_transform_edt

def hd95(pred, gt):
    # convert to binary
    pred = pred.astype(bool)
    gt   = gt.astype(bool)

    # handle empty cases
    if pred.sum() == 0 and gt.sum() == 0:
        return 0.0
    if pred.sum() == 0 or gt.sum() == 0:
        return np.inf

    # compute distance transforms
    dt_pred = distance_transform_edt(~pred)
    dt_gt   = distance_transform_edt(~gt)

    # surface voxels
    pred_surface = pred ^ ndimage.binary_erosion(pred)
    gt_surface   = gt ^ ndimage.binary_erosion(gt)

    # distances from pred surface → gt
    d1 = dt_gt[pred_surface]

    # distances from gt surface → pred
    d2 = dt_pred[gt_surface]

    all_d = np.concatenate([d1, d2])
    return np.percentile(all_d, 95)


In [7]:
test_loader = DataLoader(LADataset(test_df, prep_dir), batch_size=1, shuffle=False)

rows = []

with torch.no_grad():
    for i, (img, msk) in enumerate(tqdm(test_loader)):
        cid = test_df.iloc[i]["case_id"]

        img = img.to(device)
        msk = msk.to(device)

        logits = model(img)
        prob = torch.sigmoid(logits)
        pred = (prob > 0.5).float()

        pred_np = pred.cpu().numpy()[0,0]
        msk_np  = msk.cpu().numpy()[0,0]

        d = dice_coef(pred_np, msk_np)
        h = hd95(pred_np, msk_np)

        rows.append({"case_id": cid, "dice": float(d), "hd95": float(h)})

results = pd.DataFrame(rows)
results


100%|██████████| 3/3 [00:05<00:00,  1.84s/it]


Unnamed: 0,case_id,dice,hd95
0,la_026,0.885755,4.898979
1,la_003,0.89918,3.162278
2,la_004,0.8835,5.385165


In [8]:
out_dir = root / "results"
out_dir.mkdir(exist_ok=True)

csv_path = out_dir / "baseline_metrics.csv"
results.to_csv(csv_path, index=False)

print("Saved:", csv_path)


Saved: /content/drive/MyDrive/la_net/results/baseline_metrics.csv


In [9]:
print("Mean Dice:", results["dice"].mean())
print("Std Dice:", results["dice"].std())
print("Mean HD95:", results["hd95"].mean())
print("Std HD95:", results["hd95"].std())


Mean Dice: 0.8894783059755961
Std Dice: 0.008477127982771731
Mean HD95: 4.4821406509564135
Std HD95: 1.1685986383553724


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool3d(1)
        self.max = nn.AdaptiveMaxPool3d(1)
        self.fc = nn.Sequential(
            nn.Conv3d(channels, channels//reduction, 1),
            nn.ReLU(),
            nn.Conv3d(channels//reduction, channels, 1)
        )

    def forward(self, x):
        avg = self.fc(self.avg(x))
        mx  = self.fc(self.max(x))
        return torch.sigmoid(avg + mx)

class SpatialAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv3d(2, 1, kernel_size=7, padding=3)

    def forward(self, x):
        avg = torch.mean(x, dim=1, keepdim=True)
        mx, _ = torch.max(x, dim=1, keepdim=True)
        s = torch.cat([avg, mx], dim=1)
        return torch.sigmoid(self.conv(s))

class CBAM(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.ca = ChannelAttention(channels)
        self.sa = SpatialAttention()

    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x


In [11]:
class LANet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base=16):
        super().__init__()

        self.enc1 = nn.Sequential(DoubleConv(in_channels, base), CBAM(base))
        self.enc2 = nn.Sequential(DoubleConv(base, base*2), CBAM(base*2))
        self.enc3 = nn.Sequential(DoubleConv(base*2, base*4), CBAM(base*4))

        self.bottleneck = nn.Sequential(DoubleConv(base*4, base*8), CBAM(base*8))

        self.up3 = nn.ConvTranspose3d(base*8, base*4, 2, 2)
        self.dec3 = DoubleConv(base*8, base*4)

        self.up2 = nn.ConvTranspose3d(base*4, base*2, 2, 2)
        self.dec2 = DoubleConv(base*4, base*2)

        self.up1 = nn.ConvTranspose3d(base*2, base, 2, 2)
        self.dec1 = DoubleConv(base*2, base)

        self.out_conv = nn.Conv3d(base, out_channels, 1)
        self.pool = nn.MaxPool3d(2)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool(e1)

        e2 = self.enc2(p1)
        p2 = self.pool(e2)

        e3 = self.enc3(p2)
        p3 = self.pool(e3)

        b = self.bottleneck(p3)

        u3 = torch.cat([self.up3(b), e3], dim=1)
        d3 = self.dec3(u3)

        u2 = torch.cat([self.up2(d3), e2], dim=1)
        d2 = self.dec2(u2)

        u1 = torch.cat([self.up1(d2), e1], dim=1)
        d1 = self.dec1(u1)

        return self.out_conv(d1)


In [12]:
x = torch.randn(1,1,128,128,128).to(device)
model_lanet = LANet3D().to(device)

y = model_lanet(x)
print("Output shape:", y.shape)


Output shape: torch.Size([1, 1, 128, 128, 128])


In [13]:
def compute_diameter(mask_np, spacing=(1,1,1)):
    D,H,W = mask_np.shape
    diameters = []

    for z in range(D):
        coords = np.argwhere(mask_np[z] == 1)
        if len(coords) == 0:
            continue

        # max pairwise distance in this slice
        dist = np.max(
            np.linalg.norm(coords[:,None] - coords[None,:], axis=-1)
        )
        diameters.append(dist)

    if len(diameters) == 0:
        return 0.0

    return max(diameters) * spacing[0]   # spacing to convert to mm
