# HybridNetv7

**Version**: 7.0.0<br>
**Date**: Mon Aug 18th, 2025<br>
**Author**: Jakob Balkovec<br>

**Model Description**:

```txt
From latest run:

=== TEST SUMMARY ===
Dice mean: 0.3749 | IoU mean: 0.3060
Dice per class [MA, HE, EX, SE]: ['0.1362', '0.4267', '0.4164', '0.5201']
IoU per class  [MA, HE, EX, SE]: ['0.0881', '0.3237', '0.3286', '0.4835']
Pearson per class [MA, HE, EX, SE]: ['0.1509', '0.3786', '0.2981', '0.1454']
```

---
## Packages

In [None]:
!pip install scikit-multilearn

Collecting scikit-multilearn
  Downloading scikit_multilearn-0.2.0-py3-none-any.whl.metadata (6.0 kB)
Downloading scikit_multilearn-0.2.0-py3-none-any.whl (89 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/89.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.4/89.4 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: scikit-multilearn
Successfully installed scikit-multilearn-0.2.0


## Imports

In [None]:
# standard library
import os
import random
import importlib
import math
import time
import csv
import gc
from datetime import datetime
import json
import shutil

# TP
import numpy as np
import pandas as pd
from PIL import Image
from google.colab import drive
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
import albumentations as A
from albumentations.core.transforms_interface import ImageOnlyTransform
from albumentations.pytorch import ToTensorV2
from scipy.stats import pearsonr

# PyT
import torch
from torch import amp
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
from torch.utils.data import WeightedRandomSampler
from torchvision import models
from torch.special import logit

# TV
from torchvision import transforms, models
from torchvision.transforms import InterpolationMode

# SKL
from skmultilearn.model_selection import iterative_train_test_split

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"\nusing: {device}")
assert device.type == "cuda", "CUDA device not available || still in CPU mode"

KeyboardInterrupt: 


## Configuration

In [None]:
config = {
    "MASTER_DF_PATH": "/content/patches/patches/master_df.pkl",
    "IMAGE_ROOT": "/content/images/Original_Images",
    "PATCHES_ROOT": "/content/patches/",
    "MASK_DIR": "/content/masks",
    "OUTPUT_DIR": "/content/output",
    "DRIVE_MODEL_DIR": "/content/drive/MyDrive/models",
    "DRIVE_MODEL_PATH": "/content/drive/MyDrive/models/hybridnet_v7.pth",
    "SEED": 42,
    "IMG_SIZE": 512,
    "PATCH_SIZE": 128,
    "BATCH_SIZE": 1,
    "EPOCHS": 25,
    "VAL_INTERVAL": 1,

    "FIXED_PATCHES_VAL": 96,
    "FIXED_PATCHES_TRAIN": 48,
    "PATCH_CLASS_QUOTA": [0.35, 0.25, 0.25, 0.15],  # MA, HE, EX, SE within positives

    "ACCUM_STEPS": 4,
    "NUM_WORKERS": max(0, min(4, (os.cpu_count() or 2)//2)),
    "NUM_CLASSES": 4,

    "LOSS_TYPE": "ft_w_bce",  # "dice" | "dice_bce" | "tversky" | "dice_bce_w" | "ft_w_bce" | "focal_tversky"
    "BCE_WEIGHT": 0.7,

    # ===== FINE TUNE =======
    "TV_ALPHA": 0.3,
    "TV_BETA": 0.7,
    "FT_GAMMA": 0.75, # from 0.75
    "FT_LAMBDA": 0.7,

    "PATCH_POS_RATIO": 0.7, # from 0.5

    "FT_CLASS_WEIGHTS": [0.25, 1.0, 1.0, 1.0], # [MA, HE, EX, SE]
    "FT_CLAMP_PROBS": True,  # set to True
    "FT_IGNORE_EMPTY": True, # set to True
    "FT_EPS": 1e-6,
    "FT_BCE_ALPHA": 0.05,
    # ===== FINE TUNE =======

    "POSW_CAP": 12.0,
    "CLASSW_POWER": 1.0,
    "CLASSW_CAP": 4.0,

    "WARMUP_STEPS": 500,

    "COV_RATIOS": [0.01, 0.01, 0.01, 0.01],
    "POS_WEIGHT": [1.0, 1.0, 1.0, 1.0],
    "CLASS_WEIGHT": [1.0, 1.0, 1.0, 1.0],

    "LR": 1e-4,
    "LR_MIN": 1e-6,
    "WEIGHT_DECAY": 1e-4,
    "LR_FACTOR": 0.5,
    "LR_PATIENCE": 2,
    "MAX_NORM": 1.0,
    "USE_EMA": True,

    "PATCH_CHUNK": 25,

    "USE_GLOBAL_FILM": False,
    "AUX_WEIGHTS": [1.0, 0.4, 0.2],

    "MODE": "blend", # replace (moderate emph) | replicate (strong emph) | blend (default)
    "POST_MIN_AREA_PER_CLASS": [0, 20, 10, 20],

    "EARLY_STOP_MONITOR": "val_dice_mean_ema", # | val_dice_mean
    "EARLY_STOP_MODE": "max",
    "EARLY_STOP_PATIENCE": 8,
    "EARLY_STOP_MIN_DELTA": 0.002,
    "MAX_TIME_MIN": 120,
    "LOSS_SPIKE_FACTOR": 3.0,

    "TRAIN_FRAC": 0.70,
    "VAL_FRAC": 0.15,
    "TEST_FRAC": 0.15,
    "FORCE_RESPLIT": False,

    "RUN_MODE": "dbg", # | ""

    "THRESH_PER_CLASS": [0.30, 0.45, 0.45, 0.45],  # [MA, HE, EX, SE]: from [0.90, 0.90, 0.80, 0.90]
}

mp_config = {
    "NUM_WORKERS": max(0, min(4, (os.cpu_count() or 2)//2)),  # heuristic
    "PERSISTENT_WORKERS": True,
    "PREFETCH_FACTOR": 2,
    "PIN_MEMORY": (device.type == 'cuda'),
    "TIMEOUT": 0,
    "MP_CONTEXT": "fork"
}


## History Tracking

In [None]:
HISTORY = {
    "train_loss": [],
    "train_dice_mean": [],
    "train_dice_per_class": [],
    "train_iou_mean": [],
    "train_iou_per_class": [],
    "train_pearson_mean": [],
    "train_pearson_per_class": [],

    "val_loss": [],
    "val_dice_mean": [],
    "val_dice_per_class": [],
    "val_iou_mean": [],
    "val_iou_per_class": [],
    "val_pearson_mean": [],
    "val_pearson_per_class": [],

    "val_dice_mean_ema": [],
}

## Environment

In [None]:
gc.collect()
torch.cuda.empty_cache()

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

BOOSTS = {
    "MA": 4.0,
    "EX": 1.5,
    "HE": 1.2,
}

LESION_COLORS = {
    "MA": (1.0, 0.0, 0.0),   # red
    "HE": (0.0, 1.0, 0.0),   # green
    "EX": (0.0, 0.0, 1.0),   # blue
    "SE": (1.0, 1.0, 0.0),   # yellow
}

LESION_NAMES = list(LESION_COLORS.keys())

torch.manual_seed(config["SEED"])
np.random.seed(config["SEED"])
random.seed(config["SEED"])

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

os.makedirs(config["OUTPUT_DIR"], exist_ok=True)
os.makedirs(config["DRIVE_MODEL_DIR"], exist_ok=True)

METRICS_DIR = os.path.join(config["OUTPUT_DIR"], "metrics_logs")
VISUALS_DIR = os.path.join(config["OUTPUT_DIR"], "visuals")
PREDICTIONS_DIR = os.path.join(config["OUTPUT_DIR"], "predictions")

for d in [METRICS_DIR, VISUALS_DIR, PREDICTIONS_DIR]:
    os.makedirs(d, exist_ok=True)

SPLIT_CSV = os.path.join(METRICS_DIR, "image_split.csv")
SPLIT_SUMMARY_CSV = os.path.join(METRICS_DIR, "split_summary.csv")
os.makedirs(os.path.dirname(SPLIT_CSV), exist_ok=True)

NameError: name 'gc' is not defined

## Multiprocessing Context

In [None]:
class NopDS(Dataset):
    def __len__(self): return 64
    def __getitem__(self, i): return torch.zeros(3,16,16), i

def try_ctx(ctx_name):
    print(f"\n=== testing multiprocessing_context='{ctx_name}' ===")
    ctx = mp.get_context(ctx_name)
    loader = DataLoader(
        NopDS(), batch_size=8, shuffle=False,
        num_workers=2,
        multiprocessing_context=ctx,
        persistent_workers=False,
        pin_memory=False,
        prefetch_factor=2,
        timeout=0,
    )
    b = next(iter(loader))
    print("OK:", b[0].shape, b[1][:4])

for name in ["forkserver", "spawn", "fork"]:
    try:
        try_ctx(name)
    except Exception as e:
        print(f"FAILED ({name}): {type(e).__name__}: {e}")

try:
    mp.set_sharing_strategy("file_system")
except Exception as e:
    print("set_sharing_strategy:", e)

try:
    mp.set_start_method("fork", force=True) # or forkserver
    importlib.reload(torch.utils.data)
except RuntimeError as e:
    print("set_start_method:", e)

print("start_method:    ", mp.get_start_method(allow_none=True))
print("sharing_strategy:", mp.get_sharing_strategy())


=== testing multiprocessing_context='forkserver' ===
FAILED (forkserver): RuntimeError: DataLoader worker (pid(s) 1269, 1270) exited unexpectedly

=== testing multiprocessing_context='spawn' ===
FAILED (spawn): RuntimeError: DataLoader worker (pid(s) 1339, 1340) exited unexpectedly

=== testing multiprocessing_context='fork' ===
OK: torch.Size([8, 3, 16, 16]) tensor([0, 1, 2, 3])


## Drive

In [None]:
if os.path.isdir('/content/drive') and os.listdir('/content/drive'):
    print("non-empty mountpoint")

    # unmount
    if os.path.ismount('/content/drive'):
        drive.flush_and_unmount()
        print("drive unmounted.")

    !rm -rf /content/drive/*

# mount
drive.mount('/content/drive', force_remount=True)

non-empty mountpoint
Mounted at /content/drive


## Data

In [None]:
# extract "pathces.7z"
!7z x "/content/drive/MyDrive/patches.7z" -o"/content/patches" -y

# TB REMOVED
# passwd: DRpatchesWedJul9

# unzip "images.zip"
!unzip "/content/drive/MyDrive/images.zip" -d /content/images

# unzip "masks.zip"
!unzip "/content/drive/MyDrive/masks.zip" -d /content/masks

[1;30;43mIzpis pretočnega predvajanja je skrajšan na toliko zadnjih vrstic: 5000.[0m
  inflating: /content/masks/Hemorrhage_Masks/0921_1.png  
  inflating: /content/masks/Hemorrhage_Masks/1080_3.png  
  inflating: /content/masks/Hemorrhage_Masks/1226_1.png  
  inflating: /content/masks/Hemorrhage_Masks/0226_1.png  
  inflating: /content/masks/Hemorrhage_Masks/0263_1.png  
  inflating: /content/masks/Hemorrhage_Masks/0999_1.png  
  inflating: /content/masks/Hemorrhage_Masks/1750_3.png  
  inflating: /content/masks/Hemorrhage_Masks/0803_3.png  
  inflating: /content/masks/Hemorrhage_Masks/1278_2.png  
  inflating: /content/masks/Hemorrhage_Masks/1491_3.png  
  inflating: /content/masks/Hemorrhage_Masks/1573_1.png  
  inflating: /content/masks/Hemorrhage_Masks/0728_3.png  
  inflating: /content/masks/Hemorrhage_Masks/0304_3.png  
  inflating: /content/masks/Hemorrhage_Masks/0341_3.png  
  inflating: /content/masks/Hemorrhage_Masks/0959_1.png  
  inflating: /content/masks/Hemorrhage_Mask

## Frame

In [None]:
patch_df = pd.read_pickle(config["MASTER_DF_PATH"])
patch_df.head(n=2)

Unnamed: 0,image_id,patch_id,file_path,filter_tag,coordinates,center,label_vector
0,0000_1,0000_1_0450_0194,patches/0000_1/all/0000_1_0450_0194.png,healthy,"{'top-left': (450, 194), 'top-right': (577, 19...","{'x': 450, 'y': 194}","[0, 0, 0, 0]"
1,0000_1,0000_1_0278_0279,patches/0000_1/all/0000_1_0278_0279.png,healthy,"{'top-left': (278, 279), 'top-right': (405, 27...","{'x': 278, 'y': 279}","[0, 0, 0, 0]"


## Utility Functions `SPLIT`

In [None]:
def _merge_labels(label_list):
    arr = np.vstack(label_list).astype(int)
    return np.clip(arr.sum(axis=0), 0, 1).tolist()


def summarize_split(name, idx, Y):
    subset = Y[idx]
    counts = subset.sum(axis=0).tolist()        # number of positive images per class
    rates  = subset.mean(axis=0).round(4).tolist()
    return {"name": name, "n": int(len(idx)), "counts": counts, "rates": rates}


def stratified_split(image_df, train_frac=0.7, val_frac=0.15, seed=42):
    rng = np.random.RandomState(seed)
    X = np.arange(len(image_df)).reshape(-1, 1)
    Y = np.array(image_df["image_label"].to_list())

    X_train, Y_train, X_tmp, Y_tmp = iterative_train_test_split(
        X, Y, test_size=(1 - train_frac)
    )

    val_frac_rel = val_frac / (1 - train_frac)
    X_val, Y_val, X_test, Y_test = iterative_train_test_split(
        X_tmp, Y_tmp, test_size=(1 - val_frac_rel)
    )

    idx_train = X_train.ravel()
    idx_val   = X_val.ravel()
    idx_test  = X_test.ravel()

    split = np.full(len(image_df), "train", dtype=object)
    split[idx_val]  = "val"
    split[idx_test] = "test"

    image_df = image_df.copy()
    image_df["split"] = split
    return image_df


## Split

In [None]:
# sanity
assert abs(config["TRAIN_FRAC"] + config["VAL_FRAC"] + config["TEST_FRAC"] - 1.0) < 1e-9

need_cols = {"image_id", "label_vector", "file_path"}
missing = need_cols - set(patch_df.columns)
assert not missing, f"patch_df missing columns: {missing}"

image_df = (
    patch_df.groupby("image_id", as_index=False)["label_vector"]
            .apply(lambda s: _merge_labels(list(s)))
            .rename(columns={"label_vector": "image_label"})
)
image_df["image_path"] = image_df["image_id"].apply(
    lambda x: os.path.join(config["IMAGE_ROOT"], f"{x}.png")
)

K = len(image_df["image_label"].iloc[0])
Y = np.array(image_df["image_label"].to_list(), dtype=int)
N = len(image_df)

if config["FORCE_RESPLIT"] or not os.path.isfile(SPLIT_CSV):
    image_df = stratified_split(
        image_df,
        train_frac=config["TRAIN_FRAC"],
        val_frac=config["VAL_FRAC"],
        seed=config["SEED"],
    )
    image_df[["image_id", "split"]].to_csv(SPLIT_CSV, index=False)
else:
    prev = pd.read_csv(SPLIT_CSV)
    image_df = image_df.merge(prev, on="image_id", how="left")
    image_df["split"] = image_df["split"].fillna("train")

image_split_map = dict(zip(image_df["image_id"], image_df["split"]))
patch_df["split"] = patch_df["image_id"].map(image_split_map)
assert patch_df["split"].notna().all()

idx_train = image_df.index[image_df["split"] == "train"].to_numpy()
idx_val   = image_df.index[image_df["split"] == "val"].to_numpy()
idx_test  = image_df.index[image_df["split"] == "test"].to_numpy()

def _summarize(name, idx, Y):
    subset = Y[idx]
    counts = subset.sum(axis=0).astype(int).tolist()   # [+] image counts per class
    rates  = subset.mean(axis=0).round(4).tolist()     # [+] image rates per class
    return {"name": name, "n": int(len(idx)), "counts": counts, "rates": rates}

summary = [
    _summarize("train",   idx_train, Y),
    _summarize("val",     idx_val,   Y),
    _summarize("test",    idx_test,  Y),
    _summarize("overall", np.arange(N), Y),
]

df_summary = pd.DataFrame(summary)
df_summary.to_csv(SPLIT_SUMMARY_CSV, index=False)
print(df_summary)

# mkae sure every class appears in every split
cls_names = ["MA", "HE", "EX", "SE"]
for split_name in ["train", "val", "test"]:
    row = df_summary[df_summary["name"] == split_name].iloc[0]
    zero_classes = [cls_names[i] for i, c in enumerate(row["counts"]) if c == 0]
    if zero_classes:
        print(f"[WARN] Split '{split_name}' has zero positives for: {zero_classes}")

### IMPORTANT ###
# check distro before proceeding
### IMPORTANT ###

      name     n                   counts                             rates
0    train  1217    [994, 1016, 893, 437]  [0.8168, 0.8348, 0.7338, 0.3591]
1      val   258      [213, 217, 191, 93]  [0.8256, 0.8411, 0.7403, 0.3605]
2     test   262      [213, 218, 192, 94]   [0.813, 0.8321, 0.7328, 0.3588]
3  overall  1737  [1420, 1451, 1276, 624]  [0.8175, 0.8353, 0.7346, 0.3592]


## Dataset

In [None]:
class PatchSegDataset(Dataset):
    class _AlbuNoOp:
        def __call__(self, **kwargs):
            return kwargs

    def __init__(self,
                 patch_df,
                 image_dir,
                 patches_root,
                 mask_dir,
                 transform_img,
                 transform_patch,
                 transform_mask,
                 fixed_patches_train=48,
                 fixed_patches_val=None,
                 seed=42,
                 train_mode=True,
                 joint_tf=None):
        import os
        self.image_dir = image_dir
        self.patches_root = patches_root
        self.mask_dir = mask_dir

        self.transform_img   = transform_img or self._AlbuNoOp()
        self.transform_patch = transform_patch or self._AlbuNoOp()
        self.transform_mask  = transform_mask or self._AlbuNoOp()
        self.joint_tf        = joint_tf

        self.train_mode = bool(train_mode)
        self.fixed_patches_train = fixed_patches_train
        self.fixed_patches_val   = fixed_patches_val

        # remember seed; rng will be reseeded per-epoch for TRAIN only
        self.seed = int(seed)
        self.rng = np.random.default_rng(self.seed)

        # buffers
        self.image_ids = []
        self.pos_lists = []  # all positive patches (any class)
        self.neg_lists = []
        self.cls_lists = []  # list of dicts

        for image_id, g in patch_df.groupby("image_id"):
            g = g.sort_values(by=["center_y", "center_x"], ascending=True) if "center_y" in g.columns else g

            pos_paths, neg_paths = [], []
            per_class = {"MA": [], "HE": [], "EX": [], "SE": []}

            for pth, lv in zip(g["file_path"].tolist(), g["label_vector"].tolist()):
                full = pth if os.path.isabs(pth) else os.path.join(self.patches_root, pth)
                if sum(lv) > 0:
                    pos_paths.append(full)
                    if lv[0]: per_class["MA"].append(full)
                    if lv[1]: per_class["HE"].append(full)
                    if lv[2]: per_class["EX"].append(full)
                    if lv[3]: per_class["SE"].append(full)
                else:
                    neg_paths.append(full)

            self.image_ids.append(image_id)
            self.pos_lists.append(pos_paths)
            self.neg_lists.append(neg_paths)
            self.cls_lists.append(per_class)

        self.mask_subdirs = {
            "mask_MA": "Microaneurysms_Masks",
            "mask_HE": "Hemorrhage_Masks",
            "mask_EX": "HardExudate_Masks",
            "mask_SE": "SoftExudate_Masks",
        }

        self.val_cache = {} if (not self.train_mode and self.fixed_patches_val is not None) else None

    def __len__(self):
        return len(self.image_ids)

    def set_epoch(self, epoch: int):
        """Reseed RNG for training; validation stays deterministic & cached."""
        if self.train_mode:
            self.rng = np.random.default_rng(self.seed + int(epoch))

    def _choose_patches_unbiased(self, paths):
        if (self.fixed_patches is None) or (len(paths) == self.fixed_patches):
            return paths
        if len(paths) > self.fixed_patches:
            idx = self.rng.choice(len(paths), size=self.fixed_patches, replace=False)
            return [paths[i] for i in idx]
        pad = self.fixed_patches - len(paths)
        return paths + paths[:min(pad, len(paths))] if len(paths) > 0 else paths

    def _sample_without_replacement(self, arr, k):
        if k <= 0 or len(arr) == 0:
            return []
        k = min(k, len(arr))
        idx = self.rng.choice(len(arr), size=k, replace=False)
        return [arr[i] for i in idx]

    def _sample_with_replacement(self, arr, k):
        if k <= 0 or len(arr) == 0:
            return []
        idx = self.rng.integers(0, len(arr), size=k)
        return [arr[i] for i in idx]

    def _choose_patches_biased(self, idx, k, pos_ratio, class_quota):
        pos = self.pos_lists[idx]
        neg = self.neg_lists[idx]
        cls = self.cls_lists[idx]

        if k is None:
            return pos + neg

        k_pos = int(round(k * float(pos_ratio)))
        k_neg = max(0, k - k_pos)

        names = ["MA", "HE", "EX", "SE"]
        q = np.array(class_quota, dtype=np.float64)
        q = np.clip(q, 0.0, 1.0)
        if q.sum() <= 0:
            q = np.array([0.35, 0.25, 0.25, 0.15], dtype=np.float64)  # fallback
        q = q / q.sum()

        desired = q * k_pos
        base = np.floor(desired).astype(int)
        rem = k_pos - base.sum()

        frac_order = np.argsort(-(desired - base))
        for i in range(rem):
            base[frac_order[i]] += 1
        want_per_class = dict(zip(names, base.tolist()))

        sel = []
        for name in names:
            need = want_per_class[name]
            have = len(cls[name])
            take = min(need, have)
            if take > 0:
                sel += self._sample_without_replacement(cls[name], take)
            want_per_class[name] = need - take  # residual need

        for name in names:
            residual = want_per_class[name]
            if residual > 0 and len(cls[name]) > 0:
                sel += self._sample_with_replacement(cls[name], residual)

        need_neg = max(0, k - len(sel))
        if need_neg > 0 and len(neg) > 0:
            sel += self._sample_without_replacement(neg, need_neg)

        need_final = max(0, k - len(sel))
        if need_final > 0 and len(pos) > 0:
            sel += self._sample_with_replacement(pos, need_final)

        return sel[:k]

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_path = os.path.join(self.image_dir, f"{image_id}.png")
        img_pil  = Image.open(img_path).convert("RGB")
        img_np   = np.array(img_pil)  # (H,W,3) uint8

        mpaths = [
            os.path.join(self.mask_dir, self.mask_subdirs["mask_MA"], f"{image_id}.png"),
            os.path.join(self.mask_dir, self.mask_subdirs["mask_HE"], f"{image_id}.png"),
            os.path.join(self.mask_dir, self.mask_subdirs["mask_EX"], f"{image_id}.png"),
            os.path.join(self.mask_dir, self.mask_subdirs["mask_SE"], f"{image_id}.png"),
        ]
        mask_chs = []
        H, W = img_np.shape[:2]
        for mp in mpaths:
            if os.path.isfile(mp):
                m = Image.open(mp).convert("L")
                m_np = np.array(m, dtype=np.uint8)
                m_np = (m_np > 127).astype(np.uint8)
            else:
                m_np = np.zeros((H, W), dtype=np.uint8)
            mask_chs.append(m_np)
        mask_hwc = np.stack(mask_chs, axis=-1)  # (H,W,4)

        if self.joint_tf is not None:
            out = self.joint_tf(image=img_np, mask=mask_hwc)
            img_np_tf  = out["image"]
            mask_hwc_tf = out["mask"]
        else:
            out = self.transform_img(image=img_np, mask=mask_hwc)
            img_np_tf  = out.get("image", img_np)
            mask_hwc_tf = out.get("mask", mask_hwc)

        if isinstance(img_np_tf, torch.Tensor):
            if img_np_tf.ndim == 3 and img_np_tf.shape[0] == 3:
                image_tensor = img_np_tf.contiguous().float()
            elif img_np_tf.ndim == 3 and img_np_tf.shape[-1] == 3:
                image_tensor = img_np_tf.permute(2, 0, 1).contiguous().float()
            else:
                raise ValueError(f"Unexpected image shape {tuple(img_np_tf.shape)}")
        else:
            image_tensor = torch.from_numpy(img_np_tf.transpose(2, 0, 1)).float()

        if isinstance(mask_hwc_tf, torch.Tensor):
            if mask_hwc_tf.ndim == 3 and mask_hwc_tf.shape[0] == 4:
                mask_tensor = mask_hwc_tf.contiguous().float()
            elif mask_hwc_tf.ndim == 3 and mask_hwc_tf.shape[-1] == 4:
                mask_tensor = mask_hwc_tf.permute(2, 0, 1).contiguous().float()
            else:
                raise ValueError(f"Unexpected mask shape {tuple(mask_hwc_tf.shape)}")
        else:
            mask_tensor = torch.from_numpy(mask_hwc_tf.transpose(2, 0, 1)).float()

        mask_tensor = mask_tensor.clamp(0, 1)

        if self.train_mode:
            k = self.fixed_patches_train  # e.g., 48
            pos_ratio = float(config.get("PATCH_POS_RATIO", 0.7))
            class_quota = config.get("PATCH_CLASS_QUOTA", [0.35, 0.25, 0.25, 0.15])
            paths = self._choose_patches_biased(idx, k, pos_ratio, class_quota)
        else:
            k = self.fixed_patches_val
            if k is None:
                paths = (self.pos_lists[idx] + self.neg_lists[idx])
            else:
                if self.val_cache is not None:
                    iid = image_id
                    if iid in self.val_cache:
                        paths = self.val_cache[iid]
                    else:
                        allp = self.pos_lists[idx] + self.neg_lists[idx]
                        if len(allp) <= k:
                            sel = allp
                        else:
                            h = abs(hash((str(iid), self.seed))) % (2**32)
                            rng_local = np.random.default_rng(h)
                            idxs = rng_local.choice(len(allp), size=k, replace=False)
                            sel = [allp[i] for i in idxs]
                        self.val_cache[iid] = sel
                        paths = sel
                else:
                    paths = self._sample_without_replacement(self.pos_lists[idx] + self.neg_lists[idx], k)

        pt = []
        for p in paths:
            if os.path.isfile(p):
                im = Image.open(p).convert("RGB")
                p_np = np.array(im)
                out_p = self.transform_patch(image=p_np)
                p_img = out_p.get("image", p_np)

                if isinstance(p_img, torch.Tensor):
                    if p_img.ndim == 3 and p_img.shape[0] == 3:
                        img_tensor = p_img.contiguous().float()
                    elif p_img.ndim == 3 and p_img.shape[-1] == 3:
                        img_tensor = p_img.permute(2, 0, 1).contiguous().float()
                    else:
                        raise ValueError(f"Unexpected patch shape {tuple(p_img.shape)}")
                else:
                    img_tensor = torch.from_numpy(p_img.transpose(2, 0, 1)).float()

                pt.append(img_tensor)

        if len(pt) == 0:
            P = int(config.get("PATCH_SIZE", config["IMG_SIZE"]))
            pt = [torch.zeros(3, P, P, dtype=torch.float32)]

        patch_tensor = torch.stack(pt, dim=0)  # (N,3,P,P)
        return image_tensor, patch_tensor, mask_tensor

## Utility Functions `TRANSFORMS`

In [None]:
def green_clahe(img, clip=2.0, grid=(8,8), mode="blend", alpha=0.75):
    g = img[..., 1]
    clahe = cv2.createCLAHE(clipLimit=float(clip), tileGridSize=tuple(grid))
    g_eq = clahe.apply(g)

    if mode == "replace":
        out = img.copy()
        out[..., 1] = g_eq
        return out
    elif mode == "replicate":
        g3 = np.stack([g_eq, g_eq, g_eq], axis=-1)
        return g3
    elif mode == "blend":
        out = img.astype(np.float32)
        g_eq_f = g_eq.astype(np.float32)
        out[..., 0] = (1 - alpha) * out[..., 0] + alpha * g_eq_f
        out[..., 1] = g_eq_f
        out[..., 2] = (1 - alpha) * out[..., 2] + alpha * g_eq_f
        return np.clip(out, 0, 255).astype(np.uint8)
    else:
        raise ValueError("mode must be one of {'replace','replicate','blend'}")

# since A.lambda isn't compatibile with multiprocessing, i'm making a wrapper...

def green_clahe_transform(x, **kwargs):
    return green_clahe(x, clip=2.0, grid=(8,8), mode=config_mode, alpha=0.75)

def green_clahe_torchvision(pil_img):
    img = np.array(pil_img)
    out = green_clahe(img, clip=2.0, grid=(8,8), mode=config_mode, alpha=0.75)
    return Image.fromarray(out)

def identity_mask(x, **kwargs):
    return x

class GreenCLAHE(ImageOnlyTransform):
    def __init__(self, clip=2.0, grid=(8,8), mode="blend", alpha=0.75, always_apply=False, p=1.0):
        super(GreenCLAHE, self).__init__(always_apply, p)
        self.clip = clip
        self.grid = grid
        self.mode = mode
        self.alpha = alpha

    def apply(self, img, **params):
        return green_clahe(img, clip=self.clip, grid=self.grid, mode=self.mode, alpha=self.alpha)

## Transforms

In [None]:
config_mode = config["MODE"]

P = int(config["PATCH_SIZE"])
IMG_MEAN = IMAGENET_MEAN
IMG_STD  = IMAGENET_STD
IMG_H = IMG_W = config["IMG_SIZE"]

transform_patch = A.Compose([
    A.Resize(P, P, interpolation=cv2.INTER_LINEAR),
    GreenCLAHE(clip=2.0, grid=(8, 8), mode=config["MODE"], alpha=0.75, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=IMG_MEAN, std=IMG_STD),
    ToTensorV2(),
])

albu_train = A.Compose([
    A.PadIfNeeded(min_height=IMG_H, min_width=IMG_W, border_mode=cv2.BORDER_CONSTANT),
    A.Resize(IMG_H, IMG_W, interpolation=cv2.INTER_LINEAR),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.1),
    A.Rotate(limit=10, border_mode=cv2.BORDER_CONSTANT, fill=0),
    GreenCLAHE(clip=2.0, grid=(8,8), mode=config["MODE"], alpha=0.75, p=1.0),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(transpose_mask=True),
])

albu_val = A.Compose([
    A.PadIfNeeded(min_height=IMG_H, min_width=IMG_W, border_mode=cv2.BORDER_CONSTANT),
    A.Resize(IMG_H, IMG_W, interpolation=cv2.INTER_LINEAR),
    GreenCLAHE(clip=2.0, grid=(8,8), mode=config["MODE"], alpha=0.75, p=1.0),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(transpose_mask=True),
])

---
## Model

#### ResNet stride OS8

In [None]:
def resnet34_os8(backbone: nn.Module) -> nn.Module:
    for b in (backbone.layer3[0], backbone.layer4[0]):
        if hasattr(b.conv1, "stride"):
            b.conv1.stride = (1, 1)
        if b.downsample is not None and hasattr(b.downsample[0], "stride"):
            b.downsample[0].stride = (1, 1)
    return backbone

#### Blocks | ResDoubleConv and `conv_gn_act`

In [None]:
class ResDoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, groups=32, drop_p=0.0):
        super().__init__()
        self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, 1, 1, bias=False),
            nn.GroupNorm(num_groups=min(groups, out_ch), num_channels=out_ch),
            nn.SiLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False),
            nn.GroupNorm(num_groups=min(groups, out_ch), num_channels=out_ch),
            nn.Dropout2d(drop_p) if drop_p > 0 else nn.Identity(),
        )
        self.act = nn.SiLU(inplace=True)
    def forward(self, x):
        return self.act(self.block(x) + self.proj(x))

def conv_gn_act(in_ch, out_ch, k=3, s=1, p=1, groups=32):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
        nn.GroupNorm(num_groups=min(groups, out_ch), num_channels=out_ch),
        nn.SiLU(inplace=True),
    )

#### Image Backbone OS8

In [None]:
class ImgBackboneOS8(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        w = models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None
        net = resnet34_os8(models.resnet34(weights=w))
        self.stem   = nn.Sequential(net.conv1, net.bn1, net.relu, net.maxpool)  # OS4
        self.layer1 = net.layer1   # OS4, 64ch
        self.layer2 = net.layer2   # OS8, 128ch
        self.layer3 = net.layer3   # OS8, 256ch
        self.layer4 = net.layer4   # OS8, 512ch

    def forward(self, x):
        x  = self.stem(x)          # B, 64, H/4,  W/4
        c1 = self.layer1(x)        # B, 64, H/4,  W/4
        c2 = self.layer2(c1)       # B,128, H/8,  W/8
        c3 = self.layer3(c2)       # B,256, H/8,  W/8
        c4 = self.layer4(c3)       # B,512, H/8,  W/8
        return c1, c2, c3, c4

#### CrossPatchAttention

In [None]:
class CrossPatchAttention(nn.Module):
    def __init__(self, img_ch=128, patch_ch=512, hid=256, heads=4, dropout=0.0):
        super().__init__()
        self.q_proj = nn.Conv2d(img_ch, hid, 1, bias=False)
        self.k_proj = nn.Linear(patch_ch, hid, bias=False)
        self.v_proj = nn.Linear(patch_ch, hid, bias=False)
        self.attn   = nn.MultiheadAttention(embed_dim=hid, num_heads=heads, batch_first=True, dropout=dropout)
        self.out    = nn.Conv2d(hid, hid, 1, bias=False)

    def forward(self, img_feat, patch_vecs):
        B, _, H, W = img_feat.shape
        Q = self.q_proj(img_feat)                    # (B, hid, H, W)
        Q = Q.flatten(2).transpose(1, 2)            # (B, H*W, hid)
        K = self.k_proj(patch_vecs)                 # (B, N,   hid)
        V = self.v_proj(patch_vecs)                 # (B, N,   hid)
        ctx, _ = self.attn(Q, K, V)                 # (B, H*W, hid)
        ctx = ctx.transpose(1, 2).reshape(B, -1, H, W)
        return self.out(ctx)                        # (B, hid, H, W)

#### FPNDecoder

In [None]:
class FPNDecoder(nn.Module):
    def __init__(self, c1=64, c2=128, c3=256, c4=512, ctx_ch=256, out_ch=4,
                 C=256, fusion="add", blocks="basic", drop_p=0.1, groups=32):
        super().__init__()
        self.fusion = fusion
        Block = ResDoubleConv if blocks == "res" else lambda ic, oc, **kw: conv_gn_act(ic, oc, k=3, s=1, p=1, groups=groups)

        # lateral reducers
        self.l4 = nn.Conv2d(c4, C, 1, bias=False)
        self.l3 = nn.Conv2d(c3, C, 1, bias=False)
        self.l2 = nn.Conv2d(c2, C, 1, bias=False)
        self.l1 = nn.Conv2d(c1, C // 2, 1, bias=False)

        # p3 stage
        in_p3 = C + C if fusion == "concat" else C
        self.p3 = nn.Sequential(Block(in_p3, C, drop_p=drop_p), nn.Dropout2d(drop_p))

        # p2 stage (+ optional ctx)
        in_p2 = (C + C) if fusion == "concat" else C
        if ctx_ch:
            in_p2 += ctx_ch
        self.p2 = nn.Sequential(Block(in_p2, C, drop_p=drop_p), nn.Dropout2d(drop_p))

        # p1 stage
        self.p2_reduce_to_p1 = nn.Conv2d(C, C // 2, 1, bias=False)
        self.p1 = nn.Sequential(Block(C // 2, C // 2, drop_p=drop_p), nn.Dropout2d(drop_p))

        # heads
        self.head = nn.Sequential(conv_gn_act(C // 2, C // 4, 3, 1, 1, groups=groups),
                                  nn.Conv2d(C // 4, out_ch, 1))
        self.classifier = self.head[-1]  # convenience
        self.aux2 = nn.Conv2d(C, out_ch, 1)
        self.aux3 = nn.Conv2d(C, out_ch, 1)

    def forward(self, c1, c2, c3, c4, ctx_c2=None):
        p4 = self.l4(c4)
        up4 = F.interpolate(p4, size=c3.shape[-2:], mode="bilinear", align_corners=False)

        l3 = self.l3(c3)
        p3_in = torch.cat([l3, up4], dim=1) if self.fusion == "concat" else (l3 + up4)
        p3 = self.p3(p3_in)

        up3 = F.interpolate(p3, size=c2.shape[-2:], mode="bilinear", align_corners=False)
        l2 = self.l2(c2)
        p2_in = torch.cat([l2, up3], dim=1) if self.fusion == "concat" else (l2 + up3)
        if ctx_c2 is not None:
            p2_in = torch.cat([p2_in, ctx_c2], dim=1)
        p2 = self.p2(p2_in)

        p1 = self.l1(c1) + self.p2_reduce_to_p1(F.interpolate(p2, size=c1.shape[-2:], mode="bilinear", align_corners=False))
        p1 = self.p1(p1)

        main = self.head(F.interpolate(p1, scale_factor=4, mode="bilinear", align_corners=False))
        aux2 = F.interpolate(self.aux2(p2), scale_factor=4, mode="bilinear", align_corners=False)
        aux3 = F.interpolate(self.aux3(p3), scale_factor=8, mode="bilinear", align_corners=False)
        return main, aux2, aux3

#### PatchSegNet

In [None]:
class PatchSegNet(nn.Module):
    def __init__(self, pretrained=True, hidden=256, num_classes=4, patch_chunk=25, heads=4,
                 use_global_film=False, config=None):
        super().__init__()
        cfg = {} if config is None else dict(config)
        self.patch_chunk = int(patch_chunk)
        self.hidden = int(hidden)
        self.use_global_film = bool(cfg.get("USE_GLOBAL_FILM", use_global_film))

        # encoders
        self.img_enc   = ImgBackboneOS8(pretrained=pretrained)
        w = models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None
        self.patch_enc = nn.Sequential(*list(resnet34_os8(models.resnet34(weights=w)).children())[:-2])  # 512-ch

        self.patch_fc = nn.Linear(512, hidden)
        self.xattn = CrossPatchAttention(img_ch=128, patch_ch=512, hid=hidden, heads=heads)

        if self.use_global_film:
            self.img_pool = nn.AdaptiveAvgPool2d(1)
            self.img_fc   = nn.Linear(512, hidden)
            self.global_attn = nn.MultiheadAttention(embed_dim=hidden, num_heads=heads, batch_first=True)
            self.gamma = nn.Linear(hidden, 512)
            self.beta  = nn.Linear(hidden, 512)

        # decoder with knobs
        self.decoder = FPNDecoder(
            c1=64, c2=128, c3=256, c4=512, ctx_ch=hidden, out_ch=num_classes,
            C=cfg.get("DEC_C", 256),
            fusion=cfg.get("DEC_FUSION", "add"),        # "add" or "concat"
            blocks=cfg.get("DEC_BLOCK", "basic"),       # "basic" or "res"
            drop_p=cfg.get("DEC_DROPOUT", 0.10),
            groups=cfg.get("DEC_GN_GROUPS", 32),
        )

    def encode_patches_chunked(self, patches):
        B, N, C, H, W = patches.shape
        chunks = []
        for s in range(0, N, self.patch_chunk):
            e = min(N, s + self.patch_chunk)
            pfm  = self.patch_enc(patches[:, s:e].reshape(-1, C, H, W))
            pvec = F.adaptive_avg_pool2d(pfm, 1).flatten(1)
            chunks.append(pvec)
        return torch.cat(chunks, dim=0).reshape(B, -1, 512)

    def forward(self, image, patches):
        c1, c2, c3, c4 = self.img_enc(image)
        pvec512 = self.encode_patches_chunked(patches).float()
        pvec_h  = self.patch_fc(pvec512)

        ctx_c2 = self.xattn(c2, pvec512)  # OS8 context

        if self.use_global_film:
            g = self.img_pool(c4).flatten(1)
            qi = self.img_fc(g).unsqueeze(1)
            ctx_vec, _ = self.global_attn(qi, pvec_h, pvec_h)
            ctx_vec = ctx_vec.squeeze(1)
            with torch.no_grad():
                if not hasattr(self, "_film_inited"):
                    nn.init.zeros_(self.gamma.weight); nn.init.zeros_(self.gamma.bias)
                    nn.init.zeros_(self.beta.weight);  nn.init.zeros_(self.beta.bias)
                    self._film_inited = True
            gamma = 1 + 0.5 * torch.tanh(self.gamma(ctx_vec)).unsqueeze(-1).unsqueeze(-1)
            beta  = 0.5 * self.beta(ctx_vec).unsqueeze(-1).unsqueeze(-1)
            c4 = gamma * c4 + beta

        main, aux2, aux3 = self.decoder(c1, c2, c3, c4, ctx_c2=ctx_c2)
        def _resize(x):  # safety
            return x if x.shape[-2:] == image.shape[-2:] else F.interpolate(x, size=image.shape[-2:], mode="bilinear", align_corners=False)
        return tuple(_resize(x) for x in (main, aux2, aux3))

## Loss Functions

#### Utility Functions `LOSS`

In [None]:
class AlignLoss(torch.nn.Module):
    # Wrap a base loss so logits/targets are auto-aligned to (B,C,H,W)
    def __init__(self, base_loss):
        super().__init__()
        self.base_loss = base_loss
    def forward(self, logits, targets):
        print("BCE shapes[AlignLoss]: ", tuple(logits.shape), tuple(targets.shape))
        logits, targets = _align_logits_targets(logits, targets)
        return self.base_loss(logits, targets)

def build_pos_weight(cov_ratios, cap=20.0, device=None):
    eps = 1e-8
    r = torch.tensor([max(eps, float(x)) for x in cov_ratios], dtype=torch.float32)
    pw = ((1.0 - r) / r).clamp_(min=1.0, max=float(cap))
    if device is not None:
        pw = pw.to(device)
    return pw


def build_class_weights(cov_ratios, power=1.0, normalize_to_mean=True, cap=10.0, device=None):
    eps = 1e-8
    r = torch.tensor([max(eps, float(x)) for x in cov_ratios], dtype=torch.float32)
    cw = (1.0 / (r ** power)).clamp_(max=float(cap))
    if normalize_to_mean:
        cw = cw / cw.mean().clamp_min(1e-8)
    if device is not None:
        cw = cw.to(device)
    return cw


@torch.no_grad()
def compute_cov_ratios(dataset, thr=0.5):
    # probe to get C
    _, _, m0 = dataset[0]
    if not torch.is_tensor(m0): m0 = torch.as_tensor(m0)
    C = int(m0.shape[0])

    pos = torch.zeros(C, dtype=torch.float64)
    tot = torch.zeros(C, dtype=torch.float64)

    for i in tqdm(range(len(dataset)), desc="Computing coverage ratios", leave=False):
        _, _, msk = dataset[i]
        if not torch.is_tensor(msk): msk = torch.as_tensor(msk)
        msk = (msk >= thr).float()  # binarize to {0,1}
        pos += msk.sum(dim=(1, 2))
        tot += torch.tensor(msk.shape[1] * msk.shape[2], dtype=torch.float64).repeat(C)

    ratios = (pos / tot).to(torch.float32).tolist()
    return ratios


def compute_seg_loss(logits_tuple, target, loss_main, loss_aux=None, weights=(1.0, 0.4, 0.2)):
    if not isinstance(logits_tuple, (list, tuple)):
        return loss_main(logits_tuple, target), {"main": float("nan"), "aux2": float("nan"), "aux3": float("nan")}
    main, aux2, aux3 = logits_tuple
    la = loss_aux if loss_aux is not None else loss_main
    Lm = loss_main(main, target)
    L2 = la(aux2, target)
    L3 = la(aux3, target)
    w0, w1, w2 = weights
    return w0*Lm + w1*L2 + w2*L3, {"main": Lm.item(), "aux2": L2.item(), "aux3": L3.item()}


def build_class_weights_from_image_prev(prev, power=1.0, cap=4.0, normalize_to_mean=True, device=None):
    r = torch.tensor([max(1e-6, float(x)) for x in prev], dtype=torch.float32)
    cw = (1.0 / (r ** power)).clamp_(max=float(cap))
    if normalize_to_mean:
        cw = cw / cw.mean().clamp_min(1e-8)
    return cw.to(device) if device is not None else cw

#### SoftDice

In [None]:
class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1.0, eps=1e-7):
        super().__init__()
        self.smooth, self.eps = smooth, eps
    def forward(self, logits, targets):
        print("BCE shapes[SoftDiceLoss]: ", tuple(logits.shape), tuple(targets.shape))
        p = torch.sigmoid(logits).float()
        t = targets.float()
        dims = (0, 2, 3)  # sum over B,H,W per class
        num = 2 * (p * t).sum(dims) + self.smooth
        den = (p.pow(2) + t.pow(2)).sum(dims) + self.smooth
        dice = num / (den + self.eps)
        return (1.0 - dice).mean()

#### Dice with BCE

In [None]:
class DiceBCELoss(nn.Module):
    def __init__(self, bce_weight=0.5, smooth=1.0, eps=1e-7):
        super().__init__()
        self.w = bce_weight
        self.dice = SoftDiceLoss(smooth, eps)
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, logits, targets):
        logits, targets = _align_logits_targets(logits, targets)
        return self.w * self.bce(logits, targets) + (1 - self.w) * self.dice(logits, targets)

#### Dice with BCE (weighted)

In [None]:
class DiceBCELossW(nn.Module):
    def __init__(self, class_weights: torch.Tensor, pos_weight: torch.Tensor,
                 bce_w: float = 0.5, dice_w: float = 0.5):
        super().__init__()
        assert class_weights.ndim == 1
        assert pos_weight.ndim == 1
        self.register_buffer("cw", class_weights.detach().clone().float())
        self.register_buffer("pw", pos_weight.detach().clone().float())
        self.bw = float(bce_w); self.dw = float(dice_w)

    def forward(self, logits, targets):
        print("BCE shapes[TverskyLoss]: ", tuple(logits.shape), tuple(targets.shape))
        logits, targets = _align_logits_targets(logits, targets)  # (N,C,H,W)
        print("BCE shapes[TverskyLoss AFTER]: ", tuple(logits.shape), tuple(targets.shape))

        pw = self.pw.to(logits.device).view(1, -1, 1, 1)  # <<< key line
        bce_el = F.binary_cross_entropy_with_logits(
            logits, targets.float(), pos_weight=pw, reduction="none"
        )  # (N,C,H,W)

        # mean over (N,H,W) => (C,)
        bce_pc = bce_el.mean(dim=(0, 2, 3))
        cw = self.cw.to(bce_pc.device)
        bce = (bce_pc * cw).sum() / cw.sum().clamp_min(1e-8)

        dice = _dice_loss_weighted(logits, targets, class_weights=self.cw)
        return self.bw * bce + self.dw * dice

def _dice_loss_weighted(logits, targets, class_weights=None, eps=1e-6):
    probs   = torch.sigmoid(logits).float()
    targets = targets.float()

    dims = (0, 2, 3)
    inter = (probs * targets).sum(dim=dims)
    union = probs.sum(dim=dims) + targets.sum(dim=dims)

    dice_per_class = (2 * inter + eps) / (union + eps)
    loss_per_class = 1.0 - dice_per_class

    if class_weights is None:
        return loss_per_class.mean()
    cw = class_weights.to(loss_per_class.device, dtype=loss_per_class.dtype)

    return (loss_per_class * cw).sum() / cw.sum().clamp_min(1e-8)

#### Tversky Loss

In [None]:
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5, eps=1e-7):
        super().__init__()
        self.a, self.b, self.eps = alpha, beta, eps
    def forward(self, logits, targets):
        print("BCE shapes[TverskyLoss]: ", tuple(logits.shape), tuple(targets.shape))
        p = torch.sigmoid(logits).float(); t = targets.float()
        dims = (0, 2, 3)
        tp = (p * t).sum(dims)
        fp = (p * (1 - t)).sum(dims)
        fn = ((1 - p) * t).sum(dims)
        tv = (tp + self.eps) / (tp + self.a*fp + self.b*fn + self.eps)
        return (1.0 - tv).mean()

#### Focal Tversky Loss

In [None]:
class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, gamma=0.75,
                 class_weights=None, eps=1e-6,
                 clamp_probs=False, ignore_empty=False):
        super().__init__()
        self.alpha = float(alpha)
        self.beta  = float(beta)
        self.gamma = float(gamma)
        self.eps   = float(eps)
        self.clamp_probs  = bool(clamp_probs)
        self.ignore_empty = bool(ignore_empty)
        if class_weights is not None:
            cw = class_weights
            if isinstance(cw, torch.Tensor):
                cw = cw.detach().clone().to(torch.float32)
            else:
                cw = torch.tensor(cw, dtype=torch.float32)
            self.register_buffer("class_weights", cw)
        else:
            self.class_weights = None


    def forward(self, logits, targets):
        logits, targets = _align_logits_targets(logits, targets)  # (N,C,H,W)
        with amp.autocast(device_type=device.type, enabled=False):
            logits  = logits.float()
            targets = targets.float()
            probs   = torch.sigmoid(logits)
            if self.clamp_probs:
                probs = torch.clamp(probs, min=self.eps, max=1.0 - self.eps)

            dims = (0, 2, 3)
            tp = (probs * targets).sum(dim=dims)
            fp = (probs * (1 - targets)).sum(dim=dims)
            fn = ((1 - probs) * targets).sum(dim=dims)

            ti = tp / (tp + self.alpha * fp + self.beta * fn + self.eps)   # (C,)
            ft = (1.0 - ti).pow(self.gamma)                                # (C,)

            if self.class_weights is not None:
                cw = self.class_weights.to(ft.device, dtype=ft.dtype)
                ft = ft * cw

            if self.ignore_empty:
                has_pos = (targets.sum(dim=dims) > 0).float()  # (C,)
                denom = has_pos.sum().clamp_min(1.0)
                return (ft * has_pos).sum() / denom
            else:
                return ft.mean()

#### Focal Tversky with Tiny BCE

In [None]:
class FTwithTinyBCE(nn.Module):
    def __init__(self, ft_loss: FocalTverskyLoss, bce_alpha: float = 0.0, pos_weight: torch.Tensor | None = None):
        super().__init__()
        self.ft = ft_loss
        self.bce_alpha = float(bce_alpha)
        self.pos_weight = None if pos_weight is None else pos_weight.detach().clone().float()

    def forward(self, logits, targets):
        logits, targets = _align_logits_targets(logits, targets)  # (N,C,H,W)

        loss = self.ft(logits, targets)
        if self.bce_alpha > 0.0:
            pw = None
            if self.pos_weight is not None:
                pw = self.pos_weight.to(logits.device).view(1, -1, 1, 1)  # <<< key line
            bce = F.binary_cross_entropy_with_logits(
                logits, targets.float(), pos_weight=pw, reduction="mean"
            )
            loss = loss + self.bce_alpha * bce
        return loss

#### Focal Tversky Dice Loss (Hybrid)

In [None]:
class FTplusDice(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, gamma=1.0, lam=0.7, smooth=1.0, eps=1e-7):
        super().__init__()
        self.lam = lam
        self.ft = FocalTverskyLoss(alpha, beta, gamma, eps)
        self.dice = SoftDiceLoss(smooth, eps)
    def forward(self, logits, targets):
        print("BCE shapes[FTplusDice]: ", tuple(logits.shape), tuple(targets.shape))
        return self.lam * self.ft(logits, targets) + (1 - self.lam) * self.dice(logits, targets)

#### EMA (Exponential Moving Average)

In [None]:
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = float(decay)

        with torch.no_grad():
            self.shadow = {}
            for k, v in model.state_dict().items():
                if torch.is_floating_point(v):
                    self.shadow[k] = v.detach().clone()
        self.backup = None

    @torch.no_grad()
    def update(self, model):
        msd = model.state_dict()
        for k, v in msd.items():
            if k in self.shadow and torch.is_floating_point(v):
                self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)

    @torch.no_grad()
    def apply_to(self, model):
        msd = model.state_dict()
        self.backup = {}
        for k, v in self.shadow.items():
            if k in msd:
                self.backup[k] = msd[k].detach().clone()
                msd[k].copy_(v, non_blocking=True)

    @torch.no_grad()
    def restore(self, model):
        if self.backup is None:
            return
        msd = model.state_dict()
        for k, v in self.backup.items():
            if k in msd:
                msd[k].copy_(v, non_blocking=True)
        self.backup = None

#### Loss Factory Pattern

In [None]:
def build_loss_safe(pos_weight: torch.Tensor | None = None, class_weights: torch.Tensor | None = None):
    lt = config.get("LOSS_TYPE", "focal_tversky")

    if lt == "dice":
        return SoftDiceLoss(smooth=1.0)

    elif lt == "dice_bce":
        return DiceBCELoss(bce_weight=config.get("BCE_WEIGHT", 0.5))

    elif lt == "dice_bce_w":
        assert class_weights is not None and pos_weight is not None, \
            "dice_bce_w requires class_weights and pos_weight"
        return DiceBCELossW(
            class_weights=class_weights,
            pos_weight=pos_weight,
            bce_w=config.get("BCE_WEIGHT", 0.5),
            dice_w=1.0 - float(config.get("BCE_WEIGHT", 0.5)),
        )

    elif lt == "tversky":
        return TverskyLoss(alpha=config.get("TV_ALPHA", 0.5), beta=config.get("TV_BETA", 0.5))

    elif lt == "focal_tversky":
        return FocalTverskyLoss(
            alpha=config.get("TV_ALPHA", 0.3),
            beta=config.get("TV_BETA", 0.7),
            gamma=config.get("FT_GAMMA", 0.75),
            class_weights=(None if class_weights is None else class_weights),
            eps=1e-6,
            clamp_probs=config.get("FT_CLAMP_PROBS", True),
            ignore_empty=config.get("FT_IGNORE_EMPTY", True)
        )

    elif lt == "ft_plus_dice":
        return FTplusDice(alpha=config.get("TV_ALPHA", 0.3), beta=config.get("TV_BETA", 0.7),
                          gamma=config.get("FT_GAMMA", 1.0), lam=config.get("FT_LAMBDA", 0.7))

    elif lt == "ft_w_bce":
        ft = FocalTverskyLoss(
            alpha=config.get("TV_ALPHA", 0.3),
            beta=config.get("TV_BETA", 0.7),
            gamma=config.get("FT_GAMMA", 0.75),
            class_weights=(None if class_weights is None else class_weights),
            eps=1e-6,
            clamp_probs=config.get("FT_CLAMP_PROBS", True),
            ignore_empty=config.get("FT_IGNORE_EMPTY", True)
        )
        return FTwithTinyBCE(
            ft_loss=ft,
            bce_alpha=config.get("FT_BCE_ALPHA", 0.05),
            pos_weight=pos_weight
        )

    else:
        raise ValueError(f"Unknown LOSS_TYPE: {lt}")

## Utility Functions `TRAINING`

In [None]:
def _has_plateaued(xs, mode, patience, min_delta):
    if len(xs) < patience + 1:
        return False
    recent = xs[-(patience+1):]
    if mode == "max":
        return max(recent) <= recent[0] + min_delta
    else:
        return min(recent) >= recent[0] - min_delta

def _bad_loss(loss, prev_losses, spike_factor):
    if not math.isfinite(loss):
        return True
    if not prev_losses:
        return False
    baseline = sum(prev_losses[-3:]) / min(3, len(prev_losses))
    return loss > spike_factor * baseline if baseline > 0 else False

def _set_requires_grad(module, flag: bool):
    if module is None:
        return
    for p in module.parameters():
        p.requires_grad = flag

def _collect_head_modules(model):
    heads = [model.decoder, model.xattn, model.patch_fc]

    if hasattr(model.decoder, "classifier"):
        heads.append(model.decoder.classifier)

    if getattr(model, "use_global_film", False):
        heads.extend([model.img_fc, model.global_attn, model.gamma, model.beta])

    uniq = []
    seen = set()
    for m in heads:
        if m is None:
            continue
        mid = id(m)
        if mid in seen:
            continue
        seen.add(mid)
        uniq.append(m)
    return uniq

@torch.no_grad()
def tune_thresholds_per_class(model, loader_val, device, loss_fn=None, base_thresholds=None, grid=None, verbose=True):
    model.eval()

    C_model = int(getattr(model, "num_classes", 4))
    if base_thresholds is None or len(base_thresholds) != C_model:
        base_thresholds = [0.5] * C_model
    C = C_model

    if grid is None:
        grid = np.round(np.arange(0.20, 0.70 + 1e-9, 0.05), 2).tolist()
    else:
        grid = [float(t) for t in grid]

    best = list(map(float, base_thresholds))
    best_dpc = [-1.0] * C

    for c in range(C):
        best_t = best[c]
        best_dc = -1.0

        for t in grid:
            thr_vec = torch.tensor(best, dtype=torch.float32, device=device)
            thr_vec[c] = float(t)
            thr_t = thr_vec.view(1, C, 1, 1)

            try:
                _vl, _vd, dpc, *_ = evaluate(
                    model,
                    loader_val,
                    device=device,
                    loss_fn=loss_fn,
                    threshold=thr_t
                )
            except Exception as e:
                print(f"[Tuner] Eval failed at class {c}, t={t:.2f}: {e}")
                continue

            dc = float(dpc[c])
            if dc > best_dc:
                best_dc = dc
                best_t = float(t)

        best[c] = best_t
        best_dpc[c] = best_dc

    if verbose:
        print(f"[Tuner] New per-class thresholds: {best}")
        print(f"[Tuner] Dice per class at tuned thresholds: {['%.4f' % d for d in best_dpc]}")

    return best, best_dpc

def set_epoch(self, epoch:int):
    self.rng = np.random.default_rng(int(epoch) ^ 0x9E3779B97F4A7C15)

@torch.no_grad()
def soft_dice_from_logits(logits, targets, eps=1e-7):
    logits, targets = _align_logits_targets(logits, targets)
    p = torch.sigmoid(logits).float()
    t = targets.float()
    dims = (0,2,3)
    num = 2 * (p * t).sum(dim=dims)
    den = (p.pow(2) + t.pow(2)).sum(dim=dims)
    dpc = (num + eps) / (den + eps)  # (C,)
    return dpc.mean(), dpc

def _ema(prev, x, a=0.3):
  return x if prev is None else a*x + (1-a)*prev

def apply_freeze_policy(ep: int):
    _set_requires_grad(model, False)

    _set_requires_grad(model.decoder, True)
    _set_requires_grad(model.xattn, True)
    _set_requires_grad(model.patch_fc, True)

    if getattr(model, "use_global_film", False):
        for m in [model.gamma, model.beta, model.img_fc, model.global_attn]:
            _set_requires_grad(m, True)

    if ep <= 2:
        pass
    elif 3 <= ep <= 5:
        _set_requires_grad(model.img_enc.layer4, True)
        _set_requires_grad(model.img_enc.layer3, True)
    else:
        _set_requires_grad(model.img_enc, True)
        _set_requires_grad(model.patch_enc, True)

## Utility Functions `METRICS`

In [None]:
def dice_per_class(preds, targets, eps=1e-6):
    inter = (preds * targets).sum(dim=(0,2,3))
    denom = preds.sum(dim=(0,2,3)) + targets.sum(dim=(0,2,3)) + eps
    return (2*inter + eps) / denom

def iou_per_class(preds, targets, eps=1e-6):
    inter = (preds * targets).sum(dim=(0,2,3))
    union = preds.sum(dim=(0,2,3)) + targets.sum(dim=(0,2,3)) - inter + eps
    return (inter + eps) / union

def _pearson_per_class(probs, targets, eps=1e-8):
    B, C, H, W = probs.shape
    p = probs.float().permute(1,0,2,3).contiguous().view(C, -1)
    t = targets.float().permute(1,0,2,3).contiguous().view(C, -1)
    pm = p.mean(dim=1, keepdim=True)
    tm = t.mean(dim=1, keepdim=True)
    pv = p.var(dim=1, unbiased=False)
    tv = t.var(dim=1, unbiased=False)
    cov = ((p - pm) * (t - tm)).mean(dim=1)
    r = cov / (pv.clamp_min(eps).sqrt() * tv.clamp_min(eps).sqrt() + eps)
    r[(tv < eps) | (pv < eps)] = 0.0
    return r.detach().cpu()

def prepare_threshold(thr_cfg, probs=None, device="cpu"):
    if isinstance(thr_cfg, (list, tuple, np.ndarray)):
        thr_t = torch.as_tensor(thr_cfg, dtype=torch.float32, device=device)
        return thr_t.view(1, -1, 1, 1)
    if isinstance(thr_cfg, torch.Tensor):
        if thr_cfg.numel() == 1:
            return float(thr_cfg.item())
        return thr_cfg.to(device).view(1, -1, 1, 1)
    return float(thr_cfg)

## Utility Functions `FORMAT`

In [None]:
def _wd_filter(n, p):
    return (p.ndimension() >= 2) and ("bias" not in n) and ("bn" not in n) and ("norm" not in n)

def _align_logits_targets(logits: torch.Tensor, targets: torch.Tensor):
    targets = targets.float()

    if logits.ndim == 3:   # (B,H,W) -> (B,1,H,W)
        logits = logits.unsqueeze(1)
    if targets.ndim == 3:  # (B,H,W) -> (B,1,H,W)
        targets = targets.unsqueeze(1)

    if logits.ndim != 4 or targets.ndim != 4:
        raise ValueError(f"Expected 4D tensors, got logits {tuple(logits.shape)}, targets {tuple(targets.shape)}")

    def _to_nchw(x, name):
        if x.shape[1] in (1, 2, 3, 4):
            return x
        if x.shape[-1] in (1, 2, 3, 4):
            return x.permute(0, 3, 1, 2).contiguous()
        return x

    logits  = _to_nchw(logits,  "logits")
    targets = _to_nchw(targets, "targets")

    if logits.shape[1] != targets.shape[1]:
        raise ValueError(
            f"Channel mismatch after alignment: logits C={logits.shape[1]} vs targets C={targets.shape[1]} "
            f" | logits={tuple(logits.shape)} targets={tuple(targets.shape)}"
        )

    if logits.shape[-2:] != targets.shape[-2:]:
        logits = F.interpolate(logits, size=targets.shape[-2:], mode="bilinear", align_corners=False)

    return logits, targets

def _force_bchw(x: torch.Tensor, num_classes: int) -> torch.Tensor:
    if x.ndim != 4:
        return x
    B, C, H, W = x.shape
    if x.shape[-1] == num_classes and x.shape[1] != num_classes:
        return x.permute(0, 3, 1, 2).contiguous()
    if x.shape[1] >= 64 and x.shape[2] >= 64 and x.shape[-1] <= 64 and x.shape[1] != num_classes:
        return x.permute(0, 3, 1, 2).contiguous()
    return x

def _ensure_bchw(x: torch.Tensor, c_hint: int | None) -> torch.Tensor:
    if x.ndim == 4 and c_hint is not None:
        if x.shape[-1] == c_hint and x.shape[1] != c_hint:
            return x.permute(0, 3, 1, 2).contiguous()
    return x

## Utility Functions `UTILITY`

In [None]:
def remove_small(preds: torch.Tensor, per_class_min_areas, *, binarize: bool = False,
    thr: float = 0.5, connectivity: int = 8) -> torch.Tensor:
    assert preds.ndim == 4, f"expected (B,C,H,W), got {tuple(preds.shape)}"
    C = preds.size(1)
    assert len(per_class_min_areas) == C, "per_class_min_areas length must match channels"

    if binarize:
        preds_bin = (preds >= thr).to(torch.uint8)
    else:
        preds_bin = preds.clamp(0, 1).to(torch.uint8)

    out = torch.zeros_like(preds_bin, dtype=torch.uint8)

    for b in range(preds_bin.size(0)):
        for c in range(C):
            min_area = int(per_class_min_areas[c])
            if min_area <= 0:
                # keep as is
                out[b, c] = preds_bin[b, c]
                continue

            m = preds_bin[b, c].detach().cpu().numpy()
            num, labels = cv2.connectedComponents(m, connectivity=connectivity)
            keep = np.zeros_like(m, dtype=np.uint8)
            for i in range(1, num):  # 0 is background
                if (labels == i).sum() >= min_area:
                    keep[labels == i] = 1
            out[b, c] = torch.from_numpy(keep)

    return out.to(device=preds.device, dtype=torch.float32)


def write_metrics_csv(history, out_path, class_names=None):
    def _first_len(lst):
        for x in lst:
            if isinstance(x, (list, tuple)):
                return len(x)
        return 0

    dice_pc = history.get("val_dice_per_class", [])
    iou_pc  = history.get("val_iou_per_class", [])
    K = max(_first_len(dice_pc), _first_len(iou_pc))

    if class_names and len(class_names) == K:
        dice_cols = [f"dice_{c}" for c in class_names]
        iou_cols  = [f"iou_{c}"  for c in class_names]
    else:
        dice_cols = [f"dice_c{i}" for i in range(K)]
        iou_cols  = [f"iou_c{i}"  for i in range(K)]

    base_cols = ["epoch", "train_loss", "val_loss", "val_dice_mean", "val_iou_mean"]
    header = base_cols + dice_cols + iou_cols

    epochs = max(
        len(history.get("train_loss", [])),
        len(history.get("val_loss", [])),
        len(history.get("val_dice_mean", [])),
        len(history.get("val_iou_mean", [])),
        len(dice_pc),
        len(iou_pc),
    )

    def _get(lst, i, default=None):
        if i < len(lst):
            return lst[i]
        return default

    rows = []
    for i in range(epochs):
        tr_loss = _get(history.get("train_loss", []), i, float("nan"))
        va_loss = _get(history.get("val_loss", []), i, float("nan"))
        vd_mean = _get(history.get("val_dice_mean", []), i, float("nan"))
        vi_mean = _get(history.get("val_iou_mean", []), i, float("nan"))

        vd_pc = _get(dice_pc, i, None) or []
        vi_pc = _get(iou_pc,  i, None) or []

        if len(vd_pc) < K:
            vd_pc = list(vd_pc) + [float("nan")] * (K - len(vd_pc))
        if len(vi_pc) < K:
            vi_pc = list(vi_pc) + [float("nan")] * (K - len(vi_pc))

        row = [i + 1, tr_loss, va_loss, vd_mean, vi_mean] + list(vd_pc[:K]) + list(vi_pc[:K])
        rows.append(row)

    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(header)
        w.writerows(rows)

def save_metrics_to_drive(history: dict, config: dict, class_names=None, csv_basename: str = "metrics_expanded.csv", test_per_image_csv: str | None = None, test_summary_json: str | None = None):
    out_dir = config["OUTPUT_DIR"]
    os.makedirs(out_dir, exist_ok=True)

    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    base = os.path.splitext(csv_basename)[0]

    local_csv     = os.path.join(out_dir, csv_basename)
    local_csv_ts  = os.path.join(out_dir, f"{base}_{ts}.csv")

    write_metrics_csv(history, local_csv, class_names=class_names)
    shutil.copy2(local_csv, local_csv_ts)

    local_json = os.path.join(out_dir, f"metrics_{ts}.json")
    snapshot = {
        "timestamp": ts,
        "config": {
            k: v for k, v in config.items()
            if isinstance(v, (int, float, str, bool, list, dict, type(None)))
        },
        "class_names": class_names,
        "history": history,
        "thresholds": config.get("THRESH_PER_CLASS"),
        "post_min_area_per_class": config.get("POST_MIN_AREA_PER_CLASS"),
    }
    with open(local_json, "w") as f:
        json.dump(snapshot, f, indent=2)

    drive_dir = os.path.join(config["DRIVE_MODEL_DIR"], "metrics")
    os.makedirs(drive_dir, exist_ok=True)
    for src in [local_csv, local_csv_ts, local_json]:
        shutil.copy2(src, os.path.join(drive_dir, os.path.basename(src)))

    copied_test = []
    if test_per_image_csv and os.path.isfile(test_per_image_csv):
        dst = os.path.join(drive_dir, os.path.basename(test_per_image_csv))
        shutil.copy2(test_per_image_csv, dst)
        copied_test.append(dst)
    if test_summary_json and os.path.isfile(test_summary_json):
        dst = os.path.join(drive_dir, os.path.basename(test_summary_json))
        shutil.copy2(test_summary_json, dst)
        copied_test.append(dst)

    print("[metrics] saved locally:")
    print(" ", local_csv)
    print(" ", local_csv_ts)
    print(" ", local_json)
    print("[drive] copied to:", drive_dir)
    if copied_test:
        print("[drive][test] copied:")
        for p in copied_test:
            print(" ", p)

## Evaluation

In [None]:
@torch.no_grad()
def evaluate(model, loader, device, loss_fn, desc="Evaluating", threshold=0.5):
    model.eval()

    is_dbg = config.get("RUN_MODE", "dbg")
    amp_enabled = (device.type == 'cuda') and (not is_dbg)

    raw_thr = prepare_threshold(threshold, device=device)  # may be float or tensor

    total_loss = 0.0
    total_items = 0

    dice_sum = None
    iou_sum  = None
    r_sum    = None
    weight_sum = 0

    sd_sum = None
    sd_pc_sum = None

    pbar = tqdm(loader, desc=desc, leave=False)
    for images, patches, masks in pbar:
        bsz = images.size(0)
        images  = images.to(device, non_blocking=True)
        patches = patches.to(device, non_blocking=True)
        masks   = masks.to(device, non_blocking=True)

        with amp.autocast(device_type=device.type, enabled=amp_enabled):
            logits = model(images, patches)
            if isinstance(logits, (tuple, list)):
                logits = logits[0]
            loss = loss_fn(logits, masks)

        total_loss  += float(loss.item()) * bsz
        total_items += bsz

        logits_aligned, masks_aligned = _align_logits_targets(logits, masks)
        probs = torch.sigmoid(logits_aligned.float())  # (B,C,H,W)

        C = probs.shape[1]
        if isinstance(raw_thr, torch.Tensor):
            if raw_thr.ndim == 0:
                thr = raw_thr.view(1, 1, 1, 1).expand(1, C, 1, 1)
            elif raw_thr.ndim == 1:
                if raw_thr.numel() != C:
                    thr = torch.full((1, C, 1, 1), 0.5, device=probs.device, dtype=probs.dtype)
                else:
                    thr = raw_thr.view(1, C, 1, 1)
            else:
                thr = raw_thr.to(device=probs.device, dtype=probs.dtype)
        else:
            thr = torch.full((1, C, 1, 1), float(raw_thr), device=probs.device, dtype=probs.dtype)

        thr_exp = thr  # already shaped (1,C,1,1)
        preds = (probs >= thr_exp).float()

        preds = remove_small(preds, config["POST_MIN_AREA_PER_CLASS"], binarize=False)

        dpc = dice_per_class(preds, masks_aligned).detach().cpu()  # (C,)
        ipc = iou_per_class(preds,  masks_aligned).detach().cpu()  # (C,)
        rpc = _pearson_per_class(probs, masks_aligned).detach().cpu()  # (C,)

        with amp.autocast(device_type=device.type, enabled=False):
            p = probs  # already float32
            t = masks_aligned.float()
            dims = (0, 2, 3)
            num = 2 * (p * t).sum(dim=dims)
            den = (p.pow(2) + t.pow(2)).sum(dim=dims)
            sd_pc = (num + 1e-7) / (den + 1e-7)  # (C,)
            sd_mean = sd_pc.mean()               # scalar

        if dice_sum is None:
            dice_sum = dpc * bsz
            iou_sum  = ipc * bsz
            r_sum    = rpc * bsz
            sd_pc_sum = sd_pc.detach().cpu() * bsz
            sd_sum    = float(sd_mean.item()) * bsz
        else:
            dice_sum += dpc * bsz
            iou_sum  += ipc * bsz
            r_sum    += rpc * bsz
            sd_pc_sum += sd_pc.detach().cpu() * bsz
            sd_sum    += float(sd_mean.item()) * bsz

        weight_sum += bsz

        running_dice = float((dice_sum / max(1, weight_sum)).mean().item())
        running_soft = float((sd_pc_sum / max(1, weight_sum)).mean().item())
        pbar.set_postfix(loss=f"{loss.item():.4f}", dice=f"{running_dice:.4f}", soft=f"{running_soft:.4f}")

        del images, patches, masks, logits, logits_aligned, masks_aligned, probs, preds, loss, dpc, ipc, rpc
        torch.cuda.empty_cache()

    if weight_sum == 0:
        gc.collect()
        return float("nan"), float("nan"), [], float("nan"), [], float("nan"), []

    avg_loss = total_loss / total_items

    dpc_mean = dice_sum / weight_sum
    ipc_mean = iou_sum  / weight_sum
    rpc_mean = r_sum    / weight_sum

    mean_dice = float(dpc_mean.mean().item())
    mean_iou  = float(ipc_mean.mean().item())
    mean_r    = float(rpc_mean.mean().item())

    gc.collect()
    return (
        avg_loss,
        mean_dice, [float(x) for x in dpc_mean.tolist()],
        mean_iou,  [float(x) for x in ipc_mean.tolist()],
        mean_r,    [float(x) for x in rpc_mean.tolist()]
    )

## Training

In [None]:
def run_training(patch_df, loader_train, loader_val, epochs=None,
                 resume_checkpoint=None, accum_steps=config["ACCUM_STEPS"]):

    epochs = int(epochs or config["EPOCHS"])

    model = PatchSegNet(
        pretrained=True,
        hidden=256,
        num_classes=config["NUM_CLASSES"],
        patch_chunk=config.get("PATCH_CHUNK", 25),
        heads=4,
        use_global_film=bool(config.get("USE_GLOBAL_FILM", False)),
        config=config
    ).to(device)

    use_ema = bool(config.get("USE_EMA", True))
    ema = EMA(model, decay=float(config.get("EMA_DECAY", 0.999))) if use_ema else None

    thr_cfg = config.get("THRESH_PER_CLASS", 0.9)
    thr = prepare_threshold(thr_cfg, device=device)

    base_lr = float(config["LR"])
    wd      = float(config["WEIGHT_DECAY"])

    head_modules = _collect_head_modules(model)
    head_param_ids = {id(p) for m in head_modules for p in m.parameters()}

    NORM_TYPES = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
                  nn.GroupNorm, nn.LayerNorm, nn.InstanceNorm1d,
                  nn.InstanceNorm2d, nn.InstanceNorm3d)

    no_decay_ids = set()
    for mod in model.modules():
        if isinstance(mod, NORM_TYPES):
            for p in mod.parameters(recurse=False):
                no_decay_ids.add(id(p))

    base_decay, base_no_decay, head_decay, head_no_decay = [], [], [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        is_head = (id(p) in head_param_ids)
        is_no_decay = (id(p) in no_decay_ids) or name.endswith(".bias")

        if is_head and is_no_decay:
            head_no_decay.append(p)
        elif is_head and not is_no_decay:
            head_decay.append(p)
        elif (not is_head) and is_no_decay:
            base_no_decay.append(p)
        else:
            base_decay.append(p)

    optimizer = torch.optim.AdamW(
        [
            {"params": base_decay,    "lr": base_lr,        "weight_decay": wd},
            {"params": base_no_decay, "lr": base_lr,        "weight_decay": 0.0},
            {"params": head_decay,    "lr": base_lr * 10.0, "weight_decay": 0.0},
            {"params": head_no_decay, "lr": base_lr * 10.0, "weight_decay": 0.0},
        ],
        betas=(0.9, 0.999), eps=1e-8
    )
    for pg in optimizer.param_groups:
        pg.setdefault("init_lr", pg["lr"])

    warmup_steps = int(config.get("WARMUP_STEPS", 500))
    opt_steps = 0

    pos_w_cfg   = config.get("POS_WEIGHT", None)
    class_w_cfg = config.get("CLASS_WEIGHT", None)
    if pos_w_cfg is None or class_w_cfg is None:
        raise ValueError("Expected POS_WEIGHT and CLASS_WEIGHT in config. Compute them before run_training.")
    pos_w_t   = torch.tensor(pos_w_cfg, dtype=torch.float32, device=device)
    class_w_t = torch.tensor(class_w_cfg, dtype=torch.float32, device=device)

    base_loss = build_loss_safe(pos_weight=pos_w_t, class_weights=class_w_t)

    AUX_WEIGHTS = tuple(config.get("AUX_WEIGHTS", (1.0, 0.4, 0.2)))
    USE_AUX_BCE = bool(config.get("AUX_USE_BCE", False))
    loss_aux = (torch.nn.BCEWithLogitsLoss(pos_weight=pos_w_t)
                if USE_AUX_BCE else base_loss)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=float(config.get("LR_FACTOR", 0.5)),
        patience=int(config.get("LR_PATIENCE", 2)),
        threshold=1e-3,
        cooldown=0,
        min_lr=float(config.get("LR_MIN", 1e-6))
    )

    amp_enabled = (device.type == "cuda")
    scaler = amp.GradScaler(enabled=amp_enabled)

    best_dice = -1.0
    if resume_checkpoint and os.path.isfile(resume_checkpoint):
        ckpt = torch.load(resume_checkpoint, map_location="cpu")
        model.load_state_dict(ckpt.get("model", ckpt))

    train_metric_batches = int(config.get("TRAIN_METRIC_BATCHES", 5))

    apply_freeze_policy(1)

    for epoch in range(1, epochs + 1):
        apply_freeze_policy(epoch)

        if hasattr(loader_train.dataset, "set_epoch"):
            loader_train.dataset.set_epoch(epoch)

        model.train()
        running, n = 0.0, 0
        optimizer.zero_grad(set_to_none=True)

        pbar = tqdm(loader_train, desc=f"Epoch {epoch}/{epochs} [Training]", leave=False)
        for step, (images, patches, masks) in enumerate(pbar):
            images  = images.to(device, non_blocking=True)
            patches = patches.to(device, non_blocking=True)
            masks   = masks.to(device, non_blocking=True)

            with amp.autocast(device_type=device.type, enabled=amp_enabled):
                logits = model(images, patches)

                if isinstance(logits, (tuple, list)):
                    lm, tm   = _align_logits_targets(logits[0], masks)
                    la2, ta2 = _align_logits_targets(logits[1], masks)
                    la3, ta3 = _align_logits_targets(logits[2], masks)
                    aligned = (lm, la2, la3)
                    loss, parts = compute_seg_loss(aligned, tm, base_loss, loss_aux, weights=AUX_WEIGHTS)
                else:
                    l_aligned, t_aligned = _align_logits_targets(logits, masks)
                    loss, parts = compute_seg_loss(l_aligned, t_aligned, base_loss, loss_aux, weights=AUX_WEIGHTS)

                loss = loss / accum_steps

            if not torch.isfinite(loss):
                print(f"[nan-guard] skipping step; loss={loss.item():.4f}")
                optimizer.zero_grad(set_to_none=True)
                continue

            scaler.scale(loss).backward()

            if (step + 1) % accum_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.get("MAX_NORM", 1.0))
                scaler.step(optimizer)
                scaler.update()
                if ema is not None: ema.update(model)

                opt_steps += 1
                if warmup_steps > 0 and opt_steps <= warmup_steps:
                    scale = opt_steps / float(warmup_steps)
                    for pg in optimizer.param_groups:
                        pg["lr"] = pg["init_lr"] * scale

                optimizer.zero_grad(set_to_none=True)


            running += float(loss.item()) * images.size(0) * accum_steps
            n += images.size(0)
            pbar.set_postfix({"batch_loss": f"{loss.item() * accum_steps:.4f}",
                              "Lm": f"{parts['main']:.3f}" if not math.isnan(parts["main"]) else "nan"})

            del images, patches, masks, logits, loss
            torch.cuda.empty_cache(); gc.collect()

        train_loss = running / max(1, n)

        model.eval()
        with torch.no_grad():
            dice_sum, r_sum, count = None, None, 0
            for i, (images, patches, masks) in enumerate(loader_train):
                if i >= train_metric_batches: break
                images  = images.to(device, non_blocking=True)
                patches = patches.to(device, non_blocking=True)
                masks   = masks.to(device, non_blocking=True)

                with amp.autocast(device_type=device.type, enabled=amp_enabled):
                    logits = model(images, patches)
                    if isinstance(logits, (tuple, list)):
                        logits = logits[0]

                logits, masks = _align_logits_targets(logits, masks)
                probs = torch.sigmoid(logits)
                thr_exp = thr.view(1, -1, 1, 1) if thr.ndim == 1 else thr
                preds = (probs >= thr_exp).float()

                dpc = dice_per_class(preds, masks).detach().cpu()
                rpc = _pearson_per_class(probs, masks)

                dice_sum = dpc if dice_sum is None else (dice_sum + dpc)
                r_sum    = rpc if r_sum    is None else (r_sum    + rpc)
                count += 1

                del images, patches, masks, logits, probs, preds, dpc, rpc
                torch.cuda.empty_cache(); gc.collect()

            if count > 0:
                train_dice_pc = (dice_sum / count)
                train_dice_mean = float(train_dice_pc.mean().item())
                train_dice_pc_list = [float(x) for x in train_dice_pc.tolist()]
                train_r_pc = (r_sum / count)
                train_r_mean = float(train_r_pc.mean().item())
                train_r_pc_list = [float(x) for x in train_r_pc.tolist()]
            else:
                train_dice_mean, train_dice_pc_list = float("nan"), []
                train_r_mean,   train_r_pc_list    = float("nan"), []

        torch.cuda.empty_cache(); gc.collect()

        if ema is not None: ema.apply_to(model)
        val_loss, val_dice_mean, val_dice_pc, val_iou_mean, val_iou_pc, val_r_mean, val_r_pc = evaluate(
            model, loader_val, device, base_loss,
            desc=f"Epoch {epoch}/{epochs} [Eval Val]", threshold=thr
        )
        if ema is not None: ema.restore(model)

        THR_TUNE_START = 4
        if epoch >= THR_TUNE_START and (epoch % 3) == 0:
            try:
                tuned_thr, tuned_dpc = _tune_thresholds_per_class(
                    model, loader_val=loader_val, device=device,
                    loss_fn=base_loss, base_thresholds=config.get("THRESH_PER_CLASS"), verbose=True
                )
                config["THRESH_PER_CLASS"] = tuned_thr
                thr = prepare_threshold(tuned_thr, device=device)
            except Exception as e:
                print("[Tuner] Skipped threshold tuning due to error:", e)

        scheduler.step(val_dice_mean)

        print(f"LR after epoch {epoch}: {optimizer.param_groups[0]['lr']:.6f}")
        HISTORY["train_loss"].append(float(train_loss))
        HISTORY["train_dice_mean"].append(float(train_dice_mean))
        HISTORY["train_dice_per_class"].append(train_dice_pc_list)
        HISTORY["train_pearson_mean"].append(float(train_r_mean))
        HISTORY["train_pearson_per_class"].append(train_r_pc_list)

        HISTORY["val_loss"].append(float(val_loss))
        HISTORY["val_dice_mean"].append(float(val_dice_mean))
        HISTORY["val_dice_per_class"].append([float(x) for x in val_dice_pc])
        HISTORY["val_iou_mean"].append(float(val_iou_mean))
        HISTORY["val_iou_per_class"].append([float(x) for x in val_iou_pc])
        HISTORY["val_pearson_mean"].append(float(val_r_mean))
        HISTORY["val_pearson_per_class"].append([float(x) for x in val_r_pc])

        HISTORY["val_dice_mean_ema"].append(_ema(HISTORY["val_dice_mean_ema"][-1] if HISTORY["val_dice_mean_ema"] else None,
                                         val_dice_mean, a=0.3))

        ckpt_dir = os.path.dirname(config["DRIVE_MODEL_PATH"])
        os.makedirs(ckpt_dir, exist_ok=True)
        torch.save({"model": model.state_dict()}, os.path.join(ckpt_dir, "last.pth"))
        if val_dice_mean > best_dice:
            best_dice = val_dice_mean
            torch.save({"model": model.state_dict(), "config": config}, config["DRIVE_MODEL_PATH"])
            print(f"saved new best model at epoch {epoch} with Dice = {val_dice_mean:.4f} | IoU = {val_iou_mean:.4f}")

        print(
            f"[Epoch {epoch}/{epochs}] "
            f"train_loss={train_loss:.4f} | train_dice={train_dice_mean:.4f}  ||  "
            f"val_loss={val_loss:.4f} | val_dice={val_dice_mean:.4f} | val_iou={val_iou_mean:.4f}"
        )
        if train_dice_pc_list:
            print("  Train Dice per class:", [f"{x:.4f}" for x in train_dice_pc_list])
        print("  Val   Dice per class:", [f"{x:.4f}" for x in val_dice_pc])

        if _bad_loss(train_loss, HISTORY["train_loss"][:-1], config.get("LOSS_SPIKE_FACTOR", 3.0)):
            print(f"early exit: loss spike at epoch {epoch}")
            break

        mon   = config.get("EARLY_STOP_MONITOR", "val_dice_mean_ema")
        mode  = config.get("EARLY_STOP_MODE", "max")
        pat   = int(config.get("EARLY_STOP_PATIENCE", 6))
        delta = float(config.get("EARLY_STOP_MIN_DELTA", 0.002))
        series = HISTORY.get(mon, [])

        if _has_plateaued(series, mode, pat, delta):
            print(f"early exit: {mon} plateaued at epoch {epoch}")
            break

        if not math.isfinite(train_loss) or not math.isfinite(val_loss):
            print(f"early exit: non-finite loss at epoch {epoch}")
            break

        torch.cuda.empty_cache(); gc.collect()

    return model, HISTORY

## Utility Functions `LOADERS`

In [None]:
class AlbuNoOp:
    def __call__(self, **kwargs): return kwargs

def _seed_worker(worker_id):
    wseed = config["SEED"] + worker_id
    np.random.seed(wseed)
    random.seed(wseed)

def image_weights_from_df(df, boost_ma=BOOSTS["MA"], boost_ex=BOOSTS["EX"], boost_he=BOOSTS["HE"]):
    by_img = df.groupby("image_id")["label_vector"].first()
    w = {}
    for img_id, lv in by_img.items():
        # lv is [MA, HE, EX, SE]
        b = 1.0
        if lv[0] > 0: b *= boost_ma  # MA
        if lv[2] > 0: b *= boost_ex  # EX
        if lv[1] > 0: b *= boost_he  # HE
        w[img_id] = b
    return w

def build_loader(dataset, *, shuffle=False, sampler=None, drop_last=False):
    nw   = int(mp_config["NUM_WORKERS"])
    pin  = bool(mp_config["PIN_MEMORY"])
    pw   = (nw > 0) and bool(mp_config["PERSISTENT_WORKERS"])
    pf   = (int(mp_config["PREFETCH_FACTOR"]) if nw > 0 else None)
    ctx  = mp_config.get("MP_CONTEXT", None)
    tmo  = int(mp_config.get("TIMEOUT", 0))

    if sampler is not None:
        if shuffle:
            shuffle = False

    return DataLoader(
        dataset=dataset,
        batch_size=config["BATCH_SIZE"],
        shuffle=shuffle,
        sampler=sampler,
        num_workers=nw,
        pin_memory=pin,
        persistent_workers=pw,
        prefetch_factor=pf,
        worker_init_fn=_seed_worker,
        generator=g,
        multiprocessing_context=ctx,
        timeout=tmo,
        drop_last=drop_last,
    )

## Loaders

In [None]:
g = torch.Generator()
g.manual_seed(config["SEED"])

# splits
train_ids = patch_df[patch_df["split"] == "train"]["image_id"].unique().tolist()
val_ids   = patch_df[patch_df["split"] == "val"]["image_id"].unique().tolist()
dtr = patch_df[patch_df["image_id"].isin(train_ids)]
dva = patch_df[patch_df["image_id"].isin(val_ids)]

ds_train = PatchSegDataset(
    dtr, config["IMAGE_ROOT"], config["PATCHES_ROOT"], config["MASK_DIR"],
    transform_img=None, transform_patch=transform_patch, transform_mask=None,
    fixed_patches_train=config.get("FIXED_PATCHES_TRAIN", 48),
    fixed_patches_val=config.get("FIXED_PATCHES_VAL", None),
    seed=config["SEED"], train_mode=True, joint_tf=albu_train,
)

ds_val = PatchSegDataset(
    dva, config["IMAGE_ROOT"], config["PATCHES_ROOT"], config["MASK_DIR"],
    transform_img=None, transform_patch=transform_patch, transform_mask=None,
    fixed_patches_train=config.get("FIXED_PATCHES_TRAIN", 48),  # ignored in eval
    fixed_patches_val=config.get("FIXED_PATCHES_VAL", None),
    seed=config["SEED"], train_mode=False, joint_tf=albu_val,
)

img_w = image_weights_from_df(dtr, boost_ma=3.0, boost_ex=1.5, boost_he=1.2)
train_sample_weights = [img_w[iid] for iid in ds_train.image_ids]

sampler = WeightedRandomSampler(
    train_sample_weights, num_samples=len(train_sample_weights), replacement=True
)

# loaders
loader_train = build_loader(ds_train, sampler=sampler)
loader_val   = build_loader(ds_val,   shuffle=False)

if hasattr(loader_val.dataset, "set_epoch"):
    loader_val.dataset.set_epoch(0)

lv_by_img = patch_df[patch_df["split"] == "train"].groupby("image_id")["label_vector"].first()
lv = np.stack(lv_by_img.values, axis=0)  # (N_img, C)
img_prev = lv.mean(axis=0).tolist()

albu_noop = AlbuNoOp()

ds_train_stats = PatchSegDataset(
    dtr, config["IMAGE_ROOT"], config["PATCHES_ROOT"], config["MASK_DIR"],
    transform_img=albu_noop, transform_patch=albu_noop, transform_mask=albu_noop,
    fixed_patches_train=config.get("FIXED_PATCHES_TRAIN", 48),
    fixed_patches_val=config.get("FIXED_PATCHES_VAL", None), seed=config["SEED"], train_mode=False, joint_tf=None,
)
cov_ratios = compute_cov_ratios(ds_train_stats, thr=0.5)


POSW_CAP = float(config.get("POSW_CAP", 12.0))
pos_w = build_pos_weight(cov_ratios, cap=POSW_CAP, device=device)

CLASSW_CAP = float(config.get("CLASSW_CAP", 4.0))
CLASSW_POWER = float(config.get("CLASSW_POWER", 1.0))
class_w = build_class_weights_from_image_prev(img_prev, power=CLASSW_POWER, cap=CLASSW_CAP, device=device)

print("\n[weights] img_prev     :", [f"{x:.4f}" for x in img_prev])
print("[weights] cov_ratios   :", [f"{x:.8f}" for x in cov_ratios])
print("[weights] pos_weight   :", [f"{x:.2f}" for x in pos_w.tolist()])
print("[weights] class_weight :", [f"{x:.2f}" for x in class_w.tolist()])

config["COV_RATIOS"]   = cov_ratios
config["POS_WEIGHT"]   = pos_w.tolist()
config["CLASS_WEIGHT"] = class_w.tolist()




[weights] img_prev     : ['0.2112', '0.2958', '0.2342', '0.0337']
[weights] cov_ratios   : ['0.00331468', '0.01099265', '0.00756576', '0.00165329']
[weights] pos_weight   : ['12.00', '12.00', '12.00', '12.00']
[weights] class_weight : ['1.04', '0.88', '1.04', '1.04']


## Main Loop

In [None]:
model, HISTORY = run_training(
    patch_df,
    loader_train=loader_train,
    loader_val=loader_val,
    epochs=config["EPOCHS"],
    resume_checkpoint=None,
)

save_metrics_to_drive(
    HISTORY,
    config,
    class_names=config.get("CLASS_NAMES")
)

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


100%|██████████| 83.3M/83.3M [00:00<00:00, 198MB/s]


LR after epoch 1: 0.000061
saved new best model at epoch 1 with Dice = 0.0663 | IoU = 0.0385
[Epoch 1/25] train_loss=1.5084 | train_dice=0.1211  ||  val_loss=0.9869 | val_dice=0.0663 | val_iou=0.0385
  Train Dice per class: ['0.0253', '0.1122', '0.1319', '0.2149']
  Val   Dice per class: ['0.0059', '0.1421', '0.0647', '0.0525']




LR after epoch 2: 0.000100
saved new best model at epoch 2 with Dice = 0.1201 | IoU = 0.0737
[Epoch 2/25] train_loss=1.3795 | train_dice=0.1763  ||  val_loss=0.9520 | val_dice=0.1201 | val_iou=0.0737
  Train Dice per class: ['0.0504', '0.3422', '0.0713', '0.2411']
  Val   Dice per class: ['0.0173', '0.2129', '0.1591', '0.0911']




LR after epoch 3: 0.000100
saved new best model at epoch 3 with Dice = 0.1728 | IoU = 0.1138
[Epoch 3/25] train_loss=1.2809 | train_dice=0.2453  ||  val_loss=0.9149 | val_dice=0.1728 | val_iou=0.1138
  Train Dice per class: ['0.1408', '0.2380', '0.4729', '0.1296']
  Val   Dice per class: ['0.0531', '0.2737', '0.2046', '0.1599']




LR after epoch 4: 0.000100
saved new best model at epoch 4 with Dice = 0.1899 | IoU = 0.1262
[Epoch 4/25] train_loss=1.2310 | train_dice=0.2940  ||  val_loss=0.8672 | val_dice=0.1899 | val_iou=0.1262
  Train Dice per class: ['0.1092', '0.3941', '0.3014', '0.3715']
  Val   Dice per class: ['0.0727', '0.2969', '0.2290', '0.1611']




LR after epoch 5: 0.000100
saved new best model at epoch 5 with Dice = 0.2172 | IoU = 0.1505
[Epoch 5/25] train_loss=1.1933 | train_dice=0.3499  ||  val_loss=0.8231 | val_dice=0.2172 | val_iou=0.1505
  Train Dice per class: ['0.1368', '0.2253', '0.8374', '0.2000']
  Val   Dice per class: ['0.0857', '0.3109', '0.2719', '0.2003']




[Tuner] New per-class thresholds: [0.55, 0.55, 0.55, 0.55]
[Tuner] Dice per class at tuned thresholds: ['0.1236', '0.3621', '0.2958', '0.4802']
LR after epoch 6: 0.000100
saved new best model at epoch 6 with Dice = 0.2564 | IoU = 0.1874
[Epoch 6/25] train_loss=1.1525 | train_dice=0.4415  ||  val_loss=0.7927 | val_dice=0.2564 | val_iou=0.1874
  Train Dice per class: ['0.2396', '0.4073', '0.4415', '0.6776']
  Val   Dice per class: ['0.0948', '0.3321', '0.3108', '0.2877']




LR after epoch 7: 0.000100
saved new best model at epoch 7 with Dice = 0.2828 | IoU = 0.2126
[Epoch 7/25] train_loss=1.1259 | train_dice=0.3356  ||  val_loss=0.7741 | val_dice=0.2828 | val_iou=0.2126
  Train Dice per class: ['0.0948', '0.4473', '0.2358', '0.5645']
  Val   Dice per class: ['0.1013', '0.3603', '0.3415', '0.3281']




LR after epoch 8: 0.000100
saved new best model at epoch 8 with Dice = 0.2945 | IoU = 0.2239
[Epoch 8/25] train_loss=1.0984 | train_dice=0.3343  ||  val_loss=0.7633 | val_dice=0.2945 | val_iou=0.2239
  Train Dice per class: ['0.0233', '0.4002', '0.3952', '0.5186']
  Val   Dice per class: ['0.1141', '0.3622', '0.3657', '0.3360']




[Tuner] New per-class thresholds: [0.55, 0.55, 0.55, 0.55]
[Tuner] Dice per class at tuned thresholds: ['0.1293', '0.3327', '0.3646', '0.3319']
LR after epoch 9: 0.000100
saved new best model at epoch 9 with Dice = 0.3059 | IoU = 0.2352
[Epoch 9/25] train_loss=1.0827 | train_dice=0.3934  ||  val_loss=0.7580 | val_dice=0.3059 | val_iou=0.2352
  Train Dice per class: ['0.3250', '0.3166', '0.3497', '0.5825']
  Val   Dice per class: ['0.1173', '0.3652', '0.3798', '0.3612']




LR after epoch 10: 0.000100
[Epoch 10/25] train_loss=1.0781 | train_dice=0.2583  ||  val_loss=0.7539 | val_dice=0.3041 | val_iou=0.2330
  Train Dice per class: ['0.3011', '0.3164', '0.2520', '0.1635']
  Val   Dice per class: ['0.1206', '0.3619', '0.3783', '0.3556']




LR after epoch 11: 0.000100
saved new best model at epoch 11 with Dice = 0.3088 | IoU = 0.2375
[Epoch 11/25] train_loss=1.0382 | train_dice=0.2623  ||  val_loss=0.7517 | val_dice=0.3088 | val_iou=0.2375
  Train Dice per class: ['0.0214', '0.4979', '0.2719', '0.2579']
  Val   Dice per class: ['0.1262', '0.3648', '0.3831', '0.3611']




[Tuner] New per-class thresholds: [0.45, 0.55, 0.55, 0.55]
[Tuner] Dice per class at tuned thresholds: ['0.1396', '0.3883', '0.4101', '0.4448']
LR after epoch 12: 0.000100
saved new best model at epoch 12 with Dice = 0.3139 | IoU = 0.2428
[Epoch 12/25] train_loss=1.0427 | train_dice=0.3006  ||  val_loss=0.7518 | val_dice=0.3139 | val_iou=0.2428
  Train Dice per class: ['0.1150', '0.2845', '0.2027', '0.6000']
  Val   Dice per class: ['0.1312', '0.3699', '0.3843', '0.3701']




LR after epoch 13: 0.000100
saved new best model at epoch 13 with Dice = 0.3193 | IoU = 0.2483
[Epoch 13/25] train_loss=1.0159 | train_dice=0.3645  ||  val_loss=0.7521 | val_dice=0.3193 | val_iou=0.2483
  Train Dice per class: ['0.2862', '0.3855', '0.3357', '0.4505']
  Val   Dice per class: ['0.1333', '0.3827', '0.3781', '0.3831']




LR after epoch 14: 0.000100
saved new best model at epoch 14 with Dice = 0.3199 | IoU = 0.2491
[Epoch 14/25] train_loss=1.0140 | train_dice=0.3667  ||  val_loss=0.7539 | val_dice=0.3199 | val_iou=0.2491
  Train Dice per class: ['0.0568', '0.4373', '0.6175', '0.3553']
  Val   Dice per class: ['0.1347', '0.3804', '0.3823', '0.3824']




[Tuner] New per-class thresholds: [0.35, 0.55, 0.55, 0.55]
[Tuner] Dice per class at tuned thresholds: ['0.1326', '0.3440', '0.3889', '0.4554']
LR after epoch 15: 0.000100
saved new best model at epoch 15 with Dice = 0.3298 | IoU = 0.2589
[Epoch 15/25] train_loss=1.0030 | train_dice=0.3417  ||  val_loss=0.7550 | val_dice=0.3298 | val_iou=0.2589
  Train Dice per class: ['0.1907', '0.3779', '0.3436', '0.4547']
  Val   Dice per class: ['0.1397', '0.3875', '0.3877', '0.4043']




LR after epoch 16: 0.000100
saved new best model at epoch 16 with Dice = 0.3364 | IoU = 0.2658
[Epoch 16/25] train_loss=0.9744 | train_dice=0.2976  ||  val_loss=0.7570 | val_dice=0.3364 | val_iou=0.2658
  Train Dice per class: ['0.1715', '0.3549', '0.2212', '0.4429']
  Val   Dice per class: ['0.1443', '0.3958', '0.3863', '0.4191']




LR after epoch 17: 0.000100
saved new best model at epoch 17 with Dice = 0.3403 | IoU = 0.2701
[Epoch 17/25] train_loss=0.9675 | train_dice=0.3665  ||  val_loss=0.7580 | val_dice=0.3403 | val_iou=0.2701
  Train Dice per class: ['0.1389', '0.5071', '0.6140', '0.2059']
  Val   Dice per class: ['0.1460', '0.3944', '0.3922', '0.4287']




[Tuner] New per-class thresholds: [0.55, 0.15, 0.55, 0.55]
[Tuner] Dice per class at tuned thresholds: ['0.1427', '0.4043', '0.3745', '0.5111']
LR after epoch 18: 0.000100
saved new best model at epoch 18 with Dice = 0.3554 | IoU = 0.2856
[Epoch 18/25] train_loss=0.9511 | train_dice=0.5065  ||  val_loss=0.7596 | val_dice=0.3554 | val_iou=0.2856
  Train Dice per class: ['0.3080', '0.2180', '0.7962', '0.7038']
  Val   Dice per class: ['0.1502', '0.4052', '0.3926', '0.4735']




LR after epoch 19: 0.000100
[Epoch 19/25] train_loss=0.9464 | train_dice=0.2245  ||  val_loss=0.7611 | val_dice=0.3519 | val_iou=0.2824
  Train Dice per class: ['0.1023', '0.2344', '0.4383', '0.1230']
  Val   Dice per class: ['0.1510', '0.3890', '0.3917', '0.4759']




LR after epoch 20: 0.000100
[Epoch 20/25] train_loss=0.9320 | train_dice=0.3577  ||  val_loss=0.7616 | val_dice=0.3527 | val_iou=0.2831
  Train Dice per class: ['0.0913', '0.4318', '0.3391', '0.5685']
  Val   Dice per class: ['0.1520', '0.3925', '0.3891', '0.4771']




[Tuner] New per-class thresholds: [0.55, 0.55, 0.55, 0.55]
[Tuner] Dice per class at tuned thresholds: ['0.1511', '0.3884', '0.3718', '0.5378']
LR after epoch 21: 0.000050
[Epoch 21/25] train_loss=0.9241 | train_dice=0.4329  ||  val_loss=0.7615 | val_dice=0.3530 | val_iou=0.2833
  Train Dice per class: ['0.2572', '0.2968', '0.4096', '0.7680']
  Val   Dice per class: ['0.1519', '0.3898', '0.3926', '0.4776']




LR after epoch 22: 0.000050
saved new best model at epoch 22 with Dice = 0.3583 | IoU = 0.2887
[Epoch 22/25] train_loss=0.8933 | train_dice=0.4294  ||  val_loss=0.7615 | val_dice=0.3583 | val_iou=0.2887
  Train Dice per class: ['0.3041', '0.3194', '0.4150', '0.6790']
  Val   Dice per class: ['0.1524', '0.4083', '0.3925', '0.4800']




LR after epoch 23: 0.000050
saved new best model at epoch 23 with Dice = 0.3639 | IoU = 0.2941
[Epoch 23/25] train_loss=0.8688 | train_dice=0.4138  ||  val_loss=0.7633 | val_dice=0.3639 | val_iou=0.2941
  Train Dice per class: ['0.2401', '0.3785', '0.5797', '0.4568']
  Val   Dice per class: ['0.1489', '0.4095', '0.3882', '0.5089']




[Tuner] New per-class thresholds: [0.45, 0.55, 0.55, 0.45]
[Tuner] Dice per class at tuned thresholds: ['0.1518', '0.4209', '0.4205', '0.5650']
LR after epoch 24: 0.000050
saved new best model at epoch 24 with Dice = 0.3680 | IoU = 0.2987
[Epoch 24/25] train_loss=0.8571 | train_dice=0.3742  ||  val_loss=0.7653 | val_dice=0.3680 | val_iou=0.2987
  Train Dice per class: ['0.2025', '0.4973', '0.4796', '0.3174']
  Val   Dice per class: ['0.1479', '0.4133', '0.3925', '0.5183']




LR after epoch 25: 0.000050
[Epoch 25/25] train_loss=0.8693 | train_dice=0.4371  ||  val_loss=0.7667 | val_dice=0.3643 | val_iou=0.2949
  Train Dice per class: ['0.1935', '0.5051', '0.4429', '0.6067']
  Val   Dice per class: ['0.1487', '0.4144', '0.3871', '0.5068']
wrote metrics to: /content/output/metrics_expanded.csv


## Utility Functions `TEST`

In [None]:
@torch.no_grad()
def run_test(model, patch_df, config, device, save_preds=True, save_csv=True):
    model.eval()

    # --- dataset and loader: use ALL patches for test ---
    test_ids = patch_df[patch_df["split"] == "test"]["image_id"].unique().tolist()
    dte      = patch_df[patch_df["image_id"].isin(test_ids)]

    ds_test = PatchSegDataset(
        dte,
        image_dir=config["IMAGE_ROOT"],
        patches_root=config["PATCHES_ROOT"],
        mask_dir=config["MASK_DIR"],
        transform_img=None,
        transform_patch=transform_patch,
        transform_mask=None,
        fixed_patches_train=config.get("FIXED_PATCHES_TRAIN", 48),
        fixed_patches_val=None,
        seed=config["SEED"],
        train_mode=False,
        joint_tf=albu_val
    )
    loader_test = build_loader(ds_test, shuffle=False)

    thr = prepare_threshold(config.get("THRESH_PER_CLASS", [0.5, 0.5, 0.5, 0.5]), device=device)
    per_cls_min = config.get("POST_MIN_AREA_PER_CLASS", [0, 20, 10, 20])

    out_dir = config["OUTPUT_DIR"]
    pred_dir = os.path.join(out_dir, "predictions_test")
    os.makedirs(pred_dir, exist_ok=True)

    per_image_rows = []
    dice_sum = None
    iou_sum = None
    r_sum = None
    processed = 0

    for (images, patches, masks) in tqdm(loader_test, desc="Testing", leave=False):
        bsz = images.size(0)
        images  = images.to(device, non_blocking=True)
        patches = patches.to(device, non_blocking=True)
        masks   = masks.to(device, non_blocking=True)

        out = model(images, patches)
        if isinstance(out, tuple):
            out = out[0]

        logits, masks_al = _align_logits_targets(out, masks)
        probs = torch.sigmoid(logits)

        thr_exp = thr.view(1, -1, 1, 1) if thr.ndim == 1 else thr
        preds_bin = (probs >= thr_exp).float()
        preds_bin = remove_small(preds_bin, per_cls_min, binarize=False)

        # ---- per-image rows and optional saves, loop over the batch ----
        for bi in range(bsz):
            # metrics for this image only
            dpc = dice_per_class(preds_bin[bi:bi+1], masks_al[bi:bi+1]).detach().cpu()
            ipc = io_per_class(preds_bin[bi:bi+1], masks_al[bi:bi+1]).detach().cpu()
            rpc = _pearson_per_class(probs[bi:bi+1], masks_al[bi:bi+1]).detach().cpu()

            dice_sum = dpc if dice_sum is None else (dice_sum + dpc)
            iou_sum  = ipc if iou_sum  is None else (iou_sum  + ipc)
            r_sum    = rpc if r_sum    is None else (r_sum + rpc)

            image_id = ds_test.image_ids[processed + bi]
            per_image_rows.append({
                "image_id": image_id,
                "dice_MA": float(dpc[0]), "dice_HE": float(dpc[1]),
                "dice_EX": float(dpc[2]), "dice_SE": float(dpc[3]),
                "iou_MA":  float(ipc[0]), "iou_HE":  float(ipc[1]),
                "iou_EX":  float(ipc[2]), "iou_SE":  float(ipc[3]),
                "pearson_MA": float(rpc[0]), "pearson_HE": float(rpc[1]),
                "pearson_EX": float(rpc[2]), "pearson_SE": float(rpc[3]),
            })

            if save_preds:
                pm = preds_bin[bi].detach().cpu().numpy()  # (C,H,W)
                for ci, cls_name in enumerate(["MA","HE","EX","SE"][:pm.shape[0]]):
                    out_path = os.path.join(pred_dir, f"{image_id}_{cls_name}.png")
                    cv2.imwrite(out_path, (pm[ci] * 255).astype(np.uint8))

        processed += bsz

        del images, patches, masks, logits, probs, preds_bin
        torch.cuda.empty_cache(); gc.collect()

    n_imgs = max(1, processed)
    dice_mean_pc = (dice_sum / n_imgs).numpy().tolist()
    iou_mean_pc  = (iou_sum  / n_imgs).numpy().tolist()
    pearson_mean_pc = (r_sum / max(1, processed)).numpy().tolist()
    pearson_mean = float(np.mean(pearson_mean_pc))
    dice_mean = float(np.mean(dice_mean_pc))
    iou_mean  = float(np.mean(iou_mean_pc))

    if save_csv:
        per_img_df = pd.DataFrame(per_image_rows)
        per_img_csv = os.path.join(out_dir, "test_per_image_metrics.csv")
        per_img_df.to_csv(per_img_csv, index=False)

        summary = {
            "dice_mean": dice_mean,
            "iou_mean": iou_mean,
            "dice_per_class": dice_mean_pc,
            "iou_per_class": iou_mean_pc,
            "thresholds": config.get("THRESH_PER_CLASS", [0.5]*4),
            "post_min_area_per_class": per_cls_min
        }
        with open(os.path.join(out_dir, "test_summary.json"), "w") as f:
            json.dump(summary, f, indent=2)

    print("\n=== TEST SUMMARY ===")
    print(f"Dice mean: {dice_mean:.4f} | IoU mean: {iou_mean:.4f}")
    print("Dice per class [MA, HE, EX, SE]:", ["%.4f" % x for x in dice_mean_pc])
    print("IoU per class  [MA, HE, EX, SE]:", ["%.4f" % x for x in iou_mean_pc])
    print("Pearson per class [MA, HE, EX, SE]:", ["%.4f" % x for x in pearson_mean_pc])

    return {
        "dice_mean": dice_mean,
        "iou_mean": iou_mean,
        "dice_per_class": dice_mean_pc,
        "iou_per_class": iou_mean_pc,
        "pearson_mean": pearson_mean,
        "pearson_per_class": pearson_mean_pc,
    }

## Testing

In [None]:
ckpt_path = config["DRIVE_MODEL_PATH"]
if os.path.isfile(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(ckpt.get("model", ckpt), strict=False)

test_result = run_test(model, patch_df, config, device, save_preds=True, save_csv=True)

test_per_image_csv = os.path.join(config["OUTPUT_DIR"], "test_per_image_metrics.csv")
test_summary_json  = os.path.join(config["OUTPUT_DIR"], "test_summary.json")

save_metrics_to_drive(
    HISTORY,
    config,
    class_names=config.get("CLASS_NAMES"),
    test_per_image_csv=test_per_image_csv,
    test_summary_json=test_summary_json
)

                                                          


=== TEST SUMMARY ===
Dice mean: 0.3749 | IoU mean: 0.3060
Dice per class [MA, HE, EX, SE]: ['0.1362', '0.4267', '0.4164', '0.5201']
IoU per class  [MA, HE, EX, SE]: ['0.0881', '0.3237', '0.3286', '0.4835']
Pearson per class [MA, HE, EX, SE]: ['0.1509', '0.3786', '0.2981', '0.1454']
[metrics] saved locally:
  /content/output/metrics_expanded.csv
  /content/output/metrics_expanded_20250829_182811.csv
  /content/output/metrics_20250829_182811.json
[drive] copied to: /content/drive/MyDrive/models/metrics
[drive][test] copied:
  /content/drive/MyDrive/models/metrics/test_per_image_metrics.csv
  /content/drive/MyDrive/models/metrics/test_summary.json




## Utility Functions `PREDS`

In [None]:
def denorm_to_uint8(t):
    if isinstance(t, torch.Tensor):
        t = t.detach().cpu()
    mean = torch.tensor(IMAGENET_MEAN).view(3,1,1)
    std  = torch.tensor(IMAGENET_STD).view(3,1,1)
    x = (t * std + mean).clamp(0, 1)
    x = (x * 255.0).byte().permute(1,2,0).numpy()  # HxWx3 uint8
    return x

@torch.no_grad()
def visualize_predictions(model, dataset, indices=None, device="cpu", threshold=None):

    device = torch.device(device) if not isinstance(device, torch.device) else device
    model.eval()

    if indices is None:
        indices = list(range(len(dataset)))
    elif isinstance(indices, int):
        indices = [indices]

    default_colors = {"MA": [1,0,0], "HE": [0,1,0], "EX": [0,0,1], "SE": [1,1,0]}
    names  = globals().get("LESION_NAMES", ["MA","HE","EX","SE"])
    colors = globals().get("LESION_COLORS", default_colors)

    n = len(indices)
    fig, axes = plt.subplots(n, 3, figsize=(15, 5*n))
    if n == 1:
        axes = np.expand_dims(axes, 0)

    for row, idx in enumerate(indices):
        image, patch, mask = dataset[idx]
        H, W = image.shape[1], image.shape[2]

        image_in = image.unsqueeze(0).to(device, non_blocking=True)
        patch_in = patch.unsqueeze(0).to(device, non_blocking=True)

        with amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
            out = model(image_in, patch_in)
        logits = out[0] if isinstance(out, (tuple, list)) else out

        logits, mask_bchw = _align_logits_targets(logits, mask.unsqueeze(0).to(device))
        if logits.shape[-2:] != (H, W):
            logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
            mask_bchw = F.interpolate(mask_bchw, size=(H, W), mode="nearest")

        probs = torch.sigmoid(logits)

        thr_cfg = threshold if threshold is not None else config.get("THRESH_PER_CLASS", 0.5)
        C = probs.shape[1]

        if isinstance(thr_cfg, (list, tuple, np.ndarray, torch.Tensor)):
            thr_vec = torch.as_tensor(thr_cfg, device=probs.device, dtype=probs.dtype)
            if thr_vec.ndim == 0:
                thr_vec = thr_vec.repeat(C)
            elif thr_vec.numel() != C:
                thr_vec = torch.full((C,), 0.5, device=probs.device, dtype=probs.dtype)
            thr = thr_vec.view(1, C, 1, 1)
        else:
            thr = torch.full((1, C, 1, 1), float(thr_cfg), device=probs.device, dtype=probs.dtype)

        preds = (probs >= thr).float()

        img_disp = denorm_to_uint8(image) / 255.0
        m_np = mask_bchw[0].detach().cpu().numpy()
        p_np = preds[0].detach().cpu().numpy()

        axes[row, 0].imshow(img_disp)
        axes[row, 0].set_title(f"Original idx={idx}")
        axes[row, 0].axis("off")

        if m_np.ndim == 3 and m_np.shape[0] > 1:
            overlay_gt = np.zeros_like(img_disp)
            for c, name in enumerate(names[:m_np.shape[0]]):
                col = np.array(colors.get(name, [1,1,1]), dtype=img_disp.dtype)
                overlay_gt += np.expand_dims(m_np[c], -1) * col
            overlay_gt = np.clip(overlay_gt, 0, 1)
            axes[row, 1].imshow(img_disp, alpha=0.6)
            axes[row, 1].imshow(overlay_gt, alpha=0.4)
            axes[row, 1].set_title("Ground Truth")
        else:
            axes[row, 1].imshow(m_np.squeeze(), cmap="gray")
            axes[row, 1].set_title("Ground Truth")
        axes[row, 1].axis("off")

        if p_np.ndim == 3 and p_np.shape[0] > 1:
            overlay_pred = np.zeros_like(img_disp)
            for c, name in enumerate(names[:p_np.shape[0]]):
                col = np.array(colors.get(name, [1,1,1]), dtype=img_disp.dtype)
                overlay_pred += np.expand_dims(p_np[c], -1) * col
            overlay_pred = np.clip(overlay_pred, 0, 1)
            axes[row, 2].imshow(img_disp, alpha=0.6)
            axes[row, 2].imshow(overlay_pred, alpha=0.4)
            axes[row, 2].set_title("Predicted")
        else:
            axes[row, 2].imshow(p_np.squeeze(), cmap="gray")
            axes[row, 2].set_title("Predicted")
        axes[row, 2].axis("off")

    plt.tight_layout()
    plt.show()

## Predictions


In [None]:
indices = random.sample(range(len(ds_val)), 15)
visualize_predictions(model, ds_val, indices, device=device,
                      threshold=[0.30, 0.45, 0.45, 0.45])