In [None]:


import os, glob, random, shutil, math, time
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms.functional as TF
from tqdm import tqdm


ORIGINAL_DATA_DIR = "project/train_data_tiff"
PROCESSED_DIR     = "project/processed_tiles"

TILE_SIZE = 512
STRIDE = 400
CLASS_VALUES = [0, 50, 100, 150, 200]
NUM_CLASSES = 5

VAL_SPLIT = 0.2
SEED = 42


BATCH_SIZE = 4
ACCUM_STEPS = 4
LR = 3e-4
EPOCHS = 30
BASE_CH = 48
USE_AMP = True

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def cleanup_dir(d):
    if os.path.exists(d):
        shutil.rmtree(d)
    os.makedirs(d, exist_ok=True)

def robust_normalize(img):

    img = img.astype(np.float32)
    p2, p98 = np.percentile(img, (2, 98))
    if p98 - p2 < 1e-6:
        return np.zeros_like(img, dtype=np.float32)
    img = (img - p2) / (p98 - p2)
    return np.clip(img, 0.0, 1.0).astype(np.float32)

def map_labels_nearest(lbl_arr):
    
    lbl = lbl_arr.astype(np.int32)
    diffs = np.abs(lbl[..., None] - np.array(CLASS_VALUES)[None, None, :])
    return diffs.argmin(axis=-1).astype(np.uint8) 

def compute_starts(L, tile, stride):
    if L <= tile:
        return [0]
    starts = list(range(0, L - tile + 1, stride))
    if starts[-1] != L - tile:
        starts.append(L - tile)
    return starts

def find_raw_files(data_dir):
    exts = (".tif", ".tiff", ".png")
    candidates = sorted(glob.glob(os.path.join(data_dir, "raw_*")))
    raws = [p for p in candidates if os.path.splitext(p)[1].lower() in exts]
    if len(raws) == 0:
     
        candidates = sorted(glob.glob(os.path.join(data_dir, "*")))
        raws = [p for p in candidates if ("raw" in os.path.basename(p).lower()) and (os.path.splitext(p)[1].lower() in exts)]
    return sorted(raws)

def corresponding_label_path(raw_path):

    dn = os.path.dirname(raw_path)
    bn = os.path.basename(raw_path)

    p1 = os.path.join(dn, bn.replace("raw_", "label_", 1))
    if os.path.exists(p1):
        return p1

    if bn.startswith("raw_"):
        p2 = os.path.join(dn, "label_" + bn[len("raw_"):])
        if os.path.exists(p2):
            return p2
    return None

def tile_and_save(raw_paths, out_dir, subset):
    import os
    import numpy as np
    from PIL import Image

    os.makedirs(os.path.join(out_dir, subset, "images"), exist_ok=True)
    os.makedirs(os.path.join(out_dir, subset, "masks"), exist_ok=True)

    count = 0
    for raw_p in raw_paths:
        base = os.path.splitext(os.path.basename(raw_p))[0]

        lbl_p = raw_p.replace("raw_", "label_")
        if not os.path.exists(lbl_p):
            lbl_p = os.path.join(os.path.dirname(raw_p), "label_" + base.replace("raw_", "") + ".tif")
        if not os.path.exists(lbl_p):
            lbl_p = os.path.join(os.path.dirname(raw_p), "label_" + base.replace("raw_", "") + ".tiff")
        if not os.path.exists(lbl_p):
            continue

        raw = robust_normalize(np.array(Image.open(raw_p)))
        lbl = map_labels_nearest(np.array(Image.open(lbl_p)))

        h, w = raw.shape
        ys = compute_starts(h, TILE_SIZE, STRIDE)
        xs = compute_starts(w, TILE_SIZE, STRIDE)

        for y in ys:
            for x in xs:
                img_tile = raw[y:y+TILE_SIZE, x:x+TILE_SIZE]
                lbl_tile = lbl[y:y+TILE_SIZE, x:x+TILE_SIZE]

                unique = np.unique(lbl_tile)

                if subset == "val":
                    keep = True
                else:
                    keep = True
                    if len(unique) == 1 and unique[0] == 0:
                        keep = (random.random() < 0.10)  
                    if (3 in unique) or (4 in unique) or (1 in unique):
                        keep = True

                if keep:
                    np.save(os.path.join(out_dir, subset, "images", f"tile_{count}.npy"),
                            img_tile.astype(np.float32))
    
                    np.save(os.path.join(out_dir, subset, "masks", f"tile_{count}.npy"),
                            lbl_tile.astype(np.uint8))
                    count += 1

    print(f"[{subset}] saved tiles: {count}")


def class_frequencies_from_tiles(processed_dir, subset, num_classes=5):
    import glob, os, numpy as np
    mpaths = sorted(glob.glob(os.path.join(processed_dir, subset, "masks", "*.npy")))
    counts = np.zeros(num_classes, dtype=np.int64)

    for p in mpaths:
        m = np.load(p, mmap_mode="r")
        bc = np.bincount(m.reshape(-1), minlength=num_classes)
        counts[:num_classes] += bc[:num_classes]

    total = counts.sum()
    freqs = counts / (total + 1e-12)
    return counts, freqs, total, len(mpaths)


class TiledNPYDataset(Dataset):
    def __init__(self, root_dir, subset="train"):
        self.subset = subset
        self.img_paths = sorted(glob.glob(os.path.join(root_dir, subset, "images", "*.npy")))
        self.msk_paths = sorted(glob.glob(os.path.join(root_dir, subset, "masks", "*.npy")))
        assert len(self.img_paths) == len(self.msk_paths)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = np.load(self.img_paths[idx], mmap_mode="r").astype(np.float32)
        msk = np.load(self.msk_paths[idx], mmap_mode="r")  

        img = torch.from_numpy(img).float().unsqueeze(0)
        msk = torch.from_numpy(msk).long()

        if self.subset == "train":
            if random.random() < 0.5:
                img = TF.hflip(img); msk = TF.hflip(msk)
            if random.random() < 0.5:
                img = TF.vflip(img); msk = TF.vflip(msk)
            if random.random() < 0.3:
                img = torch.clamp(img + 0.05 * torch.randn_like(img), 0.0, 1.0)

        return img, msk


def smoke_test_loader(loader):
    t0 = time.time()
    batch = next(iter(loader))
    print("First batch fetched in", round(time.time()-t0, 2), "sec",
          "| img", batch[0].shape, "| msk", batch[1].shape)


def gn(ch, groups=8):
    groups = min(groups, ch)
    while groups > 1 and (ch % groups != 0):
        groups -= 1
    return nn.GroupNorm(groups, ch)

class ConvGNAct(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, padding=p, bias=False)
        self.gn = gn(out_ch)
        self.act = nn.SiLU(inplace=True)
    def forward(self, x):
        return self.act(self.gn(self.conv(x)))

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.c1 = ConvGNAct(in_ch, out_ch)
        self.c2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.g2 = gn(out_ch)
        self.act = nn.SiLU(inplace=True)
        self.skip = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1, bias=False)
    def forward(self, x):
        h = self.c1(x)
        h = self.g2(self.c2(h))
        return self.act(h + self.skip(x))

class AttnGate(nn.Module):
    def __init__(self, skip_ch, gate_ch, inter_ch):
        super().__init__()
        self.theta = nn.Conv2d(skip_ch, inter_ch, 1, bias=False)
        self.phi   = nn.Conv2d(gate_ch, inter_ch, 1, bias=False)
        self.psi   = nn.Conv2d(inter_ch, 1, 1, bias=True)
        self.act   = nn.SiLU(inplace=True)
        self.sig   = nn.Sigmoid()
    def forward(self, skip, gate):
        g = F.interpolate(gate, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        a = self.act(self.theta(skip) + self.phi(g))
        a = self.sig(self.psi(a))
        return skip * a

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.block = ResBlock(in_ch, out_ch)
    def forward(self, x):
        return self.block(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.reduce = nn.Conv2d(in_ch, out_ch, 1, bias=False)
    
        self.attn = AttnGate(skip_ch, out_ch, inter_ch=max(16, out_ch // 2))
        self.block = ResBlock(out_ch + skip_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x)
        x = self.reduce(x)
        skip = self.attn(skip, x)
        x = torch.cat([skip, x], dim=1)
        return self.block(x)

class ResAttnUNetDS(nn.Module):
    def __init__(self, n_classes=NUM_CLASSES, base=48):
        super().__init__()
        c1, c2, c3, c4, c5 = base, base*2, base*4, base*8, base*12

        self.stem = ResBlock(1, c1)     
        self.d1 = Down(c1, c2)          
        self.d2 = Down(c2, c3)          
        self.d3 = Down(c3, c4)          
        self.d4 = Down(c4, c5)       

        self.bottleneck = ResBlock(c5, c5)  

        self.u3 = Up(c5, c4, c4)       
        self.u2 = Up(c4, c3, c3)        
        self.u1 = Up(c3, c2, c2)         
        self.u0 = Up(c2, c1, c1)        

        self.head0 = nn.Conv2d(c1, n_classes, 1)  
        self.head1 = nn.Conv2d(c2, n_classes, 1) 
        self.head2 = nn.Conv2d(c3, n_classes, 1)

    def forward(self, x):
        s0 = self.stem(x)
        s1 = self.d1(s0)
        s2 = self.d2(s1)
        s3 = self.d3(s2)
        s4 = self.d4(s3)

        b  = self.bottleneck(s4)

        x3 = self.u3(b,  s3)
        x2 = self.u2(x3, s2)
        x1 = self.u1(x2, s1)
        x0 = self.u0(x1, s0)

        out0 = self.head0(x0)
        out1 = self.head1(x1)
        out2 = self.head2(x2)
        return out0, out1, out2

class SoftDiceLoss(nn.Module):
    def __init__(self, n_classes, include_bg=False, eps=1e-6):
        super().__init__()
        self.n_classes = n_classes
        self.include_bg = include_bg
        self.eps = eps
    def forward(self, logits, target):
        probs = torch.softmax(logits, dim=1)
        tgt = F.one_hot(target, num_classes=self.n_classes).permute(0,3,1,2).float()
        if not self.include_bg:
            probs = probs[:, 1:]
            tgt   = tgt[:, 1:]
        dims = (0,2,3)
        inter = (probs * tgt).sum(dims)
        union = probs.sum(dims) + tgt.sum(dims)
        dice = (2*inter + self.eps) / (union + self.eps)
        return 1.0 - dice.mean()

class ComboLoss(nn.Module):
    def __init__(self, class_weights, dice_w=0.6, ce_w=0.4):
        super().__init__()
        self.ce = nn.CrossEntropyLoss(weight=class_weights)
        self.dice = SoftDiceLoss(n_classes=len(class_weights), include_bg=False)
        self.dice_w = dice_w
        self.ce_w = ce_w
    def forward(self, logits, target):
        return self.ce_w * self.ce(logits, target) + self.dice_w * self.dice(logits, target)


@torch.no_grad()
def macro_dice_ex_bg(pred, tgt, n_classes=NUM_CLASSES, eps=1e-6):
    dices=[]
    for c in range(1, n_classes):
        p = (pred==c).float()
        t = (tgt==c).float()
        d = (2*(p*t).sum() + eps) / (p.sum() + t.sum() + eps)
        dices.append(d.item())
    return float(np.mean(dices)) if dices else 0.0


def compute_class_weights_from_mask_files(mask_paths, n_classes, device):
    counts = np.zeros(n_classes, dtype=np.float64)

    for p in mask_paths:
        m = np.load(p, mmap_mode="r")
        bc = np.bincount(m.reshape(-1), minlength=n_classes)
        counts[:n_classes] += bc[:n_classes]

    freq = counts / (counts.sum() + 1e-12)
    w = 1.0 / np.sqrt(freq + 1e-12)


    w[0] = 0.2

    w = w / w.mean()
    return torch.tensor(w, device=device).float()


def train(
    processed_dir,
    epochs = 30,
    batch_size=4,
    accum_steps=4,
    lr=3e-4,
    base_ch=48,
    use_amp=True,
    device="cuda",
    num_workers=0,
):
  
    DO_FULL_VAL_EVERY = 10     
    FAST_VAL_BATCHES  = 50    

    @torch.no_grad()
    def fast_val(model, val_loader, device, n_batches):
        model.eval()
        dices = []
        for b, (img, msk) in enumerate(val_loader):
            if b >= n_batches:
                break
            img = img.to(device, non_blocking=True)
            msk = msk.to(device, non_blocking=True)
            out0, _, _ = model(img)
            pred = out0.argmax(dim=1)
            dices.append(macro_dice_ex_bg(pred, msk, n_classes=NUM_CLASSES))
        return float(np.mean(dices)) if dices else 0.0

    @torch.no_grad()
    def full_val(model, val_loader, device):
        model.eval()
        dices = []
        for img, msk in val_loader:
            img = img.to(device, non_blocking=True)
            msk = msk.to(device, non_blocking=True)
            out0, _, _ = model(img)
            pred = out0.argmax(dim=1)
            dices.append(macro_dice_ex_bg(pred, msk, n_classes=NUM_CLASSES))
        return float(np.mean(dices)) if dices else 0.0

    torch.backends.cudnn.benchmark = True
    amp_on = (use_amp and str(device).startswith("cuda"))
    scaler = torch.amp.GradScaler('cuda', enabled=amp_on)

    train_ds = TiledNPYDataset(processed_dir, "train")
    val_ds   = TiledNPYDataset(processed_dir, "val")


    class_w = compute_class_weights_from_mask_files(train_ds.msk_paths, NUM_CLASSES, device)


    w_path = os.path.join(processed_dir, "train", "tile_weights.npy")
    if not os.path.exists(w_path):
        build_and_save_tile_weights(processed_dir, "train", num_classes=NUM_CLASSES)
    tile_weights = np.load(w_path)

    sampler = WeightedRandomSampler(
        weights=torch.DoubleTensor(tile_weights),
        num_samples=len(tile_weights),
        replacement=True
    )

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, sampler=sampler,
        num_workers=num_workers, pin_memory=str(device).startswith("cuda")
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=str(device).startswith("cuda")
    )

    smoke_test_loader(train_loader)

    model = ResAttnUNetDS(n_classes=NUM_CLASSES, base=base_ch).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    steps_per_epoch = math.ceil(len(train_loader) / accum_steps)
    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt, max_lr=lr, epochs=epochs, steps_per_epoch=steps_per_epoch
    )

    criterion = ComboLoss(class_w, dice_w=0.7, ce_w=0.3)

    best = -1.0
    best_is_full = False
    opt.zero_grad(set_to_none=True)

    for ep in range(epochs):
        model.train()
        run_loss = 0.0

        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {ep+1}/{epochs}")
        for i, (img, msk) in pbar:
            img = img.to(device, non_blocking=True)
            msk = msk.to(device, non_blocking=True)

            with torch.amp.autocast('cuda', enabled=amp_on):
                out0, out1, out2 = model(img)

                m1 = F.interpolate(msk.unsqueeze(1).float(), scale_factor=0.5, mode="nearest").squeeze(1).long()
                m2 = F.interpolate(msk.unsqueeze(1).float(), scale_factor=0.25, mode="nearest").squeeze(1).long()

                loss0 = criterion(out0, msk)
                loss1 = criterion(out1, m1)
                loss2 = criterion(out2, m2)

                loss = (loss0 + 0.5*loss1 + 0.25*loss2) / accum_steps

            scaler.scale(loss).backward()

            if (i + 1) % accum_steps == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sched.step()

            run_loss += loss.item() * accum_steps
            pbar.set_postfix(loss=f"{run_loss/(i+1):.4f}", lr=f"{sched.get_last_lr()[0]:.2e}")

        train_loss_ep = run_loss / max(1, len(train_loader))

        if (ep + 1) % DO_FULL_VAL_EVERY == 0:
            val_d = full_val(model, val_loader, device)
            is_full = True
            print(f"Epoch {ep+1}: train_loss={train_loss_ep:.4f}  FULL val_macroDice(ex-bg)={val_d:.4f}")
        else:
            val_d = fast_val(model, val_loader, device, FAST_VAL_BATCHES)
            is_full = False
            print(f"Epoch {ep+1}: train_loss={train_loss_ep:.4f}  fast val_macroDice(ex-bg)~={val_d:.4f} ({FAST_VAL_BATCHES} batches)")

 
        if val_d > best:
            best = val_d
            best_is_full = is_full
            torch.save(model.state_dict(), "best_resattn_unet_ds.pth")
            tag = "FULL" if is_full else "FAST"
            print(f"   saved best ({tag}): {best:.4f}")

    print("Best val dice:", best, "| based on", ("FULL" if best_is_full else "FAST"))



def build_and_save_tile_weights(processed_dir, subset="train", num_classes=5,
                                w_c4=10.0, w_c1=4.0, w_c3=2.0, w_pure_bg=0.25):
    import os, glob
    import numpy as np

    mpaths = sorted(glob.glob(os.path.join(processed_dir, subset, "masks", "*.npy")))
    weights = np.ones(len(mpaths), dtype=np.float32)

    for i, p in enumerate(mpaths):
        m = np.load(p, mmap_mode="r")  
        bc = np.bincount(m.reshape(-1), minlength=num_classes)

        w = 1.0
        if bc[4] > 0: w *= w_c4
        if bc[1] > 0: w *= w_c1
        if bc[3] > 0: w *= w_c3
        if bc[0] == m.size: w *= w_pure_bg
        weights[i] = w

    out_path = os.path.join(processed_dir, subset, "tile_weights.npy")
    np.save(out_path, weights)
    print(f"Saved {len(weights)} sampling weights -> {out_path}")
    return out_path


def split_images(raw_paths, val_split=VAL_SPLIT, seed=SEED):
    idx = list(range(len(raw_paths)))
    rng = random.Random(seed)
    rng.shuffle(idx)
    n_val = max(1, int(round(len(raw_paths) * val_split))) if len(raw_paths) > 1 else 0
    val_idx = set(idx[:n_val])
    train_paths = [raw_paths[i] for i in range(len(raw_paths)) if i not in val_idx]
    val_paths   = [raw_paths[i] for i in range(len(raw_paths)) if i in val_idx]
    return train_paths, val_paths



In [None]:
def processed_exists(processed_dir):
    return (
        len(glob.glob(os.path.join(processed_dir, "train", "images", "*.npy"))) > 0 and
        len(glob.glob(os.path.join(processed_dir, "val", "images", "*.npy"))) > 0 and
        os.path.exists(os.path.join(processed_dir, "train", "tile_weights.npy"))
    )





In [None]:
def main():
    set_seed(SEED)

    raw_paths = find_raw_files(ORIGINAL_DATA_DIR)
    if len(raw_paths) == 0:
        raise RuntimeError(
            f"No raw files found in {ORIGINAL_DATA_DIR}. Expected raw_*.tif/.tiff (and matching label_*)"
        )

    train_paths, val_paths = split_images(raw_paths, VAL_SPLIT, SEED)

    print("=== IMAGE SPLIT ===")
    print("Train raw images:", len(train_paths))
    print("Val raw images:  ", len(val_paths))
    print("Train files:", [os.path.basename(p) for p in train_paths])
    print("Val files:  ", [os.path.basename(p) for p in val_paths])

   
    if not processed_exists(PROCESSED_DIR):
        print("\n[cache miss] building tiled dataset...")
        cleanup_dir(PROCESSED_DIR)
        tile_and_save(train_paths, PROCESSED_DIR, subset="train")
        tile_and_save(val_paths,   PROCESSED_DIR, subset="val")
        build_and_save_tile_weights(PROCESSED_DIR, "train", num_classes=NUM_CLASSES)
    else:
        print("\n[cache hit] using existing tiles in:", PROCESSED_DIR)

    print("\n=== TILE COUNTS ===")
    train_tiles = len(glob.glob(os.path.join(PROCESSED_DIR, "train", "masks", "*.npy")))
    val_tiles   = len(glob.glob(os.path.join(PROCESSED_DIR, "val",   "masks", "*.npy")))
    print(f"Train tiles: {train_tiles}")
    print(f"Val tiles:   {val_tiles}")

    for split in ["train", "val"]:
        counts, freqs, total, ntiles = class_frequencies_from_tiles(PROCESSED_DIR, split, num_classes=NUM_CLASSES)
        print(f"\n=== {split.upper()} CLASS FREQUENCIES ===")
        print(f"Tiles: {ntiles} | Total pixels: {total:,}")
        for k in range(NUM_CLASSES):
            print(f"  Class {k}: {counts[k]:,} px  ({freqs[k]*100:.3f}%)")

 
    train(
        processed_dir=PROCESSED_DIR,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        accum_steps=ACCUM_STEPS,
        lr=LR,
        base_ch=BASE_CH,
        use_amp=USE_AMP,
        device=DEVICE,
        num_workers=0 
    )


if __name__ == "__main__":
    main()


=== IMAGE SPLIT ===
Train raw images: 9
Val raw images:   2
Train files: ['raw_01.tiff', 'raw_02.tiff', 'raw_04.tiff', 'raw_06.tiff', 'raw_07.tiff', 'raw_08.tiff', 'raw_10.tiff', 'raw_11.tiff', 'raw_12.tiff']
Val files:   ['raw_05.tiff', 'raw_09.tiff']

[cache hit] using existing tiles in: project/processed_tiles

=== TILE COUNTS ===
Train tiles: 2226
Val tiles:   782

=== TRAIN CLASS FREQUENCIES ===
Tiles: 2226 | Total pixels: 583,532,544
  Class 0: 279,387,728 px  (47.879%)
  Class 1: 28,325,689 px  (4.854%)
  Class 2: 165,279,928 px  (28.324%)
  Class 3: 84,708,720 px  (14.517%)
  Class 4: 25,830,479 px  (4.427%)

=== VAL CLASS FREQUENCIES ===
Tiles: 782 | Total pixels: 204,996,608
  Class 0: 101,261,018 px  (49.396%)
  Class 1: 1,042,584 px  (0.509%)
  Class 2: 61,640,157 px  (30.069%)
  Class 3: 33,805,780 px  (16.491%)
  Class 4: 7,247,069 px  (3.535%)


  msk = torch.from_numpy(msk).long()


First batch fetched in 0.01 sec | img torch.Size([4, 1, 512, 512]) | msk torch.Size([4, 512, 512])


Epoch 1/30: 100%|██████████| 557/557 [02:44<00:00,  3.39it/s, loss=1.4530, lr=2.06e-05]


Epoch 1: train_loss=1.4530  fast val_macroDice(ex-bg)~=0.3379 (50 batches)
  ✅ saved best (FAST): 0.3379


Epoch 2/30: 100%|██████████| 557/557 [02:03<00:00,  4.50it/s, loss=1.0944, lr=4.53e-05]


Epoch 2: train_loss=1.0944  fast val_macroDice(ex-bg)~=0.3848 (50 batches)
  ✅ saved best (FAST): 0.3848


Epoch 3/30: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s, loss=0.9069, lr=8.32e-05]


Epoch 3: train_loss=0.9069  fast val_macroDice(ex-bg)~=0.4433 (50 batches)
  ✅ saved best (FAST): 0.4433


Epoch 4/30: 100%|██████████| 557/557 [02:02<00:00,  4.53it/s, loss=0.7906, lr=1.30e-04]


Epoch 4: train_loss=0.7906  fast val_macroDice(ex-bg)~=0.4979 (50 batches)
  ✅ saved best (FAST): 0.4979


Epoch 5/30: 100%|██████████| 557/557 [02:02<00:00,  4.54it/s, loss=0.7518, lr=1.79e-04]


Epoch 5: train_loss=0.7518  fast val_macroDice(ex-bg)~=0.5401 (50 batches)
  ✅ saved best (FAST): 0.5401


Epoch 6/30: 100%|██████████| 557/557 [02:02<00:00,  4.54it/s, loss=0.6400, lr=2.26e-04]


Epoch 6: train_loss=0.6400  fast val_macroDice(ex-bg)~=0.5340 (50 batches)


Epoch 7/30: 100%|██████████| 557/557 [02:03<00:00,  4.50it/s, loss=0.5410, lr=2.65e-04]


Epoch 7: train_loss=0.5410  fast val_macroDice(ex-bg)~=0.4753 (50 batches)


Epoch 8/30: 100%|██████████| 557/557 [02:06<00:00,  4.42it/s, loss=0.4983, lr=2.90e-04]


Epoch 8: train_loss=0.4983  fast val_macroDice(ex-bg)~=0.5418 (50 batches)
  ✅ saved best (FAST): 0.5418


Epoch 9/30: 100%|██████████| 557/557 [02:03<00:00,  4.50it/s, loss=0.4729, lr=3.00e-04]


Epoch 9: train_loss=0.4729  fast val_macroDice(ex-bg)~=0.5449 (50 batches)
  ✅ saved best (FAST): 0.5449


Epoch 10/30: 100%|██████████| 557/557 [02:01<00:00,  4.58it/s, loss=0.4158, lr=2.99e-04]


Epoch 10: train_loss=0.4158  FULL val_macroDice(ex-bg)=0.7011
  ✅ saved best (FULL): 0.7011


Epoch 11/30: 100%|██████████| 557/557 [02:02<00:00,  4.53it/s, loss=0.3772, lr=2.94e-04]


Epoch 11: train_loss=0.3772  fast val_macroDice(ex-bg)~=0.5322 (50 batches)


Epoch 12/30: 100%|██████████| 557/557 [02:00<00:00,  4.61it/s, loss=0.3555, lr=2.86e-04]


Epoch 12: train_loss=0.3555  fast val_macroDice(ex-bg)~=0.5879 (50 batches)


Epoch 13/30: 100%|██████████| 557/557 [02:01<00:00,  4.58it/s, loss=0.3308, lr=2.75e-04]


Epoch 13: train_loss=0.3308  fast val_macroDice(ex-bg)~=0.5645 (50 batches)


Epoch 14/30: 100%|██████████| 557/557 [02:01<00:00,  4.58it/s, loss=0.3287, lr=2.61e-04]


Epoch 14: train_loss=0.3287  fast val_macroDice(ex-bg)~=0.6581 (50 batches)


Epoch 15/30: 100%|██████████| 557/557 [02:02<00:00,  4.56it/s, loss=0.3090, lr=2.45e-04]


Epoch 15: train_loss=0.3090  fast val_macroDice(ex-bg)~=0.6769 (50 batches)


Epoch 16/30: 100%|██████████| 557/557 [02:01<00:00,  4.57it/s, loss=0.2865, lr=2.27e-04]


Epoch 16: train_loss=0.2865  fast val_macroDice(ex-bg)~=0.6734 (50 batches)


Epoch 17/30: 100%|██████████| 557/557 [02:02<00:00,  4.56it/s, loss=0.2920, lr=2.07e-04]


Epoch 17: train_loss=0.2920  fast val_macroDice(ex-bg)~=0.6327 (50 batches)


Epoch 18/30: 100%|██████████| 557/557 [02:01<00:00,  4.58it/s, loss=0.2850, lr=1.86e-04]


Epoch 18: train_loss=0.2850  fast val_macroDice(ex-bg)~=0.6718 (50 batches)


Epoch 19/30: 100%|██████████| 557/557 [02:02<00:00,  4.57it/s, loss=0.2596, lr=1.64e-04]


Epoch 19: train_loss=0.2596  fast val_macroDice(ex-bg)~=0.7071 (50 batches)
  ✅ saved best (FAST): 0.7071


Epoch 20/30: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s, loss=0.2641, lr=1.42e-04]


Epoch 20: train_loss=0.2641  FULL val_macroDice(ex-bg)=0.6772


Epoch 21/30: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s, loss=0.2658, lr=1.20e-04]


Epoch 21: train_loss=0.2658  fast val_macroDice(ex-bg)~=0.7152 (50 batches)
  ✅ saved best (FAST): 0.7152


Epoch 22/30: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s, loss=0.2582, lr=9.83e-05]


Epoch 22: train_loss=0.2582  fast val_macroDice(ex-bg)~=0.7106 (50 batches)


Epoch 23/30: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s, loss=0.2395, lr=7.81e-05]


Epoch 23: train_loss=0.2395  fast val_macroDice(ex-bg)~=0.7178 (50 batches)
  ✅ saved best (FAST): 0.7178


Epoch 24/30: 100%|██████████| 557/557 [02:05<00:00,  4.44it/s, loss=0.2377, lr=5.94e-05]


Epoch 24: train_loss=0.2377  fast val_macroDice(ex-bg)~=0.7869 (50 batches)
  ✅ saved best (FAST): 0.7869


Epoch 25/30: 100%|██████████| 557/557 [02:10<00:00,  4.26it/s, loss=0.2411, lr=4.27e-05]


Epoch 25: train_loss=0.2411  fast val_macroDice(ex-bg)~=0.7746 (50 batches)


Epoch 26/30: 100%|██████████| 557/557 [02:14<00:00,  4.16it/s, loss=0.2377, lr=2.84e-05]


Epoch 26: train_loss=0.2377  fast val_macroDice(ex-bg)~=0.7703 (50 batches)


Epoch 27/30: 100%|██████████| 557/557 [02:08<00:00,  4.33it/s, loss=0.2360, lr=1.67e-05]


Epoch 27: train_loss=0.2360  fast val_macroDice(ex-bg)~=0.7912 (50 batches)
  ✅ saved best (FAST): 0.7912


Epoch 28/30: 100%|██████████| 557/557 [02:18<00:00,  4.03it/s, loss=0.2381, lr=8.00e-06]


Epoch 28: train_loss=0.2381  fast val_macroDice(ex-bg)~=0.7849 (50 batches)


Epoch 29/30: 100%|██████████| 557/557 [02:05<00:00,  4.45it/s, loss=0.2359, lr=2.41e-06]


Epoch 29: train_loss=0.2359  fast val_macroDice(ex-bg)~=0.7799 (50 batches)


Epoch 30/30: 100%|██████████| 557/557 [02:00<00:00,  4.61it/s, loss=0.2250, lr=7.32e-08]


Epoch 30: train_loss=0.2250  FULL val_macroDice(ex-bg)=0.7978
  ✅ saved best (FULL): 0.7978
Best val dice: 0.7978042656478336 | based on FULL


In [None]:
import csv
import matplotlib.pyplot as plt


CLASS_COLORS = np.array([
    [0,   0,   0],    
    [255, 0,   0],    
    [0,   255, 0],    
    [0,   0,   255],  
    [255, 255, 0],    
], dtype=np.uint8)

def colorize_mask(mask_2d):

    mask_2d = mask_2d.astype(np.int64)
    return CLASS_COLORS[np.clip(mask_2d, 0, len(CLASS_COLORS)-1)]

@torch.no_grad()
def save_qualitative_batch(model, val_loader, device, out_dir, epoch, max_items=4):
    
    os.makedirs(out_dir, exist_ok=True)
    model.eval()

    img, msk = next(iter(val_loader))
    img = img.to(device, non_blocking=True)
    msk = msk.to(device, non_blocking=True)

    out0, _, _ = model(img)
    probs = torch.softmax(out0, dim=1)          
    pred = probs.argmax(dim=1)                  

    B = min(img.shape[0], max_items)

    for i in range(B):
        raw = img[i,0].detach().cpu().numpy()
        gt  = msk[i].detach().cpu().numpy()
        pr  = pred[i].detach().cpu().numpy()

        gt_rgb = colorize_mask(gt)
        pr_rgb = colorize_mask(pr)

        err = (pr != gt).astype(np.uint8) * 255

 
        fig = plt.figure(figsize=(14, 8))
        gs = fig.add_gridspec(2, 4)

        ax = fig.add_subplot(gs[0,0]); ax.imshow(raw, cmap="gray"); ax.set_title("Raw"); ax.axis("off")
        ax = fig.add_subplot(gs[0,1]); ax.imshow(gt_rgb); ax.set_title("GT mask"); ax.axis("off")
        ax = fig.add_subplot(gs[0,2]); ax.imshow(pr_rgb); ax.set_title("Pred mask"); ax.axis("off")
        ax = fig.add_subplot(gs[0,3]); ax.imshow(err, cmap="gray"); ax.set_title("Error (pred!=gt)"); ax.axis("off")

        for c in range(1, NUM_CLASSES):
            ax = fig.add_subplot(gs[1, c-1])
            hm = probs[i, c].detach().cpu().numpy()
            ax.imshow(hm, vmin=0.0, vmax=1.0)
            ax.set_title(f"P(class {c})")
            ax.axis("off")

        fig.suptitle(f"Epoch {epoch} | val sample {i}", fontsize=12)
        fig.tight_layout()
        fig.savefig(os.path.join(out_dir, f"epoch_{epoch:03d}_val_{i}.png"), dpi=150)
        plt.close(fig)

def append_csv_row(csv_path, row_dict, header=None):
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)
    file_exists = os.path.exists(csv_path)
    with open(csv_path, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=header or list(row_dict.keys()))
        if not file_exists:
            w.writeheader()
        w.writerow(row_dict)

def plot_training_curves(log_csv, out_path_prefix):

    import pandas as pd
    df = pd.read_csv(log_csv)


    plt.figure(figsize=(10,4))
    plt.plot(df["epoch"], df["train_loss"], label="train_loss")
    plt.twinx()
    plt.plot(df["epoch"], df["lr"], label="lr")
    plt.title("Train loss + LR")
    plt.savefig(out_path_prefix + "_loss_lr.png", dpi=150)
    plt.close()


    plt.figure(figsize=(10,4))
    plt.plot(df["epoch"], df["val_dice"], label="val_dice")
    plt.title("Val macroDice (ex-bg)")
    plt.savefig(out_path_prefix + "_dice.png", dpi=150)
    plt.close()


In [None]:
def train(
    processed_dir,
    epochs=30,
    batch_size=4,
    accum_steps=4,
    lr=3e-4,
    base_ch=48,
    use_amp=True,
    device="cuda",
    num_workers=0,
):
   
    DO_FULL_VAL_EVERY = 10
    FAST_VAL_BATCHES  = 50
    FORCE_FULL_AT     = 30

    SAVE_VIS_EVERY    = 5   
    MAX_VIS_ITEMS     = 4   

    LOG_DIR = os.path.join(processed_dir, "runs", "resattn_unet")
    VIS_DIR = os.path.join(LOG_DIR, "viz")
    LOG_CSV = os.path.join(LOG_DIR, "metrics.csv")
    os.makedirs(LOG_DIR, exist_ok=True)

    @torch.no_grad()
    def fast_val(model, val_loader, device, n_batches):
        model.eval()
        dices = []
        for b, (img, msk) in enumerate(val_loader):
            if b >= n_batches:
                break
            img = img.to(device, non_blocking=True)
            msk = msk.to(device, non_blocking=True)
            out0, _, _ = model(img)
            pred = out0.argmax(dim=1)
            dices.append(macro_dice_ex_bg(pred, msk, n_classes=NUM_CLASSES))
        return float(np.mean(dices)) if dices else 0.0

    @torch.no_grad()
    def full_val(model, val_loader, device):
        model.eval()
        dices = []
        for img, msk in val_loader:
            img = img.to(device, non_blocking=True)
            msk = msk.to(device, non_blocking=True)
            out0, _, _ = model(img)
            pred = out0.argmax(dim=1)
            dices.append(macro_dice_ex_bg(pred, msk, n_classes=NUM_CLASSES))
        return float(np.mean(dices)) if dices else 0.0

    torch.backends.cudnn.benchmark = True
    amp_on = (use_amp and str(device).startswith("cuda"))
    scaler = torch.amp.GradScaler('cuda', enabled=amp_on)

    train_ds = TiledNPYDataset(processed_dir, "train")
    val_ds   = TiledNPYDataset(processed_dir, "val")

    class_w = compute_class_weights_from_mask_files(train_ds.msk_paths, NUM_CLASSES, device)

    w_path = os.path.join(processed_dir, "train", "tile_weights.npy")
    if not os.path.exists(w_path):
        build_and_save_tile_weights(processed_dir, "train", num_classes=NUM_CLASSES)
    tile_weights = np.load(w_path)

    sampler = WeightedRandomSampler(
        weights=torch.DoubleTensor(tile_weights),
        num_samples=len(tile_weights),
        replacement=True
    )

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, sampler=sampler,
        num_workers=num_workers, pin_memory=str(device).startswith("cuda")
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=str(device).startswith("cuda")
    )

    smoke_test_loader(train_loader)

    model = ResAttnUNetDS(n_classes=NUM_CLASSES, base=base_ch).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    steps_per_epoch = math.ceil(len(train_loader) / accum_steps)
    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt, max_lr=lr, epochs=epochs, steps_per_epoch=steps_per_epoch
    )

    criterion = ComboLoss(class_w, dice_w=0.7, ce_w=0.3)

    best = -1.0
    best_is_full = False
    opt.zero_grad(set_to_none=True)

    
    header = ["epoch", "train_loss", "val_dice", "val_mode", "lr"]

    for ep in range(epochs):
        model.train()
        run_loss = 0.0

        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {ep+1}/{epochs}")
        for i, (img, msk) in pbar:
            img = img.to(device, non_blocking=True)
            msk = msk.to(device, non_blocking=True)

            with torch.amp.autocast('cuda', enabled=amp_on):
                out0, out1, out2 = model(img)

                m1 = F.interpolate(msk.unsqueeze(1).float(), scale_factor=0.5, mode="nearest").squeeze(1).long()
                m2 = F.interpolate(msk.unsqueeze(1).float(), scale_factor=0.25, mode="nearest").squeeze(1).long()

                loss0 = criterion(out0, msk)
                loss1 = criterion(out1, m1)
                loss2 = criterion(out2, m2)

                loss = (loss0 + 0.5*loss1 + 0.25*loss2) / accum_steps

            scaler.scale(loss).backward()

            if (i + 1) % accum_steps == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sched.step()

            run_loss += loss.item() * accum_steps
            pbar.set_postfix(loss=f"{run_loss/(i+1):.4f}", lr=f"{sched.get_last_lr()[0]:.2e}")

        train_loss_ep = run_loss / max(1, len(train_loader))
        cur_lr = float(sched.get_last_lr()[0])

     
        if (ep + 1) == FORCE_FULL_AT:
            val_d = full_val(model, val_loader, device)
            is_full = True
            mode = "FORCED_FULL"
            print(f"Epoch {ep+1}: train_loss={train_loss_ep:.4f}  FORCED FULL val_macroDice(ex-bg)={val_d:.4f}")
        elif (ep + 1) % DO_FULL_VAL_EVERY == 0:
            val_d = full_val(model, val_loader, device)
            is_full = True
            mode = "FULL"
            print(f"Epoch {ep+1}: train_loss={train_loss_ep:.4f}  FULL val_macroDice(ex-bg)={val_d:.4f}")
        else:
            val_d = fast_val(model, val_loader, device, FAST_VAL_BATCHES)
            is_full = False
            mode = "FAST"
            print(f"Epoch {ep+1}: train_loss={train_loss_ep:.4f}  fast val_macroDice(ex-bg)~={val_d:.4f} ({FAST_VAL_BATCHES} batches)")

        append_csv_row(LOG_CSV, {
            "epoch": ep+1,
            "train_loss": train_loss_ep,
            "val_dice": val_d,
            "val_mode": mode,
            "lr": cur_lr
        }, header=header)


        if val_d > best:
            best = val_d
            best_is_full = is_full
            torch.save(model.state_dict(), "best_resattn_unet_ds.pth")
            tag = "FULL" if is_full else "FAST"
            print(f"   saved best ({tag}): {best:.4f}")

        if (ep + 1) % SAVE_VIS_EVERY == 0 or (ep + 1) == 1 or (ep + 1) == FORCE_FULL_AT:
            save_qualitative_batch(model, val_loader, device, VIS_DIR, epoch=ep+1, max_items=MAX_VIS_ITEMS)
  
            try:
                plot_training_curves(LOG_CSV, os.path.join(LOG_DIR, "curves"))
            except Exception as e:
                print("Plotting failed (non-fatal):", e)

    print("Best val dice:", best, "| based on", ("FULL" if best_is_full else "FAST"))
    print("Logs saved to:", LOG_CSV)
    print("Viz saved to:", VIS_DIR)


In [None]:
@torch.no_grad()
def per_class_dice(pred, tgt, n_classes=NUM_CLASSES, eps=1e-6):
  
    dices = np.zeros(n_classes, dtype=np.float64)
    for c in range(n_classes):
        p = (pred == c)
        t = (tgt == c)
        inter = (p & t).sum()
        denom = p.sum() + t.sum()
        dices[c] = (2.0 * inter + eps) / (denom + eps)
    return dices

@torch.no_grad()
def confusion_matrix_fast(pred, tgt, n_classes=NUM_CLASSES):
  
    pred = pred.reshape(-1)
    tgt  = tgt.reshape(-1)
    k = (tgt >= 0) & (tgt < n_classes)
    idx = n_classes * tgt[k] + pred[k]
    cm = np.bincount(idx, minlength=n_classes*n_classes).reshape(n_classes, n_classes)
    return cm

@torch.no_grad()
def evaluate_full_val(model, val_loader, device, n_classes=NUM_CLASSES):
    model.eval()
    dice_list = []
    cm_all = np.zeros((n_classes, n_classes), dtype=np.int64)

    for img, msk in tqdm(val_loader, desc="FULL EVAL", leave=False):
        img = img.to(device, non_blocking=True)
        msk = msk.to(device, non_blocking=True)

        out0, _, _ = model(img)                 
        pred = out0.argmax(dim=1)              

        pred_np = pred.detach().cpu().numpy().astype(np.int64)
        msk_np  = msk.detach().cpu().numpy().astype(np.int64)

        for b in range(pred_np.shape[0]):
            dices = per_class_dice(pred_np[b], msk_np[b], n_classes=n_classes)
            dice_list.append(dices)
            cm_all += confusion_matrix_fast(pred_np[b], msk_np[b], n_classes=n_classes)

    dice_mean = np.mean(np.stack(dice_list, axis=0), axis=0) if dice_list else np.zeros(n_classes)
    return dice_mean, cm_all

def save_confusion_heatmap(cm, out_path, title="Confusion Matrix"):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(6,5))
    plt.imshow(cm)
    plt.title(title)
    plt.xlabel("Pred")
    plt.ylabel("GT")
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()

@torch.no_grad()
def save_predictions_gallery(model, val_ds, device, out_dir, k=12, seed=0):

    os.makedirs(out_dir, exist_ok=True)
    rng = np.random.RandomState(seed)
    idxs = rng.choice(len(val_ds), size=min(k, len(val_ds)), replace=False)

    model.eval()
    for j, idx in enumerate(idxs):
        img, msk = val_ds[idx]
        img_b = img.unsqueeze(0).to(device)
        msk_b = msk.unsqueeze(0).to(device)

        out0, _, _ = model(img_b)
        probs = torch.softmax(out0, dim=1)[0]         
        pred  = probs.argmax(dim=0)                

        raw = img[0].numpy()
        gt  = msk.numpy().astype(np.int64)
        pr  = pred.detach().cpu().numpy().astype(np.int64)

        gt_rgb = colorize_mask(gt)
        pr_rgb = colorize_mask(pr)
        err = (pr != gt).astype(np.uint8) * 255

        import matplotlib.pyplot as plt
        fig = plt.figure(figsize=(14, 8))
        gs = fig.add_gridspec(2, 4)

        ax = fig.add_subplot(gs[0,0]); ax.imshow(raw, cmap="gray"); ax.set_title("Raw"); ax.axis("off")
        ax = fig.add_subplot(gs[0,1]); ax.imshow(gt_rgb); ax.set_title("GT mask"); ax.axis("off")
        ax = fig.add_subplot(gs[0,2]); ax.imshow(pr_rgb); ax.set_title("Pred mask"); ax.axis("off")
        ax = fig.add_subplot(gs[0,3]); ax.imshow(err, cmap="gray"); ax.set_title("Error"); ax.axis("off")

        for c in range(1, NUM_CLASSES):
            ax = fig.add_subplot(gs[1, c-1])
            hm = probs[c].detach().cpu().numpy()
            ax.imshow(hm, vmin=0.0, vmax=1.0)
            ax.set_title(f"P(class {c})")
            ax.axis("off")

        fig.suptitle(f"val tile idx={idx}", fontsize=12)
        fig.tight_layout()
        fig.savefig(os.path.join(out_dir, f"val_pred_{j:02d}_idx{idx}.png"), dpi=150)
        plt.close(fig)

def evaluate_and_visualize(
    processed_dir=PROCESSED_DIR,
    ckpt_path="best_resattn_unet_ds.pth",
    base_ch=BASE_CH,
    batch_size=4,
    device=DEVICE,
    num_workers=0,
    out_dir=None,
    k_examples=12
):
    if out_dir is None:
        out_dir = os.path.join(processed_dir, "eval_report")

    os.makedirs(out_dir, exist_ok=True)

    val_ds = TiledNPYDataset(processed_dir, "val")
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=str(device).startswith("cuda"))

    model = ResAttnUNetDS(n_classes=NUM_CLASSES, base=base_ch).to(device)
    sd = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(sd)


    dice_mean, cm = evaluate_full_val(model, val_loader, device, n_classes=NUM_CLASSES)

    print("\n=== FULL VAL REPORT (per-class Dice) ===")
    for c in range(NUM_CLASSES):
        print(f"Class {c}: Dice = {dice_mean[c]:.4f}")
    macro_ex_bg = float(np.mean(dice_mean[1:])) if NUM_CLASSES > 1 else float(dice_mean[0])
    print(f"\nMacro Dice (exclude bg classes 1..{NUM_CLASSES-1}): {macro_ex_bg:.4f}")


    save_confusion_heatmap(cm, os.path.join(out_dir, "confusion_matrix.png"),
                           title="Confusion Matrix (rows=GT, cols=Pred)")


    save_predictions_gallery(model, val_ds, device, os.path.join(out_dir, "examples"),
                             k=k_examples, seed=SEED)

    print("\nSaved report to:", out_dir)
    print(" - confusion_matrix.png")
    print(" - examples/val_pred_*.png")


In [7]:
if __name__ == "__main__":
    evaluate_and_visualize(
        processed_dir=PROCESSED_DIR,
        ckpt_path="best_resattn_unet_ds.pth",
        base_ch=BASE_CH,
        batch_size=4,
        device=DEVICE,
        num_workers=0,
        k_examples=12
    )


                                                            


=== FULL VAL REPORT (per-class Dice) ===
Class 0: Dice = 0.8436
Class 1: Dice = 0.8705
Class 2: Dice = 0.8867
Class 3: Dice = 0.7751
Class 4: Dice = 0.7953

Macro Dice (exclude bg classes 1..4): 0.8319

Saved report to: project/processed_tiles\eval_report
 - confusion_matrix.png
 - examples/val_pred_*.png
