In [54]:
import os
import random
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler
import matplotlib.pyplot as plt

In [55]:
data_root ="C:\\Users\\ASUS\\Desktop\\liver_seg\\data" 
images_dir = os.path.join(data_root, "images")  
labels_dir = os.path.join(data_root, "labels") 

In [56]:
patch_size = (96,160,160)  
batch_size = 1
num_epochs = 15 
num_train_steps_per_epoch = 100 
num_classes = 3            
base_lr = 1e-3


In [57]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [59]:
def load_nifti(path):
    nii = nib.load(path)
    vol = nii.get_fdata().astype(np.float32) 
    if vol.ndim == 3:
        if vol.shape[0] != vol.shape[1] and vol.shape[0] != vol.shape[2]:
            vol = np.transpose(vol, (2, 0, 1))
    return vol


In [60]:
def normalize_volume(vol):
    non_zero = vol[vol > 0]
    if non_zero.size > 0:
        m = non_zero.mean()
        s = non_zero.std() + 1e-8
        vol = (vol - m) / s
    else:
        vol = (vol - vol.mean()) / (vol.std() + 1e-8)
    return vol


In [61]:
image_files = sorted([f for f in os.listdir(images_dir) if f.endswith(".nii")])
label_files = sorted([f for f in os.listdir(labels_dir) if f.endswith(".nii")])


In [62]:
len(image_files)

123

In [63]:
len(label_files)

123

In [64]:
pairs = []
for img_name, lbl_name in zip(image_files, label_files):
    img_path = os.path.join(images_dir, img_name)
    lbl_path = os.path.join(labels_dir, lbl_name)
    pairs.append((img_path, lbl_path))

In [65]:
pairs[0]

('C:\\Users\\ASUS\\Desktop\\liver_seg\\data\\images\\liver_0.nii',
 'C:\\Users\\ASUS\\Desktop\\liver_seg\\data\\labels\\liver_0.nii')

In [66]:
num_total = len(pairs)
num_train = int(0.8 * num_total)
indices = list(range(num_total))
random.shuffle(indices)

In [67]:
train_idx = indices[:num_train]
val_idx   = indices[num_train:]

In [68]:
train_pairs = [pairs[i] for i in train_idx]
val_pairs   = [pairs[i] for i in val_idx]

In [69]:
len(train_pairs)

98

In [70]:
train_images = []
train_labels = []

for img_path, lbl_path in train_pairs:
    img = load_nifti(img_path)
    lbl = load_nifti(lbl_path)
    img = normalize_volume(img)
    lbl = lbl.astype(np.int64)
    train_images.append(img)
    train_labels.append(lbl)



In [71]:
len(train_images)

98

In [72]:
val_images = []
val_labels = []

for img_path, lbl_path in val_pairs:
    img = load_nifti(img_path)
    lbl = load_nifti(lbl_path)
    img = normalize_volume(img)
    lbl = lbl.astype(np.int64)
    val_images.append(img)
    val_labels.append(lbl)




In [73]:
len(val_images)

25

In [74]:
train_images[0].shape

(187, 187, 229)

In [75]:
np.unique(train_labels[0])

array([0, 1, 2])

In [76]:
def get_random_patch(img, lbl, patch_size, min_dim=16):
    D, H, W = img.shape
    pd_t, ph_t, pw_t = patch_size

    pd = min(pd_t, D)
    ph = min(ph_t, H)
    pw = min(pw_t, W)

    tumor_vox = np.argwhere(lbl == 2)
    liver_vox = np.argwhere(lbl == 1)

    if tumor_vox.size > 0 and random.random() < 0.7:
    # 70% of the time if tumor exists → center on tumor
        cz, cy, cx = tumor_vox[random.randint(0, len(tumor_vox) - 1)]
    elif liver_vox.size > 0:
        cz, cy, cx = liver_vox[random.randint(0, len(liver_vox) - 1)]
    else:
        cz, cy, cx = D // 2, H // 2, W // 2

    z1 = cz - pd // 2
    y1 = cy - ph // 2
    x1 = cx - pw // 2
    z1 = max(0, min(z1, D - pd))
    y1 = max(0, min(y1, H - ph))
    x1 = max(0, min(x1, W - pw))
    z1, y1, x1 = int(z1), int(y1), int(x1)
    img_patch = img[z1:z1 + pd, y1:y1 + ph, x1:x1 + pw]
    lbl_patch = lbl[z1:z1 + pd, y1:y1 + ph, x1:x1 + pw]


    cd, ch, cw = img_patch.shape
    pad_d = max(min_dim - cd, 0)
    pad_h = max(min_dim - ch, 0)
    pad_w = max(min_dim - cw, 0)

    if pad_d > 0 or pad_h > 0 or pad_w > 0:
        pd0 = pad_d // 2
        pd1 = pad_d - pd0
        ph0 = pad_h // 2
        ph1 = pad_h - ph0
        pw0 = pad_w // 2
        pw1 = pad_w - pw0
        img_patch = np.pad(img_patch, ((pd0, pd1), (ph0, ph1), (pw0, pw1)), mode="constant", constant_values=0)
        lbl_patch = np.pad(lbl_patch, ((pd0, pd1), (ph0, ph1), (pw0, pw1)), mode="constant", constant_values=0)


    img_patch = img_patch[np.newaxis, ...]
    img_tensor = torch.from_numpy(img_patch).float()
    lbl_tensor = torch.from_numpy(lbl_patch).long()
    return img_tensor, lbl_tensor


In [77]:
def train_batch_generator(images, labels, patch_size, batch_size):
    while True:
        batch_imgs = []
        batch_lbls = []
        for _ in range(batch_size):
            idx = random.randint(0, len(images) - 1)
            img = images[idx]
            lbl = labels[idx]
            p_img, p_lbl = get_random_patch(img, lbl, patch_size)
            batch_imgs.append(p_img)
            batch_lbls.append(p_lbl)
        yield torch.stack(batch_imgs), torch.stack(batch_lbls)


In [78]:
train_gen = train_batch_generator(train_images, train_labels, patch_size, batch_size)


In [79]:
class DoubleConv3D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.InstanceNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.InstanceNorm3d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


In [80]:
class UNet3D(nn.Module):
    def __init__(self, in_channels=1, num_classes=3, base_ch=16):
        super().__init__()
        self.enc1 = DoubleConv3D(in_channels, base_ch)
        self.enc2 = DoubleConv3D(base_ch, base_ch * 2)
        self.enc3 = DoubleConv3D(base_ch * 2, base_ch * 4)
        self.enc4 = DoubleConv3D(base_ch * 4, base_ch * 8)

        self.pool = nn.MaxPool3d(2)

        self.bottleneck = DoubleConv3D(base_ch * 8, base_ch * 16)

        self.up4 = nn.ConvTranspose3d(base_ch * 16, base_ch * 8, 2, 2)
        self.dec4 = DoubleConv3D(base_ch * 16, base_ch * 8)

        self.up3 = nn.ConvTranspose3d(base_ch * 8, base_ch * 4, 2, 2)
        self.dec3 = DoubleConv3D(base_ch * 8, base_ch * 4)

        self.up2 = nn.ConvTranspose3d(base_ch * 4, base_ch * 2, 2, 2)
        self.dec2 = DoubleConv3D(base_ch * 4, base_ch * 2)

        self.up1 = nn.ConvTranspose3d(base_ch * 2, base_ch, 2, 2)
        self.dec1 = DoubleConv3D(base_ch * 2, base_ch)

        self.out_conv = nn.Conv3d(base_ch, num_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.up4(b)
        if d4.shape[2:] != e4.shape[2:]:
            d4 = F.interpolate(d4, size=e4.shape[2:], mode="trilinear", align_corners=False)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))

        d3 = self.up3(d4)
        if d3.shape[2:] != e3.shape[2:]:
            d3 = F.interpolate(d3, size=e3.shape[2:], mode="trilinear", align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.up2(d3)
        if d2.shape[2:] != e2.shape[2:]:
            d2 = F.interpolate(d2, size=e2.shape[2:], mode="trilinear", align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.up1(d2)
        if d1.shape[2:] != e1.shape[2:]:
            d1 = F.interpolate(d1, size=e1.shape[2:], mode="trilinear", align_corners=False)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        out = self.out_conv(d1)
        return out


In [81]:
def dice_loss_multi(logits, targets, num_classes, eps=1e-6):
    probs = F.softmax(logits, dim=1)
    one_hot = F.one_hot(targets, num_classes=num_classes)  
    one_hot = one_hot.permute(0, 4, 1, 2, 3).float()     
    dims = (0, 2, 3, 4)
    intersection = torch.sum(probs * one_hot, dims)
    union = torch.sum(probs + one_hot, dims)
    dice = (2 * intersection + eps) / (union + eps)
    return 1 - dice[1:].mean()


In [82]:
model = UNet3D(in_channels=1, num_classes=num_classes, base_ch=16).to(device)
dummy = torch.randn(1, 1, *patch_size).to(device)
with torch.no_grad():
    out = model(dummy)



In [83]:
out.shape

torch.Size([1, 3, 96, 160, 160])

In [84]:
ce_weights = torch.tensor([0.2, 0.4, 0.4], device=device)
ce_loss_fn = nn.CrossEntropyLoss(weight=ce_weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr)
scaler = GradScaler()


  scaler = GradScaler()


In [85]:
def train_one_epoch(epoch):
    model.train()
    total_ce = 0.0
    total_dice = 0.0
    for step in range(num_train_steps_per_epoch):
        imgs, lbls = next(train_gen)  
        imgs = imgs.to(device)
        lbls = lbls.to(device)
        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
            logits = model(imgs)
            loss_ce = ce_loss_fn(logits, lbls)
            loss_dice = dice_loss_multi(logits, lbls, num_classes)
            loss = loss_ce + loss_dice

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_ce += loss_ce.item()
        total_dice += loss_dice.item()
        if step % 10 == 0:
            print(f"[Epoch {epoch} Step {step}] CE: {loss_ce.item():.4f}, DiceLoss: {loss_dice.item():.4f}")

    avg_ce = total_ce / num_train_steps_per_epoch
    avg_dice = total_dice / num_train_steps_per_epoch
    print(f"Epoch {epoch} TRAIN — CE: {avg_ce:.4f}, DiceLoss: {avg_dice:.4f}")


class1: liver

class2: tumor

In [86]:
@torch.no_grad()
def validate(epoch, num_val_samples=20):
    model.eval()
    liver_dice_sum = 0.0
    tumor_dice_sum = 0.0
    liver_count = 0
    tumor_count = 0
    for _ in range(num_val_samples):
        idx = random.randint(0, len(val_images) - 1)
        img = val_images[idx]
        lbl = val_labels[idx]
        p_img, p_lbl = get_random_patch(img, lbl, patch_size)
        p_img = p_img.unsqueeze(0).to(device)   
        p_lbl = p_lbl.unsqueeze(0).to(device)   
        with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
            logits = model(p_img)
        preds = logits.argmax(dim=1)         
        pred_b = preds[0]
        lbl_b  = p_lbl[0]

        # class1
        pred_liver = (pred_b == 1)
        lbl_liver  = (lbl_b == 1)
        denom_liver = pred_liver.sum() + lbl_liver.sum()
        if denom_liver > 0:
            inter_liver = (pred_liver & lbl_liver).sum()
            dice_liver = (2.0 * inter_liver.float() / denom_liver.float()).item()
            liver_dice_sum += dice_liver
            liver_count += 1

        # class2
        pred_tumor = (pred_b == 2)
        lbl_tumor  = (lbl_b == 2)
        denom_tumor = pred_tumor.sum() + lbl_tumor.sum()
        if denom_tumor > 0:
            inter_tumor = (pred_tumor & lbl_tumor).sum()
            dice_tumor = (2.0 * inter_tumor.float() / denom_tumor.float()).item()
            tumor_dice_sum += dice_tumor
            tumor_count += 1

    avg_liver = liver_dice_sum / liver_count if liver_count > 0 else 0.0
    avg_tumor = tumor_dice_sum / tumor_count if tumor_count > 0 else 0.0

    print(f"Epoch {epoch} VAL — liver Dice: {avg_liver:.4f}, tumor Dice: {avg_tumor:.4f}")
    return avg_liver, avg_tumor


In [87]:
best_val_score = -1.0

In [88]:
for epoch in range(1, num_epochs + 1):
    train_one_epoch(epoch)
    avg_liver, avg_tumor = validate(epoch)
    print(f"Epoch {epoch} — liver Dice: {avg_liver:.4f},tumor Dice:{avg_tumor:.4f}")
    score = (avg_liver + avg_tumor) / 2.0
    print(f"Epoch {epoch} mean Dice (liver+tumor): {score:.4f}")

    if score > best_val_score:
        best_val_score = score
        torch.save(model.state_dict(), "best_unet3d_liver_tumor.pth")
        print("best_val_score:",best_val_score)


  with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):


[Epoch 1 Step 0] CE: 1.2751, DiceLoss: 0.9772
[Epoch 1 Step 10] CE: 1.0756, DiceLoss: 0.9029
[Epoch 1 Step 20] CE: 0.9195, DiceLoss: 0.8882
[Epoch 1 Step 30] CE: 0.7862, DiceLoss: 0.8291
[Epoch 1 Step 40] CE: 0.8798, DiceLoss: 0.9339
[Epoch 1 Step 50] CE: 0.6719, DiceLoss: 0.7601
[Epoch 1 Step 60] CE: 0.6838, DiceLoss: 0.8731
[Epoch 1 Step 70] CE: 0.6505, DiceLoss: 0.8800
[Epoch 1 Step 80] CE: 0.6768, DiceLoss: 0.9278
[Epoch 1 Step 90] CE: 0.5855, DiceLoss: 0.8040
Epoch 1 TRAIN — CE: 0.7859, DiceLoss: 0.8774


  with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):


Epoch 1 VAL — liver Dice: 0.7529, tumor Dice: 0.0000
Epoch 1 — liver Dice: 0.7529,tumor Dice:0.0000
Epoch 1 mean Dice (liver+tumor): 0.3765
best_val_score: 0.3764700435101986
[Epoch 2 Step 0] CE: 0.5312, DiceLoss: 0.7319
[Epoch 2 Step 10] CE: 0.6177, DiceLoss: 0.9295
[Epoch 2 Step 20] CE: 0.4996, DiceLoss: 0.8576
[Epoch 2 Step 30] CE: 0.4598, DiceLoss: 0.8420
[Epoch 2 Step 40] CE: 0.3804, DiceLoss: 0.8199
[Epoch 2 Step 50] CE: 0.3910, DiceLoss: 0.8570
[Epoch 2 Step 60] CE: 0.3175, DiceLoss: 0.7306
[Epoch 2 Step 70] CE: 0.3870, DiceLoss: 0.7049
[Epoch 2 Step 80] CE: 0.2963, DiceLoss: 0.7799
[Epoch 2 Step 90] CE: 0.2710, DiceLoss: 0.7939
Epoch 2 TRAIN — CE: 0.4127, DiceLoss: 0.8123
Epoch 2 VAL — liver Dice: 0.8285, tumor Dice: 0.0000
Epoch 2 — liver Dice: 0.8285,tumor Dice:0.0000
Epoch 2 mean Dice (liver+tumor): 0.4142
best_val_score: 0.4142269916832447
[Epoch 3 Step 0] CE: 0.2671, DiceLoss: 0.7759
[Epoch 3 Step 10] CE: 0.2736, DiceLoss: 0.7848
[Epoch 3 Step 20] CE: 0.2863, DiceLoss: 0.8

In [102]:
@torch.no_grad()
def visualize_liver_and_tumor(filename="liver_tumor_1.png"):
    desktop = os.path.join(os.path.expanduser("~"), "Desktop")
    save_path = os.path.join(desktop, filename)

    model.eval()

   
    idx = random.randint(0, len(val_images) - 1)
    img = val_images[idx]
    lbl = val_labels[idx]
    img_patch, lbl_patch = get_random_patch(img, lbl, patch_size)

  
    x = img_patch.unsqueeze(0).to(device)          
    with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
        logits = model(x)
    pred = logits.argmax(1)[0].cpu().numpy()       

    img_np = img_patch[0].numpy()                 
    lbl_np = lbl_patch.numpy()                    
    d = img_np.shape[0] // 2
    ct = img_np[d]
    gt_slice = lbl_np[d]
    pr_slice = pred[d]
    gt_liver = (gt_slice == 1)
    gt_tumor = (gt_slice == 2)
    pr_liver = (pr_slice == 1)
    pr_tumor = (pr_slice == 2)


    gt_rgb = np.zeros((*ct.shape, 3), dtype=float)
    gt_rgb[gt_liver] = [0, 1, 0]   # green
    gt_rgb[gt_tumor] = [1, 0, 0]   # red

    pr_rgb = np.zeros((*ct.shape, 3), dtype=float)
    pr_rgb[pr_liver] = [0, 1, 0] 
    pr_rgb[pr_tumor] = [1, 0, 0]  

  
    fig, axes = plt.subplots(1, 3, figsize=(18, 8))

    axes[0].imshow(ct, cmap="gray")
    axes[0].set_title("CT")
    axes[0].axis("off")

    axes[1].imshow(ct, cmap="gray")
    axes[1].imshow(gt_rgb, alpha=0.4)
    axes[1].set_title("Ground Truth\nGreen=liver, Red=tumor")
    axes[1].axis("off")

    axes[2].imshow(ct, cmap="gray")
    axes[2].imshow(pr_rgb, alpha=0.4)
    axes[2].set_title("Prediction\nGreen=liver, Red=tumor")
    axes[2].axis("off")

    plt.tight_layout()
    plt.savefig(save_path, dpi=200)
    plt.close(fig)

    


In [103]:
for i in range(1, 6):
    visualize_liver_and_tumor(filename=f"liver_tumor_vis_{i}.png")


  with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
