In [65]:
from __future__ import annotations

from pathlib import Path
import traceback
import numpy as np
from PIL import Image
import cairosvg
import xml.etree.ElementTree as ET
from skimage.morphology import dilation, disk


In [66]:
LABELS = {"nv": 1, "nh": 2, "hy": 3, "st/zo": 4}
PRIORITY_HIGH_TO_LOW = ["st/zo", "hy", "nh", "nv"]  # high wins (written last)

IMAGE_LAYER_LABELS = {"image", "Image"}

INK_LABEL = "{http://www.inkscape.org/namespaces/inkscape}label"
INK_GROUPMODE = "{http://www.inkscape.org/namespaces/inkscape}groupmode"

FILENAME_SAFE = {"st/zo": "st_zo"}
REVERSE_FILENAME_SAFE = {v: k for k, v in FILENAME_SAFE.items()}

# Visualization colors (match your overlay colors)
OVERLAY_COLORS = {
    "nv": (0, 255, 0),       # green
    "nh": (255, 0, 0),       # red
    "hy": (0, 0, 255),       # blue
    "st/zo": (255, 255, 0),  # yellow
}

# Seg-color (by class id) â€“ consistent with LABELS above
SEG_COLORS = {
    0: (0, 0, 0),        # background
    1: (0, 255, 0),      # nv
    2: (255, 0, 0),      # nh
    3: (0, 0, 255),      # hy
    4: (255, 255, 0),    # st/zo
}

In [67]:
def parse_svg(svg_path: Path) -> ET.ElementTree:
    return ET.parse(svg_path)

def is_layer_group(el: ET.Element) -> bool:
    return el.tag.endswith("g") and el.attrib.get(INK_GROUPMODE) == "layer"

def layer_label(el: ET.Element) -> str | None:
    return el.attrib.get(INK_LABEL) or el.attrib.get("id") or el.attrib.get("class")

def find_layer_groups(tree: ET.ElementTree) -> list[ET.Element]:
    root = tree.getroot()
    return [el for el in root.iter() if is_layer_group(el)]

def find_image_and_class_layers(layers: list[ET.Element]) -> tuple[ET.Element, dict[str, ET.Element]]:
    image_layer = None
    class_layers: dict[str, ET.Element] = {}

    for layer in layers:
        lab = layer_label(layer)
        if not lab:
            continue
        s = lab.strip()

        if s.lower() in {x.lower() for x in IMAGE_LAYER_LABELS}:
            image_layer = layer
            continue

        key = s.lower()
        if key in LABELS:
            class_layers[key] = layer

    if image_layer is None:
        raise ValueError(f"No image layer found. Layer labels present: {[layer_label(l) for l in layers]}")

    return image_layer, class_layers

def set_layer_visibility(layers: list[ET.Element], visible: set[ET.Element]) -> None:
    for layer in layers:
        layer.attrib["style"] = "display:inline !important" if layer in visible else "display:none !important"

def write_svg(tree: ET.ElementTree, out_path: Path) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    tree.write(out_path, encoding="utf-8", xml_declaration=True)

def svg_to_png(svg_path: Path, png_path: Path) -> None:
    png_path.parent.mkdir(parents=True, exist_ok=True)
    cairosvg.svg2png(url=str(svg_path), write_to=str(png_path))

def png_to_rgb_inplace(png_path: Path) -> None:
    Image.open(png_path).convert("RGB").save(png_path)

def load_binary_mask(png_path: Path) -> np.ndarray:
    arr = np.array(Image.open(png_path))
    if arr.ndim == 2:
        return arr > 0
    if arr.shape[-1] == 4:
        return arr[..., 3] > 0
    return arr[..., :3].sum(axis=-1) > 0

In [68]:
def combine_per_class_masks(mask_paths: dict[str, Path], radius: int) -> np.ndarray:
    sample = np.array(Image.open(next(iter(mask_paths.values()))))
    H, W = sample.shape[:2]
    seg = np.zeros((H, W), dtype=np.uint8)

    # low->high so high overwrites
    for k in reversed(PRIORITY_HIGH_TO_LOW):  # nv, nh, hy, st/zo
        p = mask_paths.get(k)
        if not p or not p.exists():
            continue
        binary = load_binary_mask(p)
        if radius and radius > 0:
            binary = dilation(binary, disk(radius))
        seg[binary] = LABELS[k]

    return seg

def save_color_segmentation(seg: np.ndarray, out_png: Path) -> None:
    h, w = seg.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for k, color in SEG_COLORS.items():
        rgb[seg == k] = color
    Image.fromarray(rgb).save(out_png)

def overlay_masks_on_image(
    image_png: Path,
    mask_paths: dict[str, Path],
    out_png: Path,
    radius: int = 0,
    alpha: float = 0.5,
) -> None:
    img = np.array(Image.open(image_png).convert("RGB"))
    overlay = img.copy()

    for key, color in OVERLAY_COLORS.items():
        p = mask_paths.get(key)
        if not p or not p.exists():
            continue

        mask = load_binary_mask(p)
        if radius > 0:
            mask = dilation(mask, disk(radius))

        overlay[mask] = ((1 - alpha) * overlay[mask] + alpha * np.array(color)).astype(np.uint8)

    Image.fromarray(overlay).save(out_png)

In [None]:
def process_svg(svg_path: Path, out_root: Path, r_main: int = 4, r_extra: int = 10) -> Path:
    """
    Writes outputs into: out_root/<svg_stem>/
    Returns out_dir.
    """
    stem = svg_path.stem
    out_dir = out_root / stem
    masks_dir = out_dir / "masks"
    out_dir.mkdir(parents=True, exist_ok=True)
    masks_dir.mkdir(parents=True, exist_ok=True)

    # Detect layer groups
    tree0 = parse_svg(svg_path)
    layers0 = find_layer_groups(tree0)
    img_layer0, class_layers0 = find_image_and_class_layers(layers0)

    if not class_layers0:
        raise ValueError(
            f"No class layers (nv/nh/hy/st/zo) found. "
            f"Layer labels present: {[layer_label(l) for l in layers0]}"
        )

    # --- image_only.svg -> image.png ---
    tree_img = parse_svg(svg_path)
    layers_img = find_layer_groups(tree_img)
    img_layer, _ = find_image_and_class_layers(layers_img)
    set_layer_visibility(layers_img, visible={img_layer})
    image_svg = out_dir / "image_only.svg"
    write_svg(tree_img, image_svg)

    image_png = out_dir / "image.png"
    svg_to_png(image_svg, image_png)
    png_to_rgb_inplace(image_png)

    # --- per-class masks (PNG + SVG), filename-safe ---
    mask_paths: dict[str, Path] = {}
    for k in class_layers0.keys():
        tree_k = parse_svg(svg_path)
        layers_k = find_layer_groups(tree_k)
        _, class_layers_k = find_image_and_class_layers(layers_k)
        layer_el = class_layers_k.get(k)
        if layer_el is None:
            continue

        set_layer_visibility(layers_k, visible={layer_el})

        safe = FILENAME_SAFE.get(k, k)  # st/zo -> st_zo
        k_svg = masks_dir / f"{safe}.svg"
        k_png = masks_dir / f"{safe}.png"
        write_svg(tree_k, k_svg)
        svg_to_png(k_svg, k_png)

        mask_paths[k] = k_png  # keep original semantic key (st/zo)

    if not mask_paths:
        raise ValueError("No mask PNGs produced.")

    # --- seg @ main radius ---
    seg = combine_per_class_masks(mask_paths, radius=r_main)
    seg_png = out_dir / "seg.png"
    Image.fromarray(seg).save(seg_png)
    save_color_segmentation(seg, out_dir / "seg_color.png")

    # --- seg @ extra radius (10) ---
    seg10 = combine_per_class_masks(mask_paths, radius=r_extra)
    seg10_png = out_dir / f"seg_r{r_extra}.png"
    Image.fromarray(seg10).save(seg10_png)
    save_color_segmentation(seg10, out_dir / f"seg_color_r{r_extra}.png")

    # --- overlays (raw + r10) ---
    overlay_masks_on_image(image_png, mask_paths, out_dir / "overlay_raw.png", radius=0, alpha=0.5)
    overlay_masks_on_image(image_png, mask_paths, out_dir / f"overlay_r{r_extra}.png", radius=r_extra, alpha=0.5)

    return out_dir

In [70]:
def batch_process(svg_dir: Path, out_root: Path, r_main: int = 4, r_extra: int = 10) -> None:
    svgs = sorted(p for p in svg_dir.rglob("*.svg") if "_gr.svg" not in p.name.lower())
    print("Found:", len(svgs))

    ok = 0
    failed = 0

    for svg in svgs:
        try:
            out_dir = process_svg(svg, out_root=out_root, r_main=r_main, r_extra=r_extra)
            ok += 1
        except Exception as e:
            failed += 1
            print(f"FAILED: {svg}")
            print(" ", repr(e))
            # uncomment for full trace:
            # print(traceback.format_exc())

    print(f"Done. OK={ok}  FAILED={failed}")
    print("Output root:", out_root)

In [71]:
if __name__ == "__main__":
    svg_dir = Path("/home/jrhowell/benthic_ecology_group/Jack/coral_seg/four_layer_images6")
    out_root = Path("/home/jrhowell/benthic_ecology_group/Jack/coral_seg/processed_images_real")

    batch_process(svg_dir, out_root, r_main=4, r_extra=10)

Found: 199


KeyboardInterrupt: 