
# Model to mask Vastus Lateralis

1) Provide input video file 
2) Let the model run 
3) Once it has finished the predicted video will be in the same folder as the input 




## How to Use 
1. **Install dependencies** run the next cell    
2. **Set hyperparameters** choose to keep defaults or change the parameters.  
3. **Provide your input video file** insert the path to your video. 
4. **Run the rest of the cells** click on "execute cell and below"
5. **Outputs are saved alongside the input file**


## Step 1: Dependencies

In [1]:

# Install libs (run once)
%pip install -q numpy tifffile pynrrd ipywidgets torch torchvision torchaudio transformers pytorch_lightning pillow albumentations


Note: you may need to restart the kernel to use updated packages.


In [2]:

import os
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List

import numpy as np
import tifffile as tiff
import nrrd

import torch
import torch.nn.functional as F
import albumentations as A


from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, SegformerFeatureExtractor

from ipywidgets import VBox, HBox, Text, Button, Checkbox, Dropdown, IntText, FloatText, HTML, Accordion, Layout
from IPython.display import display

def info(msg): print(f"[INFO] {msg}")
def warn(msg): print(f"[WARN] {msg}")
def ok(msg):   print(f"[OK]   {msg}")



A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.6 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "c:\ProgramData\anaconda3\envs\segformer\lib\runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "c:\ProgramData\anaconda3\envs\segformer\lib\runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "C:\Users\DSBG-Public\AppData\Roaming\Python\Python310\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "C:\Users\DSBG-Public\AppData\Roaming\Python\Python310\site-packages\traitlets\config\application.py", line 1075, 

## Step 2: Hyperparameters

In [3]:

HP: Dict[str, Any] = {
    # Converter + IO
    "use_internal_nrrd_to_tif": True,
    "tif_subdir": "tif_slices",
    "mask_subdir": "pred_masks",
    "overwrite_outputs": True,
    "save_compressed_nrrd": True,
    "make_single_nrrd": True,
    "also_save_separate": False,

    # Model
    "inference_backend": "hf",   # "hf" (Hugging Face) or "pl" (PyTorch Lightning)
    "hf_model_dir": "C:/Users/DSBG-Public/segformer/segformer-1/segformer_export_hf",          # e.g., a local dir with config+weights OR a HF repo id
    "pl_ckpt_path": "",          # e.g., /path/to/lightning/checkpoints/best.ckpt
    "target_class_idx": 1,       # used if model has multiple classes; this is the positive class to extract
    "threshold": 0.5,            # used if model produces single-channel logits (sigmoid) or after softmax prob for target_class
    "batch_size": 4,
    "resize_to": 512,            # if >0, resize (H,W) to this square size for model; output is resized back to original

    # Preprocessing
    "normalize": True,           # simple 0-1 normalization
}
# Quick widget
w_backend  = Dropdown(options=["hf","pl"], value=HP["inference_backend"], description="backend")
w_hf_dir   = Text(value=HP["hf_model_dir"], description="hf_model_dir", layout=Layout(width="60%"))
w_pl_ckpt  = Text(value=HP["pl_ckpt_path"], description="pl_ckpt_path", layout=Layout(width="60%"))
w_tclass   = IntText(value=HP["target_class_idx"], description="target_class_idx")
w_thr      = FloatText(value=HP["threshold"], description="threshold")
w_bs       = IntText(value=HP["batch_size"], description="batch_size")
w_resize   = IntText(value=HP["resize_to"], description="resize_to")
w_overwr   = Checkbox(value=HP["overwrite_outputs"], description="overwrite_outputs")
w_single   = Checkbox(value=HP["make_single_nrrd"], description="make_single_nrrd")
w_sep      = Checkbox(value=HP["also_save_separate"], description="also_save_separate")

def _apply(_):
    HP.update({
        "inference_backend": w_backend.value,
        "hf_model_dir": w_hf_dir.value.strip(),
        "pl_ckpt_path": w_pl_ckpt.value.strip(),
        "target_class_idx": int(w_tclass.value),
        "threshold": float(w_thr.value),
        "batch_size": int(w_bs.value),
        "resize_to": int(w_resize.value),
        "overwrite_outputs": w_overwr.value,
        "make_single_nrrd": w_single.value,
        "also_save_separate": w_sep.value,
        "max_epochs": 50,          
        "aug_preset": "basic",     
        "aug_count": 2,            
        "aug_seed": 42,            
    })
    ok("Hyperparameters updated.")
btn = Button(description="Apply", button_style="primary")
btn.on_click(_apply)
display(Accordion(children=[
    VBox([
        HBox([w_backend, w_bs, w_resize]),
        HBox([w_tclass, w_thr]),
        w_hf_dir, w_pl_ckpt,
        HBox([w_overwr, w_single, w_sep]),
        btn
    ])
], titles=("Hyperparameters",)))


Accordion(children=(VBox(children=(HBox(children=(Dropdown(description='backend', options=('hf', 'pl'), value=…

## Step 3: Input file

In [4]:
in_path = Text(description="Input .nrrd", placeholder="/path/to/volume.nrrd", layout=Layout(width="80%"))
btn_val = Button(description="Validate Path", button_style="primary")
lbl_val = HTML()
def _val(_):
    p = Path(in_path.value).expanduser().resolve()
    lbl_val.value = f"<b style='color:green'>OK</b>: {p}" if p.exists() and p.suffix.lower()==".nrrd" else f"<b style='color:#b00'>Invalid</b>: {p}"
btn_val.on_click(_val)
display(VBox([HBox([in_path, btn_val]), lbl_val]))
mask_path = Text(description="Mask .nrrd", placeholder="/path/to/mask.nrrd", layout=Layout(width="80%"))
btn_val_mask = Button(description="Validate Mask", button_style="primary")
lbl_val_mask = HTML()
def _val_mask(_):
    p = Path(mask_path.value).expanduser().resolve()
    lbl_val_mask.value = f"<b style='color:green'>OK</b>: {p}" if p.exists() and p.suffix.lower()==".nrrd" else f"<b style='color:#b00'>Invalid</b>: {p}"
btn_val_mask.on_click(_val_mask)
display(VBox([HBox([mask_path, btn_val_mask]), lbl_val_mask]))

VBox(children=(HBox(children=(Text(value='', description='Input .nrrd', layout=Layout(width='80%'), placeholde…

VBox(children=(HBox(children=(Text(value='', description='Mask .nrrd', layout=Layout(width='80%'), placeholder…

## Step 3.1: Create masks

In [5]:
from PIL import Image

def crop_image_get_bounds(image, threshold):
    rows = np.any(image > threshold, axis=1)
    cols = np.any(image > threshold, axis=0)
    if not np.any(rows) or not np.any(cols):
        return None
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    return rmin, rmax, cmin, cmax

def apply_crop(image, bounds):
    if bounds is None:
        return None
    rmin, rmax, cmin, cmax = bounds
    return image[rmin:rmax+1, cmin:cmax+1]

def resize_with_padding(image, target_size=(512, 512)):
    old_h, old_w = image.shape[:2]
    tgt_h, tgt_w = target_size
    delta_w = max(tgt_w - old_w, 0)
    delta_h = max(tgt_h - old_h, 0)
    top, bottom = delta_h // 2, delta_h - (delta_h // 2)
    left, right = delta_w // 2, delta_w - (delta_w // 2)
    return np.pad(image, ((top, bottom), (left, right)), mode='constant', constant_values=0)

def export_to_tif_with_threshold(volume_data, mask_data, out_vol_dir, out_mask_dir, threshold=10, pad_size=1500):
    exported_count = 0
    if volume_data.shape[0] != mask_data.shape[0]:
        # Try to detect if the data is (Y,X,Z)
        if volume_data.ndim == 3 and volume_data.shape[2] == mask_data.shape[2]:
            volume_data = np.transpose(volume_data, (2,0,1))
            mask_data = np.transpose(mask_data, (2,0,1))
    assert volume_data.shape == mask_data.shape, f"Volume/mask shapes differ: {volume_data.shape} vs {mask_data.shape}"
    Z = volume_data.shape[0]

    out_vol_dir.mkdir(parents=True, exist_ok=True)
    out_mask_dir.mkdir(parents=True, exist_ok=True)

    for i in range(Z):
        slice_volume = volume_data[i]
        slice_mask = mask_data[i]

        masked_volume = slice_volume * slice_mask
        if not np.any(masked_volume > threshold):
            continue

        bounds = crop_image_get_bounds(slice_volume, threshold)
        cropped_volume = apply_crop(slice_volume, bounds)
        cropped_mask = apply_crop(slice_mask, bounds)
        if cropped_volume is None or cropped_mask is None:
            continue

        volume_resized = resize_with_padding(cropped_volume, target_size=(pad_size, pad_size))
        mask_resized = resize_with_padding(cropped_mask, target_size=(pad_size, pad_size))

        vpath = out_vol_dir / f"frame_{exported_count:05d}.tif"
        mpath = out_mask_dir / f"frame_{exported_count:05d}.tif"
        Image.fromarray(volume_resized.astype(np.uint16 if volume_resized.max()>255 else np.uint8)).save(vpath)
        Image.fromarray((mask_resized>0).astype(np.uint8)*255).save(mpath)
        exported_count += 1
    return exported_count

def _infer_volume_and_mask(vol):
    # Return (vol3d, mask3d).
    if vol.ndim == 4 and vol.shape[0] == 2:
        img = vol[0]
        msk = vol[1]
        return img, msk
    if vol.ndim == 3:
        return vol, None
    raise ValueError(f"Unsupported NRRD shape: {vol.shape}")

def internal_nrrd_to_tifs_masked(nrrd_vol_path: Path, nrrd_msk_path: Optional[Path], out_dir_images: Path, out_dir_masks: Path) -> int:
    vol_np, _ = nrrd.read(str(nrrd_vol_path))
    vol_np = np.asarray(vol_np)

    msk_np = None
    if nrrd_msk_path and nrrd_msk_path.exists():
        msk_np, _ = nrrd.read(str(nrrd_msk_path))
        msk_np = np.asarray(msk_np)

    # Fallback to sibling files only if mask not provided
    if msk_np is None:
        sibs = [nrrd_vol_path.with_name(nrrd_vol_path.stem + "_masks.nrrd"),
                nrrd_vol_path.with_name(nrrd_vol_path.stem + "_mask.nrrd")]
        for s in sibs:
            if s.exists():
                msk_np, _ = nrrd.read(str(s))
                msk_np = np.asarray(msk_np)
                break

    if msk_np is None:
        raise RuntimeError("Mask .nrrd not found. Provide mask_path or add *_mask(s).nrrd next to the volume.")

    # Ensure 3D arrays; drop channels if present
    if vol_np.ndim == 4 and vol_np.shape[0] <= 4: vol_np = vol_np[0]
    if msk_np.ndim == 4 and msk_np.shape[0] <= 4: msk_np = msk_np[0]

    # If needed, align shapes to (Y,X,Z) expected by user's exporter
    if not HP.get("assume_yxz", True):
        # our previous default was (Z,Y,X) -> transpose to (Y,X,Z)
        if vol_np.ndim == 3 and vol_np.shape[0] == msk_np.shape[0]:
            vol_np = np.transpose(vol_np, (1,2,0))
            msk_np = np.transpose(msk_np, (1,2,0))

    out_dir_images.mkdir(parents=True, exist_ok=True)
    out_dir_masks.mkdir(parents=True, exist_ok=True)

    # --- User's exporter logic on (Y,X,Z) ---
    exported_count = 0
    Y, X, Z = vol_np.shape  # assumes (Y,X,Z)
    for i in range(Z):
        slice_volume = vol_np[:, :, i]
        slice_mask   = msk_np[:, :, i]

        masked_volume = slice_volume * slice_mask
        if not np.any(masked_volume > int(HP.get("crop_threshold", 10))):
            continue

        bounds = crop_image_get_bounds(slice_volume, int(HP.get("crop_threshold", 10)))
        cropped_volume = apply_crop(slice_volume, bounds)
        cropped_mask   = apply_crop(slice_mask,   bounds)
        if cropped_volume is None or cropped_mask is None:
            continue

        pad_sz = int(HP.get("pad_size", 1500))
        volume_resized = resize_with_padding(cropped_volume, target_size=(pad_sz, pad_sz))
        mask_resized   = resize_with_padding(cropped_mask,   target_size=(pad_sz, pad_sz))

        vpath = out_dir_images / f"frame_{exported_count:05d}.tif"
        mpath = out_dir_masks  / f"frame_{exported_count:05d}.tif"
        Image.fromarray(volume_resized.astype(np.uint16 if volume_resized.max()>255 else np.uint8)).save(vpath)
        Image.fromarray((mask_resized>0).astype(np.uint8)*255).save(mpath)
        exported_count += 1

    return exported_count


def do_step3():
    p_vol = Path(in_path.value).expanduser().resolve()
    assert p_vol.exists() and p_vol.suffix.lower()==".nrrd", "Set a valid .nrrd path."

    p_msk = None
    if mask_path.value.strip():
        cand = Path(mask_path.value).expanduser().resolve()
        if cand.exists() and cand.suffix.lower()==".nrrd":
            p_msk = cand
    
    out_root = p_vol.parent
    tif_dir  = out_root / HP["tif_subdir"]
    gt_dir   = out_root / HP["mask_subdir"]
    if not HP["overwrite_outputs"] and (any(tif_dir.glob("*.tif")) or any(gt_dir.glob("*.tif"))):
        warn("Output dirs not empty; skipping due to overwrite_outputs=False")
        return tif_dir
    n = internal_nrrd_to_tifs_masked(p_vol, p_msk, tif_dir, gt_dir)
    ok(f"Wrote {n} TIF pairs to {tif_dir} and {gt_dir}")
    return tif_dir

b3 = Button(description="Run Step 3 (masked+cropped)", button_style="primary")
l3 = HTML()
def _b3(_):
    try:
        out = do_step3()
        l3.value = f"Done: <code>{out}</code>"
    except Exception as e:
        l3.value = f"<span style='color:#b00'>Error: {e}</span>"
b3.on_click(_b3)
display(VBox([b3, l3]))


VBox(children=(Button(button_style='primary', description='Run Step 3 (masked+cropped)', style=ButtonStyle()),…

# Step 3.2: Dataset Augmentation

In [6]:

from pathlib import Path
import random

def _build_augs(preset: str) -> A.Compose:
    if preset == "none":
        return A.Compose([], is_check_shapes=False)
    if preset == "strong":
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.2),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.06, scale_limit=0.15, rotate_limit=25, border_mode=0, p=0.7),
            A.ElasticTransform(alpha=20, sigma=4, alpha_affine=10, border_mode=0, p=0.3),
            A.GaussNoise(var_limit=(5.0, 20.0), p=0.3),
            A.RandomBrightnessContrast(p=0.4),
        ], is_check_shapes=False)
    # basic
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.3),
        A.ShiftScaleRotate(shift_limit=0.03, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.5),
        A.RandomBrightnessContrast(p=0.25),
    ], is_check_shapes=False)

def run_augmentation():
    p = Path(in_path.value).expanduser().resolve()
    img_dir = p.parent / HP["tif_subdir"]       
    msk_dir = p.parent / HP["mask_subdir"]      
    assert img_dir.exists() and msk_dir.exists(), "Run Step 3 first to create paired TIFs."

    out_img = p.parent / f"{HP['tif_subdir']}_aug"
    out_msk = p.parent / f"{HP['mask_subdir']}_aug"
    out_img.mkdir(parents=True, exist_ok=True)
    out_msk.mkdir(parents=True, exist_ok=True)

    preset = HP.get("aug_preset", "basic")
    copies = int(HP.get("aug_count", 2))
    seed = int(HP.get("aug_seed", 42))
    random.seed(seed)
    np.random.seed(seed)

    t = _build_augs(preset)
    pairs = sorted(img_dir.glob("frame_*.tif"))
    assert pairs, f"No frames found in {img_dir}"

    written = 0
    for img_path in pairs:
        m_path = (msk_dir / img_path.name)
        if not m_path.exists():
            continue
        img = tiff.imread(img_path)
        msk = tiff.imread(m_path)

        # ensure 2D
        if img.ndim == 3:
            img = img.squeeze()
        if msk.ndim == 3:
            msk = msk.squeeze()

        for k in range(copies):
            aug = t(image=img, mask=msk)
            ai, am = aug["image"], aug["mask"]

            # keep dtype-friendly saves
            ai_save = ai.astype(np.uint16 if ai.max() > 255 else np.uint8)
            am_save = (am > 0).astype(np.uint8) * 255

            stem = img_path.stem  # frame_00000
            tiff.imwrite(out_img / f"{stem}_aug{k:02d}.tif", ai_save)
            tiff.imwrite(out_msk / f"{stem}_aug{k:02d}.tif", am_save)
            written += 1
    ok(f"Augmented pairs written: {written} → {out_img} / {out_msk}")

# UI button
_btn_aug = Button(description="Run Step 3.1: Augment Dataset", button_style="primary")
_lbl_aug = HTML()
def _on_aug(_):
    try:
        run_augmentation()
        _lbl_aug.value = "<b>Augmentation done.</b>"
    except Exception as e:
        _lbl_aug.value = f"<span style='color:#b00'>Error: {e}</span>"
_btn_aug.on_click(_on_aug)
display(VBox([_btn_aug, _lbl_aug]))


VBox(children=(Button(button_style='primary', description='Run Step 3.1: Augment Dataset', style=ButtonStyle()…

## Step 4: Let the model run

In [7]:
# ## Optional — Fine-tune SegFormer with max_epochs
# Requires your Lightning module (e.g., SegformerFinetuner) to be importable.
try:
    import pytorch_lightning as pl
    from torch.utils.data import Dataset, DataLoader
    from PIL import Image

    class PairedTifDataset(Dataset):
        def __init__(self, img_dir: Path, msk_dir: Path, aug: A.Compose|None=None):
            self.imgs = sorted(img_dir.glob("*.tif"))
            self.msk_dir = msk_dir
            self.aug = aug
        def __len__(self): return len(self.imgs)
        def __getitem__(self, i):
            ip = self.imgs[i]
            mp = self.msk_dir / ip.name
            x = tiff.imread(ip).astype(np.float32)
            y = (tiff.imread(mp) > 0).astype(np.uint8)
            if HP.get("normalize", True) and x.max() > 0:
                x = x / x.max()
            if self.aug is not None:
                out = self.aug(image=x, mask=y)
                x, y = out["image"], out["mask"]
            x = torch.from_numpy(x).unsqueeze(0).repeat(3,1,1)  
            y = torch.from_numpy(y).long()
            return x, y

    # Choose which dirs to use 
    _p = Path(in_path.value).expanduser().resolve()
    _img_dir = _p.parent / (f"{HP['tif_subdir']}_aug" if (_p.parent / f"{HP['tif_subdir']}_aug").exists() else HP["tif_subdir"])
    _msk_dir = _p.parent / (f"{HP['mask_subdir']}_aug" if (_p.parent / f"{HP['mask_subdir']}_aug").exists() else HP["mask_subdir"])

    _train_aug = _build_augs(HP.get("aug_preset","basic")) if HP.get("aug_preset","basic")!="none" else None
    ds = PairedTifDataset(_img_dir, _msk_dir, _train_aug)
    dl = DataLoader(ds, batch_size=int(HP.get("batch_size",4)), shuffle=True, num_workers=2)

    
    from SegformerFinetuner import SegformerFinetuner  
    model = SegformerFinetuner()  

    trainer = pl.Trainer(
        max_epochs=int(HP.get("max_epochs", 50)),
        accelerator="gpu" if torch.cuda.is_available() else "auto",
        devices=1 if torch.cuda.is_available() else "auto",
        precision=16 if torch.cuda.is_available() else 32,
        log_every_n_steps=10
    )
    # trainer.fit(model, dl)  # Uncomment if model ready
    info("Trainer prepared. Uncomment trainer.fit(...) to start training.")
except Exception as e:
    warn(f"Fine-tune cell skipped: {e}")


[WARN] Fine-tune cell skipped: num_samples should be a positive integer value, but got num_samples=0


  original_init(self, **validated_kwargs)


In [None]:
from PIL import Image
def _load_hf(model_dir: str):
    processor = None
    # Prefer new API
    try:
        processor = SegformerImageProcessor.from_pretrained(model_dir)
    except Exception:
        processor = SegformerFeatureExtractor.from_pretrained(model_dir)
    model = SegformerForSemanticSegmentation.from_pretrained(model_dir)
    model.eval()
    return model, processor

def _load_pl(ckpt_path: str):
    # Attempt to import SegformerFinetuner from the provided notebook context
    try:
        from importlib import import_module
        from SegformerFinetuner import SegformerFinetuner 
    except Exception as e:
        raise RuntimeError("PyTorch Lightning backend selected, but SegformerFinetuner class is not importable. "
                           "Place its module on PYTHONPATH or switch to 'hf' backend.") from e

def _read_tif_batch(paths: List[Path]) -> Tuple[torch.Tensor, List[Tuple[int,int]]]:
    imgs: List[Any] = []
    shapes: List[Tuple[int,int]] = []
    for p in paths:
        arr = tiff.imread(p)  # typically a NumPy array
        # keep float32 for the net
        try:
            arr = arr.astype("float32", copy=False)  # if it's a NumPy array
        except Exception:
            pass  # not a NumPy array; leave as-is
        shapes.append((int(arr.shape[-2]), int(arr.shape[-1])) if hasattr(arr, "shape") else (None, None))
        imgs.append(arr)

    # Try fast path with NumPy; otherwise fall back to pure-Python tensors
    try:
        import numpy as np  # ensure available name
        ten = torch.from_numpy(np.stack(imgs, 0)).float()  # (B,H,W)
    except Exception:
        # Fallback: avoid torch.from_numpy by converting to lists
        ten = torch.stack([torch.tensor(img.tolist(), dtype=torch.float32) for img in imgs], dim=0)

    ten = ten.unsqueeze(1)  # (B,1,H,W)
    return ten, shapes

def _save_mask_tif(path: Path, tensor: torch.Tensor) -> None:
    t = tensor.detach().to("cpu", dtype=torch.uint8).contiguous()  # (H,W)
    try:
        import numpy as np  # fast path if NumPy is available
        import tifffile as tiff
        tiff.imwrite(path, t.numpy())
        return
    except Exception:
        # NumPy unavailable → fallback to PIL from raw bytes
        h, w = int(t.shape[-2]), int(t.shape[-1])
        raw = bytes(t.view(-1).tolist())  # slower but works without NumPy
        img = Image.frombytes("L", (w, h), raw)
        img.save(path, format="TIFF")


def _resize_if_needed(x: torch.Tensor, size: int) -> torch.Tensor:
    if size and size>0 and (x.shape[-1]!=size or x.shape[-2]!=size):
        x = F.interpolate(x, size=(size,size), mode="bilinear", align_corners=False)
    return x

def predict_segformer(images_dir: Path, masks_out_dir: Path) -> int:
    masks_out_dir.mkdir(parents=True, exist_ok=True)
    slice_paths = sorted(images_dir.glob("*.tif"))
    assert slice_paths, f"No TIFs found in {images_dir}"

    backend = HP.get("inference_backend","hf")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if backend == "hf":
        model_dir = HP.get("hf_model_dir") or "nvidia/segformer-b0-finetuned-ade-512-512"
        model, processor = _load_hf(model_dir)
        model.to(device)
    else:
        _load_pl(HP.get("pl_ckpt_path",""))
        raise RuntimeError("Lightning backend is stubbed. Provide a HF model dir for immediate use.")

    bs = int(HP.get("batch_size",4))
    resize_to = int(HP.get("resize_to",512))
    target_class = int(HP.get("target_class_idx",1))
    threshold = float(HP.get("threshold",0.5))

    written = 0
    for i in range(0, len(slice_paths), bs):
        batch_paths = slice_paths[i:i+bs]
        x, orig_shapes = _read_tif_batch(batch_paths)  # (B,1,H,W)
        x = _resize_if_needed(x, resize_to)
        x3 = x.repeat(1,3,1,1)  # (B,3,H,W)

        # Use processor to build pixel_values if available
        inputs = {"pixel_values": x3.to(device)}
        with torch.no_grad():
            out = model(**inputs)
            logits = out.logits  # (B,num_labels,h,w)

        # Upsample logits to original (or current) size
        up = F.interpolate(logits, size=x.shape[-2:], mode="bilinear", align_corners=False)  # (B,C,H,W)

        # Convert to masks
        if up.shape[1] == 1:
            prob = torch.sigmoid(up[:,0])
            pred = (prob >= threshold).to(torch.uint8) * 255
        else:
            prob = torch.softmax(up, dim=1)[:, target_class]
            pred = (prob >= threshold).to(torch.uint8) * 255

        # Resize back to original per-slice size if we resized
        if resize_to and resize_to>0:
            out_resized = []
            for j in range(pred.shape[0]):
                pj = pred[j:j+1].float().unsqueeze(0)  # (1,1,H,W)
                pj = F.interpolate(pj, size=orig_shapes[j], mode="nearest")
                out_resized.append(pj[0,0].to(torch.uint8))
            pred = torch.stack(out_resized, 0)

        # Save
        for pth, pmask in zip(batch_paths, pred):
            _save_mask_tif(masks_out_dir / pth.name, pmask)
            written += 1
    return written

# UI
b4 = Button(description="Run Step 4 (SegFormer)", button_style="primary")
l4 = HTML()
def _b4(_):
    try:
        p = Path(in_path.value).expanduser().resolve()
        images_dir = p.parent / "tif_slices"
        n = predict_segformer(images_dir, p.parent / HP["mask_subdir"])
        l4.value = f"Wrote {n} mask TIFs."
    except Exception as e:
        import traceback, sys
        err_type, err_val, err_tb = sys.exc_info()
        tb_str = "".join(traceback.format_exception(err_type, err_val, err_tb))
        print("==== ERROR TRACEBACK ====")
        print(tb_str)
        l4.value = (
            "<b style='color:#b00'>Error during Step 4.</b><br>"
            f"<pre style='white-space:pre-wrap'>{tb_str}</pre>"
        )
b4.on_click(_b4)
display(VBox([b4, l4]))

VBox(children=(Button(button_style='primary', description='Run Step 4 (SegFormer)', style=ButtonStyle()), HTML…

In [9]:

def load_tif_stack(folder: Path) -> np.ndarray:
    files = sorted(folder.glob("*.tif"))
    assert files, f"No .tif files in {folder}"
    return np.stack([tiff.imread(f).astype(np.float32) for f in files], axis=0)

def save_nrrd(path: Path, array: np.ndarray, compressed: bool=True, header: Optional[dict]=None):
    header = dict(header or {})
    if compressed:
        header["encoding"] = "gzip"
    nrrd.write(str(path), array, header=header)

def do_step5():
    p = Path(in_path.value).expanduser().resolve()
    img_dir = p.parent / "tif_slices"
    msk_dir = p.parent / HP["mask_subdir"]
    imgs = load_tif_stack(img_dir)
    msks = load_tif_stack(msk_dir)
    assert imgs.shape == msks.shape, f"Shape mismatch {imgs.shape} vs {msks.shape}"
    if HP["make_single_nrrd"]:
        comb = np.stack([imgs, msks], axis=0)
        out_single = p.parent / f"{p.stem}_imgmask.nrrd"
        if out_single.exists() and not HP["overwrite_outputs"]:
            warn(f"Exists, not overwriting: {out_single}")
        else:
            save_nrrd(out_single, comb, compressed=HP["save_compressed_nrrd"],
                      header={"kinds": ["list","domain","domain","domain"]})
            ok(f"Wrote: {out_single}")
    if HP["also_save_separate"]:
        out_i = p.parent / f"{p.stem}_images.nrrd"
        out_m = p.parent / f"{p.stem}_masks.nrrd"
        save_nrrd(out_i, imgs, compressed=HP["save_compressed_nrrd"])
        save_nrrd(out_m, msks, compressed=HP["save_compressed_nrrd"])
        ok(f"Wrote: {out_i} and {out_m}")

b5 = Button(description="Run Step 5", button_style="primary")
l5 = HTML()
def _b5(_):
    try:
        do_step5()
        l5.value = "Done."
    except Exception as e:
        l5.value = f"<span style='color:#b00'>Error: {e}</span>"
b5.on_click(_b5)
display(VBox([b5, l5]))


VBox(children=(Button(button_style='primary', description='Run Step 5', style=ButtonStyle()), HTML(value='')))