In [1]:
!pip -q install timm==0.9.2 segmentation-models-pytorch==0.3.3 unzip



[notice] A new release of pip is available: 23.2.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [4]:
!file /content/train.zip
!unzip -q /content/train.zip

/content/train.zip: Zip archive data, at least v2.0 to extract, compression method=store


In [1]:
from src.mask_generation import process_all_parallel
from src.detect_bad_masks import get_bad_masks_parallel,regenerate_missing


train_path="/content/train"

process_all_parallel(train_path,workers=12)
get_bad_masks_parallel(train_path,delete=True,workers=12)
regenerate_missing(train_path)

Processing ECGs: 100%|██████████| 977/977 [06:34<00:00,  2.48ecg/s]

DONE | ok=977 bad=0 total=977



Scanning npz files: 100%|██████████| 977/977 [00:13<00:00, 74.45it/s]


Total npz: 977
Bad npz: 0
The corrupted masks will be deleted? True
Missing masks: 0


Regenerating: 0it [00:00, ?it/s]


In [1]:
import os, glob, random
import numpy as np
from PIL import Image
import cv2
import timm, segmentation_models_pytorch as smp

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import segmentation_models_pytorch as smp

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

DATA_ROOT = "/content/train"   # <-- change
# res of source images  is (1700,2200) , to use Unet/Imagenet must be multiple of 32
#H_T, W_T = 1696, 2176
H_T, W_T = 864,1120
K = 13

BATCH_SIZE = 32                      # try 4; if OOM, drop to 2 or 1
NUM_WORKERS = 8
LR = 3e-4
EPOCHS = 20                          # debug run first


In [2]:

class ECGSegDataset(Dataset):
    def __init__(self, folders, H=H_T, W=W_T):
        self.folders = folders
        self.H, self.W = H, W

    def __len__(self):
        return len(self.folders)

    def __getitem__(self, idx):
        d = self.folders[idx]
        png_path = sorted(glob.glob(os.path.join(d, "*0001.png")))[0]
        npz_path = sorted(glob.glob(os.path.join(d, "mask-*.npz")))[0]

        # image
        img = np.array(Image.open(png_path).convert("L"), dtype=np.float32) / 255.0  # (H0,W0)
        img_r = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)      # (H,W)
        x = torch.from_numpy(img_r[None, ...]).float()                               # (1,H,W)

        # masks
        z = np.load(npz_path, allow_pickle=True)
        masks = z["masks"]  # expected (H0,W0,13) uint8

        if masks.ndim != 3:
            raise ValueError(f"Unexpected masks.ndim={masks.ndim} in {npz_path}")
        if masks.shape[-1] != K:
            raise ValueError(f"Expected last dim {K}, got {masks.shape} in {npz_path}")

        masks_r = np.zeros((self.H, self.W, K), dtype=np.uint8)
        for k in range(K):
            masks_r[..., k] = cv2.resize(masks[..., k], (self.W, self.H), interpolation=cv2.INTER_NEAREST)

        y = torch.from_numpy(np.transpose(masks_r, (2,0,1))).float()  # (13,H,W)
        return x, y



In [3]:
folders = sorted([p for p in glob.glob(os.path.join(DATA_ROOT, "*")) if os.path.isdir(p)])
random.shuffle(folders)

val_n = max(200, int(0.1 * len(folders)))
val_f = folders[:val_n]
tr_f  = folders[val_n:]

train_loader = DataLoader(ECGSegDataset(tr_f), batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(ECGSegDataset(val_f), batch_size=1, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

len(tr_f), len(val_f)


(777, 200)

In [4]:
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=1,
    classes=K,
).to(DEVICE)


In [5]:
bce = nn.BCEWithLogitsLoss()

def dice_loss(logits, y, eps=1e-6):
    p = torch.sigmoid(logits)
    num = 2*(p*y).sum((2,3))
    den = (p+y).sum((2,3)) + eps
    return (1 - num/den).mean()

def loss_fn(logits, y):
    return bce(logits, y) + dice_loss(logits, y)

opt = torch.optim.AdamW(model.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))


  scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))


In [6]:
from tqdm.auto import tqdm

def run_epoch(loader, train=True, log_every=50):
    model.train(train)
    tot, n = 0.0, 0

    pbar = tqdm(loader, desc="train" if train else "val", leave=False)
    for step, (x, y) in enumerate(pbar, start=1):
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
            logits = model(x)
            loss = loss_fn(logits, y)

        if train:
            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

        tot += loss.item()
        n += 1
        avg = tot / n

        # Update progress bar postfix every step (lightweight)
        pbar.set_postfix(loss=f"{loss.item():.4f}", avg=f"{avg:.4f}")

        # Optional: also print a line every log_every steps
        if log_every and (step % log_every == 0):
            tqdm.write(f"[{'train' if train else 'val'}] step {step}/{len(loader)}  loss={loss.item():.4f}  avg={avg:.4f}")

    return tot / max(n, 1)




In [7]:
from torch.utils.data import DataLoader

debug_loader = DataLoader(
    val_loader.dataset,   # reuse the same dataset
    batch_size=1,
    shuffle=False,
    num_workers=0,         # IMPORTANT: avoid worker/process issues
    pin_memory=False
)

debug_x, debug_y = next(iter(debug_loader))
debug_x = debug_x.contiguous()
debug_y = debug_y.contiguous()


In [None]:
from src.train import save_pred_debug_epoch
DEBUG_DIR = "debug_train_evolution"

best_val = float("inf")
for ep in range(1, EPOCHS):
    print(f"\n=== Epoch {ep}/{EPOCHS} ===")
    tr = run_epoch(train_loader, train=True, log_every=50)
    va = run_epoch(val_loader, train=False, log_every=0)

    print(f"epoch {ep}/{EPOCHS} | train {tr:.4f} | val {va:.4f}")


    # ✅ Save debug prediction snapshot each epoch (cheap)
    snap = save_pred_debug_epoch(
        model,
        debug_x, debug_y,
        epoch=ep,
        out_dir=DEBUG_DIR,
        ch_idxs=(1, 12),
        thr=0.3,
        device=DEVICE
    )
    print("saved debug snapshot:", snap)

    if va < best_val:
        best_val = va
        torch.save({"model": model.state_dict(), "val_loss": va},
                   "best_unet_resnet34_halfres.pt")
        print("saved best checkpoint")



=== Epoch 1/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):


val:   0%|          | 0/200 [00:00<?, ?it/s]

epoch 1/20 | train 1.4684 | val 1.3224
saved debug snapshot: debug_train_evolution/epoch_001.png
saved best checkpoint

=== Epoch 2/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

val:   0%|          | 0/200 [00:00<?, ?it/s]

epoch 2/20 | train 1.2184 | val 1.1449
saved debug snapshot: debug_train_evolution/epoch_002.png
saved best checkpoint

=== Epoch 3/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

val:   0%|          | 0/200 [00:00<?, ?it/s]

epoch 3/20 | train 1.0970 | val 1.0619
saved debug snapshot: debug_train_evolution/epoch_003.png
saved best checkpoint

=== Epoch 4/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

val:   0%|          | 0/200 [00:00<?, ?it/s]

epoch 4/20 | train 1.0288 | val 1.0023
saved debug snapshot: debug_train_evolution/epoch_004.png
saved best checkpoint

=== Epoch 5/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

val:   0%|          | 0/200 [00:00<?, ?it/s]

epoch 5/20 | train 0.9703 | val 0.9435
saved debug snapshot: debug_train_evolution/epoch_005.png
saved best checkpoint

=== Epoch 6/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

val:   0%|          | 0/200 [00:00<?, ?it/s]

epoch 6/20 | train 0.8938 | val 0.8592
saved debug snapshot: debug_train_evolution/epoch_006.png
saved best checkpoint

=== Epoch 7/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

val:   0%|          | 0/200 [00:00<?, ?it/s]

epoch 7/20 | train 0.7706 | val 0.7153
saved debug snapshot: debug_train_evolution/epoch_007.png
saved best checkpoint

=== Epoch 8/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

val:   0%|          | 0/200 [00:00<?, ?it/s]

epoch 8/20 | train 0.6071 | val 0.5407
saved debug snapshot: debug_train_evolution/epoch_008.png
saved best checkpoint

=== Epoch 9/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

val:   0%|          | 0/200 [00:00<?, ?it/s]

epoch 9/20 | train 0.4451 | val 0.3919
saved debug snapshot: debug_train_evolution/epoch_009.png
saved best checkpoint

=== Epoch 10/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

val:   0%|          | 0/200 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a12039f2e80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process


epoch 10/20 | train 0.3643 | val 0.6022
saved debug snapshot: debug_train_evolution/epoch_010.png

=== Epoch 11/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

val:   0%|          | 0/200 [00:00<?, ?it/s]

epoch 11/20 | train 0.2834 | val 0.2603
saved debug snapshot: debug_train_evolution/epoch_011.png
saved best checkpoint

=== Epoch 12/20 ===


train:   0%|          | 0/49 [00:00<?, ?it/s]

val:   0%|          | 0/200 [00:00<?, ?it/s]

epoch 12/20 | train 0.2374 | val 0.2240


In [None]:
import os, glob
import numpy as np
from PIL import Image
import cv2
import torch

def save_pred_overlays(val_f, out_dir="/content/overlays", n=50, thr=0.5):
    os.makedirs(out_dir, exist_ok=True)
    model.eval()

    n = min(n, len(val_f))
    for i in range(n):
        d = val_f[i]
        ecg_id = os.path.basename(d)

        png_path = sorted(glob.glob(os.path.join(d, "*0001.png")))[0]
        npz_path = sorted(glob.glob(os.path.join(d, "mask-*.npz")))[0]

        img0 = np.array(Image.open(png_path).convert("L"), dtype=np.uint8)
        img = cv2.resize(img0.astype(np.float32)/255.0, (W_T, H_T), interpolation=cv2.INTER_AREA)
        x = torch.from_numpy(img[None,None,...]).float().to(DEVICE)

        z = np.load(npz_path, allow_pickle=True)
        gt0 = z["masks"]
        gt = np.zeros((H_T, W_T, K), dtype=np.uint8)
        for k in range(K):
            gt[..., k] = cv2.resize(gt0[..., k], (W_T, H_T), interpolation=cv2.INTER_NEAREST)

        with torch.no_grad():
            logits = model(x)[0].detach().cpu().numpy()  # (K,H,W)
            prob = 1/(1+np.exp(-logits))

        pred_union = (prob.max(axis=0) > thr).astype(np.uint8)
        gt_union   = (gt.max(axis=2) > 0).astype(np.uint8)

        base = (img*255).astype(np.uint8)
        base_rgb = np.stack([base]*3, axis=-1)

        # Pred overlay in red
        out_pred = base_rgb.copy()
        m = pred_union.astype(bool)
        out_pred[m] = (0.5*out_pred[m] + 0.5*np.array([255,0,0])).astype(np.uint8)

        # GT overlay in green (optional but very useful)
        out_gt = base_rgb.copy()
        g = gt_union.astype(bool)
        out_gt[g] = (0.5*out_gt[g] + 0.5*np.array([0,255,0])).astype(np.uint8)

        Image.fromarray(out_pred).save(os.path.join(out_dir, f"{i:04d}_{ecg_id}_pred.png"))
        Image.fromarray(out_gt).save(os.path.join(out_dir, f"{i:04d}_{ecg_id}_gt.png"))

    return out_dir

out_dir = save_pred_overlays(val_f, out_dir="/content/overlays", n=1, thr=0.1)
print("Saved to:", out_dir)
