# WBC Segmentation (Unstained) - Kaggle GPU Notebook

Run this end-to-end in a Kaggle Notebook with GPU. Adjust the dataset path as noted below.

In [None]:
# Clone the repo if not already present (Kaggle starts in /kaggle/working)
import os, subprocess
from pathlib import Path

REPO_URL = "https://github.com/mpotalib/dip-blood_cell_segmentations.git"
REPO_DIR = Path("/kaggle/working/dip-blood_cell_segmentations")

if not REPO_DIR.exists():
    subprocess.run(["git", "clone", REPO_URL, str(REPO_DIR)], check=True)
os.chdir(REPO_DIR)
print("CWD:", Path.cwd())

In [None]:
# Install dependencies
!pip install -r requirements.txt

In [None]:
# Point to your Kaggle dataset containing data/train|val|test with images + masks (or annotations)
# Example: upload a dataset and set DATASET_BASE to /kaggle/input/your-dataset-name
from pathlib import Path
import os, shutil

DATASET_BASE = Path("/kaggle/input/dip-wbc-dataset")  # TODO: set to your dataset name
TARGET = Path("data")

if not DATASET_BASE.exists():
    raise FileNotFoundError(f"Dataset path not found: {DATASET_BASE}. Update DATASET_BASE above.")

# If data is already structured as data/train/images etc, just symlink
if TARGET.exists():
    if TARGET.is_symlink():
        TARGET.unlink()
    else:
        shutil.rmtree(TARGET)
os.symlink(DATASET_BASE, TARGET)
print("Linked", DATASET_BASE, "->", TARGET)

# If masks are not pre-generated but annotations exist, create masks
for split in ["train", "val", "test"]:
    img_dir = TARGET / split / "images"
    ann_dir = TARGET / split / "annotations"
    mask_dir = TARGET / split / "masks"
    if img_dir.exists() and ann_dir.exists() and not mask_dir.exists():
        mask_dir.mkdir(parents=True, exist_ok=True)
        !python prepare_masks.py --images-dir {img_dir} --annotations-dir {ann_dir} --output-dir {mask_dir}
    else:
        print(f"Split {split}: images={img_dir.exists()}, annotations={ann_dir.exists()}, masks={mask_dir.exists()}")

In [None]:
# Train (choose config: baseline or deeplab)
!python train.py --config experiments/deeplab.yaml


In [None]:
# Evaluate on val/test and export qualitative masks for the report/deck
!python evaluate.py --config experiments/deeplab.yaml --checkpoint outputs/deeplab/checkpoints/best.pt --split val --save-dir outputs/deeplab/preds_val --limit 20
!python evaluate.py --config experiments/deeplab.yaml --checkpoint outputs/deeplab/checkpoints/best.pt --split test --save-dir outputs/deeplab/preds_test --limit 20


## Train with log and plot curves
Use tee to capture stdout to train.log, then parse and plot train/val curves.


In [None]:
# Train DeepLab and save log for plotting
!python train.py --config experiments/deeplab.yaml | tee train.log


In [None]:
# Parse train.log and plot loss/Dice curves
import re, matplotlib.pyplot as plt
train_loss, val_loss, val_dice = [], [], []
with open('train.log') as f:
    for line in f:
        m = re.search(r'Epoch (\d+)/(\d+).*train loss ([0-9.]+)', line)
        if m:
            train_loss.append(float(m.group(3)))
        m = re.search(r'Validation - loss: ([0-9.]+) \| dice: ([0-9.]+)', line)
        if m:
            val_loss.append(float(m.group(1)))
            val_dice.append(float(m.group(2)))
if train_loss and val_loss:
    plt.figure(figsize=(8,4))
    plt.plot(train_loss, label='train loss')
    plt.plot(val_loss, label='val loss')
    plt.xlabel('epoch'); plt.ylabel('loss'); plt.legend(); plt.tight_layout()
if val_dice:
    plt.figure(figsize=(6,4))
    plt.plot(val_dice, label='val Dice')
    plt.xlabel('epoch'); plt.ylabel('Dice'); plt.legend(); plt.tight_layout()
plt.show()


## Visualizations and Metrics
Per-class metrics, overlays, augmentation previews, and worst cases.


In [None]:
# Compute per-class metrics on val and list worst cases by Dice
import torch, yaml, numpy as np
from pathlib import Path
from torch.utils.data import DataLoader
from src.dataset import WBCDataset
from src.transforms import get_val_transforms
from src.models import build_model
from src.utils import set_seed

cfg = yaml.safe_load(open('experiments/deeplab.yaml'))
class_mapping = cfg['data'].get('class_mapping', {'background':0, 'n':1, 'b':2})
num_classes = len(class_mapping)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(cfg['logging'].get('seed', 1337))

val_ds = WBCDataset(Path(cfg['data']['val_images']), mask_dir=Path(cfg['data']['val_masks']),
                    annotation_dir=None, transforms=get_val_transforms(), class_mapping=class_mapping)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2)
model = build_model(name=cfg['model']['name'], in_channels=cfg['model']['in_channels'],
                    num_classes=cfg['model']['num_classes'],
                    pretrained_backbone=cfg['model'].get('pretrained_backbone', False)).to(device)
ckpt = torch.load('outputs/deeplab/checkpoints/best.pt', map_location=device)
model.load_state_dict(ckpt['state_dict'], strict=False)
model.eval()

def per_class_dice(pred, target, k):
    pred_oh = torch.nn.functional.one_hot(pred, num_classes=k).permute(0,3,1,2)
    tgt_oh = torch.nn.functional.one_hot(target, num_classes=k).permute(0,3,1,2)
    dims = (0,2,3)
    inter = (pred_oh * tgt_oh).sum(dim=dims)
    card = (pred_oh + tgt_oh).sum(dim=dims)
    return ((2*inter + 1e-6)/(card + 1e-6)).squeeze(0)

per_image = []
with torch.no_grad():
    for i,(imgs,masks) in enumerate(val_loader):
        imgs, masks = imgs.to(device), masks.to(device)
        out = model(imgs)
        logits = out['out'] if isinstance(out, dict) else out
        preds = torch.argmax(torch.softmax(logits, dim=1), dim=1)
        d = per_class_dice(preds, masks, num_classes).cpu().numpy()
        per_image.append({'idx': i, 'mean_dice': float(d.mean()), 'per_class': d})

per_image = sorted(per_image, key=lambda x: x['mean_dice'])
mean_per_class = np.mean([p['per_class'] for p in per_image], axis=0)
print('Mean Dice per class (bg, nucleus, boundary):', np.round(mean_per_class, 3))
print('Worst 5 images by Dice:')
for w in per_image[:5]:
    print('{} mean {:.3f} per-class {}'.format(w['idx'], w['mean_dice'], np.round(w['per_class'],3)))
worst_indices = [w['idx'] for w in per_image[:6]]
Path('tmp_worst.npy').write_bytes(np.array(worst_indices, dtype=np.int64).tobytes())


In [None]:
# Show overlays (input / GT / Pred) for a few val samples (including worst cases)
import numpy as np, matplotlib.pyplot as plt, cv2
from pathlib import Path
from src.dataset import WBCDataset
from src.transforms import get_val_transforms

def colorize(mask):
    colors = np.array([[0,0,0],[255,255,255],[255,0,0]], dtype=np.uint8)
    mask = np.clip(mask, 0, len(colors)-1)
    return colors[mask]

val_ds = WBCDataset(Path(cfg['data']['val_images']), mask_dir=Path(cfg['data']['val_masks']),
                    annotation_dir=None, transforms=get_val_transforms(), class_mapping=class_mapping)
worst_indices = np.frombuffer(Path('tmp_worst.npy').read_bytes(), dtype=np.int64).tolist() if Path('tmp_worst.npy').exists() else []
to_show = (worst_indices + list(range(6)))[:6]

rows = len(to_show)
fig, axes = plt.subplots(rows, 3, figsize=(9, 3*rows))
for r, idx in enumerate(to_show):
    img, mask = val_ds[idx]
    img_np = (img.permute(1,2,0).numpy()*255).astype(np.uint8)
    pred_path = Path(f'outputs/deeplab/preds_val/{idx:04d}_0_pred.png')
    pred = cv2.cvtColor(cv2.imread(str(pred_path)), cv2.COLOR_BGR2RGB) if pred_path.exists() else np.zeros_like(img_np)
    axes[r,0].imshow(img_np); axes[r,0].axis('off'); axes[r,0].set_title(f'Idx {idx} Input')
    axes[r,1].imshow(colorize(mask.numpy())); axes[r,1].axis('off'); axes[r,1].set_title('GT')
    axes[r,2].imshow(pred); axes[r,2].axis('off'); axes[r,2].set_title('Pred')
plt.tight_layout(); plt.show()


In [None]:
# Preview augmentations (image + mask)
import matplotlib.pyplot as plt
from src.transforms import get_train_transforms
train_ds = WBCDataset(Path(cfg['data']['train_images']), mask_dir=Path(cfg['data']['train_masks']),
                      annotation_dir=None, transforms=get_train_transforms(cfg), class_mapping=class_mapping)
fig, axes = plt.subplots(2, 3, figsize=(9,6))
for i in range(6):
    img, m = train_ds[i]
    img_np = ((img.permute(1,2,0).numpy()*255).clip(0,255)).astype(np.uint8)
    axes[i//3, i%3].imshow(img_np)
    axes[i//3, i%3].imshow(colorize(m.numpy()), alpha=0.4)
    axes[i//3, i%3].axis('off')
plt.tight_layout(); plt.show()
