# Run tools and rasterizers 

This notebook shows how to:

1) **Run a tool** (e.g., `HoverNetTool`) to produce instance outputs (`.dat`).
2) **Rasterize** those instance outputs to pixel‑aligned feature tensors via `HoverNetRasterizer`.
3) **Inspect / visualize** a few channels to sanity‑check the pipeline.

> The classes used here come from your repo modules:
>
>- `tools.py`: `Tool`, `HoverNetTool`
>- `rasterizers.py`: `Rasterizer`, `HoverNetRasterizer` + common rasterization functions

**Requirements** (install in your environment):
- `tiatoolbox` (for HoVer‑Net)
- `torch`, `numpy`, `Pillow`, `joblib`, `tqdm`, `scikit-image`, `scipy`

You can adapt this notebook to other tools by swapping out the `Tool`/`Rasterizer` classes.

## Point this notebook to your repo
Set `REPO_ROOT` to the folder containing `tools.py` and `rasterizers.py`. If you're running from inside the repo, you can leave it as `'.'`.

In [None]:
import sys
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parent.parent if "__file__" in globals() else Path.cwd().parent
sys.path.insert(0, str(REPO_ROOT))

from tbm.tools import HoverNetTool
from tbm.rasterizers import HoverNetRasterizer

import os, torch
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


## 1) Imports and configuration

In [None]:
# Where your input images live (TIFF/PNG/JPG). These can be tiles or small patches.
#IMAGE_DIR = Path('sample_images')  # <-- replace with your dataset folder
IMAGE_DIR = (REPO_ROOT / "sample_images").resolve()
IMAGE_EXTS = ('.tif', '.tiff', '.png', '.jpg', '.jpeg')

# Where to write tool outputs (.dat) and rasterized features (.pt)
OUT_DIR = Path('outputs')
TOOL_OUT_DIR = OUT_DIR / 'tool_outputs'
FEAT_OUT_DIR = OUT_DIR / 'features'
TOOL_OUT_DIR.mkdir(parents=True, exist_ok=True)
FEAT_OUT_DIR.mkdir(parents=True, exist_ok=True)

# Device for HoVerNet. Change to 'cpu' if you don't have a GPU.
DEVICE = 'cpu'  # or 'cpu'

## Collect a small list of images
We recommend starting with a handful of images to validate the pipeline end‑to‑end. You can scale up later.

In [None]:
def list_images(folder: Path, exts=IMAGE_EXTS, limit=None):
    paths = [p for p in folder.rglob('*') if p.suffix.lower() in exts]
    paths.sort()
    if limit is not None:
        paths = paths[:limit]
    return paths

# LIMIT is optional; set to None to process everything found
LIMIT = 4
image_paths = list_images(IMAGE_DIR, limit=LIMIT)
print(f'Found {len(image_paths)} images:')
for p in image_paths:
    print('  -', p)

In [None]:
from pathlib import Path
from PIL import Image
from tqdm import tqdm

# Optionally resize images to a consistent input size for HoVerNet
# Config
TARGET_SIZE = (256, 256)   # choose 256 or 512 etc.
RESIZED_DIR = OUT_DIR / "resized_inputs"
RESIZED_DIR.mkdir(parents=True, exist_ok=True)

def resize_to_dir(src_paths, dst_dir, size=(256,256)):
    dst_paths = []
    shape_map = {}  # original (H,W) and resized (H,W), by dst path
    for p in tqdm(src_paths, desc="Resizing"):
        p = Path(p)
        dst = dst_dir / p.name  # flat copy (customize if you want subfolders)
        if not dst.exists():
            im = Image.open(p).convert("RGB")
            W0, H0 = im.size
            im = im.resize(size, Image.BILINEAR)
            im.save(dst)
            shape_map[str(dst)] = {"orig_hw": (H0, W0), "resized_hw": (size[1], size[0])}
        else:
            # still record original dims
            im = Image.open(p).convert("RGB")
            W0, H0 = im.size
            shape_map[str(dst)] = {"orig_hw": (H0, W0), "resized_hw": (size[1], size[0])}
        dst_paths.append(str(dst))
    return dst_paths, shape_map

image_paths_resized, size_info = resize_to_dir(image_paths, RESIZED_DIR, size=TARGET_SIZE)
print(f"Prepared {len(image_paths_resized)} resized images in {RESIZED_DIR}")


## Run a tool to produce instance outputs (.dat)
Here we use `HoverNetTool` as an example. The output files follow a standard schema, with one Python dict per image (serialized via `joblib`).

In [None]:
hovernet = HoverNetTool(
    name='hovernet',
    model='hovernet_fast-pannuke',
    device=DEVICE,
    batch_size=4,
    num_loader_workers=2,
    num_postproc_workers=2,
)

results = hovernet.process(image_paths=image_paths_resized, save_dir=str(TOOL_OUT_DIR))
print('Tool results (first 2):')
for r in results[:2]:
    print(r)
print(f"Saved {len(results)} .dat files to {TOOL_OUT_DIR}")

## Rasterize the tool outputs into feature maps
We convert instance outputs (boxes, centroids, contours, types) into pixel‑wise maps. The `HoverNetRasterizer` stacks channels in a stable order (`type[...]` first, then `box`, `centroid`, `contour`).

In [None]:
rasterizer = HoverNetRasterizer(
    name='hovernet_features',
    num_types=6,
    include_box=True,
    include_centroid=True,
    include_contour=True,
    include_types=True,
    type_mode='points',   # 'points' or 'gaussian'
    contour_mode='filled' # 'filled' or 'edge'
)

saved_feature_paths = []
for r in results:
    saved = rasterizer.process_and_save(
        tool_output_path=str(Path(r['output_path'])),
        save_dir=FEAT_OUT_DIR,
        image_path=Path(r['input_path']),  # enables exact (H,W)
        save_individual=False,
    )
    saved_feature_paths.append(saved['stacked'])

print('Example stacked feature file:', saved_feature_paths[0])

## Inspect shapes and visualize a few channels (optional)
We load one of the saved `[C, H, W]` tensors and display a couple of channels. This is just a sanity check.

In [None]:
import numpy as np, torch
from PIL import Image
import matplotlib.pyplot as plt
from scipy.ndimage import binary_dilation
from skimage import measure

def overlay_from_pt(
    pt_path,
    image_path,
    num_types=6,
    type_radius=2,       # dilate type points
    centroid_radius=2,   # dilate centroid points
    edge_thickness=2,    # dilate contour edges
    box_outline=True,    # compute box outlines from mask
    contour_outline=True # outline contours instead of filled
):
    # --- load image + features ---
    stk = torch.load(pt_path).float()  # [C,H,W]
    C,H,W = stk.shape
    print(H, W, stk.shape)
    img = Image.open(image_path).convert("RGB")
    if img.size != (W,H):
        img = img.resize((W,H), Image.BILINEAR)
    img = np.asarray(img)


    # --- channel split (HoverNetRasterizer order) ---
    idx = 0
    type_maps = None
    if C >= num_types:
        type_maps = stk[idx:idx+num_types].cpu().numpy()  # [T,H,W]
        idx += num_types
    box_map = stk[idx].cpu().numpy() if idx < C else None; idx += 1 if idx < C else 0
    centroid_map = stk[idx].cpu().numpy() if idx < C else None; idx += 1 if idx < C else 0
    contour_map = stk[idx].cpu().numpy() if idx < C else None

    fig, ax = plt.subplots(figsize=(7,7))
    ax.imshow(img)
    ax.set_axis_off()

    # --- TYPE: show dilated colored dots per class ---
    if type_maps is not None:
        # small palette (extend if needed)
        base_colors = np.array([
            [1.00, 0.65, 0.00],  # 0
            [1.00, 0.00, 0.00],  # 1
            [1.00, 1.00, 0.00],  # 2
            [0.00, 1.00, 0.00],  # 3
            [0.00, 0.00, 0.00],  # 4
            [0.00, 0.00, 1.00],  # 5
        ], dtype=np.float32)
        if base_colors.shape[0] < num_types:
            extra = np.random.default_rng(0).random((num_types - base_colors.shape[0], 3))
            base_colors = np.vstack([base_colors, extra])

        # build a circular structuring element for dilation
        if type_radius > 0:
            yy, xx = np.ogrid[-type_radius:type_radius+1, -type_radius:type_radius+1]
            se = (xx*xx + yy*yy) <= (type_radius*type_radius)
        else:
            se = np.ones((1,1), dtype=bool)

        for cls_id in range(min(num_types, type_maps.shape[0])):
            pts = type_maps[cls_id] > 0.5
            if not pts.any():
                continue
            dots = binary_dilation(pts, structure=se)
            layer = np.zeros((H,W,4), dtype=np.float32)
            rgb = base_colors[cls_id]
            layer[dots] = [rgb[0], rgb[1], rgb[2], 1.0]  # solid dots
            ax.imshow(layer)

    # --- CONTOUR: edge outline from mask (thickened) ---
    if contour_map is not None and contour_outline:
        filled = contour_map > 0.5
        if filled.any():
            # outline via marching squares
            for c in measure.find_contours(filled.astype(float), level=0.5):
                # dilate by drawing multiple offsets
                ax.plot(c[:,1], c[:,0], color="white", linewidth=1.0)
                # thicken visually by over-plotting nearby offsets
                for off in range(1, edge_thickness):
                    ax.plot(c[:,1]+off, c[:,0], color="white", linewidth=1.0, alpha=0.7)
                    ax.plot(c[:,1]-off, c[:,0], color="white", linewidth=1.0, alpha=0.7)
                    ax.plot(c[:,1], c[:,0]+off, color="white", linewidth=1.0, alpha=0.7)
                    ax.plot(c[:,1], c[:,0]-off, color="white", linewidth=1.0, alpha=0.7)

    # --- BOX: outline rectangles from mask blobs (connected components) ---
    if box_map is not None and box_outline:
        bm = box_map > 0.5
        if bm.any():
            # label components then draw min/max rectangles
            lab = measure.label(bm, connectivity=2)
            regions = measure.regionprops(lab)
            for r in regions:
                minr, minc, maxr, maxc = r.bbox  # (row/col)
                ax.plot([minc, maxc, maxc, minc, minc],
                        [minr, minr, maxr, maxr, minr],
                        color="cyan", linewidth=1.5, alpha=0.9)

    # --- CENTROIDS: dilated dots ---
    if centroid_map is not None:
        pts = centroid_map > 0.5
        if pts.any():
            if centroid_radius > 0:
                yy, xx = np.ogrid[-centroid_radius:centroid_radius+1, -centroid_radius:centroid_radius+1]
                se = (xx*xx + yy*yy) <= (centroid_radius*centroid_radius)
                dots = binary_dilation(pts, structure=se)
            else:
                dots = pts
            layer = np.zeros((H,W,4), dtype=np.float32)
            layer[dots] = [1, 1, 1, 1.0]  # white
            ax.imshow(layer)

    plt.tight_layout()
    plt.show()

# ---- find the paired image for this feature file (by stem) ----
from pathlib import Path
pt_path = saved_feature_paths[0]
stem = Path(pt_path).stem.replace("_features","")

match = next((r for r in results if Path(r["input_path"]).stem == stem), None)
image_path = match["input_path"] if match else Path(IMAGE_DIR) / f"{stem}.png"

overlay_from_pt(pt_path, image_path, num_types=6)


## Tips: swapping tools or adding your own
- To use a different tool, subclass `Tool` in `tools.py` and implement `process(...)`.
- If your tool writes a custom output format, add a matching `Rasterizer` subclass in `rasterizers.py` implementing:
  - `load_tool_output(path)`
  - `rasterize(tool_output, H, W, ...) -> Dict[str, Tensor]`
  - `stack_features(features) -> Tensor[C,H,W]`
  - `get_num_channels()`
- You can also re‑use the common rasterization primitives (boxes, centroids, contours, types) for many instance‑style outputs.