# Micropillar Dual-Mode Image Analysis Pipeline (Pillar U-Net + C2 Quantification)

This single notebook is intended for **public reproducibility** (e.g., a GitHub repo for reviewers).
It:

1. **Downloads** the trained pillar-segmentation model from Box (if missing)
2. Runs **U-Net inference** on **C1** images to create pillar masks
3. Uses pillar masks to quantify **birefringent CaCO₃** signal in **C2** images
4. Exports **per-image metrics** and **condition-aggregated summaries** as CSV

**Expected naming convention (recommended):**  
`<channel>.<section>.<diameter>.<trial>.c1.png` and `...c2.png`  
Example: `t1.1.8.a.c1.png` and `t1.1.8.a.c2.png`

If your filenames differ, update `parse_filename()` and `pair_c1_c2()`.


### Optional: install dependencies (Colab)

If you're running in Google Colab, uncomment and run the install cell below.


In [None]:
# Install dependencies (Colab / fresh environments)
!pip -q install segmentation-models-pytorch albumentations opencv-python tqdm pandas matplotlib requests


In [None]:
import os
import re
import json
import math
import random
import shutil
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import cv2
from tqdm.auto import tqdm

import torch
import segmentation_models_pytorch as smp

import matplotlib.pyplot as plt

In [None]:
def set_seed(seed: int = 0) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def ensure_dir(p: Path) -> Path:
    p.mkdir(parents=True, exist_ok=True)
    return p

def download_box_file(url: str, out_path: Path) -> None:
    """
    Downloads a file from a Box shared link.

    Note: most Box share links can be downloaded by appending '?download=1'
    """
    import requests

    out_path.parent.mkdir(parents=True, exist_ok=True)
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        total = int(r.headers.get("content-length", 0))
        with open(out_path, "wb") as f:
            pbar = tqdm(total=total, unit="B", unit_scale=True, desc=f"Downloading {out_path.name}")
            for chunk in r.iter_content(chunk_size=1024 * 1024):
                if chunk:
                    f.write(chunk)
                    pbar.update(len(chunk))
            pbar.close()

def read_grayscale_uint8(img_path: Path) -> np.ndarray:
    """Reads an image as uint8 grayscale (0..255)."""
    img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED)
    if img is None:
        raise FileNotFoundError(f"Could not read image: {img_path}")
    if img.ndim == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    # If 16-bit, scale to 8-bit for consistent thresholding
    if img.dtype == np.uint16:
        img = (img / 256).astype(np.uint8)
    elif img.dtype != np.uint8:
        img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    return img

def save_mask_png(mask01: np.ndarray, out_path: Path) -> None:
    """
    Save a binary mask (0/1) as PNG (0/255).
    """
    out_path.parent.mkdir(parents=True, exist_ok=True)
    m = (mask01.astype(np.uint8) * 255)
    cv2.imwrite(str(out_path), m)

def parse_filename(fname: str) -> Dict[str, Optional[str]]:
    """
    Parses filenames like: t1.1.8.a.c1.png
    Returns dict with channel, section, diameter, trial, modality (c1/c2).

    If your naming differs, edit this function.
    """
    stem = Path(fname).name
    stem = re.sub(r"\.(png|jpg|jpeg|tif|tiff)$", "", stem, flags=re.IGNORECASE)
    parts = stem.split(".")
    out = {"channel": None, "section": None, "diameter": None, "trial": None, "modality": None}

    for token in parts[::-1]:
        if token.lower() in ("c1", "c2"):
            out["modality"] = token.lower()
            break

    if len(parts) >= 5:
        out["channel"]  = parts[0]
        out["section"]  = parts[1]
        out["diameter"] = parts[2]
        out["trial"]    = parts[3]
    return out

def pair_c1_c2(c1_files: List[Path], c2_files: List[Path]) -> List[Tuple[Path, Optional[Path]]]:
    """
    Pair C1 images to their matching C2 image based on filename tokens.
    Returns list of (c1, c2_or_None).
    """
    c2_map = {}
    for p in c2_files:
        stem = p.name.replace(".c2.", ".")
        c2_map[stem] = p

    pairs = []
    for c1 in c1_files:
        key = c1.name.replace(".c1.", ".")
        pairs.append((c1, c2_map.get(key)))
    return pairs

In [None]:
# =======================
# CONFIG (edit these)
# =======================

set_seed(0)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# Folder containing your exported PNGs (will search recursively)
INPUT_ROOT = Path("data/raw_images")     # <-- change this
OUT_ROOT   = ensure_dir(Path("results"))

# Model download (Box)
BOX_MODEL_URL = "https://cornell.box.com/s/4dyu78bhtpabm98jgz40gdp5wmoe71xn?download=1"
MODEL_PATH    = Path("models/pillar_unet.pt")  # saved locally in repo

# U-Net input size control
SHORT_SIDE = 768  # resize shorter side to this for inference (keeps aspect ratio)

# C2 quantification threshold
THRESH_METHOD = "otsu"   # "otsu" or "fixed"
FIXED_THRESH  = 25       # used only if THRESH_METHOD == "fixed" (0..255)

# Morphology options for counting/size
MIN_COMPONENT_AREA_PX = 10     # ignore tiny components (noise)
MORPH_KERNEL = 3               # morphological opening kernel size (px)

print("INPUT_ROOT:", INPUT_ROOT.resolve())
print("OUT_ROOT:", OUT_ROOT.resolve())

In [None]:
# =======================
# Download + load model
# =======================

if not MODEL_PATH.exists():
    print(f"Model not found at {MODEL_PATH}. Downloading from Box...")
    try:
        download_box_file(BOX_MODEL_URL, MODEL_PATH)
    except Exception as e:
        raise RuntimeError(
            "Model download failed. If Box requires login or blocks requests, "
            "manually download the model from the Box link and place it at:\n"
            f"  {MODEL_PATH}\n\nOriginal error:\n{e}"
        )

print("Model file:", MODEL_PATH.resolve())

# Define the exact architecture used in training (edit if needed)
model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights=None,
    in_channels=1,
    classes=1,
)

ckpt = torch.load(MODEL_PATH, map_location="cpu")
state = ckpt["state_dict"] if (isinstance(ckpt, dict) and "state_dict" in ckpt) else ckpt

new_state = {}
for k, v in state.items():
    k2 = k[len("model."):] if k.startswith("model.") else k
    new_state[k2] = v

missing, unexpected = model.load_state_dict(new_state, strict=False)
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

model = model.to(DEVICE)
model.eval()

In [None]:
# =======================
# Discover images
# =======================

all_imgs = sorted([p for p in INPUT_ROOT.rglob("*") if p.suffix.lower() in (".png", ".jpg", ".jpeg", ".tif", ".tiff")])

c1_files = [p for p in all_imgs if ".c1." in p.name.lower() or p.name.lower().endswith(".c1.png")]
c2_files = [p for p in all_imgs if ".c2." in p.name.lower() or p.name.lower().endswith(".c2.png")]

print(f"Found {len(all_imgs)} total images")
print(f"Found {len(c1_files)} C1 images")
print(f"Found {len(c2_files)} C2 images")

pairs = pair_c1_c2(c1_files, c2_files)
print("Pairs:", len(pairs))

MASK_DIR = ensure_dir(OUT_ROOT / "pillar_masks")

In [None]:
# =======================
# Run pillar inference on C1
# =======================

def resize_short_side(img: np.ndarray, short_side: int) -> Tuple[np.ndarray, float]:
    h, w = img.shape[:2]
    scale = short_side / min(h, w)
    new_w = int(round(w * scale))
    new_h = int(round(h * scale))
    resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
    return resized, scale

def infer_pillar_mask(c1_uint8: np.ndarray) -> np.ndarray:
    """Returns binary mask (0/1) same size as input."""
    img_rs, _ = resize_short_side(c1_uint8, SHORT_SIDE)
    x = (img_rs.astype(np.float32) / 255.0)[None, None, :, :]  # (1,1,H,W)
    x_t = torch.from_numpy(x).to(DEVICE)

    with torch.no_grad():
        logits = model(x_t)
        prob = torch.sigmoid(logits)[0, 0].detach().cpu().numpy()

    prob_up = cv2.resize(prob, (c1_uint8.shape[1], c1_uint8.shape[0]), interpolation=cv2.INTER_LINEAR)
    mask01 = (prob_up >= 0.5).astype(np.uint8)
    return mask01

mask_index = []

for c1_path, c2_path in tqdm(pairs, desc="Pillar inference"):
    c1 = read_grayscale_uint8(c1_path)
    mask01 = infer_pillar_mask(c1)

    rel = c1_path.relative_to(INPUT_ROOT)
    out_mask_path = (MASK_DIR / rel).with_suffix(".mask.png")
    ensure_dir(out_mask_path.parent)
    save_mask_png(mask01, out_mask_path)

    meta = parse_filename(c1_path.name)
    mask_index.append({
        "c1_path": str(c1_path),
        "c2_path": str(c2_path) if c2_path is not None else None,
        "mask_path": str(out_mask_path),
        **meta
    })

mask_index_df = pd.DataFrame(mask_index)
mask_index_csv = OUT_ROOT / "mask_index.csv"
mask_index_df.to_csv(mask_index_csv, index=False)
print("Saved:", mask_index_csv)
mask_index_df.head()

In [None]:
# =======================
# Quick preview montage (random subset)
# =======================

N_SHOW = min(12, len(mask_index_df))
subset = mask_index_df.sample(N_SHOW, random_state=0)

plt.figure(figsize=(12, 8))
for i, row in enumerate(subset.itertuples(index=False), start=1):
    c1 = read_grayscale_uint8(Path(row.c1_path))
    m  = read_grayscale_uint8(Path(row.mask_path))
    m01 = (m > 0).astype(np.uint8)

    overlay = cv2.cvtColor(c1, cv2.COLOR_GRAY2BGR)
    overlay[m01 == 1] = (0, 0, 255)

    plt.subplot(3, 4, i)
    plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
    plt.title(Path(row.c1_path).name, fontsize=8)
    plt.axis("off")
plt.tight_layout()
plt.show()

In [None]:
# =======================
# C2 quantification (mask out pillars)
# =======================

def threshold_c2(c2_uint8: np.ndarray, method: str = "otsu", fixed_thr: int = 25) -> np.ndarray:
    """Returns binary (0/1) thresholded C2 image."""
    if method == "fixed":
        return (c2_uint8 >= fixed_thr).astype(np.uint8)
    elif method == "otsu":
        _, bw = cv2.threshold(c2_uint8, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        return bw.astype(np.uint8)
    else:
        raise ValueError("method must be 'otsu' or 'fixed'")

def component_stats(bw01: np.ndarray) -> Dict[str, float]:
    """Connected-component stats for a 0/1 binary image."""
    _, _, stats, _ = cv2.connectedComponentsWithStats(bw01.astype(np.uint8), connectivity=8)
    areas = stats[1:, cv2.CC_STAT_AREA].astype(np.float32)  # skip background
    areas = areas[areas >= MIN_COMPONENT_AREA_PX]

    return {
        "n_components": int(len(areas)),
        "mean_component_area_px": float(np.mean(areas)) if len(areas) else 0.0,
        "std_component_area_px":  float(np.std(areas, ddof=1)) if len(areas) > 1 else 0.0,
        "median_component_area_px": float(np.median(areas)) if len(areas) else 0.0,
    }

Q_DIR = ensure_dir(OUT_ROOT / "quant")
rows = []

kernel = np.ones((MORPH_KERNEL, MORPH_KERNEL), np.uint8)

for row in tqdm(mask_index_df.itertuples(index=False), total=len(mask_index_df), desc="Quantifying C2"):
    if row.c2_path is None:
        continue

    c2 = read_grayscale_uint8(Path(row.c2_path))
    mask = read_grayscale_uint8(Path(row.mask_path))
    pillar01 = (mask > 0).astype(np.uint8)

    nonpillar01 = (1 - pillar01).astype(np.uint8)
    nonpillar_area_px = int(nonpillar01.sum())

    bw01 = threshold_c2(c2, method=THRESH_METHOD, fixed_thr=FIXED_THRESH)
    crystal01 = (bw01 * nonpillar01).astype(np.uint8)

    crystal01_clean = cv2.morphologyEx(crystal01, cv2.MORPH_OPEN, kernel)

    crystal_area_px = int(crystal01_clean.sum())
    coverage = (crystal_area_px / nonpillar_area_px) if nonpillar_area_px > 0 else 0.0

    comp = component_stats(crystal01_clean)
    meta = parse_filename(Path(row.c1_path).name)

    rows.append({
        "image_id": Path(row.c1_path).name.replace(".c1.", "."),
        "c1_path": row.c1_path,
        "c2_path": row.c2_path,
        "mask_path": row.mask_path,
        "nonpillar_area_px": nonpillar_area_px,
        "crystal_area_px": crystal_area_px,
        "coverage_frac": coverage,
        **comp,
        **meta
    })

quant_df = pd.DataFrame(rows)
quant_csv = Q_DIR / "per_image_metrics.csv"
quant_df.to_csv(quant_csv, index=False)
print("Saved:", quant_csv)
quant_df.head()

In [None]:
# =======================
# Aggregate by condition
# =======================

group_cols = ["channel", "section", "diameter", "trial"]
metric_cols = ["nonpillar_area_px", "crystal_area_px", "coverage_frac", "n_components",
               "mean_component_area_px", "median_component_area_px"]

agg = (quant_df
       .groupby(group_cols, dropna=False)[metric_cols]
       .agg(["mean", "std", "count"])
       .reset_index())

agg.columns = ["_".join([c for c in col if c]) for col in agg.columns.to_flat_index()]

agg_csv = Q_DIR / "aggregated_metrics.csv"
agg.to_csv(agg_csv, index=False)
print("Saved:", agg_csv)
agg.head()

In [None]:
# =======================
# Simple plot
# =======================

plt.figure(figsize=(8, 5))
for key, sub in quant_df.groupby(["channel", "section"], dropna=False):
    label = f"{key[0]} sec{key[1]}"
    x = pd.to_numeric(sub["diameter"], errors="coerce")
    plt.scatter(x, sub["coverage_frac"], alpha=0.6, label=label)

plt.xlabel("Diameter (parsed from filename)")
plt.ylabel("Coverage fraction (crystal_area / nonpillar_area)")
plt.title("Per-image CaCO₃ coverage (masked)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# =======================
# Save environment versions
# =======================

import platform
import sys

env = {
    "python": sys.version,
    "platform": platform.platform(),
    "torch": torch.__version__,
    "segmentation_models_pytorch": smp.__version__,
    "opencv": cv2.__version__,
    "numpy": np.__version__,
    "pandas": pd.__version__,
}

env_path = OUT_ROOT / "environment_versions.json"
with open(env_path, "w") as f:
    json.dump(env, f, indent=2)

print("Saved:", env_path)
env