
# Mask Uniformization & Slice-wise GIF Pipeline

This notebook demonstrates **post-processing of fetal ultrasound slices**, including:

- Tail extrapolation of skeletonized masks  
- Uniform mask reconstruction (mask dilation from skeleton)  
- Overlay visualization of original vs. uniformized masks  
- Slice-wise GIF generation across time frames  
- Batch mode for multiple case folders  

> **Tip:** This notebook is GitHub-friendly. Users only need to set `ROOT_DIR` and `OUTPUT_DIR` to run.



## Expected Folder Structure

**Input:**

```
Root_Dataset/
├── Case001/
│   ├── time001/
│   │   ├── image/         # original ultrasound images
│   │   └── mask/          # raw binary masks (single-channel PNGs: 0/255)
│   ├── time002/
│   │   ├── image/
│   │   └── mask/
│   ...
├── Case002/
│   ├── time001/
│   └── ...
...
```

**Output:**

```
Processed_Dataset/
├── Case001/
│   ├── time001/
│   │   ├── image/         # copied from original
│   │   └── mask/          # uniformized masks
│   ├── Uni_GIF/           # slice-wise GIF animations
│   └── ...
...
```


In [None]:

# 📦 Imports
import os, re, shutil
from pathlib import Path

import cv2
import imageio
import numpy as np
from scipy.ndimage import distance_transform_edt, label
from skimage.morphology import skeletonize, disk, dilation, binary_opening


In [None]:

# Utility Functions

def natural_key(s):
    """Sort helper that treats numeric parts as integers."""
    return [int(t) if t.isdigit() else t.lower() for t in re.split(r'(\d+)', str(s))]

def keep_largest_region(mask: np.ndarray):
    """Keep only the largest connected component in a binary mask."""
    labeled, num = label(mask)
    if num == 0:
        return mask
    largest_label = np.argmax(np.bincount(labeled.flat)[1:]) + 1
    return labeled == largest_label


In [None]:

# Tail Extrapolation for Skeletons

def extrapolate_tail(skel: np.ndarray, tail_frac=0.1):
    """
    Extend the tail of the skeleton on both sides by line fitting.
    Intuition: stabilize endpoints by fitting a short segment near the tail and extrapolating.
    """
    h, w = skel.shape
    pts = np.array(np.nonzero(skel)).T
    if len(pts) == 0:
        return skel.copy()
    xs, ys = pts[:, 1], pts[:, 0]
    x_mid = np.median(xs)
    sides = [pts[xs <= x_mid], pts[xs > x_mid]]
    out = skel.copy()
    for side_pts in sides:
        if len(side_pts) < 6:
            continue
        y_min, y_max = side_pts[:, 0].min(), side_pts[:, 0].max()
        span = y_max - y_min
        if span < 5:
            continue
        y_fit_low = y_max - 0.3 * span
        y_fit_high = y_max - 0.1 * span
        y_anchor = y_max - 0.2 * span
        fit_mask = (side_pts[:, 0] >= y_fit_low) & (side_pts[:, 0] <= y_fit_high)
        fit_pts = side_pts[fit_mask]
        if len(fit_pts) < 3:
            continue
        y_fit = fit_pts[:, 0].astype(float)
        x_fit = fit_pts[:, 1].astype(float)
        A = np.vstack([y_fit, np.ones_like(y_fit)]).T
        a, b = np.linalg.lstsq(A, x_fit, rcond=None)[0]
        y_new = np.arange(int(round(y_anchor)), int(round(y_max)) + 1, dtype=int)
        x_new = np.rint(a * y_new + b).astype(int)
        x_new = np.clip(x_new, 0, w - 1)
        out[y_new, x_new] = True
    return skeletonize(out)


In [None]:

# Mask Uniformization from Skeleton

def process_one(mask_bool: np.ndarray, open_r: int = 1):
    """
    Pipeline for a single binary mask:
    1) Keep largest component
    2) Skeletonize
    3) Tail extrapolation
    4) Dilate skeleton by average half-thickness estimated via distance transform
    5) Optional binary opening to smooth
    Returns: (uniform_mask, skeleton)
    """
    mask_bool = keep_largest_region(mask_bool)
    skel = skeletonize(mask_bool)
    skel = extrapolate_tail(skel, tail_frac=0.1)
    dist = distance_transform_edt(mask_bool)
    # Estimate half-thickness as mean distance at skeleton pixels
    vals = dist[skel]
    half = int(round(vals.mean())) if vals.size > 0 else 1
    half = max(1, half)
    uni = dilation(skel, disk(half))
    if open_r > 0:
        uni = binary_opening(uni, disk(open_r))
    return uni, skel


In [None]:

# Overlay Visualization

def draw_contours(base_img, mask, color):
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return cv2.drawContours(base_img, contours, -1, color, thickness=1)

def overlay_visuals(time_dir: Path):
    img_dir = time_dir / 'image'
    mask_orig_dir = time_dir / 'mask'
    mask_new_dir = time_dir / 'mask_uniform'
    skel_dir = time_dir / 'mask_skel'

    out_dir = time_dir / 'vis_mask_overlay'
    out_dir.mkdir(parents=True, exist_ok=True)

    images = sorted([p for p in img_dir.glob("*.png")], key=lambda p: natural_key(p.name))
    orig_masks = sorted([p for p in mask_orig_dir.glob("*.png")]) if mask_orig_dir.exists() else []
    new_masks = sorted([p for p in mask_new_dir.glob("*.png")]) if mask_new_dir.exists() else []
    skels = sorted([p for p in skel_dir.glob("*.png")]) if skel_dir.exists() else []

    for i, img_path in enumerate(images):
        img = cv2.imread(str(img_path))
        if img is None:
            continue
        if i < len(orig_masks):
            img = draw_contours(img, cv2.imread(str(orig_masks[i]), cv2.IMREAD_GRAYSCALE) > 0, (0, 0, 255))  # red
        if i < len(new_masks):
            img = draw_contours(img, cv2.imread(str(new_masks[i]), cv2.IMREAD_GRAYSCALE) > 0, (0, 255, 0))    # green
        if i < len(skels):
            skel = cv2.imread(str(skels[i]), cv2.IMREAD_GRAYSCALE) > 0
            yx = np.argwhere(skel)
            for y, x in yx:
                if 0 <= y < img.shape[0] and 0 <= x < img.shape[1]:
                    img[y, x] = (0, 255, 255)  # yellow skeleton pixels
        cv2.imwrite(str(out_dir / img_path.name), img)


In [None]:

# Slice-wise GIF Generation Across Time

def generate_gifs(case_dir: Path, alpha: float = 0.5):
    gif_dir = case_dir / 'Uni_GIF'
    gif_dir.mkdir(parents=True, exist_ok=True)
    time_dirs = sorted([d for d in case_dir.iterdir() if d.is_dir() and d.name.lower().startswith("time")],
                       key=lambda p: natural_key(p.name))
    slice_dict = {}
    for tdir in time_dirs:
        img_dir = tdir / 'image'
        mask_dir = tdir / 'mask_uniform'
        if not img_dir.exists() or not mask_dir.exists():
            continue
        img_paths = sorted(img_dir.glob("*.png"), key=lambda p: natural_key(p.name))
        mask_paths = sorted(mask_dir.glob("*.png"), key=lambda p: natural_key(p.name))
        for i, (img_path, mask_path) in enumerate(zip(img_paths, mask_paths)):
            img = cv2.imread(str(img_path))
            mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            if img is None or mask is None:
                continue
            mask_colored = np.zeros_like(img)
            mask_colored[:, :, 1] = 255  # green channel
            overlay = img.copy()
            overlay[mask > 0] = cv2.addWeighted(img[mask > 0], 0.7, mask_colored[mask > 0], 0.3, 0)
            slice_dict.setdefault(i, []).append(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
    for idx, frames in slice_dict.items():
        gif_path = gif_dir / f'slice{idx:04d}.gif'
        imageio.mimsave(gif_path, frames, duration=0.2)


In [None]:

# Main Processing

def process_case(case_dir: Path, output_root: Path, open_radius: int = 1, alpha: float = 0.3, overwrite: bool = True):
    time_dirs = sorted([d for d in case_dir.iterdir() if d.is_dir() and d.name.lower().startswith('time')],
                       key=lambda p: natural_key(p.name))
    print(f"📂 Processing case: {case_dir.name} ({len(time_dirs)} time folders)")
    for tdir in time_dirs:
        mask_dir = tdir / 'mask'
        if not mask_dir.is_dir():
            print(f"Skip {tdir.name}: no mask folder")
            continue
        masks = sorted(mask_dir.glob("*.png"), key=lambda p: natural_key(p.name))
        uni_dir = tdir / 'mask_uniform'
        sk_dir = tdir / 'mask_skel'
        if overwrite:
            shutil.rmtree(uni_dir, ignore_errors=True)
            shutil.rmtree(sk_dir, ignore_errors=True)
        uni_dir.mkdir(parents=True, exist_ok=True)
        sk_dir.mkdir(parents=True, exist_ok=True)
        for i, m_path in enumerate(masks):
            m = cv2.imread(str(m_path), cv2.IMREAD_GRAYSCALE) > 0
            uni, sk = process_one(m, open_r=open_radius)
            cv2.imwrite(str(uni_dir / m_path.name), (uni.astype(np.uint8)) * 255)
            cv2.imwrite(str(sk_dir / m_path.name), (sk.astype(np.uint8)) * 255)
        overlay_visuals(tdir)

    generate_gifs(case_dir, alpha=alpha)

    # Copy processed images & masks to output_root
    dst_case = output_root / case_dir.name
    shutil.rmtree(dst_case, ignore_errors=True)
    dst_case.mkdir(parents=True, exist_ok=True)
    for tdir in time_dirs:
        image_src = tdir / 'image'
        mask_src = tdir / 'mask_uniform'
        if not image_src.exists() or not mask_src.exists():
            continue
        dst_time = dst_case / tdir.name
        (dst_time / 'image').mkdir(parents=True, exist_ok=True)
        (dst_time / 'mask').mkdir(parents=True, exist_ok=True)  # renamed to mask/
        # copy trees
        for p in image_src.glob("*.*"):
            shutil.copy2(p, dst_time / 'image' / p.name)
        for p in mask_src.glob("*.*"):
            shutil.copy2(p, dst_time / 'mask' / p.name)

    print(f"✅ Saved processed results to: {dst_case}")



## Run: Configure Paths and Start
Set the paths below, then run the batch cell.


In [None]:

# 🔧 Configure your paths here
ROOT_DIR = Path('/path/to/Root_Dataset')         # <-- change me
OUTPUT_DIR = Path('/path/to/Processed_Dataset')  # <-- change me
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# 🏃 Batch run
if ROOT_DIR.exists():
    for case in sorted(ROOT_DIR.iterdir()):
        if case.is_dir():
            process_case(case, OUTPUT_DIR, open_radius=1, alpha=0.3, overwrite=True)
else:
    print('⚠️ ROOT_DIR does not exist. Please set the correct path.')
