# Micropillar analysis pipeline (C1 pillar masking + C2 precipitation quantification)


This notebook runs the full pipeline:
1) **C1 brightfield ‚Üí pillar mask** using a trained U-Net checkpoint (downloaded from Box)
2) **C2 TL-POL ‚Üí precipitation quantification** restricted to pore space (non-pillar region)

**Inputs:** place images in `data/raw_images/` (or edit paths in the config cell).  
**Outputs:** masks in `outputs/pillar_masks/` and metrics/figures in `outputs/`.


# Part 1 ‚Äî Pillar mask inference (C1 brightfield)


In [None]:
# === CONFIG (edit paths if needed) ===
from pathlib import Path
import urllib.request

# Input folder containing C1 brightfield PNG images
IMAGE_DIR = Path("data/raw_images")  # <-- put your C1 PNG images here

# Output folder for predicted pillar masks
OUTPUT_DIR = Path("outputs/pillar_masks")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Model checkpoint (download from Box if missing)
MODEL_DIR = Path("models")
MODEL_DIR.mkdir(exist_ok=True)
MODEL_PATH = MODEL_DIR / "unet_pillars_finetuned.pt"

BOX_SHARE_LINK = "https://cornell.box.com/s/4dyu78bhtpabm98jgz40gdp5wmoe71xn"
DIRECT_DOWNLOAD = BOX_SHARE_LINK + "?download=1"

if not MODEL_PATH.exists():
    print("Downloading model checkpoint from Box...")
    urllib.request.urlretrieve(DIRECT_DOWNLOAD, MODEL_PATH)
    print("Saved:", MODEL_PATH)
else:
    print("Model already exists:", MODEL_PATH)

# Inference settings
PROB_THRESHOLD = 0.5


In [None]:
# --- config for new dataset (no retraining) ---
import os, glob, re, pandas as pd

ROOT = "data/raw_images"   # <- your new data root
SETS = {                               # short = short1+short2, tall = tall1+tall2
    "short": ["short1", "short2"],
    "tall":  ["tall1",  "tall2"],
}

# existing model checkpoint (what you used for set2/set3 inference)
MODEL_PATH = "data/raw_images"


In [None]:
NAME_RE = re.compile(
    r".*channel\s*(?P<ch>\d+)[^0-9]*(?P<diam>\d+(?:\.\d+)?)"
    r"[^0-9a-zA-Z]*(?P<trial>[a-zA-Z0-9]+)", re.IGNORECASE
)

def parse_name(fname):
    base = os.path.basename(fname)
    m = NAME_RE.match(base)
    if not m:
        return None
    ch = int(m.group("ch"))
    diam = float(m.group("diam"))
    trial = str(m.group("trial"))
    porosity = 0.35 if (ch % 2 == 1) else 0.45   # odd ‚Üí 0.35, even ‚Üí 0.45
    return {"channel": ch, "diameter": diam, "trial": trial, "porosity": porosity}


In [None]:
# === Inference on ALL set4 C1 raw images, save ONLY masks (uses set4-finetuned + THR=0.47) ===
!pip -q install segmentation-models-pytorch albumentations opencv-python

import os, glob, cv2, numpy as np, torch
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

# ---------- CONFIG ----------
try:
    ROOT
except NameError:
    ROOT = "data/raw_images"

MODEL_PATH = "data/raw_images"   # updated model
OUT_MASKS  = "data/raw_images"                          # mirrors set4 structure
SHORT_SIDE = 768
OVERWRITE  = False                                                            # set True to re-write existing masks

# Use tuned threshold if present; else default to 0.47 (from tuner)
try:
    THR = float(best_thr)
except NameError:
    THR = 0.47

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----------------------------

os.makedirs(OUT_MASKS, exist_ok=True)

def resize_short(im, short=SHORT_SIDE):
    h, w = im.shape[:2]; s = short / min(h, w)
    return cv2.resize(im, (int(round(w*s)), int(round(h*s))), interpolation=cv2.INTER_AREA)

IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

print("Loading model:", MODEL_PATH)
model = smp.Unet(encoder_name="resnet18", encoder_weights=None, in_channels=3, classes=1)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE).eval()

# --- gather all *_C1 images anywhere under set4 (skip answers) ---
candidates = [p for p in glob.glob(os.path.join(ROOT, "**", "*.*"), recursive=True) if is_img(p)]
raws_all = [p for p in candidates
            if ("_c1" in os.path.basename(p).lower())
            and ("_answer" not in os.path.basename(p).lower())
            and ("_testanswer" not in os.path.basename(p).lower())]
raws_all = sorted(raws_all)
print(f"Found {len(raws_all)} C1 raw images in {ROOT}")
if len(raws_all) == 0:
    raise FileNotFoundError("‚ùå No *_C1 images found under set4. Double-check filenames and extensions.")

k = np.ones((5,5), np.uint8)
saved = 0; skipped = 0

for rp in raws_all:
    rel_path = os.path.relpath(rp, ROOT)     # e.g., "short1/3_1_a_C1.png"
    rel_dir  = os.path.dirname(rel_path)     # e.g., "short1"
    out_dir  = os.path.join(OUT_MASKS, rel_dir)
    os.makedirs(out_dir, exist_ok=True)

    base = os.path.splitext(os.path.basename(rp))[0]   # e.g., "3_1_a_C1"
    out_mask = os.path.join(out_dir, f"{base}_mask_{THR:.2f}.png")
    if (not OVERWRITE) and os.path.isfile(out_mask):
        skipped += 1
        continue

    bgr = cv2.imread(rp)
    if bgr is None:
        print("‚ö†Ô∏è Skip unreadable:", rp); continue
    H, W = bgr.shape[:2]

    # inference
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    rs  = resize_short(rgb)
    ten = ToTensorV2()(image=rs)["image"].unsqueeze(0).to(DEVICE).float()
    with torch.no_grad():
        prob = torch.sigmoid(model(ten))[0,0].cpu().numpy()

    # resize back + morphology clean
    prob_full = cv2.resize(prob, (W, H), interpolation=cv2.INTER_LINEAR)
    mask_bin  = (prob_full > THR).astype(np.uint8) * 255
    mask_clean = cv2.morphologyEx(cv2.morphologyEx(mask_bin, cv2.MORPH_OPEN, k), cv2.MORPH_CLOSE, k)

    cv2.imwrite(out_mask, mask_clean)
    saved += 1
    print("Saved:", os.path.relpath(out_mask, OUT_MASKS))

print(f"\n‚úÖ Done! Saved {saved} masks to {OUT_MASKS} (skipped existing: {skipped}, THR={THR:.2f})")


In [None]:
# ===== Fine-tune U-Net on labeled set4 C1 images (uses short2 + tall2 *_answer) =====
!pip -q install segmentation-models-pytorch albumentations opencv-python

import os, glob, random, cv2, numpy as np, torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

# ------------ CONFIG ------------
# Try to reuse ROOT from earlier; otherwise fall back to common path.
try:
    ROOT
except NameError:
    ROOT = "data/raw_images"

# Only C1 images are used for pillar masks.
FILTER_TAG = "_C1"

# Previous (base) model and where to save the new fine-tuned weights
BASE_MODEL = "data/raw_images"
SAVE_TO    = "data/raw_images"

SHORT_SIDE  = 768
BATCH_TRAIN, BATCH_VAL = 2, 1
EPOCHS_MAX  = 40
LR          = 1e-4
PATIENCE    = 6
SEED        = 42
# ---------------------------------

os.makedirs(os.path.dirname(SAVE_TO), exist_ok=True)
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

# --- helper: robust label ‚Üí binary mask (supports red overlay OR white-on-black) ---
def label_to_binary(mask_bgr):
    """Return uint8 mask {0,1}. Prefers red-overlay extraction; falls back to grayscale threshold."""
    if mask_bgr is None:
        raise ValueError("label_to_binary: mask_bgr is None")
    # Try red in HSV
    hsv = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2HSV)
    lower1, upper1 = np.array([0,70,40]),  np.array([12,255,255])
    lower2, upper2 = np.array([170,70,40]), np.array([180,255,255])
    m_red = cv2.inRange(hsv, lower1, upper1) | cv2.inRange(hsv, lower2, upper2)
    # If too few pixels detected, try binary via Otsu on grayscale
    frac_red = float(m_red.sum()) / max(1.0, m_red.size)
    if frac_red < 1e-4:  # ~fallback trigger if red overlay isn't present
        gray = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2GRAY)
        # Normalize then Otsu
        g = cv2.normalize(gray, None, 0, 255, cv2.NORM_MINMAX)
        _, m_bw = cv2.threshold(g, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
        m = (m_bw > 0).astype(np.uint8)
    else:
        # Clean up the red mask a bit
        k = np.ones((5,5), np.uint8)
        m = cv2.morphologyEx(m_red, cv2.MORPH_OPEN, k, iterations=1)
        m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, k, iterations=1)
        m = (m > 0).astype(np.uint8)
    return m

# --- collect labeled pairs (C1 raw + *_answer) recursively under set4 ---
def find_labeled_pairs(root, filter_tag=FILTER_TAG):
    # Find all *_answer images (any extension), then match to sibling raw C1 with same base (minus _answer)
    label_paths = glob.glob(os.path.join(root, "**", "*_answer.*"), recursive=True)
    pairs = []
    for mp in label_paths:
        base = os.path.basename(mp)
        stem, _ = os.path.splitext(base)
        if filter_tag.lower() not in stem.lower():
            continue
        raw_stem = stem.replace("_answer", "")
        # Look for raw in the *same directory* with any allowed image extension
        dirpath = os.path.dirname(mp)
        candidates = []
        for ext in IMG_EXTS:
            cand = os.path.join(dirpath, raw_stem + ext)
            if os.path.isfile(cand):
                candidates.append(cand)
        if not candidates:
            # As a fallback, search recursively under root for a basename match (last resort)
            fallback = glob.glob(os.path.join(root, "**", raw_stem + ".*"), recursive=True)
            fallback = [p for p in fallback if is_img(p)]
            candidates = fallback
        if candidates:
            # Prefer first candidate
            ip = candidates[0]
            pairs.append((ip, mp))
    return pairs

pairs = find_labeled_pairs(ROOT, FILTER_TAG)
print(f"Using labeled C1 pairs: {len(pairs)} (expecting ~10 from short2/tall2)")
assert len(pairs) >= 1, "‚ùå No labeled pairs found in set4. Ensure *_C1.* raw and *_answer.* live together."

# Optional: print a tiny sample and per-subdir counts
from collections import Counter
cnt = Counter([os.path.basename(os.path.dirname(p[0])) for p in pairs])
print("Pairs per folder:", dict(cnt))
for i, (ip, mp) in enumerate(pairs[:5]):
    print(f"  {i+1}. RAW={os.path.basename(ip)}  |  ANS={os.path.basename(mp)}")

# --- split 80/20 train/val ---
random.shuffle(pairs)
n_val = max(1, len(pairs)//5)
val_pairs   = pairs[:n_val]
train_pairs = pairs[n_val:]
print(f"Train: {len(train_pairs)} | Val: {len(val_pairs)}")

# --- Dataset class ---
class PillarSeg(Dataset):
    def __init__(self, pairs, augment=None, short_side=SHORT_SIDE):
        self.pairs, self.aug, self.short = pairs, augment, short_side
    def _resize_short(self, im):
        h,w = im.shape[:2]; s = self.short / min(h,w)
        return cv2.resize(im, (int(round(w*s)), int(round(h*s))), interpolation=cv2.INTER_AREA)
    def __len__(self): return len(self.pairs)
    def __getitem__(self, i):
        ip, mp = self.pairs[i]
        img_bgr = cv2.imread(ip)
        if img_bgr is None:
            raise FileNotFoundError(f"Could not read raw image: {ip}")
        msk_bgr = cv2.imread(mp)
        if msk_bgr is None:
            raise FileNotFoundError(f"Could not read mask image: {mp}")
        img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        msk = label_to_binary(msk_bgr)

        img = self._resize_short(img); msk = self._resize_short(msk)
        if self.aug:
            out = self.aug(image=img, mask=msk)
            img = out["image"].float()
            msk = torch.as_tensor(out["mask"], dtype=torch.float32).unsqueeze(0)
        else:
            img = ToTensorV2()(image=img)["image"].float()
            msk = torch.from_numpy(msk).unsqueeze(0).float()
        return img, msk

train_aug = A.Compose([
    A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.08, rotate_limit=10,
                       border_mode=cv2.BORDER_REFLECT, p=0.6),
    A.RandomBrightnessContrast(0.1,0.1,p=0.3),
    ToTensorV2()
])
val_aug = A.Compose([ToTensorV2()])

train_dl = DataLoader(PillarSeg(train_pairs, train_aug), batch_size=BATCH_TRAIN, shuffle=True,  num_workers=0)
val_dl   = DataLoader(PillarSeg(val_pairs,   val_aug),   batch_size=BATCH_VAL,   shuffle=False, num_workers=0)

# --- pos_weight for BCE (handle class imbalance) ---
pos_pix = tot_pix = 0
for _, mp in train_pairs:
    m = label_to_binary(cv2.imread(mp))
    pos_pix += int(m.sum()); tot_pix += int(m.size)
pos_frac = max(1e-6, pos_pix / max(1, tot_pix))
neg_frac = 1.0 - pos_frac
pos_weight_val = float(np.clip(neg_frac / pos_frac, 1.0, 10.0))
print(f"Estimated pos_weight ‚âà {pos_weight_val:.2f} (pos_frac={pos_frac:.4f})")

# --- load model (fine-tune) ---
model = smp.Unet(encoder_name="resnet18", encoder_weights=None, in_channels=3, classes=1)
model.load_state_dict(torch.load(BASE_MODEL, map_location=device))  # load old weights
model.to(device)

# --- losses & optimizer ---
bce = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight_val], device=device))
def dice_loss(logits, targets, eps=1e-7):
    p = torch.sigmoid(logits)
    num = 2*(p*targets).sum(dim=(1,2,3)) + eps
    den = p.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + eps
    return 1 - (num/den).mean()

opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

# --- early stopping ---
class EarlyStopping:
    def __init__(self, patience=6, delta=1e-4):
        self.patience, self.delta = patience, delta
        self.best = float("inf"); self.count = 0
        self.best_state = None
    def step(self, val_loss, model):
        if val_loss < self.best - self.delta:
            self.best = val_loss; self.count = 0
            self.best_state = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()}
            return False, True
        else:
            self.count += 1
            return self.count >= self.patience, False

early = EarlyStopping(patience=PATIENCE, delta=1e-4)
scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available())

# --- training loop ---
for ep in range(1, EPOCHS_MAX+1):
    # train
    model.train(); tr=0.0
    for x,y in train_dl:
        x=x.to(device).float(); y=y.to(device).float()
        opt.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
            logits = model(x)
            loss = 0.5*bce(logits,y) + 0.5*dice_loss(logits,y)
        scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
        tr += loss.item()
    tr /= max(1,len(train_dl))

    # validate
    model.eval(); vl=0.0; n=0
    with torch.no_grad(), torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
        for x,y in val_dl:
            x=x.to(device).float(); y=y.to(device).float()
            logits = model(x)
            vl += (0.5*bce(logits,y) + 0.5*dice_loss(logits,y)).item()
            n += 1
    vl /= max(1,n)

    stop, improved = early.step(vl, model)
    if improved or ep==1:
        torch.save(early.best_state if early.best_state is not None else model.state_dict(), SAVE_TO)
    print(f"Epoch {ep:02d} | train_loss {tr:.4f} | val_loss {vl:.4f} | saved_best={'yes' if improved or ep==1 else 'no'}")
    if stop:
        print("Early stopping."); break

print("‚úÖ Best fine-tuned model saved at:", SAVE_TO)


In [None]:
# === Storyboard for random set4 C2 image with its matched C1 mask + metrics
#     (robust mask matching + conservative C2 thresholding) ===
import os, glob, cv2, random, re
import numpy as np
import matplotlib.pyplot as plt

# ---------------- Paths ----------------
try:
    ROOT
except NameError:
    ROOT = "data/raw_images"  # set4 root with short1/short2/tall1/tall2
C2_FOLDER    = ROOT
C1_MASK_DIR  = "data/raw_images"  # AI masks (mirrors set4 structure)

# ---------------- Imaging scale ----------------
FIELD_UM   = 1331.2    # field of view (¬µm)
IMG_PX     = 2048      # pixels per side
um_per_px  = FIELD_UM / IMG_PX
mm2_per_px = (um_per_px / 1000.0) ** 2  # not used, kept for completeness

# ---------------- Scale bar style ----------------
SCALE_UM      = 500
BAR_THICK_PX  = 35
MARGIN_PX     = 100
FONT          = cv2.FONT_HERSHEY_SIMPLEX
FONT_SCALE    = 4
FONT_THICK    = 13

# ---------------- C2 binarization (more conservative) ----------------
# Choose ONE mode: "percentile" (recommended), "fixed", or "otsu+"
C2_THRESH_MODE = "fixed"
C2_FIXED       = 140          # if mode == "fixed" (raise to pick less)
C2_PERCENTILE  = 94        # if mode == "percentile" (98‚Äì99 picks less)
C2_DELTA       = 15           # if mode == "otsu+" (Otsu + delta => stricter)

# ---------------- Helpers ----------------
IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

def add_scale_bar(gray_img, um_per_px, scale_um=SCALE_UM,
                  bar_thick_px=BAR_THICK_PX, margin_px=MARGIN_PX):
    """Return RGB copy of gray_img with a scale bar."""
    rgb = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB)
    h, w = gray_img.shape[:2]
    bar_len_px = int(round(scale_um / um_per_px))
    bar_len_px = min(bar_len_px, w - 2*margin_px - 1)
    x2 = w - margin_px; x1 = x2 - bar_len_px
    y2 = h - margin_px; y1 = y2 - bar_thick_px
    outline_pad = 2
    cv2.rectangle(rgb, (x1 - outline_pad, y1 - outline_pad),
                  (x2 + outline_pad, y2 + outline_pad), (0,0,0), -1)
    cv2.rectangle(rgb, (x1, y1), (x2, y2), (255,255,255), -1)
    label = f"{scale_um:.0f} um"
    (tw, th), _ = cv2.getTextSize(label, FONT, FONT_SCALE, FONT_THICK)
    tx = x1 + (bar_len_px - tw) // 2
    ty = y1 - 8
    cv2.putText(rgb, label, (tx, ty), FONT, FONT_SCALE, (0,0,0), FONT_THICK+2, cv2.LINE_AA)
    cv2.putText(rgb, label, (tx, ty), FONT, FONT_SCALE, (255,255,255), FONT_THICK, cv2.LINE_AA)
    return rgb

# --- robust mask selection (prefer same folder; then same top-level; then global) ---
def parse_thr_from_name(path):
    m = re.search(r"_mask_(\d+(?:\.\d+)?)", os.path.basename(path))
    return float(m.group(1)) if m else None

# If you ran the tuner, best_thr may exist; prefer masks closest to it
THR_TARGET = None
try:
    THR_TARGET = float(best_thr)  # e.g., 0.47
except Exception:
    pass

def pick_best_mask(candidates):
    if not candidates:
        return None
    # Prefer closest threshold to THR_TARGET (if known), then newest file
    if THR_TARGET is not None:
        candidates = sorted(
            candidates,
            key=lambda p: (abs((parse_thr_from_name(p) or 999.0) - THR_TARGET), -os.path.getmtime(p))
        )
    else:
        candidates = sorted(candidates, key=lambda p: -os.path.getmtime(p))
    return candidates[0]

def find_mask_for_c2(c2_path):
    """Find matching C1 mask for given C2 path."""
    stem = os.path.splitext(os.path.basename(c2_path))[0]   # e.g., "3_1_a_C2"
    base_key = stem.rsplit("_C2", 1)[0]                     # -> "3_1_a"
    rel_path = os.path.relpath(c2_path, C2_FOLDER)
    rel_dir  = os.path.dirname(rel_path)                    # e.g., "short2" or nested

    # 1) Same relative directory under set4_all_masks/
    candidates = glob.glob(os.path.join(C1_MASK_DIR, rel_dir, f"{base_key}_C1_mask_*.png"))
    mpath = pick_best_mask(candidates)
    if mpath:
        return base_key, mpath, "same_dir"

    # 2) Same top-level (short1/short2/tall1/tall2)
    top = rel_dir.split(os.sep)[0] if rel_dir else ""
    if top:
        candidates = glob.glob(os.path.join(C1_MASK_DIR, top, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
        mpath = pick_best_mask(candidates)
        if mpath:
            return base_key, mpath, "same_top"

    # 3) Anywhere under set4_all_masks
    candidates = glob.glob(os.path.join(C1_MASK_DIR, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
    mpath = pick_best_mask(candidates)
    if mpath:
        return base_key, mpath, "global"

    return base_key, None, "not_found"

# --- C2 binarization (conservative) ---
def to_u8(img):
    if img.dtype == np.uint8:
        return img
    return cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def binarize_c2(C2_raw, nonpillar_mask):
    """Return binary crystal mask with conservative thresholding."""
    img = to_u8(C2_raw)
    roi = cv2.bitwise_and(img, img, mask=(nonpillar_mask * 255))
    blur = cv2.GaussianBlur(roi, (3,3), 0)

    if C2_THRESH_MODE == "fixed":
        thr_val = int(C2_FIXED)
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    elif C2_THRESH_MODE == "percentile":
        vals = blur[nonpillar_mask > 0]
        thr_val = int(np.percentile(vals, C2_PERCENTILE)) if vals.size else 255
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    else:  # "otsu+"
        _, t0 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        thr_val = int(max(0, min(255, t0 + C2_DELTA)))
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    thr[nonpillar_mask == 0] = 0
    k = np.ones((3,3), np.uint8)
    thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN,  k, iterations=1)
    thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, k, iterations=1)
    return thr

# ---------------- Find a C2 + matching C1 mask ----------------
c2_files = sorted([p for p in glob.glob(os.path.join(C2_FOLDER, "**", "*.*"), recursive=True)
                   if is_img(p)
                   and ("_c2" in os.path.basename(p).lower())
                   and ("_answer" not in os.path.basename(p).lower())])
if not c2_files:
    raise FileNotFoundError("No *_C2 images found under set4.")

c2_path = random.choice(c2_files)
base_key, mpath, match_mode = find_mask_for_c2(c2_path)
if not mpath:
    raise FileNotFoundError(f"No C1 mask found for base '{base_key}'")

# find raw C1 (same dir first, then global)
c2_dir = os.path.dirname(c2_path)
c1_path = None
for ext in IMG_EXTS:
    cand = os.path.join(c2_dir, f"{base_key}_C1{ext}")
    if os.path.isfile(cand):
        c1_path = cand; break
if c1_path is None:
    cands = []
    for ext in IMG_EXTS:
        cands += glob.glob(os.path.join(C2_FOLDER, "**", f"{base_key}_C1{ext}"), recursive=True)
    if cands: c1_path = sorted(cands)[0]
if not c1_path:
    raise FileNotFoundError(f"Missing raw C1 for base '{base_key}'")

print(f"Storyboard for: {base_key}")
print(f"Mask match mode: {match_mode} ‚Üí {os.path.relpath(mpath, C1_MASK_DIR)}")

# ---------------- Load images ----------------
C1_raw  = cv2.imread(c1_path, cv2.IMREAD_GRAYSCALE)
C2_raw  = cv2.imread(c2_path, cv2.IMREAD_GRAYSCALE)
C1_mask = cv2.imread(mpath, cv2.IMREAD_GRAYSCALE)
if C1_mask.shape != C2_raw.shape:
    C1_mask = cv2.resize(C1_mask, (C2_raw.shape[1], C2_raw.shape[0]), interpolation=cv2.INTER_NEAREST)

# ---------------- Metrics: pillar vs non-pillar ----------------
total_pixels     = C1_mask.size
pillar_pixels    = int((C1_mask > 0).sum())
nonpillar_pixels = total_pixels - pillar_pixels
pillar_pct       = 100.0 * pillar_pixels / total_pixels
nonpillar_pct    = 100.0 * nonpillar_pixels / total_pixels

# ---------------- C2 processing (conservative) ----------------
channel   = (C1_mask == 0).astype(np.uint8)
precip    = binarize_c2(C2_raw, channel)

# ---------------- Crystal metrics ----------------
crystal_pixels           = int((precip > 0).sum())
crystal_pct              = 100.0 * crystal_pixels / total_pixels
crystal_in_nonpillar_pct = 100.0 * crystal_pixels / max(1, nonpillar_pixels)

# ---------------- Overlays ----------------
proc_rgb = cv2.cvtColor(C2_raw, cv2.COLOR_GRAY2RGB)
proc_rgb[precip > 0] = (255,0,0)             # crystals in red
proc_rgb[C1_mask > 0] = (0,0,0)              # pillars black for contrast

overlay = np.zeros_like(proc_rgb)
contours, _ = cv2.findContours((C1_mask>0).astype(np.uint8),
                               cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(overlay, contours, -1, (255,255,0), 2)   # pillar outlines in yellow
proc_with_outline = cv2.addWeighted(proc_rgb, 1.0, overlay, 1.0, 0)

final_rgb = np.zeros_like(proc_rgb)
final_rgb[C1_mask > 0] = (0,0,0)             # pillars
final_rgb[channel > 0] = (255,255,255)       # non-pillar background
final_rgb[precip > 0]  = (255,0,0)           # crystals

# ---------------- Add scale bars ----------------
C1_with_bar = add_scale_bar(C1_raw, um_per_px, SCALE_UM)
C2_with_bar = add_scale_bar(C2_raw, um_per_px, SCALE_UM)

# ---------------- Plot storyboard ----------------
fig, axes = plt.subplots(1,5, figsize=(22,5))

axes[0].imshow(C1_with_bar); axes[0].set_title("C1 Brightfield (raw + scale bar)"); axes[0].axis("off")
axes[1].imshow(C1_mask, cmap="gray"); axes[1].set_title(f"Pillar Mask\nNon-pillar: {nonpillar_pct:.1f}%"); axes[1].axis("off")
axes[2].imshow(C2_with_bar); axes[2].set_title("C2 TL-POL (raw + scale bar)"); axes[2].axis("off")
axes[3].imshow(proc_with_outline); axes[3].set_title("C2 Processed\n+ Pillar Outlines"); axes[3].axis("off")
axes[4].imshow(final_rgb); axes[4].set_title(
    f"Final Composite\n"
    f"Pillar: {pillar_pct:.1f}% | Non-pillar: {nonpillar_pct:.1f}%\n"
    f"Crystal: {crystal_pct:.1f}% | Crystal/Non-pillar: {crystal_in_nonpillar_pct:.1f}%"
); axes[4].axis("off")

plt.suptitle(f"{base_key} | 2048√ó2048 px | {FIELD_UM:.1f} ¬µm FOV (~{um_per_px:.2f} ¬µm/px)",
             fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()


In [None]:
# === SET4 ‚Äî 3-Panel: C1 Raw | C2 Raw | Composite+IDs + 0.25 mm Grid (no save)
# - Filters super-small crystals (area & eq. diameter)
# - Offsets labels to avoid features and other labels
# - Draws leader lines if label is far from the crystal
# - Uses set4 robust mask matching + conservative TL-POL thresholding
import os, glob, cv2, random, re
import numpy as np
import matplotlib.pyplot as plt

# ---------------- Paths ----------------
try:
    ROOT
except NameError:
    ROOT = "data/raw_images"  # set4 root with short1/short2/tall1/tall2
C2_FOLDER    = ROOT
C1_MASK_DIR  = "data/raw_images"  # AI masks (mirrors set4 structure)

# ---------------- Imaging scale (set4) ----------------
FIELD_UM   = 1331.2   # field of view (¬µm) for 2048√ó2048 TL-POL
IMG_PX     = 2048     # expected px per side for this FOV
um_per_px  = FIELD_UM / IMG_PX
FIELD_MM   = FIELD_UM / 1000.0  # ‚âà 1.3312 mm FOV (set4)

# ---------------- Grid ----------------
GRID_MM = 0.25  # grid spacing in mm

# ---------------- Crystal filters ----------------
MIN_CRYSTAL_AREA_PX = 25     # area threshold (raise to be stricter)
MIN_EQDIAM_PX       = 4.0    # eq. diameter threshold in px = sqrt(4A/œÄ)

# ---------------- Label/search tunables ----------------
LABEL_THICKNESS     = 2
SEARCH_DIRECTIONS   = [(1,-1), (-1,-1), (1,1), (-1,1), (2,0), (0,-2), (-2,0), (0,2)]  # NE,NW,SE,SW,E,N,W,S
EXTRA_OFFSETS       = [1.0, 1.5, 2.0, 2.5, 3.0]
LEADER_IF_FAR_MUL   = 0.035  # leader if centroid‚Üílabel distance > this*min(H,W)
LABEL_MARGIN_PX     = 3      # spacing between label rectangles

# ---------------- C2 binarization (conservative) ----------------
# Choose ONE mode: "percentile" (recommended), "fixed", or "otsu+"
C2_THRESH_MODE = "fixed"
C2_FIXED       = 140
C2_PERCENTILE  = 94
C2_DELTA       = 15

# ---------------- Helpers ----------------
IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

def to_u8(img):
    if img.dtype == np.uint8: return img
    return cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def binarize_c2(C2_raw, nonpillar_mask):
    """Return binary crystal mask with conservative thresholding."""
    img = to_u8(C2_raw)
    roi = cv2.bitwise_and(img, img, mask=(nonpillar_mask * 255))
    blur = cv2.GaussianBlur(roi, (3,3), 0)

    if C2_THRESH_MODE == "fixed":
        thr_val = int(C2_FIXED)
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)
    elif C2_THRESH_MODE == "percentile":
        vals = blur[nonpillar_mask > 0]
        thr_val = int(np.percentile(vals, C2_PERCENTILE)) if vals.size else 255
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)
    else:  # "otsu+"
        _, t0 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        thr_val = int(max(0, min(255, t0 + C2_DELTA)))
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    thr[nonpillar_mask == 0] = 0
    k = np.ones((3,3), np.uint8)
    thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN,  k, iterations=1)
    thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, k, iterations=1)
    return thr

# --- robust mask selection (prefer same folder; then same top-level; then global) ---
def parse_thr_from_name(path):
    m = re.search(r"_mask_(\d+(?:\.\d+)?)", os.path.basename(path))
    return float(m.group(1)) if m else None

THR_TARGET = None
try:
    THR_TARGET = float(best_thr)  # if you set best_thr elsewhere, we bias toward it
except Exception:
    pass

def pick_best_mask(candidates):
    if not candidates:
        return None
    if THR_TARGET is not None:
        candidates = sorted(
            candidates,
            key=lambda p: (abs((parse_thr_from_name(p) or 999.0) - THR_TARGET), -os.path.getmtime(p))
        )
    else:
        candidates = sorted(candidates, key=lambda p: -os.path.getmtime(p))
    return candidates[0]

def find_mask_for_c2(c2_path):
    """Find matching C1 mask for given C2 path, preferring same dir/top-level."""
    stem     = os.path.splitext(os.path.basename(c2_path))[0]   # e.g., "3_1_a_C2"
    base_key = stem.rsplit("_C2", 1)[0]                         # -> "3_1_a"
    rel_path = os.path.relpath(c2_path, C2_FOLDER)
    rel_dir  = os.path.dirname(rel_path)                        # e.g., "short2"

    # 1) Same relative directory
    candidates = glob.glob(os.path.join(C1_MASK_DIR, rel_dir, f"{base_key}_C1_mask_*.png"))
    mpath = pick_best_mask(candidates)
    if mpath:
        return base_key, mpath, "same_dir"

    # 2) Same top-level (short1/short2/tall1/tall2)
    top = rel_dir.split(os.sep)[0] if rel_dir else ""
    if top:
        candidates = glob.glob(os.path.join(C1_MASK_DIR, top, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
        mpath = pick_best_mask(candidates)
        if mpath:
            return base_key, mpath, "same_top"

    # 3) Anywhere
    candidates = glob.glob(os.path.join(C1_MASK_DIR, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
    mpath = pick_best_mask(candidates)
    if mpath:
        return base_key, mpath, "global"

    return base_key, None, "not_found"

# ---------------- Pick a set4 C2 + matching C1 mask ----------------
c2_files = sorted([p for p in glob.glob(os.path.join(C2_FOLDER, "**", "*.*"), recursive=True)
                   if is_img(p) and ("_c2" in os.path.basename(p).lower())
                   and ("_answer" not in os.path.basename(p).lower())])
if not c2_files:
    raise FileNotFoundError("No *_C2 images found under set4.")

c2_path = random.choice(c2_files)
base_key, mpath, match_mode = find_mask_for_c2(c2_path)
if not mpath:
    raise FileNotFoundError(f"No C1 mask found for base '{base_key}'")

# Find raw C1 (prefer same dir, else newest anywhere)
c2_dir = os.path.dirname(c2_path)
c1_path = None
for ext in IMG_EXTS:
    cand = os.path.join(c2_dir, f"{base_key}_C1{ext}")
    if os.path.isfile(cand):
        c1_path = cand
        break
if c1_path is None:
    cands = []
    for ext in IMG_EXTS:
        cands += glob.glob(os.path.join(C2_FOLDER, "**", f"{base_key}_C1{ext}"), recursive=True)
    if cands:
        c1_path = max(cands, key=os.path.getmtime)
if not c1_path:
    raise FileNotFoundError(f"Missing raw C1 for base '{base_key}'")

print(f"[set4] Base: {base_key} | Mask match: {match_mode} ‚Üí {os.path.relpath(mpath, C1_MASK_DIR)}")

# ---------------- Load images ----------------
C1_raw  = cv2.imread(c1_path, cv2.IMREAD_GRAYSCALE)
C2_raw  = cv2.imread(c2_path, cv2.IMREAD_GRAYSCALE)
C1_mask = cv2.imread(mpath,   cv2.IMREAD_GRAYSCALE)
if C1_mask.shape != C2_raw.shape:
    C1_mask = cv2.resize(C1_mask, (C2_raw.shape[1], C2_raw.shape[0]), interpolation=cv2.INTER_NEAREST)

H, W   = C2_raw.shape[:2]
minHW  = min(H, W)

# ---------------- Size-scaled label params ----------------
FONT_SCALE     = max(0.45, minHW * 0.0008)
BASE_OFFSET_PX = max(8,    int(minHW * 0.018))
SAFE_PAD_PX    = max(4,    int(minHW * 0.006))
LEADER_IF_FAR  = int(minHW * LEADER_IF_FAR_MUL)

# ---------------- Segment crystals in non-pillar channel ----------------
channel = (C1_mask == 0).astype(np.uint8)  # non-pillar
precip  = binarize_c2(C2_raw, channel)

# ---------------- Build composite (panel 3 base) ----------------
final_rgb = np.zeros((H, W, 3), dtype=np.uint8)
final_rgb[C1_mask > 0] = (0,0,0)        # black pillars
final_rgb[channel > 0] = (255,255,255)  # white channel
final_rgb[precip > 0]  = (255,0,0)      # red crystals

# ---------------- Outlines + OFFSET black IDs on composite ----------------
composite = final_rgb.copy()

# contours -> filter super-small by area and equivalent diameter
raw_ctrs, _ = cv2.findContours((precip > 0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
crystal_ctrs = []
for c in raw_ctrs:
    A = cv2.contourArea(c)
    if A < MIN_CRYSTAL_AREA_PX:
        continue
    eqdiam = np.sqrt(4.0 * A / np.pi)
    if eqdiam < MIN_EQDIAM_PX:
        continue
    crystal_ctrs.append(c)

# thin yellow outline for context
cv2.drawContours(composite, crystal_ctrs, -1, (0,255,255), 1)

precip_mask = (precip > 0).astype(np.uint8)
pillar_mask = (C1_mask > 0).astype(np.uint8)
font        = cv2.FONT_HERSHEY_SIMPLEX

def rect_is_clear(x, y, w, h, precip_mask, pillar_mask, pad=0):
    x0 = max(0, x - pad); y0 = max(0, y - pad)
    x1 = min(W, x + w + pad); y1 = min(H, y + h + pad)
    if x0 >= x1 or y0 >= y1: return False
    return (precip_mask[y0:y1, x0:x1].max() == 0) and (pillar_mask[y0:y1, x0:x1].max() == 0)

def rects_overlap(a, b, margin=0):
    ax, ay, aw, ah = a; bx, by, bw, bh = b
    return not (ax+aw+margin <= bx or bx+bw+margin <= ax or ay+ah+margin <= by or by+bh+margin <= ay)

def find_label_pos(cx, cy, text, placed_rects):
    (tw, th), _ = cv2.getTextSize(text, font, FONT_SCALE, LABEL_THICKNESS)
    tw = max(tw, 8); th = max(th, 10)
    for mul in EXTRA_OFFSETS:
        step = int(BASE_OFFSET_PX * mul)
        for dx, dy in SEARCH_DIRECTIONS:
            x = int(cx + dx * step); y = int(cy + dy * step)
            rx = np.clip(x, 0, W-1); ry = np.clip(y - th, 0, H-1)
            if rx + tw >= W: rx = W - tw - 1
            if ry + th >= H: ry = H - th - 1
            if rx < 0 or ry < 0:
                continue
            rect = (rx, ry, tw, th)
            # avoid crystals/pillars and other labels
            if not rect_is_clear(rx, ry, tw, th, precip_mask, pillar_mask, pad=SAFE_PAD_PX):
                continue
            if any(rects_overlap(rect, pr, margin=LABEL_MARGIN_PX) for pr in placed_rects):
                continue
            return (rx, ry, tw, th)
    # fallback NE clamp
    rx = min(max(cx + BASE_OFFSET_PX, 0), W - 8)
    (tw, th), _ = cv2.getTextSize(text, font, FONT_SCALE, LABEL_THICKNESS)
    ry = min(max(cy - BASE_OFFSET_PX - th, 0), H - th - 1)
    return (rx, ry, tw, th)

placed = []
centroids_and_labels = []

for idx, c in enumerate(crystal_ctrs, start=1):
    M = cv2.moments(c)
    if M["m00"] > 0:
        cx = int(M["m10"] / M["m00"]); cy = int(M["m01"] / M["m00"])
    else:
        cx, cy = c[0,0,0], c[0,0,1]
    label = str(idx)
    rx, ry, tw, th = find_label_pos(cx, cy, label, placed)
    cv2.putText(composite, label, (rx, ry + th - 2), font, FONT_SCALE, (0,0,0), LABEL_THICKNESS, cv2.LINE_AA)
    placed.append((rx, ry, tw, th))
    centroids_and_labels.append(((cx, cy), (rx, ry, tw, th)))

# leader lines for far labels
LEADER_IF_FAR = int(minHW * LEADER_IF_FAR_MUL)
for (cx, cy), (rx, ry, tw, th) in centroids_and_labels:
    lx = rx + tw // 2; ly = ry + th // 2
    dist = int(np.hypot(lx - cx, ly - cy))
    if dist > LEADER_IF_FAR:
        cv2.line(composite, (cx, cy), (lx, ly), (0,0,0), 1, lineType=cv2.LINE_AA)

n_crystals_labeled = len(crystal_ctrs)

# ---------------- 0.25 mm grid on composite ----------------
grid_img = composite.copy()
# Use set4 optical scale so grid spacing is metric-true
px_per_mm = 1.0 / (um_per_px / 1000.0)   # px per mm
step      = GRID_MM * px_per_mm

GRID_COLOR   = (110, 110, 110)  # darker gray for visibility
GRID_THICK   = 2
BORDER_COLOR = (0, 0, 0)
BORDER_THICK = 2

# vertical grid
x_pos = 0.0
while x_pos <= W:
    x = int(round(x_pos))
    cv2.line(grid_img, (x, 0), (x, H-1), GRID_COLOR, GRID_THICK, lineType=cv2.LINE_AA)
    x_pos += step
# horizontal grid
y_pos = 0.0
while y_pos <= H:
    y = int(round(y_pos))
    cv2.line(grid_img, (0, y), (W-1, y), GRID_COLOR, GRID_THICK, lineType=cv2.LINE_AA)
    y_pos += step
# border box
cv2.rectangle(grid_img, (0,0), (W-1, H-1), BORDER_COLOR, BORDER_THICK, lineType=cv2.LINE_AA)

# ---------------- High-res display (no saving) ----------------
SCALE_UP = 2.0
grid_img_hr = cv2.resize(grid_img, (int(W*SCALE_UP), int(H*SCALE_UP)), interpolation=cv2.INTER_CUBIC)

fig, axes = plt.subplots(1, 3, figsize=(16,5), dpi=300)

axes[0].imshow(C1_raw, cmap="gray", interpolation="nearest")
axes[0].set_title("C1 Raw"); axes[0].axis("off")

axes[1].imshow(C2_raw, cmap="gray", interpolation="nearest")
axes[1].set_title("C2 Raw"); axes[1].axis("off")

axes[2].imshow(grid_img_hr, interpolation="nearest")
axes[2].set_title(f"Composite + IDs + 0.25 mm Grid  |  N = {n_crystals_labeled}")
axes[2].axis("off")

plt.suptitle(f"{base_key}  |  {W}√ó{H}px  |  set4 FOV: {FIELD_MM:.3f} mm (~{um_per_px:.3f} ¬µm/px)",
             fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()


In [None]:
# === SET4 ‚Äî 3-Panel: C1 Raw | C2 Raw | Composite + IDs-every-20 + 0.25 mm Grid (no save)
# - Thicker crystal outlines
# - Print ID labels only for every Nth crystal (default 20)
# - Offsets labels (avoids features & other labels); draws leader lines if far
# - Scale bar EXACTLY like the 5-photo storyboard (white 500 um with black outline)

import os, glob, cv2, random, re
import numpy as np
import matplotlib.pyplot as plt

# ---------------- Paths ----------------
try:
    ROOT
except NameError:
    ROOT = "data/raw_images"  # set4 root with short1/short2/tall1/tall2
C2_FOLDER    = ROOT
C1_MASK_DIR  = "data/raw_images"

# ---------------- Imaging scale (set4) ----------------
FIELD_UM   = 1331.2   # FOV (¬µm) for 2048√ó2048
IMG_PX     = 2048
um_per_px  = FIELD_UM / IMG_PX
FIELD_MM   = FIELD_UM / 1000.0

# ---------------- Grid ----------------
GRID_MM = 0.25

# ---------------- Crystal filters ----------------
MIN_CRYSTAL_AREA_PX = 25
MIN_EQDIAM_PX       = 4.0

# ---------------- Labeling / outlines ----------------
LABEL_EVERY_N       = 20      # <-- label only every 20th crystal (20, 40, 60, ...)
LABEL_THICKNESS     = 2
OUTLINE_THICK       = 3       # <-- thicker crystal outlines
SEARCH_DIRECTIONS   = [(1,-1), (-1,-1), (1,1), (-1,1), (2,0), (0,-2), (-2,0), (0,2)]
EXTRA_OFFSETS       = [1.0, 1.5, 2.0, 2.5, 3.0]
LEADER_IF_FAR_MUL   = 0.035
LABEL_MARGIN_PX     = 3

# ---------------- C2 binarization ----------------
# Choose ONE: "percentile" | "fixed" | "otsu+"
C2_THRESH_MODE = "fixed"
C2_FIXED       = 140
C2_PERCENTILE  = 94
C2_DELTA       = 15

# ---------------- Helpers ----------------
IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

def to_u8(img):
    if img.dtype == np.uint8: return img
    return cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def binarize_c2(C2_raw, nonpillar_mask):
    img = to_u8(C2_raw)
    roi = cv2.bitwise_and(img, img, mask=(nonpillar_mask * 255))
    blur = cv2.GaussianBlur(roi, (3,3), 0)

    if C2_THRESH_MODE == "fixed":
        thr_val = int(C2_FIXED)
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)
    elif C2_THRESH_MODE == "percentile":
        vals = blur[nonpillar_mask > 0]
        thr_val = int(np.percentile(vals, C2_PERCENTILE)) if vals.size else 255
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)
    else:  # "otsu+"
        _, t0 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        thr_val = int(max(0, min(255, t0 + C2_DELTA)))
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    thr[nonpillar_mask == 0] = 0
    k = np.ones((3,3), np.uint8)
    thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN,  k, iterations=1)
    thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, k, iterations=1)
    return thr

# --- robust mask selection ---
def parse_thr_from_name(path):
    m = re.search(r"_mask_(\d+(?:\.\d+)?)", os.path.basename(path))
    return float(m.group(1)) if m else None

THR_TARGET = None
try:
    THR_TARGET = float(best_thr)
except Exception:
    pass

def pick_best_mask(candidates):
    if not candidates:
        return None
    if THR_TARGET is not None:
        candidates = sorted(
            candidates,
            key=lambda p: (abs((parse_thr_from_name(p) or 999.0) - THR_TARGET), -os.path.getmtime(p))
        )
    else:
        candidates = sorted(candidates, key=lambda p: -os.path.getmtime(p))
    return candidates[0]

def find_mask_for_c2(c2_path):
    stem     = os.path.splitext(os.path.basename(c2_path))[0]  # e.g., "3_1_a_C2"
    base_key = stem.rsplit("_C2", 1)[0]
    rel_path = os.path.relpath(c2_path, C2_FOLDER)
    rel_dir  = os.path.dirname(rel_path)

    # 1) Same relative directory
    candidates = glob.glob(os.path.join(C1_MASK_DIR, rel_dir, f"{base_key}_C1_mask_*.png"))
    mpath = pick_best_mask(candidates)
    if mpath: return base_key, mpath, "same_dir"

    # 2) Same top-level
    top = rel_dir.split(os.sep)[0] if rel_dir else ""
    if top:
        candidates = glob.glob(os.path.join(C1_MASK_DIR, top, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
        mpath = pick_best_mask(candidates)
        if mpath: return base_key, mpath, "same_top"

    # 3) Anywhere
    candidates = glob.glob(os.path.join(C1_MASK_DIR, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
    mpath = pick_best_mask(candidates)
    if mpath: return base_key, mpath, "global"

    return base_key, None, "not_found"

# ---------------- Pick a set4 C2 + matching C1 mask ----------------
c2_files = sorted([p for p in glob.glob(os.path.join(C2_FOLDER, "**", "*.*"), recursive=True)
                   if is_img(p) and ("_c2" in os.path.basename(p).lower())
                   and ("_answer" not in os.path.basename(p).lower())])
if not c2_files:
    raise FileNotFoundError("No *_C2 images found under set4.")

c2_path = random.choice(c2_files)
base_key, mpath, match_mode = find_mask_for_c2(c2_path)
if not mpath:
    raise FileNotFoundError(f"No C1 mask found for base '{base_key}'")

# Find raw C1 (prefer same dir, else newest anywhere)
c2_dir = os.path.dirname(c2_path)
c1_path = None
for ext in IMG_EXTS:
    cand = os.path.join(c2_dir, f"{base_key}_C1{ext}")
    if os.path.isfile(cand):
        c1_path = cand; break
if c1_path is None:
    cands = []
    for ext in IMG_EXTS:
        cands += glob.glob(os.path.join(C2_FOLDER, "**", f"{base_key}_C1{ext}"), recursive=True)
    if cands:
        c1_path = max(cands, key=os.path.getmtime)
if not c1_path:
    raise FileNotFoundError(f"Missing raw C1 for base '{base_key}'")

print(f"[set4] Base: {base_key} | Mask match: {match_mode} ‚Üí {os.path.relpath(mpath, C1_MASK_DIR)}")

# ---------------- Load images ----------------
C1_raw  = cv2.imread(c1_path, cv2.IMREAD_GRAYSCALE)
C2_raw  = cv2.imread(c2_path, cv2.IMREAD_GRAYSCALE)
C1_mask = cv2.imread(mpath,   cv2.IMREAD_GRAYSCALE)
if C1_mask.shape != C2_raw.shape:
    C1_mask = cv2.resize(C1_mask, (C2_raw.shape[1], C2_raw.shape[0]), interpolation=cv2.INTER_NEAREST)

H, W   = C2_raw.shape[:2]
minHW  = min(H, W)

# ---------------- Size-scaled label params ----------------
FONT_SCALE     = max(0.45, minHW * 0.0008)
BASE_OFFSET_PX = max(8,    int(minHW * 0.018))
SAFE_PAD_PX    = max(4,    int(minHW * 0.006))
LEADER_IF_FAR  = int(minHW * LEADER_IF_FAR_MUL)
font           = cv2.FONT_HERSHEY_SIMPLEX

# ---------------- Segment crystals in non-pillar channel ----------------
channel = (C1_mask == 0).astype(np.uint8)  # non-pillar
precip  = binarize_c2(C2_raw, channel)

# ---------------- Build composite (panel 3 base) ----------------
final_rgb = np.zeros((H, W, 3), dtype=np.uint8)
final_rgb[C1_mask > 0] = (0,0,0)        # black pillars
final_rgb[channel > 0] = (255,255,255)  # white channel
final_rgb[precip > 0]  = (255,0,0)      # red crystals

# ---------------- Outlines + IDs-every-20 (offset) ----------------
composite = final_rgb.copy()

# contours -> filter super-small by area and equivalent diameter
raw_ctrs, _ = cv2.findContours((precip > 0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
crystal_ctrs = []
for c in raw_ctrs:
    A = cv2.contourArea(c)
    if A < MIN_CRYSTAL_AREA_PX:
        continue
    eqdiam = np.sqrt(4.0 * A / np.pi)
    if eqdiam < MIN_EQDIAM_PX:
        continue
    crystal_ctrs.append(c)

# thicker yellow outlines for ALL crystals
cv2.drawContours(composite, crystal_ctrs, -1, (0,255,255), OUTLINE_THICK)

precip_mask = (precip > 0).astype(np.uint8)
pillar_mask = (C1_mask > 0).astype(np.uint8)

def rect_is_clear(x, y, w, h, precip_mask, pillar_mask, pad=0):
    x0 = max(0, x - pad); y0 = max(0, y - pad)
    x1 = min(W, x + w + pad); y1 = min(H, y + h + pad)
    if x0 >= x1 or y0 >= y1: return False
    return (precip_mask[y0:y1, x0:x1].max() == 0) and (pillar_mask[y0:y1, x0:x1].max() == 0)

def rects_overlap(a, b, margin=0):
    ax, ay, aw, ah = a; bx, by, bw, bh = b
    return not (ax+aw+margin <= bx or bx+bw+margin <= ax or ay+ah+margin <= by or by+bh+margin <= ay)

def find_label_pos(cx, cy, text, placed_rects):
    (tw, th), _ = cv2.getTextSize(text, font, FONT_SCALE, LABEL_THICKNESS)
    tw = max(tw, 8); th = max(th, 10)
    for mul in EXTRA_OFFSETS:
        step = int(BASE_OFFSET_PX * mul)
        for dx, dy in SEARCH_DIRECTIONS:
            x = int(cx + dx * step); y = int(cy + dy * step)
            rx = np.clip(x, 0, W-1); ry = np.clip(y - th, 0, H-1)
            if rx + tw >= W: rx = W - tw - 1
            if ry + th >= H: ry = H - th - 1
            if rx < 0 or ry < 0:
                continue
            rect = (rx, ry, tw, th)
            if not rect_is_clear(rx, ry, tw, th, precip_mask, pillar_mask, pad=SAFE_PAD_PX):
                continue
            if any(rects_overlap(rect, pr, margin=LABEL_MARGIN_PX) for pr in placed_rects):
                continue
            return (rx, ry, tw, th)
    # fallback NE clamp
    rx = min(max(cx + BASE_OFFSET_PX, 0), W - 8)
    (tw, th), _ = cv2.getTextSize(text, font, FONT_SCALE, LABEL_THICKNESS)
    ry = min(max(cy - BASE_OFFSET_PX - th, 0), H - th - 1)
    return (rx, ry, tw, th)

placed = []
centroids_and_labels = []  # only for labeled ones

# label ONLY every 20th crystal (20, 40, 60, ...)
for idx, c in enumerate(crystal_ctrs, start=1):
    if idx % LABEL_EVERY_N != 0:
        continue  # skip labeling this crystal
    M = cv2.moments(c)
    if M["m00"] > 0:
        cx = int(M["m10"] / M["m00"]); cy = int(M["m01"] / M["m00"])
    else:
        cx, cy = c[0,0,0], c[0,0,1]
    label = str(idx)
    rx, ry, tw, th = find_label_pos(cx, cy, label, placed)
    cv2.putText(composite, label, (rx, ry + th - 2), font, FONT_SCALE, (0,0,0), LABEL_THICKNESS, cv2.LINE_AA)
    placed.append((rx, ry, tw, th))
    centroids_and_labels.append(((cx, cy), (rx, ry, tw, th)))

# leader lines for far labels (only for labeled subset)
LEADER_IF_FAR = int(minHW * LEADER_IF_FAR_MUL)
for (cx, cy), (rx, ry, tw, th) in centroids_and_labels:
    lx = rx + tw // 2; ly = ry + th // 2
    dist = int(np.hypot(lx - cx, ly - cy))
    if dist > LEADER_IF_FAR:
        cv2.line(composite, (cx, cy), (lx, ly), (0,0,0), 1, lineType=cv2.LINE_AA)

n_crystals_total   = len(crystal_ctrs)
n_crystals_labeled = len(centroids_and_labels)

# ---------------- 0.25 mm grid on composite ----------------
grid_img = composite.copy()
px_per_mm = 1.0 / (um_per_px / 1000.0)
step      = GRID_MM * px_per_mm

GRID_COLOR   = (110, 110, 110)
GRID_THICK   = 2
BORDER_COLOR = (0, 0, 0)
BORDER_THICK = 2

x_pos = 0.0
while x_pos <= W:
    x = int(round(x_pos))
    cv2.line(grid_img, (x, 0), (x, H-1), GRID_COLOR, GRID_THICK, lineType=cv2.LINE_AA)
    x_pos += step
y_pos = 0.0
while y_pos <= H:
    y = int(round(y_pos))
    cv2.line(grid_img, (0, y), (W-1, y), GRID_COLOR, GRID_THICK, lineType=cv2.LINE_AA)
    y_pos += step
cv2.rectangle(grid_img, (0,0), (W-1, H-1), BORDER_COLOR, BORDER_THICK, lineType=cv2.LINE_AA)

# ---------------- Scale bar (EXACTLY like 5-panel storyboard) ----------------
SCALE_UM      = 500
BAR_THICK_PX  = 35
MARGIN_PX     = 100
FONT_SB       = cv2.FONT_HERSHEY_SIMPLEX
FONT_SCALE_SB = 4
FONT_THICK_SB = 13

def add_scale_bar(gray_img, um_per_px, scale_um=SCALE_UM,
                  bar_thick_px=BAR_THICK_PX, margin_px=MARGIN_PX):
    """Return RGB copy of gray_img with a 500 um white bar + '500 um' white label (black outline)."""
    rgb = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB)
    h, w = gray_img.shape[:2]
    bar_len_px = int(round(scale_um / um_per_px))
    bar_len_px = min(bar_len_px, w - 2*margin_px - 1)

    # Bar rectangle near bottom-right
    x2 = w - margin_px
    x1 = x2 - bar_len_px
    y2 = h - margin_px
    y1 = y2 - bar_thick_px

    # Black outline (shadow), then white fill
    outline_pad = 2
    cv2.rectangle(rgb, (x1 - outline_pad, y1 - outline_pad),
                  (x2 + outline_pad, y2 + outline_pad), (0,0,0), -1)
    cv2.rectangle(rgb, (x1, y1), (x2, y2), (255,255,255), -1)

    # White label with black outline, centered above bar
    label = f"{scale_um:.0f} um"  # exactly "um" as in your 5-panel code
    (tw, th), _ = cv2.getTextSize(label, FONT_SB, FONT_SCALE_SB, FONT_THICK_SB)
    tx = x1 + (bar_len_px - tw) // 2
    ty = y1 - 8
    cv2.putText(rgb, label, (tx+2, ty+2), FONT_SB, FONT_SCALE_SB, (0,0,0), FONT_THICK_SB+2, cv2.LINE_AA)
    cv2.putText(rgb, label, (tx,   ty),   FONT_SB, FONT_SCALE_SB, (255,255,255),   FONT_THICK_SB,   cv2.LINE_AA)
    return rgb

# ---------------- High-res display (no saving) ----------------
SCALE_UP = 2.0
grid_img_hr = cv2.resize(grid_img, (int(W*SCALE_UP), int(H*SCALE_UP)), interpolation=cv2.INTER_CUBIC)

# Build scale-bar versions of the raw images (500 um), keep RGB
C1_with_bar = add_scale_bar(C1_raw, um_per_px, SCALE_UM)
C2_with_bar = add_scale_bar(C2_raw, um_per_px, SCALE_UM)

fig, axes = plt.subplots(1, 3, figsize=(16,5), dpi=300)

axes[0].imshow(C1_with_bar)              # already RGB
axes[0].set_title("C1 Raw"); axes[0].axis("off")

axes[1].imshow(C2_with_bar)              # already RGB
axes[1].set_title("C2 Raw"); axes[1].axis("off")

axes[2].imshow(grid_img_hr, interpolation="nearest")
axes[2].set_title(
    f"Composite + 0.25 mm Grid | N_total={n_crystals_total} | labeled every {LABEL_EVERY_N}th ‚Üí {n_crystals_labeled}"
)
axes[2].axis("off")

plt.suptitle(f"{base_key}  |  {W}√ó{H}px  |  set4 FOV: {FIELD_MM:.3f} mm (~{um_per_px:.3f} ¬µm/px)",
             fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()


In [None]:
!pip install xlsxwriter


In [None]:
!pip install xlsxwriter


In [None]:
# ====================== Set4 ‚Üí crystal metrics ‚Üí Excel workbook ======================
# Outputs:
#   /content/drive/MyDrive/set4_crystal_metrics.csv
#   /content/drive/MyDrive/set4_metrics.xlsx
# ====================================================================================
import os, glob, re, cv2, numpy as np, pandas as pd
from skimage.measure import label, regionprops_table

# ---------------- PATHS ----------------
ROOT         = "data/raw_images"              # parent folder with short1/short2/tall1/tall2
C1_MASK_DIR  = "data/raw_images"    # AI-generated C1 masks
CSV_OUT      = "data/raw_images"
XLSX_OUT     = "data/raw_images"

# ---------------- Imaging calibration -------------------
FIELD_UM   = 1331.2
IMG_PX     = 2048
um_per_px  = FIELD_UM / IMG_PX
mm2_per_px = (um_per_px / 1000.0) ** 2

# ---------------- Thresholding --------------------------
C2_THRESH_MODE = "fixed"   # "fixed", "percentile", or "otsu+"
C2_FIXED       = 94        # fixed cutoff
C2_PERCENTILE  = 94
C2_DELTA       = 15
BLUR_KSIZE     = (3,3)
MORPH_KERNEL   = (3,3)

# ---------------- Helpers ----------------
def parse_thr_from_name(path):
    m = re.search(r"_mask_(\d+(?:\.\d+)?)", os.path.basename(path))
    return float(m.group(1)) if m else None

def pick_best_mask(candidates, thr_target=None):
    if not candidates:
        return None
    if thr_target is not None:
        candidates = sorted(
            candidates,
            key=lambda p: (abs((parse_thr_from_name(p) or 999.0) - thr_target), -os.path.getmtime(p))
        )
    else:
        candidates = sorted(candidates, key=lambda p: -os.path.getmtime(p))
    return candidates[0]

def find_mask_for_c2(c2_path, thr_target=None):
    stem = os.path.splitext(os.path.basename(c2_path))[0]   # e.g., "2_1_c_C2"
    base_key = stem.rsplit("_C2", 1)[0]                     # -> "2_1_c"
    rel_path = os.path.relpath(c2_path, ROOT)
    rel_dir  = os.path.dirname(rel_path)                    # e.g., "short2"

    # 1) Same relative directory
    candidates = glob.glob(os.path.join(C1_MASK_DIR, rel_dir, f"{base_key}_C1_mask_*.png"))
    mpath = pick_best_mask(candidates, thr_target)
    if mpath: return base_key, mpath

    # 2) Same top-level folder
    top = rel_dir.split(os.sep)[0] if rel_dir else ""
    if top:
        candidates = glob.glob(os.path.join(C1_MASK_DIR, top, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
        mpath = pick_best_mask(candidates, thr_target)
        if mpath: return base_key, mpath

    # 3) Anywhere
    candidates = glob.glob(os.path.join(C1_MASK_DIR, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
    mpath = pick_best_mask(candidates, thr_target)
    if mpath: return base_key, mpath

    return base_key, None

def binarize_c2(C2_raw, nonpillar_mask):
    blur = cv2.GaussianBlur(C2_raw, BLUR_KSIZE, 0)
    if C2_THRESH_MODE == "fixed":
        _, thr = cv2.threshold(blur, C2_FIXED, 255, cv2.THRESH_BINARY)
    elif C2_THRESH_MODE == "percentile":
        vals = blur[nonpillar_mask > 0]
        thr_val = int(np.percentile(vals, C2_PERCENTILE)) if vals.size else 255
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)
    else:  # "otsu+"
        _, t0 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        thr_val = int(max(0, min(255, t0 + C2_DELTA)))
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    thr[nonpillar_mask == 0] = 0
    k = np.ones(MORPH_KERNEL, np.uint8)
    thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN,  k, iterations=1)
    thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, k, iterations=1)
    return thr

# ---------------- Collect all C2 images ----------------
c2_files = sorted(glob.glob(os.path.join(ROOT, "**", "*_C2.png"), recursive=True))
print(f"Found {len(c2_files)} C2 images")

# ---------------- Process all images -------------------
rows = []
for c2_path in c2_files:
    base_key, mpath = find_mask_for_c2(c2_path)
    if not mpath:
        print(f"‚ö†Ô∏è No mask found for {base_key}")
        continue

    C2 = cv2.imread(c2_path, cv2.IMREAD_GRAYSCALE)
    C1_mask = cv2.imread(mpath, cv2.IMREAD_GRAYSCALE)
    if C2 is None or C1_mask is None:
        print(f"‚ö†Ô∏è Unreadable for {base_key}")
        continue
    if C1_mask.shape != C2.shape:
        C1_mask = cv2.resize(C1_mask, (C2.shape[1], C2.shape[0]), interpolation=cv2.INTER_NEAREST)

    total_pixels     = C1_mask.size
    pillar_pixels    = int((C1_mask > 0).sum())
    nonpillar_pixels = total_pixels - pillar_pixels
    channel          = (C1_mask == 0).astype(np.uint8)

    precip = binarize_c2(C2, channel)

    crystal_pixels           = int((precip > 0).sum())
    pct_pillar               = 100.0 * pillar_pixels / total_pixels
    pct_nonpillar            = 100.0 * nonpillar_pixels / total_pixels
    pct_crystal_total        = 100.0 * crystal_pixels / total_pixels
    pct_crystal_nonpillar    = 100.0 * crystal_pixels / max(1, nonpillar_pixels)

    # Connected components
    from skimage.measure import regionprops_table, label
    lab = label(precip > 0, connectivity=2)
    props = regionprops_table(lab, properties=("area",))
    areas_px = np.array(props["area"], dtype=float) if len(props["area"]) else np.array([])
    n = len(areas_px)
    mean_px = areas_px.mean() if n > 0 else 0.0
    std_px  = areas_px.std(ddof=1) if n > 1 else 0.0

    rows.append({
        "file": os.path.basename(c2_path),
        "folder": os.path.basename(os.path.dirname(c2_path)),
        "base": base_key,
        "n_crystals": n,
        "mean_area_px": mean_px,
        "std_area_px": std_px,
        "total_area_px": areas_px.sum(),
        "mean_area_mm2": mean_px * mm2_per_px,
        "std_area_mm2": std_px * mm2_per_px,
        "total_area_mm2": areas_px.sum() * mm2_per_px,
        "pillar_area_px": pillar_pixels,
        "nonpillar_area_px": nonpillar_pixels,
        "percent_pillar_area": pct_pillar,
        "percent_nonpillar_area": pct_nonpillar,
        "percent_crystal_total": pct_crystal_total,
        "percent_crystal_in_nonpillar": pct_crystal_nonpillar
    })

# ---------------- Save CSV + Excel ----------------------
df = pd.DataFrame(rows)
df.to_csv(CSV_OUT, index=False)
print(f"‚úÖ Saved CSV: {CSV_OUT}")

with pd.ExcelWriter(XLSX_OUT, engine="xlsxwriter") as writer:
    df.to_excel(writer, sheet_name="metrics", index=False)
print(f"üìó Excel workbook written to: {XLSX_OUT}")


In [None]:
# === Save set4 crystal metrics ‚Üí Excel with porosity, pillar diameter, and channel label ===
import os, re, pandas as pd

CSV_OUT  = "data/raw_images"
XLSX_OUT = "data/raw_images"

# Load CSV from previous step
df = pd.read_csv(CSV_OUT)

# --- Parse filename fields ---
# Expect filenames like: "short1/2_1_c_C2.png" or "tall2/3_0.5_a_C2.png"
pat = re.compile(r"(?P<section>\d+)_(?P<pillar>[\d.]+)_(?P<trial>[a-z])", re.IGNORECASE)

sections, pillars, trials = [], [], []
for f in df["file"]:
    m = pat.search(f)
    if m:
        sections.append(int(m.group("section")))
        pillars.append(float(m.group("pillar")))
        trials.append(m.group("trial").lower())
    else:
        sections.append(None); pillars.append(None); trials.append(None)

df["section"] = sections
df["pillar_diameter_mm"] = pillars
df["trial"] = trials

# --- Porosity mapping: odd = 0.35, even = 0.45 ---
def section_to_porosity(sec):
    if pd.isna(sec): return None
    return 0.35 if sec % 2 == 1 else 0.45

df["porosity"] = df["section"].apply(section_to_porosity)

# --- Channel label (short1, tall2, etc.) from file path ---
df["channel_type"] = df["file"].apply(
    lambda s: str(s).split(os.sep)[0] if os.sep in str(s) else ""
)

# --- Save updated CSV and Excel ---
df.to_csv(CSV_OUT, index=False)
print(f"‚úÖ Updated CSV with porosity + pillar_diameter + channel_type: {CSV_OUT}")

with pd.ExcelWriter(XLSX_OUT, engine="openpyxl") as writer:   # use "xlsxwriter" if installed
    df.to_excel(writer, sheet_name="metrics", index=False)

print(f"üìó Excel workbook written to: {XLSX_OUT}")


C2_THRESH_MODE = "fixed"
C2_FIXED       = 140          # if mode == "fixed" (raise to pick less)
C2_PERCENTILE  = 94           # if mode == "percentile" (98‚Äì99 picks less)
C2_DELTA       = 15           # if mode == "otsu+" (Otsu + delta => stricter)

In [None]:
# === Storyboard generator for set4 with FORM + DIAM filter ===
import os, glob, re, cv2, random
import numpy as np
import matplotlib.pyplot as plt

# ---------------- User controls ----------------
FORM   = "tall"     # "short" or "tall"
DIAM   = 0.9        # pillar diameter to filter (float)
N_SHOW = 12         # number of storyboards to show
# -----------------------------------------------

ROOT        = "data/raw_images"
C2_FOLDER   = ROOT
C1_MASK_DIR = "data/raw_images"

# --- imaging scale ---
FIELD_UM   = 1331.2
IMG_PX     = 2048
um_per_px  = FIELD_UM / IMG_PX

# --- C2 binarization params ---
C2_THRESH_MODE = "fixed"   # "fixed", "percentile", or "otsu+"
C2_FIXED       = 140
C2_PERCENTILE  = 94
C2_DELTA       = 15

# --- regex parser for filenames: e.g., 3_1_a_C2.png ---
pat = re.compile(r"^(?P<section>\d+)_(?P<pillar>[\d.]+)_(?P<trial>[a-z])_C2\.png$", re.IGNORECASE)

def parse_pillar(fname):
    m = pat.match(fname)
    if not m: return None
    return float(m.group("pillar"))

# --- helpers ---
IMG_EXTS = (".png",".jpg",".jpeg",".tif",".tiff")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

def add_scale_bar(gray_img, um_per_px, scale_um=500):
    rgb = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB)
    h,w = gray_img.shape[:2]
    bar_len_px = int(round(scale_um / um_per_px))
    x2 = w - 100; x1 = x2 - bar_len_px
    y2 = h - 100; y1 = y2 - 35
    cv2.rectangle(rgb, (x1-2,y1-2), (x2+2,y2+2), (0,0,0), -1)
    cv2.rectangle(rgb, (x1,y1), (x2,y2), (255,255,255), -1)
    return rgb

def binarize_c2(C2_raw, nonpillar_mask):
    img = cv2.normalize(C2_raw, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    roi = cv2.bitwise_and(img, img, mask=(nonpillar_mask*255))
    blur = cv2.GaussianBlur(roi, (3,3), 0)
    if C2_THRESH_MODE=="fixed":
        _, thr = cv2.threshold(blur, C2_FIXED, 255, cv2.THRESH_BINARY)
    elif C2_THRESH_MODE=="percentile":
        vals = blur[nonpillar_mask>0]
        thr_val = int(np.percentile(vals, C2_PERCENTILE)) if vals.size else 255
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)
    else: # otsu+
        _, t0 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
        thr_val = int(max(0,min(255,t0+C2_DELTA)))
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)
    thr[nonpillar_mask==0] = 0
    k = np.ones((3,3),np.uint8)
    thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN,k)
    thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE,k)
    return thr

# --- gather all matching C2 images ---
all_c2 = glob.glob(os.path.join(C2_FOLDER, FORM+"*", "*_C2.png"))
c2_files = []
for f in all_c2:
    pill = parse_pillar(os.path.basename(f))
    if pill is not None and abs(pill-DIAM)<1e-6:
        c2_files.append(f)
print(f"Found {len(c2_files)} {FORM} images with {DIAM} mm pillars")

if not c2_files:
    raise FileNotFoundError("No matching C2 images found!")

random.shuffle(c2_files)
for c2_path in c2_files[:N_SHOW]:
    base_key = os.path.splitext(os.path.basename(c2_path))[0].rsplit("_C2",1)[0]

    # find C1 mask (same dir in set4_all_masks)
    rel_dir = os.path.relpath(os.path.dirname(c2_path), C2_FOLDER)
    mpath = glob.glob(os.path.join(C1_MASK_DIR, rel_dir, f"{base_key}_C1_mask_*.png"))
    if not mpath: continue
    mpath = sorted(mpath)[0]

    # load images
    C2_raw  = cv2.imread(c2_path, cv2.IMREAD_GRAYSCALE)
    C1_mask = cv2.imread(mpath, cv2.IMREAD_GRAYSCALE)
    if C1_mask.shape!=C2_raw.shape:
        C1_mask = cv2.resize(C1_mask, (C2_raw.shape[1],C2_raw.shape[0]), interpolation=cv2.INTER_NEAREST)

    # metrics
    total_px     = C1_mask.size
    pillar_px    = int((C1_mask>0).sum())
    nonpillar_px = total_px - pillar_px
    channel = (C1_mask==0).astype(np.uint8)
    precip = binarize_c2(C2_raw, channel)
    crystal_px = int((precip>0).sum())
    pillar_pct = 100*pillar_px/total_px
    nonpillar_pct = 100*nonpillar_px/total_px
    crystal_pct = 100*crystal_px/total_px
    crystal_nonpillar_pct = 100*crystal_px/max(1,nonpillar_px)

    # overlays
    proc_rgb = cv2.cvtColor(C2_raw, cv2.COLOR_GRAY2RGB)
    proc_rgb[precip>0]=(255,0,0)
    proc_rgb[C1_mask>0]=(0,0,0)
    overlay=np.zeros_like(proc_rgb)
    contours,_=cv2.findContours((C1_mask>0).astype(np.uint8),cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(overlay,contours,-1,(255,255,0),2)
    proc_with_outline=cv2.addWeighted(proc_rgb,1.0,overlay,1.0,0)
    final_rgb=np.zeros_like(proc_rgb)
    final_rgb[C1_mask>0]=(0,0,0)
    final_rgb[channel>0]=(255,255,255)
    final_rgb[precip>0]=(255,0,0)

    # scale bars
    C1_with_bar=add_scale_bar(C1_mask, um_per_px)
    C2_with_bar=add_scale_bar(C2_raw, um_per_px)

    # plot
    fig,axes=plt.subplots(1,5,figsize=(22,5))
    axes[0].imshow(C1_with_bar); axes[0].set_title("C1 Brightfield"); axes[0].axis("off")
    axes[1].imshow(C1_mask,cmap="gray"); axes[1].set_title(f"Pillar Mask\nNon-pillar:{nonpillar_pct:.1f}%"); axes[1].axis("off")
    axes[2].imshow(C2_with_bar); axes[2].set_title("C2 TL-POL"); axes[2].axis("off")
    axes[3].imshow(proc_with_outline); axes[3].set_title("C2 Processed"); axes[3].axis("off")
    axes[4].imshow(final_rgb); axes[4].set_title(
        f"Pillar:{pillar_pct:.1f}% | Non-pillar:{nonpillar_pct:.1f}%\n"
        f"Crystal:{crystal_pct:.1f}% | Crystal/Non-pillar:{crystal_nonpillar_pct:.1f}%"
    ); axes[4].axis("off")
    plt.suptitle(f"{FORM}, {DIAM} mm pillars | {os.path.basename(c2_path)}", fontsize=14, fontweight="bold")
    plt.show()


In [None]:
# === Storyboard for random set4 C2 image with its matched C1 mask + metrics
#     (robust mask matching + conservative C2 thresholding) ===
import os, glob, cv2, random, re
import numpy as np
import matplotlib.pyplot as plt

# ---------------- Paths ----------------
try:
    ROOT
except NameError:
    ROOT = "data/raw_images"  # set4 root with short1/short2/tall1/tall2
C2_FOLDER    = ROOT
C1_MASK_DIR  = "data/raw_images"  # AI masks (mirrors set4 structure)

# ---------------- Imaging scale ----------------
FIELD_UM   = 1331.2    # field of view (¬µm)
IMG_PX     = 2048      # pixels per side
um_per_px  = FIELD_UM / IMG_PX
mm2_per_px = (um_per_px / 1000.0) ** 2  # not used, kept for completeness

# ---------------- Scale bar style ----------------
SCALE_UM      = 500
BAR_THICK_PX  = 35
MARGIN_PX     = 100
FONT          = cv2.FONT_HERSHEY_SIMPLEX
FONT_SCALE    = 4
FONT_THICK    = 13

# ---------------- C2 binarization (more conservative) ----------------
# Choose ONE mode: "percentile" (recommended), "fixed", or "otsu+"
C2_THRESH_MODE = "fixed"
C2_FIXED       = 140          # if mode == "fixed" (raise to pick less)
C2_PERCENTILE  = 94        # if mode == "percentile" (98‚Äì99 picks less)
C2_DELTA       = 15           # if mode == "otsu+" (Otsu + delta => stricter)

# ---------------- Helpers ----------------
IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

def add_scale_bar(gray_img, um_per_px, scale_um=SCALE_UM,
                  bar_thick_px=BAR_THICK_PX, margin_px=MARGIN_PX):
    """Return RGB copy of gray_img with a scale bar."""
    rgb = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB)
    h, w = gray_img.shape[:2]
    bar_len_px = int(round(scale_um / um_per_px))
    bar_len_px = min(bar_len_px, w - 2*margin_px - 1)
    x2 = w - margin_px; x1 = x2 - bar_len_px
    y2 = h - margin_px; y1 = y2 - bar_thick_px
    outline_pad = 2
    cv2.rectangle(rgb, (x1 - outline_pad, y1 - outline_pad),
                  (x2 + outline_pad, y2 + outline_pad), (0,0,0), -1)
    cv2.rectangle(rgb, (x1, y1), (x2, y2), (255,255,255), -1)
    label = f"{scale_um:.0f} um"
    (tw, th), _ = cv2.getTextSize(label, FONT, FONT_SCALE, FONT_THICK)
    tx = x1 + (bar_len_px - tw) // 2
    ty = y1 - 8
    cv2.putText(rgb, label, (tx, ty), FONT, FONT_SCALE, (0,0,0), FONT_THICK+2, cv2.LINE_AA)
    cv2.putText(rgb, label, (tx, ty), FONT, FONT_SCALE, (255,255,255), FONT_THICK, cv2.LINE_AA)
    return rgb

# --- robust mask selection (prefer same folder; then same top-level; then global) ---
def parse_thr_from_name(path):
    m = re.search(r"_mask_(\d+(?:\.\d+)?)", os.path.basename(path))
    return float(m.group(1)) if m else None

# If you ran the tuner, best_thr may exist; prefer masks closest to it
THR_TARGET = None
try:
    THR_TARGET = float(best_thr)  # e.g., 0.47
except Exception:
    pass

def pick_best_mask(candidates):
    if not candidates:
        return None
    # Prefer closest threshold to THR_TARGET (if known), then newest file
    if THR_TARGET is not None:
        candidates = sorted(
            candidates,
            key=lambda p: (abs((parse_thr_from_name(p) or 999.0) - THR_TARGET), -os.path.getmtime(p))
        )
    else:
        candidates = sorted(candidates, key=lambda p: -os.path.getmtime(p))
    return candidates[0]

def find_mask_for_c2(c2_path):
    """Find matching C1 mask for given C2 path."""
    stem = os.path.splitext(os.path.basename(c2_path))[0]   # e.g., "3_1_a_C2"
    base_key = stem.rsplit("_C2", 1)[0]                     # -> "3_1_a"
    rel_path = os.path.relpath(c2_path, C2_FOLDER)
    rel_dir  = os.path.dirname(rel_path)                    # e.g., "short2" or nested

    # 1) Same relative directory under set4_all_masks/
    candidates = glob.glob(os.path.join(C1_MASK_DIR, rel_dir, f"{base_key}_C1_mask_*.png"))
    mpath = pick_best_mask(candidates)
    if mpath:
        return base_key, mpath, "same_dir"

    # 2) Same top-level (short1/short2/tall1/tall2)
    top = rel_dir.split(os.sep)[0] if rel_dir else ""
    if top:
        candidates = glob.glob(os.path.join(C1_MASK_DIR, top, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
        mpath = pick_best_mask(candidates)
        if mpath:
            return base_key, mpath, "same_top"

    # 3) Anywhere under set4_all_masks
    candidates = glob.glob(os.path.join(C1_MASK_DIR, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
    mpath = pick_best_mask(candidates)
    if mpath:
        return base_key, mpath, "global"

    return base_key, None, "not_found"

# --- C2 binarization (conservative) ---
def to_u8(img):
    if img.dtype == np.uint8:
        return img
    return cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def binarize_c2(C2_raw, nonpillar_mask):
    """Return binary crystal mask with conservative thresholding."""
    img = to_u8(C2_raw)
    roi = cv2.bitwise_and(img, img, mask=(nonpillar_mask * 255))
    blur = cv2.GaussianBlur(roi, (3,3), 0)

    if C2_THRESH_MODE == "fixed":
        thr_val = int(C2_FIXED)
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    elif C2_THRESH_MODE == "percentile":
        vals = blur[nonpillar_mask > 0]
        thr_val = int(np.percentile(vals, C2_PERCENTILE)) if vals.size else 255
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    else:  # "otsu+"
        _, t0 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        thr_val = int(max(0, min(255, t0 + C2_DELTA)))
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    thr[nonpillar_mask == 0] = 0
    k = np.ones((3,3), np.uint8)
    thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN,  k, iterations=1)
    thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, k, iterations=1)
    return thr

# ---------------- Find a C2 + matching C1 mask ----------------
c2_files = sorted([p for p in glob.glob(os.path.join(C2_FOLDER, "**", "*.*"), recursive=True)
                   if is_img(p)
                   and ("_c2" in os.path.basename(p).lower())
                   and ("_answer" not in os.path.basename(p).lower())])
if not c2_files:
    raise FileNotFoundError("No *_C2 images found under set4.")

c2_path = random.choice(c2_files)
base_key, mpath, match_mode = find_mask_for_c2(c2_path)
if not mpath:
    raise FileNotFoundError(f"No C1 mask found for base '{base_key}'")

# find raw C1 (same dir first, then global)
c2_dir = os.path.dirname(c2_path)
c1_path = None
for ext in IMG_EXTS:
    cand = os.path.join(c2_dir, f"{base_key}_C1{ext}")
    if os.path.isfile(cand):
        c1_path = cand; break
if c1_path is None:
    cands = []
    for ext in IMG_EXTS:
        cands += glob.glob(os.path.join(C2_FOLDER, "**", f"{base_key}_C1{ext}"), recursive=True)
    if cands: c1_path = sorted(cands)[0]
if not c1_path:
    raise FileNotFoundError(f"Missing raw C1 for base '{base_key}'")

print(f"Storyboard for: {base_key}")
print(f"Mask match mode: {match_mode} ‚Üí {os.path.relpath(mpath, C1_MASK_DIR)}")

# ---------------- Load images ----------------
C1_raw  = cv2.imread(c1_path, cv2.IMREAD_GRAYSCALE)
C2_raw  = cv2.imread(c2_path, cv2.IMREAD_GRAYSCALE)
C1_mask = cv2.imread(mpath, cv2.IMREAD_GRAYSCALE)
if C1_mask.shape != C2_raw.shape:
    C1_mask = cv2.resize(C1_mask, (C2_raw.shape[1], C2_raw.shape[0]), interpolation=cv2.INTER_NEAREST)

# ---------------- Metrics: pillar vs non-pillar ----------------
total_pixels     = C1_mask.size
pillar_pixels    = int((C1_mask > 0).sum())
nonpillar_pixels = total_pixels - pillar_pixels
pillar_pct       = 100.0 * pillar_pixels / total_pixels
nonpillar_pct    = 100.0 * nonpillar_pixels / total_pixels

# ---------------- C2 processing (conservative) ----------------
channel   = (C1_mask == 0).astype(np.uint8)
precip    = binarize_c2(C2_raw, channel)

# ---------------- Crystal metrics ----------------
crystal_pixels           = int((precip > 0).sum())
crystal_pct              = 100.0 * crystal_pixels / total_pixels
crystal_in_nonpillar_pct = 100.0 * crystal_pixels / max(1, nonpillar_pixels)

# ---------------- Overlays ----------------
proc_rgb = cv2.cvtColor(C2_raw, cv2.COLOR_GRAY2RGB)
proc_rgb[precip > 0] = (255,0,0)             # crystals in red
proc_rgb[C1_mask > 0] = (0,0,0)              # pillars black for contrast

overlay = np.zeros_like(proc_rgb)
contours, _ = cv2.findContours((C1_mask>0).astype(np.uint8),
                               cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(overlay, contours, -1, (255,255,0), 2)   # pillar outlines in yellow
proc_with_outline = cv2.addWeighted(proc_rgb, 1.0, overlay, 1.0, 0)

final_rgb = np.zeros_like(proc_rgb)
final_rgb[C1_mask > 0] = (0,0,0)             # pillars
final_rgb[channel > 0] = (255,255,255)       # non-pillar background
final_rgb[precip > 0]  = (255,0,0)           # crystals

# ---------------- Add scale bars ----------------
C1_with_bar = add_scale_bar(C1_raw, um_per_px, SCALE_UM)
C2_with_bar = add_scale_bar(C2_raw, um_per_px, SCALE_UM)

# ---------------- Plot storyboard ----------------
fig, axes = plt.subplots(1,5, figsize=(22,5))

axes[0].imshow(C1_with_bar); axes[0].set_title("C1 Brightfield (raw + scale bar)"); axes[0].axis("off")
axes[1].imshow(C1_mask, cmap="gray"); axes[1].set_title(f"Pillar Mask\nNon-pillar: {nonpillar_pct:.1f}%"); axes[1].axis("off")
axes[2].imshow(C2_with_bar); axes[2].set_title("C2 TL-POL (raw + scale bar)"); axes[2].axis("off")
axes[3].imshow(proc_with_outline); axes[3].set_title("C2 Processed\n+ Pillar Outlines"); axes[3].axis("off")
axes[4].imshow(final_rgb); axes[4].set_title(
    f"Final Composite\n"
    f"Pillar: {pillar_pct:.1f}% | Non-pillar: {nonpillar_pct:.1f}%\n"
    f"Crystal: {crystal_pct:.1f}% | Crystal/Non-pillar: {crystal_in_nonpillar_pct:.1f}%"
); axes[4].axis("off")

plt.suptitle(f"{base_key} | 2048√ó2048 px | {FIELD_UM:.1f} ¬µm FOV (~{um_per_px:.2f} ¬µm/px)",
             fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()


In [None]:
# === Storyboard for random set4 C2 image with its matched C1 mask + metrics
#     (robust mask matching + conservative C2 thresholding) ===
import os, glob, cv2, random, re
import numpy as np
import matplotlib.pyplot as plt

# ---------------- Paths ----------------
try:
    ROOT
except NameError:
    ROOT = "data/raw_images"  # set4 root with short1/short2/tall1/tall2
C2_FOLDER    = ROOT
C1_MASK_DIR  = "data/raw_images"  # AI masks (mirrors set4 structure)

# ---------------- Imaging scale ----------------
FIELD_UM   = 1331.2    # field of view (¬µm)
IMG_PX     = 2048      # pixels per side
um_per_px  = FIELD_UM / IMG_PX
mm2_per_px = (um_per_px / 1000.0) ** 2  # not used, kept for completeness

# ---------------- Scale bar style ----------------
SCALE_UM      = 500
BAR_THICK_PX  = 35
MARGIN_PX     = 100
FONT          = cv2.FONT_HERSHEY_SIMPLEX
FONT_SCALE    = 4
FONT_THICK    = 13

# ---------------- C2 binarization (more conservative) ----------------
# Choose ONE mode: "percentile" (recommended), "fixed", or "otsu+"
C2_THRESH_MODE = "fixed"
C2_FIXED       = 140          # if mode == "fixed" (raise to pick less)
C2_PERCENTILE  = 94        # if mode == "percentile" (98‚Äì99 picks less)
C2_DELTA       = 15           # if mode == "otsu+" (Otsu + delta => stricter)

# ---------------- Helpers ----------------
IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

def add_scale_bar(gray_img, um_per_px, scale_um=SCALE_UM,
                  bar_thick_px=BAR_THICK_PX, margin_px=MARGIN_PX):
    """Return RGB copy of gray_img with a scale bar."""
    rgb = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB)
    h, w = gray_img.shape[:2]
    bar_len_px = int(round(scale_um / um_per_px))
    bar_len_px = min(bar_len_px, w - 2*margin_px - 1)
    x2 = w - margin_px; x1 = x2 - bar_len_px
    y2 = h - margin_px; y1 = y2 - bar_thick_px
    outline_pad = 2
    cv2.rectangle(rgb, (x1 - outline_pad, y1 - outline_pad),
                  (x2 + outline_pad, y2 + outline_pad), (0,0,0), -1)
    cv2.rectangle(rgb, (x1, y1), (x2, y2), (255,255,255), -1)
    label = f"{scale_um:.0f} um"
    (tw, th), _ = cv2.getTextSize(label, FONT, FONT_SCALE, FONT_THICK)
    tx = x1 + (bar_len_px - tw) // 2
    ty = y1 - 8
    cv2.putText(rgb, label, (tx, ty), FONT, FONT_SCALE, (0,0,0), FONT_THICK+2, cv2.LINE_AA)
    cv2.putText(rgb, label, (tx, ty), FONT, FONT_SCALE, (255,255,255), FONT_THICK, cv2.LINE_AA)
    return rgb

# --- robust mask selection (prefer same folder; then same top-level; then global) ---
def parse_thr_from_name(path):
    m = re.search(r"_mask_(\d+(?:\.\d+)?)", os.path.basename(path))
    return float(m.group(1)) if m else None

# If you ran the tuner, best_thr may exist; prefer masks closest to it
THR_TARGET = None
try:
    THR_TARGET = float(best_thr)  # e.g., 0.47
except Exception:
    pass

def pick_best_mask(candidates):
    if not candidates:
        return None
    # Prefer closest threshold to THR_TARGET (if known), then newest file
    if THR_TARGET is not None:
        candidates = sorted(
            candidates,
            key=lambda p: (abs((parse_thr_from_name(p) or 999.0) - THR_TARGET), -os.path.getmtime(p))
        )
    else:
        candidates = sorted(candidates, key=lambda p: -os.path.getmtime(p))
    return candidates[0]

def find_mask_for_c2(c2_path):
    """Find matching C1 mask for given C2 path."""
    stem = os.path.splitext(os.path.basename(c2_path))[0]   # e.g., "3_1_a_C2"
    base_key = stem.rsplit("_C2", 1)[0]                     # -> "3_1_a"
    rel_path = os.path.relpath(c2_path, C2_FOLDER)
    rel_dir  = os.path.dirname(rel_path)                    # e.g., "short2" or nested

    # 1) Same relative directory under set4_all_masks/
    candidates = glob.glob(os.path.join(C1_MASK_DIR, rel_dir, f"{base_key}_C1_mask_*.png"))
    mpath = pick_best_mask(candidates)
    if mpath:
        return base_key, mpath, "same_dir"

    # 2) Same top-level (short1/short2/tall1/tall2)
    top = rel_dir.split(os.sep)[0] if rel_dir else ""
    if top:
        candidates = glob.glob(os.path.join(C1_MASK_DIR, top, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
        mpath = pick_best_mask(candidates)
        if mpath:
            return base_key, mpath, "same_top"

    # 3) Anywhere under set4_all_masks
    candidates = glob.glob(os.path.join(C1_MASK_DIR, "**", f"{base_key}_C1_mask_*.png"), recursive=True)
    mpath = pick_best_mask(candidates)
    if mpath:
        return base_key, mpath, "global"

    return base_key, None, "not_found"

# --- C2 binarization (conservative) ---
def to_u8(img):
    if img.dtype == np.uint8:
        return img
    return cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def binarize_c2(C2_raw, nonpillar_mask):
    """Return binary crystal mask with conservative thresholding."""
    img = to_u8(C2_raw)
    roi = cv2.bitwise_and(img, img, mask=(nonpillar_mask * 255))
    blur = cv2.GaussianBlur(roi, (3,3), 0)

    if C2_THRESH_MODE == "fixed":
        thr_val = int(C2_FIXED)
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    elif C2_THRESH_MODE == "percentile":
        vals = blur[nonpillar_mask > 0]
        thr_val = int(np.percentile(vals, C2_PERCENTILE)) if vals.size else 255
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    else:  # "otsu+"
        _, t0 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        thr_val = int(max(0, min(255, t0 + C2_DELTA)))
        _, thr = cv2.threshold(blur, thr_val, 255, cv2.THRESH_BINARY)

    thr[nonpillar_mask == 0] = 0
    k = np.ones((3,3), np.uint8)
    thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN,  k, iterations=1)
    thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, k, iterations=1)
    return thr

# ---------------- Find a C2 + matching C1 mask ----------------
c2_files = sorted([p for p in glob.glob(os.path.join(C2_FOLDER, "**", "*.*"), recursive=True)
                   if is_img(p)
                   and ("_c2" in os.path.basename(p).lower())
                   and ("_answer" not in os.path.basename(p).lower())])
if not c2_files:
    raise FileNotFoundError("No *_C2 images found under set4.")

c2_path = random.choice(c2_files)
base_key, mpath, match_mode = find_mask_for_c2(c2_path)
if not mpath:
    raise FileNotFoundError(f"No C1 mask found for base '{base_key}'")

# find raw C1 (same dir first, then global)
c2_dir = os.path.dirname(c2_path)
c1_path = None
for ext in IMG_EXTS:
    cand = os.path.join(c2_dir, f"{base_key}_C1{ext}")
    if os.path.isfile(cand):
        c1_path = cand; break
if c1_path is None:
    cands = []
    for ext in IMG_EXTS:
        cands += glob.glob(os.path.join(C2_FOLDER, "**", f"{base_key}_C1{ext}"), recursive=True)
    if cands: c1_path = sorted(cands)[0]
if not c1_path:
    raise FileNotFoundError(f"Missing raw C1 for base '{base_key}'")

print(f"Storyboard for: {base_key}")
print(f"Mask match mode: {match_mode} ‚Üí {os.path.relpath(mpath, C1_MASK_DIR)}")

# ---------------- Load images ----------------
C1_raw  = cv2.imread(c1_path, cv2.IMREAD_GRAYSCALE)
C2_raw  = cv2.imread(c2_path, cv2.IMREAD_GRAYSCALE)
C1_mask = cv2.imread(mpath, cv2.IMREAD_GRAYSCALE)
if C1_mask.shape != C2_raw.shape:
    C1_mask = cv2.resize(C1_mask, (C2_raw.shape[1], C2_raw.shape[0]), interpolation=cv2.INTER_NEAREST)

# ---------------- Metrics: pillar vs non-pillar ----------------
total_pixels     = C1_mask.size
pillar_pixels    = int((C1_mask > 0).sum())
nonpillar_pixels = total_pixels - pillar_pixels
pillar_pct       = 100.0 * pillar_pixels / total_pixels
nonpillar_pct    = 100.0 * nonpillar_pixels / total_pixels

# ---------------- C2 processing (conservative) ----------------
channel   = (C1_mask == 0).astype(np.uint8)
precip    = binarize_c2(C2_raw, channel)

# ---------------- Crystal metrics ----------------
crystal_pixels           = int((precip > 0).sum())
crystal_pct              = 100.0 * crystal_pixels / total_pixels
crystal_in_nonpillar_pct = 100.0 * crystal_pixels / max(1, nonpillar_pixels)

# ---------------- Overlays ----------------
proc_rgb = cv2.cvtColor(C2_raw, cv2.COLOR_GRAY2RGB)
proc_rgb[precip > 0] = (255,0,0)             # crystals in red
proc_rgb[C1_mask > 0] = (0,0,0)              # pillars black for contrast

overlay = np.zeros_like(proc_rgb)
contours, _ = cv2.findContours((C1_mask>0).astype(np.uint8),
                               cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(overlay, contours, -1, (255,255,0), 2)   # pillar outlines in yellow
proc_with_outline = cv2.addWeighted(proc_rgb, 1.0, overlay, 1.0, 0)

final_rgb = np.zeros_like(proc_rgb)
final_rgb[C1_mask > 0] = (0,0,0)             # pillars
final_rgb[channel > 0] = (255,255,255)       # non-pillar background
final_rgb[precip > 0]  = (255,0,0)           # crystals

# ---------------- Add scale bars ----------------
C1_with_bar = add_scale_bar(C1_raw, um_per_px, SCALE_UM)
C2_with_bar = add_scale_bar(C2_raw, um_per_px, SCALE_UM)

# ---------------- Plot storyboard ----------------
fig, axes = plt.subplots(1,5, figsize=(22,5))

axes[0].imshow(C1_with_bar); axes[0].set_title("C1 Brightfield (raw + scale bar)"); axes[0].axis("off")
axes[1].imshow(C1_mask, cmap="gray"); axes[1].set_title(f"Pillar Mask\nNon-pillar: {nonpillar_pct:.1f}%"); axes[1].axis("off")
axes[2].imshow(C2_with_bar); axes[2].set_title("C2 TL-POL (raw + scale bar)"); axes[2].axis("off")
axes[3].imshow(proc_with_outline); axes[3].set_title("C2 Processed\n+ Pillar Outlines"); axes[3].axis("off")
axes[4].imshow(final_rgb); axes[4].set_title(
    f"Final Composite\n"
    f"Pillar: {pillar_pct:.1f}% | Non-pillar: {nonpillar_pct:.1f}%\n"
    f"Crystal: {crystal_pct:.1f}% | Crystal/Non-pillar: {crystal_in_nonpillar_pct:.1f}%"
); axes[4].axis("off")

plt.suptitle(f"{base_key} | 2048√ó2048 px | {FIELD_UM:.1f} ¬µm FOV (~{um_per_px:.2f} ¬µm/px)",
             fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()


# Part 2 ‚Äî Precipitation quantification (C2 TL-POL)


In [None]:
import re
from pathlib import Path

BASE = Path("data/raw_images")

# Your folders to count
folders = ["t1t2", "t3t4", "s2s3", "t5t6", "d"]

# Patterns
# Standard: anything ending in .png
png_pat = re.compile(r".*\.png$", re.IGNORECASE)

# For folder "d": only count files whose *base name* starts with d#.f#
# Examples counted:
#   d3.f7.c1.png
#   d10.f12.c2.png
# Examples ignored:
#   d3.p7.c1.png
#   d3.p7.c2.png
d_full_pat = re.compile(r"^d\d+\.f\d+(\.|$)", re.IGNORECASE)

def count_pngs_any(folder_path: Path) -> int:
    return len(list(folder_path.rglob("*.png")))

def count_d_full_only(folder_path: Path) -> int:
    n = 0
    for p in folder_path.rglob("*.png"):
        stem = p.stem  # filename without ".png"
        if d_full_pat.match(stem):
            n += 1
    return n

print(f"Base: {BASE}\n")

for name in folders:
    p = BASE / name
    if not p.exists():
        print(f"{name}: ‚ùå folder not found at {p}")
        continue

    if name == "d":
        n = count_d_full_only(p)
        print(f"{name}: {n} PNG files (counting ONLY d#.f#, excluding d#.p#)")
    else:
        n = count_pngs_any(p)
        print(f"{name}: {n} PNG files")


In [None]:
import random, re
from pathlib import Path

BASE = Path("data/raw_images")

# How many C1 images to sample per folder
PLAN = {
    "t1t2": 7,
    "t3t4": 7,
    "t5t6": 7,
    "s2s3": 4,
    "d":    5,   # only d#.f# (full), not d#.p#
}

SEED = 42  # change if you want a different random draw
random.seed(SEED)

c1_pat = re.compile(r".*\.c1\.png$", re.IGNORECASE)
d_full_pat = re.compile(r"^d\d+\.f\d+(\.|$)", re.IGNORECASE)

def get_c1_files(folder: str):
    p = BASE / folder
    files = [f for f in p.rglob("*.png") if c1_pat.match(f.name)]
    if folder == "d":
        files = [f for f in files if d_full_pat.match(f.stem)]  # keep only d#.f#
    return sorted(files)

picked = {}
all_picks = []

for folder, n in PLAN.items():
    files = get_c1_files(folder)
    if len(files) < n:
        raise RuntimeError(f"Not enough C1 files in {folder}: have {len(files)}, need {n}")
    sample = random.sample(files, n)
    picked[folder] = sorted(sample)
    all_picks.extend(sample)

print("Selected C1 images to annotate (30 total):\n")
for folder in PLAN:
    print(f"--- {folder} ({len(picked[folder])}) ---")
    for f in picked[folder]:
        print(f.relative_to(BASE))

# Save list to Drive for reference
out_txt = BASE / "mask_annotation_list_30C1.txt"
with open(out_txt, "w") as fp:
    fp.write(f"SEED={SEED}\n\n")
    for folder in PLAN:
        fp.write(f"--- {folder} ({len(picked[folder])}) ---\n")
        for f in picked[folder]:
            fp.write(str(f) + "\n")
        fp.write("\n")

print(f"\nSaved list to: {out_txt}")


In [None]:
import csv
from pathlib import Path
import cv2

BASE = Path("data/raw_images")
GT_ROOT = BASE / "gt_masks_30C1"   # where you will save your manual masks
GT_ROOT.mkdir(parents=True, exist_ok=True)

# ----- paste your chosen files here -----
selected_rel = [
    # t1t2 (7)
    "t1t2/t1.1.8.a.c1.png",
    "t1t2/t1.1.9.a.c1.png",
    "t1t2/t1.2.5.b.c1.png",
    "t1t2/t1.2.5.c.c1.png",
    "t1t2/t2.1.9.c.c1.png",
    "t1t2/t2.2.5.c.c1.png",
    "t1t2/t2.2.8.a.c1.png",
    # t3t4 (7)
    "t3t4/t3.1.8.c.c1.png",
    "t3t4/t3.1.9.a.c1.png",
    "t3t4/t3.1.9.b.c1.png",
    "t3t4/t4.1.8.c.c1.png",
    "t3t4/t4.2.5.a.c1.png",
    "t3t4/t4.2.5.c.c1.png",
    "t3t4/t4.2.9.b.c1.png",
    # t5t6 (7)
    "t5t6/t5.1.8.a.c1.png",
    "t5t6/t5.1.8.b.c1.png",
    "t5t6/t5.1.8.c.c1.png",
    "t5t6/t5.2.5.a.c1.png",
    "t5t6/t5.2.5.b.c1.png",
    "t5t6/t5.2.9.b.c1.png",
    "t5t6/t6.1.9.a.c1.png",
    # s2s3 (4)
    "s2s3/s2.1.8.a.c1.png",
    "s2s3/s2.2.9.a.c1.png",
    "s2s3/s3.1.8.c.c1.png",
    "s2s3/s3.2.9.b.c1.png",
    # d (5)
    "d/d3.f2.c1.png",
    "d/d6.f2.c1.png",
    "d/d6.f3.c1.png",
    "d/d7.f3.c1.png",
    "d/d8.f3.c1.png",
]

def expected_gt_path(rel_path: str) -> Path:
    """
    GT mask naming rule:
      raw: <folder>/<name>.c1.png
      gt : gt_masks_30C1/<folder>/<name>.c1_gt.png
    """
    rel = Path(rel_path)
    return GT_ROOT / rel.parent / (rel.stem + "_gt.png")

# Make GT subfolders and write mapping CSV
rows = []
for rel in selected_rel:
    raw = BASE / rel
    gt  = expected_gt_path(rel)
    gt.parent.mkdir(parents=True, exist_ok=True)
    rows.append((str(raw), str(gt)))

csv_path = GT_ROOT / "gt_mapping_30C1.csv"
with open(csv_path, "w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["raw_c1_path", "gt_mask_path"])
    w.writerows(rows)

print(f"GT root created: {GT_ROOT}")
print(f"Mapping CSV saved: {csv_path}")
print("\nExample:")
print(" RAW:", rows[0][0])
print("  GT:", rows[0][1])

# ---- checker you can rerun anytime ----
def check_gt_masks():
    missing = []
    wrong_size = []
    ok = 0

    for raw_str, gt_str in rows:
        raw_p = Path(raw_str)
        gt_p  = Path(gt_str)

        if not raw_p.exists():
            missing.append((raw_p, gt_p, "RAW_MISSING"))
            continue

        if not gt_p.exists():
            missing.append((raw_p, gt_p, "GT_MISSING"))
            continue

        raw = cv2.imread(str(raw_p), cv2.IMREAD_COLOR)
        gt  = cv2.imread(str(gt_p), cv2.IMREAD_GRAYSCALE)

        if raw is None or gt is None:
            missing.append((raw_p, gt_p, "UNREADABLE"))
            continue

        if gt.shape[:2] != raw.shape[:2]:
            wrong_size.append((raw_p, gt_p, raw.shape[:2], gt.shape[:2]))
            continue

        ok += 1

    print(f"\nStatus: {ok}/{len(rows)} GT masks present + readable + correct size")
    if missing:
        print("\nMissing / unreadable:")
        for raw_p, gt_p, reason in missing:
            print(f" - {reason}: {gt_p.name}   (raw: {raw_p.name})")
    if wrong_size:
        print("\nWrong size:")
        for raw_p, gt_p, raw_hw, gt_hw in wrong_size:
            print(f" - {gt_p.name}: raw={raw_hw} vs gt={gt_hw}")

# run once now:
check_gt_masks()


In [None]:
all_png = sorted(glob.glob(os.path.join(C20_DIR, "*.png")))
gt_paths  = [p for p in all_png if os.path.basename(p).lower().endswith("_gt.png")]
raw_paths = [p for p in all_png if not os.path.basename(p).lower().endswith("_gt.png")]

raw_map = {os.path.splitext(os.path.basename(p))[0]: p for p in raw_paths}

pairs = []
missing = []
for gp in gt_paths:
    base = os.path.splitext(os.path.basename(gp))[0].replace("_gt", "")
    rp = raw_map.get(base, None)
    if rp is None:
        missing.append(os.path.basename(gp))
    else:
        pairs.append((rp, gp))

pairs = sorted(pairs, key=lambda x: os.path.basename(x[0]).lower())

print("Raw:", len(raw_paths))
print("GT :", len(gt_paths))
print("Pairs:", len(pairs))
if missing:
    print("GT without matching raw (first 10):", missing[:10])

assert len(pairs) > 0, "No pairs found. Check that filenames are X.png and X_gt.png."


In [None]:
def red_to_binary(mask_bgr):
    hsv = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2HSV)
    lower1, upper1 = np.array([0,70,40]),   np.array([12,255,255])
    lower2, upper2 = np.array([170,70,40]), np.array([180,255,255])
    m = cv2.inRange(hsv, lower1, upper1) | cv2.inRange(hsv, lower2, upper2)
    k = np.ones((5,5), np.uint8)
    m = cv2.morphologyEx(m, cv2.MORPH_OPEN,  k, iterations=1)
    m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, k, iterations=1)
    return (m > 0).astype(np.uint8)

def resize_short(im, short=SHORT_SIDE):
    h, w = im.shape[:2]
    s = short / min(h, w)
    return cv2.resize(im, (int(round(w*s)), int(round(h*s))), interpolation=cv2.INTER_AREA)

def overlay_green(bgr, mask01, alpha=0.30):
    out = bgr.copy().astype(np.float32)
    green = np.array([0,255,0], dtype=np.float32)
    idx = mask01.astype(bool)
    out[idx] = (1-alpha)*out[idx] + alpha*green
    return out.astype(np.uint8)


In [None]:
model = smp.Unet(encoder_name="resnet18", encoder_weights=None, in_channels=3, classes=1)
state = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state)
model.to(DEVICE).eval()
print("Loaded model:", MODEL_PATH)


In [None]:
def infer_prob_rs(raw_path):
    bgr = cv2.imread(raw_path)
    if bgr is None:
        raise ValueError(f"Unreadable raw: {raw_path}")
    bgr_rs = resize_short(bgr)
    rgb_rs = cv2.cvtColor(bgr_rs, cv2.COLOR_BGR2RGB)
    ten = ToTensorV2()(image=rgb_rs)["image"].unsqueeze(0).to(DEVICE).float()
    with torch.no_grad():
        prob = torch.sigmoid(model(ten))[0,0].cpu().numpy().astype(np.float32)
    return bgr_rs, prob

cache = {}
for rp, gp in pairs:
    base = os.path.splitext(os.path.basename(rp))[0]
    bgr_rs, prob = infer_prob_rs(rp)

    gt_bgr = cv2.imread(gp)
    if gt_bgr is None:
        raise ValueError(f"Unreadable GT: {gp}")
    gt01 = red_to_binary(gt_bgr)
    gt_rs = cv2.resize(gt01, (prob.shape[1], prob.shape[0]), interpolation=cv2.INTER_NEAREST).astype(np.uint8)

    cache[base] = {"bgr_rs": bgr_rs, "prob": prob, "gt": gt_rs}

print("Cached images:", len(cache))


In [None]:
def metrics_np(pred01, gt01, eps=1e-9):
    tp = np.logical_and(pred01==1, gt01==1).sum()
    fp = np.logical_and(pred01==1, gt01==0).sum()
    fn = np.logical_and(pred01==0, gt01==1).sum()
    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)
    dice      = (2*tp) / (2*tp + fp + fn + eps)
    iou       = tp / (tp + fp + fn + eps)
    return dice, iou, precision, recall

# Sweep range (adjust if you want)
thresholds = np.linspace(0.20, 0.70, 26)

rows = []  # thr, dice, iou, prec, rec (means)
per_image_iou = {}  # thr -> dict(base->iou) for later "best/worst" examples

for thr in thresholds:
    ms = []
    iou_map = {}
    for base, d in cache.items():
        pred01 = (d["prob"] > thr).astype(np.uint8)
        dice, iou, prec, rec = metrics_np(pred01, d["gt"])
        ms.append([dice, iou, prec, rec])
        iou_map[base] = iou
    per_image_iou[float(thr)] = iou_map
    m = np.mean(ms, axis=0)
    rows.append([float(thr), float(m[0]), float(m[1]), float(m[2]), float(m[3])])

rows = np.array(rows, dtype=float)
thr = rows[:,0]; dice = rows[:,1]; iou = rows[:,2]; prec = rows[:,3]; rec = rows[:,4]

# Choose selection criterion
CRITERION = "iou"  # "iou" or "dice"
best_idx = int(np.argmax(iou if CRITERION=="iou" else dice))
best_thr = float(thr[best_idx])

print(f"Best threshold by mean {CRITERION.upper()}: {best_thr:.2f}")
print(f"Mean Dice={dice[best_idx]:.3f} | Mean IoU={iou[best_idx]:.3f} | Prec={prec[best_idx]:.3f} | Rec={rec[best_idx]:.3f}")

# Save CSV
csv_path = os.path.join(OUT_DIR, "threshold_sweep_metrics.csv")
with open(csv_path, "w") as f:
    f.write("thr,mean_dice,mean_iou,mean_precision,mean_recall\n")
    for r in rows:
        f.write(f"{r[0]:.4f},{r[1]:.6f},{r[2]:.6f},{r[3]:.6f},{r[4]:.6f}\n")
print("Saved:", csv_path)

# Save plot
plt.figure()
plt.plot(thr, iou,  label="IoU")
plt.plot(thr, dice, label="Dice")
plt.plot(thr, prec, label="Precision")
plt.plot(thr, rec,  label="Recall")
plt.axvline(best_thr, linestyle="--", label=f"Selected thr={best_thr:.2f}")
plt.xlabel("Threshold")
plt.ylabel("Mean metric (n=20)")
plt.title("Threshold sweep (GT extracted from annotated images)")
plt.legend()
fig_path = os.path.join(OUT_DIR, "threshold_sweep_plot.png")
plt.savefig(fig_path, dpi=300, bbox_inches="tight")
plt.show()
print("Saved:", fig_path)


In [None]:
import os, random, cv2, time
import numpy as np
import matplotlib.pyplot as plt

# -------- settings --------
THR_SHOW = 0.50
N_SHOW   = 8
ALPHA    = 0.30
# --------------------------

# NEW: seed changes every run
SEED = int(time.time())
random.seed(SEED)
print("Using SEED:", SEED)

# Build base -> gt_path
gt_by_base = {}
for rp, gp in pairs:
    base = os.path.splitext(os.path.basename(rp))[0]
    gt_by_base[base] = gp

# Pick random examples
bases_all = sorted(list(cache.keys()))
show_bases = random.sample(bases_all, k=min(N_SHOW, len(bases_all)))
print("Showing:", show_bases)

cols = 4
rowsN = len(show_bases)

plt.figure(figsize=(cols*4.2, rowsN*4.2))

for r, base in enumerate(show_bases):
    d = cache[base]
    bgr = d["bgr_rs"]
    gt_mask01 = d["gt"]
    prob = d["prob"]

    raw_rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

    # raw + red circles (the annotated image)
    gt_path = gt_by_base.get(base, None)
    if gt_path is None or cv2.imread(gt_path) is None:
        gtvis_rgb = raw_rgb.copy()
    else:
        gt_bgr_full = cv2.imread(gt_path)
        gt_bgr_rs = cv2.resize(gt_bgr_full, (bgr.shape[1], bgr.shape[0]), interpolation=cv2.INTER_AREA)
        gtvis_rgb = cv2.cvtColor(gt_bgr_rs, cv2.COLOR_BGR2RGB)

    gt_overlay_rgb = cv2.cvtColor(overlay_green(bgr, gt_mask01, alpha=ALPHA), cv2.COLOR_BGR2RGB)

    pred01 = (prob > THR_SHOW).astype(np.uint8)
    pred_overlay_rgb = cv2.cvtColor(overlay_green(bgr, pred01, alpha=ALPHA), cv2.COLOR_BGR2RGB)

    ax = plt.subplot(rowsN, cols, r*cols + 1)
    ax.imshow(raw_rgb); ax.set_title(f"{base}\nRaw"); ax.axis("off")

    ax = plt.subplot(rowsN, cols, r*cols + 2)
    ax.imshow(gtvis_rgb); ax.set_title("Raw + red circles"); ax.axis("off")

    ax = plt.subplot(rowsN, cols, r*cols + 3)
    ax.imshow(gt_overlay_rgb); ax.set_title("GT mask overlay"); ax.axis("off")

    ax = plt.subplot(rowsN, cols, r*cols + 4)
    ax.imshow(pred_overlay_rgb); ax.set_title(f"Prediction overlay (thr={THR_SHOW:.2f})"); ax.axis("off")

plt.tight_layout()

out_path = os.path.join(OUT_DIR, f"montage_examples_thr_{THR_SHOW:.2f}_n{rowsN}_seed{SEED}.png")
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.show()
print("Saved:", out_path)


In [None]:
import os, numpy as np, pandas as pd
import matplotlib.pyplot as plt

THR_EVAL = 0.50  # lock this (your selected threshold)

def metrics_np(pred01, gt01, eps=1e-9):
    tp = np.logical_and(pred01==1, gt01==1).sum()
    fp = np.logical_and(pred01==1, gt01==0).sum()
    fn = np.logical_and(pred01==0, gt01==1).sum()
    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)
    dice      = (2*tp) / (2*tp + fp + fn + eps)
    iou       = tp / (tp + fp + fn + eps)
    return dice, iou, precision, recall

rows = []
for base, d in cache.items():
    pred01 = (d["prob"] > THR_EVAL).astype(np.uint8)
    dice, iou, prec, rec = metrics_np(pred01, d["gt"])
    rows.append([base, dice, iou, prec, rec])

df_img = pd.DataFrame(rows, columns=["image", "dice", "iou", "precision", "recall"]).sort_values("image")

# FIX: older pandas doesn't support numeric_only=...
print(df_img[["dice","iou","precision","recall"]].describe())

# Save per-image metrics table (optional)
csv_path = os.path.join(OUT_DIR, f"per_image_metrics_thr_{THR_EVAL:.2f}.csv")
df_img.to_csv(csv_path, index=False)
print("Saved:", csv_path)

# Histogram of IoU
plt.figure(figsize=(6.8, 4.8))
plt.hist(df_img["iou"], bins=10)
plt.xlabel("IoU (Jaccard) per image")
plt.ylabel("Count")
plt.title(f"Distribution of IoU across annotated calibration images (thr={THR_EVAL:.2f}, n={len(df_img)})")
plt.tight_layout()
fig1 = os.path.join(OUT_DIR, f"iou_hist_thr_{THR_EVAL:.2f}.png")
plt.savefig(fig1, dpi=300, bbox_inches="tight")
plt.show()
print("Saved:", fig1)

# Boxplot of IoU + Dice
plt.figure(figsize=(6.8, 4.8))
plt.boxplot([df_img["iou"].values, df_img["dice"].values], labels=["IoU", "Dice"], showfliers=True)
plt.ylabel("Metric value")
plt.title(f"Per-image overlap metrics (thr={THR_EVAL:.2f}, n={len(df_img)})")
plt.tight_layout()
fig2 = os.path.join(OUT_DIR, f"iou_dice_box_thr_{THR_EVAL:.2f}.png")
plt.savefig(fig2, dpi=300, bbox_inches="tight")
plt.show()
print("Saved:", fig2)


In [None]:
import os, glob, random, cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import torch

# ========= CONFIG =========
C10_DIR    = "data/raw_images"
MODEL_PATH = "data/raw_images"
THR_EVAL   = 0.50   # LOCKED from calibration
OUT_DIR    = "data/raw_images"
SHORT_SIDE = 768
N_MONTAGE  = 6
SEED       = 2026
ALPHA      = 0.30
# ==========================

os.makedirs(OUT_DIR, exist_ok=True)
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

def resize_short(im, short=SHORT_SIDE):
    h, w = im.shape[:2]
    s = short / min(h, w)
    return cv2.resize(im, (int(round(w*s)), int(round(h*s))), interpolation=cv2.INTER_AREA)

def red_to_binary(mask_bgr):
    hsv = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2HSV)
    lower1, upper1 = np.array([0,70,40]),   np.array([12,255,255])
    lower2, upper2 = np.array([170,70,40]), np.array([180,255,255])
    m = cv2.inRange(hsv, lower1, upper1) | cv2.inRange(hsv, lower2, upper2)
    k = np.ones((5,5), np.uint8)
    m = cv2.morphologyEx(m, cv2.MORPH_OPEN,  k, iterations=1)
    m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, k, iterations=1)
    return (m > 0).astype(np.uint8)

def overlay_green(bgr, mask01, alpha=ALPHA):
    out = bgr.copy().astype(np.float32)
    green = np.array([0,255,0], dtype=np.float32)
    idx = mask01.astype(bool)
    out[idx] = (1-alpha)*out[idx] + alpha*green
    return out.astype(np.uint8)

def metrics_np(pred01, gt01, eps=1e-9):
    tp = np.logical_and(pred01==1, gt01==1).sum()
    fp = np.logical_and(pred01==1, gt01==0).sum()
    fn = np.logical_and(pred01==0, gt01==1).sum()
    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)
    dice      = (2*tp) / (2*tp + fp + fn + eps)
    iou       = tp / (tp + fp + fn + eps)
    return float(dice), float(iou), float(precision), float(recall), int(tp), int(fp), int(fn)

# ---- pair: X.png with X_gt.png ----
all_png = sorted(glob.glob(os.path.join(C10_DIR, "*.png")))
gt_paths  = [p for p in all_png if os.path.basename(p).lower().endswith("_gt.png")]
raw_paths = [p for p in all_png if not os.path.basename(p).lower().endswith("_gt.png")]
raw_map = {os.path.splitext(os.path.basename(p))[0]: p for p in raw_paths}

pairs = []
missing = []
for gp in gt_paths:
    base = os.path.splitext(os.path.basename(gp))[0].replace("_gt", "")
    rp = raw_map.get(base, None)
    if rp is None:
        missing.append(os.path.basename(gp))
    else:
        pairs.append((rp, gp))

pairs = sorted(pairs, key=lambda x: os.path.basename(x[0]).lower())

print("raw:", len(raw_paths), "| gt:", len(gt_paths), "| paired:", len(pairs))
if missing:
    print("GT without matching raw (first 10):", missing[:10])

assert len(pairs) > 0, "No pairs found. Check naming: X.png and X_gt.png in the same folder."

# ---- load model ----
print("Loading model:", MODEL_PATH)
model = smp.Unet(encoder_name="resnet18", encoder_weights=None, in_channels=3, classes=1)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE).eval()

# ---- run evaluation ----
rows = []
for rp, gp in pairs:
    base = os.path.splitext(os.path.basename(rp))[0]

    bgr = cv2.imread(rp)
    if bgr is None:
        print("Skip unreadable raw:", rp); continue
    bgr_rs = resize_short(bgr)

    # GT -> binary, resize to match
    gt_bgr = cv2.imread(gp)
    if gt_bgr is None:
        print("Skip unreadable GT:", gp); continue
    gt01 = red_to_binary(gt_bgr)
    gt01_rs = cv2.resize(gt01, (bgr_rs.shape[1], bgr_rs.shape[0]), interpolation=cv2.INTER_NEAREST).astype(np.uint8)

    # inference prob
    rgb_rs = cv2.cvtColor(bgr_rs, cv2.COLOR_BGR2RGB)
    ten = ToTensorV2()(image=rgb_rs)["image"].unsqueeze(0).to(DEVICE).float()
    with torch.no_grad():
        prob = torch.sigmoid(model(ten))[0,0].cpu().numpy()

    pred01 = (prob > THR_EVAL).astype(np.uint8)

    dice, iou, prec, rec, tp, fp, fn = metrics_np(pred01, gt01_rs)
    rows.append([base, dice, iou, prec, rec, tp, fp, fn])

df = pd.DataFrame(rows, columns=["image","dice","iou","precision","recall","tp","fp","fn"]).sort_values("image")
print("\nPer-image metrics (first 5):")
print(df.head())

# Summary
mean_vals = df[["dice","iou","precision","recall"]].mean()
std_vals  = df[["dice","iou","precision","recall"]].std()

print("\n=== C10 TEST SUMMARY (thr=0.50 locked) ===")
print(f"Mean Dice      = {mean_vals['dice']:.3f} ¬± {std_vals['dice']:.3f}")
print(f"Mean IoU       = {mean_vals['iou']:.3f} ¬± {std_vals['iou']:.3f}")
print(f"Mean Precision = {mean_vals['precision']:.3f} ¬± {std_vals['precision']:.3f}")
print(f"Mean Recall    = {mean_vals['recall']:.3f} ¬± {std_vals['recall']:.3f}")

# Save table
csv_out = os.path.join(OUT_DIR, "c10_test_metrics_thr_0.50.csv")
df.to_csv(csv_out, index=False)
print("\nSaved CSV:", csv_out)

# ---- montage: random examples ----
N = min(N_MONTAGE, len(pairs))
random.seed(SEED)
pick = random.sample(pairs, k=N)

cols = 3  # raw | GT overlay | pred overlay
rowsN = N
plt.figure(figsize=(cols*4.2, rowsN*4.2))

for r, (rp, gp) in enumerate(pick):
    base = os.path.splitext(os.path.basename(rp))[0]

    bgr = cv2.imread(rp); bgr_rs = resize_short(bgr)
    gt01 = red_to_binary(cv2.imread(gp))
    gt01_rs = cv2.resize(gt01, (bgr_rs.shape[1], bgr_rs.shape[0]), interpolation=cv2.INTER_NEAREST).astype(np.uint8)

    rgb_rs = cv2.cvtColor(bgr_rs, cv2.COLOR_BGR2RGB)
    ten = ToTensorV2()(image=rgb_rs)["image"].unsqueeze(0).to(DEVICE).float()
    with torch.no_grad():
        prob = torch.sigmoid(model(ten))[0,0].cpu().numpy()
    pred01 = (prob > THR_EVAL).astype(np.uint8)

    raw_rgb = cv2.cvtColor(bgr_rs, cv2.COLOR_BGR2RGB)
    gt_overlay_rgb   = cv2.cvtColor(overlay_green(bgr_rs, gt01_rs, alpha=ALPHA), cv2.COLOR_BGR2RGB)
    pred_overlay_rgb = cv2.cvtColor(overlay_green(bgr_rs, pred01, alpha=ALPHA), cv2.COLOR_BGR2RGB)

    ax = plt.subplot(rowsN, cols, r*cols + 1)
    ax.imshow(raw_rgb); ax.set_title(f"{base}\nRaw"); ax.axis("off")

    ax = plt.subplot(rowsN, cols, r*cols + 2)
    ax.imshow(gt_overlay_rgb); ax.set_title("GT mask overlay"); ax.axis("off")

    ax = plt.subplot(rowsN, cols, r*cols + 3)
    ax.imshow(pred_overlay_rgb); ax.set_title(f"Pred (thr={THR_EVAL:.2f})"); ax.axis("off")

plt.tight_layout()
mont_path = os.path.join(OUT_DIR, f"c10_montage_thr_{THR_EVAL:.2f}.png")
plt.savefig(mont_path, dpi=300, bbox_inches="tight")
plt.show()
print("Saved montage:", mont_path)


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

CSV_IN  = "data/raw_images"
OUT_DIR = "data/raw_images"
os.makedirs(OUT_DIR, exist_ok=True)

df = pd.read_csv(CSV_IN).sort_values("image").reset_index(drop=True)

metrics = ["dice", "iou", "precision", "recall"]
titles  = ["Dice", "IoU (Jaccard)", "Precision", "Recall"]

fig = plt.figure(figsize=(10.5, 7.5))

for i, (m, t) in enumerate(zip(metrics, titles), start=1):
    ax = plt.subplot(2, 2, i)
    y = df[m].values
    x = np.arange(len(y))

    mu = float(np.mean(y))
    sd = float(np.std(y, ddof=1)) if len(y) > 1 else 0.0

    ax.scatter(x, y)
    ax.axhline(mu, linestyle="--", label=f"Mean = {mu:.3f}")
    ax.axhspan(mu - sd, mu + sd, alpha=0.2, label=f"¬± SD = {sd:.3f}")

    ax.set_title(t)
    ax.set_xlabel("Test image index")
    ax.set_ylabel("Metric")
    ax.set_xticks(x)
    ax.set_xticklabels(df["image"].values, rotation=60, ha="right", fontsize=8)
    ax.set_ylim(0.0, 1.02)
    ax.grid(True, axis="y", alpha=0.3)

    if i == 1:
        ax.legend(loc="lower right")

plt.suptitle("Held-out test performance (threshold locked at 0.50; n=10 images)", y=0.99)
plt.tight_layout()

png_out = os.path.join(OUT_DIR, "c10_test_metrics_per_image.png")
pdf_out = os.path.join(OUT_DIR, "c10_test_metrics_per_image.pdf")
plt.savefig(png_out, dpi=300, bbox_inches="tight")
plt.savefig(pdf_out, bbox_inches="tight")
plt.show()

print("Saved:", png_out)
print("Saved:", pdf_out)


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

CSV_IN  = "data/raw_images"
OUT_DIR = "data/raw_images"

df = pd.read_csv(CSV_IN).sort_values("image").reset_index(drop=True)
x = np.arange(len(df))

plt.figure(figsize=(9.5, 4.8))
plt.scatter(x, df["iou"], label="IoU")
plt.scatter(x, df["dice"], label="Dice")

plt.axhline(df["iou"].mean(), linestyle="--", label=f"Mean IoU = {df['iou'].mean():.3f}")
plt.axhline(df["dice"].mean(), linestyle="--", label=f"Mean Dice = {df['dice'].mean():.3f}")

plt.xticks(x, df["image"], rotation=60, ha="right", fontsize=8)
plt.ylim(0.0, 1.02)
plt.ylabel("Metric")
plt.title("Held-out test overlap metrics (thr=0.50; n=10)")
plt.grid(True, axis="y", alpha=0.3)
plt.legend()
plt.tight_layout()

out = os.path.join(OUT_DIR, "c10_test_iou_dice.png")
plt.savefig(out, dpi=300, bbox_inches="tight")
plt.show()
print("Saved:", out)


In [None]:
# === Storyboards (finalexp folders): C1‚ÜíAI pillar mask + C2 precipitation @ fixed thr=70 ===
import os, glob, re, random, cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

# ---------------- Paths ----------------
FINAL_ROOT = "data/raw_images"
FOLDERS = ["t1t2", "t3t4", "t5t6", "s2s3", "d"]

MODEL_PATH = "data/raw_images"

# ---------------- Pillar mask (C1) ----------------
THR_PILLAR = 0.50
SHORT_SIDE = 768

# ---------------- Imaging scale ----------------
FIELD_UM   = 1331.2
IMG_PX     = 2048
um_per_px  = FIELD_UM / IMG_PX

# ---------------- Scale bar style ----------------
SCALE_UM      = 500
BAR_THICK_PX  = 35
MARGIN_PX     = 100
FONT          = cv2.FONT_HERSHEY_SIMPLEX
FONT_SCALE    = 4
FONT_THICK    = 13

# ---------------- C2 precipitation thresholding ----------------
C2_THRESH_MODE = "fixed"
C2_FIXED = 70   # <<< final chosen threshold

# ---------------- Misc ----------------
IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
SEED = 2026
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

# ---------------- Helpers ----------------
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

def resize_short(im, short=SHORT_SIDE):
    h, w = im.shape[:2]
    s = short / min(h, w)
    return cv2.resize(im, (int(round(w*s)), int(round(h*s))), interpolation=cv2.INTER_AREA)

def add_scale_bar(gray_img, um_per_px, scale_um=SCALE_UM,
                  bar_thick_px=BAR_THICK_PX, margin_px=MARGIN_PX):
    rgb = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB)
    h, w = gray_img.shape[:2]
    bar_len_px = int(round(scale_um / um_per_px))
    bar_len_px = min(bar_len_px, w - 2*margin_px - 1)
    x2 = w - margin_px; x1 = x2 - bar_len_px
    y2 = h - margin_px; y1 = y2 - bar_thick_px
    outline_pad = 2
    cv2.rectangle(rgb, (x1 - outline_pad, y1 - outline_pad),
                  (x2 + outline_pad, y2 + outline_pad), (0,0,0), -1)
    cv2.rectangle(rgb, (x1, y1), (x2, y2), (255,255,255), -1)
    label = f"{scale_um:.0f} um"
    (tw, th), _ = cv2.getTextSize(label, FONT, FONT_SCALE, FONT_THICK)
    tx = x1 + (bar_len_px - tw) // 2
    ty = y1 - 8
    cv2.putText(rgb, label, (tx, ty), FONT, FONT_SCALE, (0,0,0), FONT_THICK+2, cv2.LINE_AA)
    cv2.putText(rgb, label, (tx, ty), FONT, FONT_SCALE, (255,255,255), FONT_THICK, cv2.LINE_AA)
    return rgb

def to_u8(img):
    if img.dtype == np.uint8:
        return img
    return cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def binarize_c2_fixed(C2_raw, nonpillar_mask01, thr_val=C2_FIXED):
    """Binary crystal mask with fixed threshold (conservative), restricted to non-pillar ROI."""
    img = to_u8(C2_raw)
    img[nonpillar_mask01 == 0] = 0
    blur = cv2.GaussianBlur(img, (3,3), 0)
    _, thr = cv2.threshold(blur, int(thr_val), 255, cv2.THRESH_BINARY)
    thr[nonpillar_mask01 == 0] = 0
    k = np.ones((3,3), np.uint8)
    thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN,  k, iterations=1)
    thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, k, iterations=1)
    return thr

def base_key_from_c2(fname):
    stem = os.path.splitext(fname)[0]
    m = re.match(r"^(.*?)(?:[._]c2)$", stem, flags=re.IGNORECASE)
    return m.group(1) if m else None

def find_pair_in_folder(folder_dir):
    c2_files = []
    for p in glob.glob(os.path.join(folder_dir, "*")):
        if not is_img(p):
            continue
        stem = os.path.splitext(os.path.basename(p))[0]
        if re.search(r"(?:[._]c2)$", stem, flags=re.IGNORECASE):
            c2_files.append(p)
    c2_files = sorted(c2_files)
    if not c2_files:
        return None

    c2_path = random.choice(c2_files)
    base = base_key_from_c2(os.path.basename(c2_path))
    if base is None:
        return None

    # match C1 in same folder: base + .c1 / _c1
    c1_path = None
    for ext in IMG_EXTS:
        for sep in [".", "_"]:
            cand = os.path.join(folder_dir, f"{base}{sep}c1{ext}")
            if os.path.isfile(cand):
                c1_path = cand; break
        if c1_path: break

    return (base, c1_path, c2_path)

# ---------------- Load model ----------------
print("Loading model:", MODEL_PATH)
model = smp.Unet(encoder_name="resnet18", encoder_weights=None, in_channels=3, classes=1)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE).eval()

def infer_pillar_mask_from_c1(c1_path):
    """Returns C1_gray_u8 (H,W), pillar_mask01 (H,W 0/1)."""
    bgr = cv2.imread(c1_path)
    if bgr is None:
        raise FileNotFoundError(f"Unreadable C1: {c1_path}")

    if len(bgr.shape) == 3:
        C1_gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    else:
        C1_gray = bgr
        rgb = np.stack([bgr, bgr, bgr], axis=-1)

    H, W = C1_gray.shape[:2]
    rgb_rs = resize_short(rgb, short=SHORT_SIDE)
    ten = ToTensorV2()(image=rgb_rs)["image"].unsqueeze(0).to(DEVICE).float()

    with torch.no_grad():
        prob_rs = torch.sigmoid(model(ten))[0,0].cpu().numpy()

    prob_full = cv2.resize(prob_rs, (W, H), interpolation=cv2.INTER_LINEAR)
    mask01 = (prob_full > THR_PILLAR).astype(np.uint8)

    # tidy edges a bit
    k = np.ones((5,5), np.uint8)
    mask01 = cv2.morphologyEx(mask01*255, cv2.MORPH_OPEN,  k, iterations=1)
    mask01 = cv2.morphologyEx(mask01,     cv2.MORPH_CLOSE, k, iterations=1)
    mask01 = (mask01 > 0).astype(np.uint8)

    return C1_gray.astype(np.uint8), mask01

# ---------------- Run storyboards ----------------
for folder in FOLDERS:
    folder_dir = os.path.join(FINAL_ROOT, folder)
    if not os.path.isdir(folder_dir):
        print("Skip missing:", folder_dir);
        continue

    picked = find_pair_in_folder(folder_dir)
    if picked is None:
        print(f"[{folder}] No C2 files found (need *c2.png).")
        continue

    base_key, c1_path, c2_path = picked
    if c1_path is None:
        print(f"[{folder}] Found C2 but missing matching C1 for base '{base_key}'.")
        continue

    print(f"\nStoryboard: {folder} | {base_key}")

    # Load + infer pillars
    C1_raw, C1_mask01 = infer_pillar_mask_from_c1(c1_path)

    # Load C2
    C2_raw = cv2.imread(c2_path, cv2.IMREAD_GRAYSCALE)
    if C2_raw is None:
        print("‚ö†Ô∏è Skip unreadable C2:", c2_path);
        continue
    if C2_raw.shape != C1_raw.shape:
        C2_raw = cv2.resize(C2_raw, (C1_raw.shape[1], C1_raw.shape[0]), interpolation=cv2.INTER_AREA)

    # Metrics: pillar vs non-pillar
    total_pixels     = C1_mask01.size
    pillar_pixels    = int(C1_mask01.sum())
    nonpillar_pixels = total_pixels - pillar_pixels
    pillar_pct       = 100.0 * pillar_pixels / total_pixels
    nonpillar_pct    = 100.0 * nonpillar_pixels / total_pixels

    nonpillar01 = (C1_mask01 == 0).astype(np.uint8)

    # C2 processing (fixed thr=70)
    precip = binarize_c2_fixed(C2_raw, nonpillar01, thr_val=C2_FIXED)

    # Crystal metrics
    crystal_pixels           = int((precip > 0).sum())
    crystal_pct              = 100.0 * crystal_pixels / total_pixels
    crystal_in_nonpillar_pct = 100.0 * crystal_pixels / max(1, nonpillar_pixels)

    # Overlays (processed + outlines)
    proc_rgb = cv2.cvtColor(to_u8(C2_raw), cv2.COLOR_GRAY2RGB)
    proc_rgb[precip > 0] = (255,0,0)        # crystals red
    proc_rgb[C1_mask01 > 0] = (0,0,0)       # pillars black

    overlay = np.zeros_like(proc_rgb)
    contours, _ = cv2.findContours(C1_mask01.astype(np.uint8),
                                   cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(overlay, contours, -1, (255,255,0), 2)  # outlines yellow
    proc_with_outline = cv2.addWeighted(proc_rgb, 1.0, overlay, 1.0, 0)

    final_rgb = np.zeros_like(proc_rgb)
    final_rgb[C1_mask01 > 0] = (0,0,0)
    final_rgb[nonpillar01 > 0] = (255,255,255)
    final_rgb[precip > 0] = (255,0,0)

    # Add scale bars on raw panels
    C1_with_bar = add_scale_bar(to_u8(C1_raw), um_per_px, SCALE_UM)
    C2_with_bar = add_scale_bar(to_u8(C2_raw), um_per_px, SCALE_UM)

    # Plot storyboard
    fig, axes = plt.subplots(1, 5, figsize=(22, 5))
    axes[0].imshow(C1_with_bar); axes[0].set_title("C1 (raw + scale bar)"); axes[0].axis("off")
    axes[1].imshow(C1_mask01, cmap="gray"); axes[1].set_title(f"Pillar mask (AI)\nNon-pillar: {nonpillar_pct:.1f}%"); axes[1].axis("off")
    axes[2].imshow(C2_with_bar); axes[2].set_title("C2 TL-POL (raw + scale bar)"); axes[2].axis("off")
    axes[3].imshow(proc_with_outline); axes[3].set_title(f"C2 processed (thr={C2_FIXED})\n+ pillar outlines"); axes[3].axis("off")
    axes[4].imshow(final_rgb); axes[4].set_title(
        f"Composite\nPillar: {pillar_pct:.1f}% | Non-pillar: {nonpillar_pct:.1f}%\n"
        f"Crystal: {crystal_pct:.1f}% | Crystal/Non-pillar: {crystal_in_nonpillar_pct:.1f}%"
    ); axes[4].axis("off")

    plt.suptitle(f"{folder} | {base_key} | Pillar thr={THR_PILLAR:.2f} | C2 fixed thr={C2_FIXED}",
                 fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.show()


In [None]:
import os, glob, re, random, cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

# ================== CONFIG ==================
FINAL_ROOT = "data/raw_images"
FOLDERS = ["t1t2", "t3t4", "t5t6", "s2s3", "d"]

MODEL_PATH = "data/raw_images"
THR_PILLAR = 0.50           # locked pillar threshold (from your calibration)
SHORT_SIDE = 768

# >>> YOU TUNE THESE <<<
THR_LIST = [50, 60, 70, 80]   # C2 thresholds to preview (lower = more sensitive)
N_PER_FOLDER = 1                # how many random examples per folder (increase to 2 for nicer SI)
SEED = 2027

# Optional pre-normalization of C2 to increase sensitivity/consistency
# "none" or "clahe"
C2_PRENORM = "none"

# Output (small figure only)
OUT_DIR = "data/raw_images"
OUT_NAME = "c2_threshold_tuning_montage.png"
# ============================================

os.makedirs(OUT_DIR, exist_ok=True)
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

def resize_short(im, short=SHORT_SIDE):
    h, w = im.shape[:2]
    s = short / min(h, w)
    return cv2.resize(im, (int(round(w*s)), int(round(h*s))), interpolation=cv2.INTER_AREA)

def base_key_from_c2(fname):
    stem = os.path.splitext(fname)[0]
    m = re.match(r"^(.*?)(?:[._]c2)$", stem, flags=re.IGNORECASE)
    return m.group(1) if m else None

def find_pair_in_folder(folder_dir):
    # find C2 files (X.c2.png or X_c2.png)
    c2_files = []
    for p in glob.glob(os.path.join(folder_dir, "*")):
        if not is_img(p):
            continue
        stem = os.path.splitext(os.path.basename(p))[0]
        if re.search(r"(?:[._]c2)$", stem, flags=re.IGNORECASE):
            c2_files.append(p)
    c2_files = sorted(c2_files)
    if not c2_files:
        return None

    c2_path = random.choice(c2_files)
    base = base_key_from_c2(os.path.basename(c2_path))
    if base is None:
        return None

    # match C1 in same folder: base + .c1 / _c1
    c1_path = None
    for ext in IMG_EXTS:
        for sep in [".", "_"]:
            cand = os.path.join(folder_dir, f"{base}{sep}c1{ext}")
            if os.path.isfile(cand):
                c1_path = cand; break
        if c1_path: break

    return (base, c1_path, c2_path)

def to_u8(img):
    if img.dtype == np.uint8:
        return img
    return cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def prenorm_c2(img_u8, mode="none"):
    if mode == "none":
        return img_u8
    if mode == "clahe":
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        return clahe.apply(img_u8)
    return img_u8

def binarize_c2_fixed(C2_raw_u8, nonpillar01, thr_val):
    """Fixed-threshold C2 segmentation inside non-pillar ROI only."""
    img = C2_raw_u8.copy()
    img[nonpillar01 == 0] = 0
    blur = cv2.GaussianBlur(img, (3,3), 0)
    _, thr = cv2.threshold(blur, int(thr_val), 255, cv2.THRESH_BINARY)
    thr[nonpillar01 == 0] = 0
    k = np.ones((3,3), np.uint8)
    thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN,  k, iterations=1)
    thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, k, iterations=1)
    return thr

def overlay_precip_on_c2(C2_u8, pillar01, precip255):
    """Return RGB: pillars black, precip red, background grayscale."""
    rgb = cv2.cvtColor(C2_u8, cv2.COLOR_GRAY2RGB)
    rgb[pillar01 > 0] = (0,0,0)
    rgb[precip255 > 0] = (255,0,0)
    return rgb

# ---- load model once ----
print("Loading model:", MODEL_PATH)
model = smp.Unet(encoder_name="resnet18", encoder_weights=None, in_channels=3, classes=1)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE).eval()

def infer_pillar_mask_from_c1(c1_path):
    bgr = cv2.imread(c1_path)
    if bgr is None:
        raise FileNotFoundError(f"Unreadable C1: {c1_path}")

    if len(bgr.shape) == 3:
        gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
        rgb  = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    else:
        gray = bgr
        rgb  = np.stack([bgr, bgr, bgr], axis=-1)

    H, W = gray.shape[:2]
    rgb_rs = resize_short(rgb, short=SHORT_SIDE)
    ten = ToTensorV2()(image=rgb_rs)["image"].unsqueeze(0).to(DEVICE).float()

    with torch.no_grad():
        prob_rs = torch.sigmoid(model(ten))[0,0].cpu().numpy()

    prob_full = cv2.resize(prob_rs, (W, H), interpolation=cv2.INTER_LINEAR)
    mask01 = (prob_full > THR_PILLAR).astype(np.uint8)

    # light cleanup
    k = np.ones((5,5), np.uint8)
    mask01 = cv2.morphologyEx(mask01*255, cv2.MORPH_OPEN,  k, iterations=1)
    mask01 = cv2.morphologyEx(mask01,     cv2.MORPH_CLOSE, k, iterations=1)
    mask01 = (mask01 > 0).astype(np.uint8)

    return gray.astype(np.uint8), mask01

# ---- pick examples (folder-balanced) ----
examples = []
for folder in FOLDERS:
    folder_dir = os.path.join(FINAL_ROOT, folder)
    if not os.path.isdir(folder_dir):
        print("Missing folder:", folder_dir); continue

    # pick N_PER_FOLDER examples
    picks = 0
    tries = 0
    while picks < N_PER_FOLDER and tries < 50:
        tries += 1
        out = find_pair_in_folder(folder_dir)
        if out is None: break
        base, c1_path, c2_path = out
        if c1_path is None:
            continue
        examples.append((folder, base, c1_path, c2_path))
        picks += 1

print("Examples:", len(examples))
for e in examples:
    print(e[0], os.path.basename(e[3]))

# ---- build montage ----
cols = 1 + len(THR_LIST)          # raw + each threshold
rowsN = len(examples)
plt.figure(figsize=(cols*4.0, rowsN*4.0))

for r, (folder, base, c1_path, c2_path) in enumerate(examples):
    # pillar mask
    C1_gray, pillar01 = infer_pillar_mask_from_c1(c1_path)
    nonpillar01 = (pillar01 == 0).astype(np.uint8)

    # C2
    C2 = cv2.imread(c2_path, cv2.IMREAD_GRAYSCALE)
    if C2 is None:
        continue
    if C2.shape != C1_gray.shape:
        C2 = cv2.resize(C2, (C1_gray.shape[1], C1_gray.shape[0]), interpolation=cv2.INTER_AREA)

    C2_u8 = prenorm_c2(to_u8(C2), mode=C2_PRENORM)

    # col 1: raw C2
    ax = plt.subplot(rowsN, cols, r*cols + 1)
    ax.imshow(C2_u8, cmap="gray")
    ax.set_title(f"{folder} | {base}\nC2 raw")
    ax.axis("off")

    # threshold columns
    for j, thr_val in enumerate(THR_LIST):
        precip = binarize_c2_fixed(C2_u8, nonpillar01, thr_val)
        ov = overlay_precip_on_c2(C2_u8, pillar01, precip)

        ax = plt.subplot(rowsN, cols, r*cols + 2 + j)
        ax.imshow(ov)
        ax.set_title(f"thr={thr_val}")
        ax.axis("off")

plt.suptitle(f"C2 threshold tuning (pillars via U-Net @ {THR_PILLAR:.2f}; C2_PRENORM={C2_PRENORM})",
             fontsize=14, fontweight="bold")
plt.tight_layout()

out_path = os.path.join(OUT_DIR, OUT_NAME)
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.show()

print("Saved montage:", out_path)
print("Tip: adjust THR_LIST (lower = more sensitive) and re-run.")


In [None]:
import os, glob, re, random, cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

# ================== CONFIG ==================
D_DIR = "data/raw_images"

MODEL_PATH = "data/raw_images"
THR_PILLAR = 0.50
SHORT_SIDE = 768

# C2 precipitation threshold (your final choice)
C2_FIXED = 70

# How many d# groups to show
N_GROUPS_TO_SHOW = 4
SEED = 2026

# Imaging scale for scale bar
FIELD_UM   = 1331.2
IMG_PX     = 2048
um_per_px  = FIELD_UM / IMG_PX

# Scale bar style
SCALE_UM      = 500
BAR_THICK_PX  = 35
MARGIN_PX     = 100
FONT          = cv2.FONT_HERSHEY_SIMPLEX
FONT_SCALE    = 4
FONT_THICK    = 13
# ============================================

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

def resize_short(im, short=SHORT_SIDE):
    h, w = im.shape[:2]
    s = short / min(h, w)
    return cv2.resize(im, (int(round(w*s)), int(round(h*s))), interpolation=cv2.INTER_AREA)

def add_scale_bar(gray_img, um_per_px, scale_um=SCALE_UM,
                  bar_thick_px=BAR_THICK_PX, margin_px=MARGIN_PX):
    rgb = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB)
    h, w = gray_img.shape[:2]
    bar_len_px = int(round(scale_um / um_per_px))
    bar_len_px = min(bar_len_px, w - 2*margin_px - 1)
    x2 = w - margin_px; x1 = x2 - bar_len_px
    y2 = h - margin_px; y1 = y2 - bar_thick_px
    outline_pad = 2
    cv2.rectangle(rgb, (x1 - outline_pad, y1 - outline_pad),
                  (x2 + outline_pad, y2 + outline_pad), (0,0,0), -1)
    cv2.rectangle(rgb, (x1, y1), (x2, y2), (255,255,255), -1)
    label = f"{scale_um:.0f} um"
    (tw, th), _ = cv2.getTextSize(label, FONT, FONT_SCALE, FONT_THICK)
    tx = x1 + (bar_len_px - tw) // 2
    ty = y1 - 8
    cv2.putText(rgb, label, (tx, ty), FONT, FONT_SCALE, (0,0,0), FONT_THICK+2, cv2.LINE_AA)
    cv2.putText(rgb, label, (tx, ty), FONT, FONT_SCALE, (255,255,255), FONT_THICK, cv2.LINE_AA)
    return rgb

def to_u8(img):
    if img.dtype == np.uint8:
        return img
    return cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def binarize_c2_fixed(C2_raw, nonpillar_mask01, thr_val=C2_FIXED):
    img = to_u8(C2_raw).copy()
    img[nonpillar_mask01 == 0] = 0
    blur = cv2.GaussianBlur(img, (3,3), 0)
    _, thr = cv2.threshold(blur, int(thr_val), 255, cv2.THRESH_BINARY)
    thr[nonpillar_mask01 == 0] = 0
    k = np.ones((3,3), np.uint8)
    thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN,  k, iterations=1)
    thr = cv2.morphologyEx(thr, cv2.MORPH_CLOSE, k, iterations=1)
    return thr

# ---------- load model ----------
print("Loading model:", MODEL_PATH)
model = smp.Unet(encoder_name="resnet18", encoder_weights=None, in_channels=3, classes=1)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE).eval()

def infer_pillar_mask_from_c1(c1_path):
    bgr = cv2.imread(c1_path)
    if bgr is None:
        raise FileNotFoundError(f"Unreadable C1: {c1_path}")

    if len(bgr.shape) == 3:
        C1_gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    else:
        C1_gray = bgr
        rgb = np.stack([bgr, bgr, bgr], axis=-1)

    H, W = C1_gray.shape[:2]
    rgb_rs = resize_short(rgb, short=SHORT_SIDE)
    ten = ToTensorV2()(image=rgb_rs)["image"].unsqueeze(0).to(DEVICE).float()

    with torch.no_grad():
        prob_rs = torch.sigmoid(model(ten))[0,0].cpu().numpy()

    prob_full = cv2.resize(prob_rs, (W, H), interpolation=cv2.INTER_LINEAR)
    mask01 = (prob_full > THR_PILLAR).astype(np.uint8)

    # tidy
    k = np.ones((5,5), np.uint8)
    mask01 = cv2.morphologyEx(mask01*255, cv2.MORPH_OPEN,  k, iterations=1)
    mask01 = cv2.morphologyEx(mask01,     cv2.MORPH_CLOSE, k, iterations=1)
    mask01 = (mask01 > 0).astype(np.uint8)

    return C1_gray.astype(np.uint8), mask01

# ---------- find all d*.f*.c2 / d*.p*.c2 ----------
# Example name: d6.f3.c2.png or d6.p2.c2.png
pat = re.compile(r"^(d\d+)\.(f|p)(\d+)\.(c2)$", flags=re.IGNORECASE)

c2_paths = []
for p in glob.glob(os.path.join(D_DIR, "*")):
    if not is_img(p):
        continue
    stem = os.path.splitext(os.path.basename(p))[0]
    if pat.match(stem):
        c2_paths.append(p)

if not c2_paths:
    raise FileNotFoundError(f"No d#.f#/p#.c2 images found in {D_DIR}")

# group by d#
groups = {}
for c2p in c2_paths:
    stem = os.path.splitext(os.path.basename(c2p))[0]
    m = pat.match(stem)
    d_id, fp, idx, _ = m.group(1).lower(), m.group(2).lower(), int(m.group(3)), m.group(4).lower()
    groups.setdefault(d_id, {}).setdefault(fp, []).append((idx, c2p))

# sort within each
for d_id in groups:
    for fp in groups[d_id]:
        groups[d_id][fp] = sorted(groups[d_id][fp], key=lambda t: t[0])

d_ids = sorted(groups.keys())
random.shuffle(d_ids)
show_d = d_ids[:min(N_GROUPS_TO_SHOW, len(d_ids))]
print("Showing d groups:", show_d)

def c1_from_c2(c2_path):
    # replace trailing ".c2" with ".c1" in stem
    base = os.path.splitext(os.path.basename(c2_path))[0]  # d6.f3.c2
    c1_stem = re.sub(r"\.c2$", ".c1", base, flags=re.IGNORECASE)
    # try any extension
    for ext in IMG_EXTS:
        cand = os.path.join(D_DIR, c1_stem + ext)
        if os.path.isfile(cand):
            return cand
    return None

def storyboard_one(base_label, c1_path, c2_path):
    # load / infer pillars
    C1_raw, pillar01 = infer_pillar_mask_from_c1(c1_path)
    nonpillar01 = (pillar01 == 0).astype(np.uint8)

    # load C2
    C2_raw = cv2.imread(c2_path, cv2.IMREAD_GRAYSCALE)
    if C2_raw is None:
        raise FileNotFoundError("Unreadable C2: " + c2_path)
    if C2_raw.shape != C1_raw.shape:
        C2_raw = cv2.resize(C2_raw, (C1_raw.shape[1], C1_raw.shape[0]), interpolation=cv2.INTER_AREA)

    # metrics
    total = pillar01.size
    pillar_px = int(pillar01.sum())
    nonpillar_px = total - pillar_px
    pillar_pct = 100.0 * pillar_px / total
    nonpillar_pct = 100.0 * nonpillar_px / total

    precip = binarize_c2_fixed(C2_raw, nonpillar01, thr_val=C2_FIXED)

    crystal_px = int((precip > 0).sum())
    crystal_pct = 100.0 * crystal_px / total
    crystal_nonpillar_pct = 100.0 * crystal_px / max(1, nonpillar_px)

    # overlays
    proc_rgb = cv2.cvtColor(to_u8(C2_raw), cv2.COLOR_GRAY2RGB)
    proc_rgb[precip > 0] = (255,0,0)           # red precip
    proc_rgb[pillar01 > 0] = (0,0,0)           # black pillars

    overlay = np.zeros_like(proc_rgb)
    contours, _ = cv2.findContours(pillar01.astype(np.uint8),
                                   cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(overlay, contours, -1, (255,255,0), 2)  # yellow outlines
    proc_with_outline = cv2.addWeighted(proc_rgb, 1.0, overlay, 1.0, 0)

    final_rgb = np.zeros_like(proc_rgb)
    final_rgb[pillar01 > 0] = (0,0,0)
    final_rgb[nonpillar01 > 0] = (255,255,255)
    final_rgb[precip > 0] = (255,0,0)

    C1_bar = add_scale_bar(to_u8(C1_raw), um_per_px, SCALE_UM)
    C2_bar = add_scale_bar(to_u8(C2_raw), um_per_px, SCALE_UM)

    fig, axes = plt.subplots(1,5, figsize=(22,5))
    axes[0].imshow(C1_bar); axes[0].set_title("C1 (raw + scale bar)"); axes[0].axis("off")
    axes[1].imshow(pillar01, cmap="gray"); axes[1].set_title(f"Pillar mask (AI)\nNon-pillar: {nonpillar_pct:.1f}%"); axes[1].axis("off")
    axes[2].imshow(C2_bar); axes[2].set_title("C2 TL-POL (raw + scale bar)"); axes[2].axis("off")
    axes[3].imshow(proc_with_outline); axes[3].set_title(f"C2 processed (thr={C2_FIXED})\n+ pillar outlines"); axes[3].axis("off")
    axes[4].imshow(final_rgb); axes[4].set_title(
        f"Composite\n"
        f"Pillar: {pillar_pct:.1f}% | Non-pillar: {nonpillar_pct:.1f}%\n"
        f"Crystal: {crystal_pct:.1f}% | Crystal/Non-pillar: {crystal_nonpillar_pct:.1f}%"
    ); axes[4].axis("off")

    plt.suptitle(f"{base_label} | Pillar thr={THR_PILLAR:.2f} | C2 fixed thr={C2_FIXED}",
                 fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.show()

# ---------- show storyboards ----------
for d_id in show_d:
    # try to show one full + one partial for the same d#
    picks = []
    if "f" in groups[d_id] and len(groups[d_id]["f"]) > 0:
        picks.append(groups[d_id]["f"][0])  # (idx, path) - first full
    if "p" in groups[d_id] and len(groups[d_id]["p"]) > 0:
        picks.append(groups[d_id]["p"][0])  # first partial

    for idx, c2p in picks:
        c1p = c1_from_c2(c2p)
        if c1p is None:
            print("Missing C1 for:", os.path.basename(c2p))
            continue
        base_label = os.path.splitext(os.path.basename(c2p))[0].replace(".c2","").replace("_c2","")
        storyboard_one(base_label, c1p, c2p)


In [None]:
import os, glob, re, random, cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ================== CONFIG ==================
D_DIR = "data/raw_images"

C2_THR = 70          # fixed threshold you chose
MIN_AREA_OPEN = 1    # set to 0 to turn off cleanup; 1 is usually safe
SEED = 2026

# display settings
N_SHOW = 9           # random examples to show
ALPHA = 0.35         # overlay strength
# ============================================

random.seed(SEED); np.random.seed(SEED)

IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

# filename pattern: d6.f3.c2.png  OR d6.p2.c2.png
pat = re.compile(r"^(d\d+)\.(f|p)(\d+)\.(c2)$", flags=re.IGNORECASE)

# --- gather C2 files ---
rows = []
for p in glob.glob(os.path.join(D_DIR, "*")):
    if not is_img(p):
        continue
    stem = os.path.splitext(os.path.basename(p))[0]
    m = pat.match(stem)
    if not m:
        continue
    d_id = m.group(1).lower()
    fp   = m.group(2).lower()   # f or p
    rep  = int(m.group(3))      # 1..3
    rows.append((d_id, fp, rep, p))

if not rows:
    raise FileNotFoundError("No files matching d#.f#/p#.c2.* found in: " + D_DIR)

df = pd.DataFrame(rows, columns=["d_id","fp","rep","path"]).sort_values(["d_id","fp","rep"]).reset_index(drop=True)
print("Found C2 images:", len(df))
display(df.head(10))

def to_u8(img):
    if img is None:
        return None
    if img.dtype == np.uint8:
        return img
    return cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def c2_to_binary_count(C2_gray_u8, thr=C2_THR, min_area_open=MIN_AREA_OPEN):
    """
    Threshold C2 and return:
      bin01 (0/1), count_on (pixels), frac_on
    Optional tiny cleanup: open+close with 3x3.
    """
    # fixed threshold
    _, bin255 = cv2.threshold(C2_gray_u8, int(thr), 255, cv2.THRESH_BINARY)

    if min_area_open > 0:
        k = np.ones((3,3), np.uint8)
        bin255 = cv2.morphologyEx(bin255, cv2.MORPH_OPEN,  k, iterations=1)
        bin255 = cv2.morphologyEx(bin255, cv2.MORPH_CLOSE, k, iterations=1)

    bin01 = (bin255 > 0).astype(np.uint8)
    count_on = int(bin01.sum())
    frac_on  = float(count_on / bin01.size)
    return bin01, count_on, frac_on

# --- compute per-image fluorescent pixel count ---
metrics = []
for _, r in df.iterrows():
    c2 = cv2.imread(r["path"], cv2.IMREAD_GRAYSCALE)
    c2 = to_u8(c2)
    if c2 is None:
        print("‚ö†Ô∏è Skip unreadable:", r["path"])
        continue

    bin01, count_on, frac_on = c2_to_binary_count(c2, thr=C2_THR)
    metrics.append({
        "d_id": r["d_id"], "fp": r["fp"], "rep": r["rep"],
        "thr": C2_THR,
        "fluor_pixels": count_on,
        "fluor_frac": frac_on,
        "total_pixels": int(bin01.size),
        "path": r["path"],
    })

dfm = pd.DataFrame(metrics).sort_values(["d_id","fp","rep"]).reset_index(drop=True)

print("\nPer-image thresholded C2 counts (first 8):")
display(dfm.head(8))

# --- average within d# by full vs partial ---
agg_fp = (dfm.groupby(["d_id","fp"])
          .agg(n=("fluor_pixels","size"),
               mean_fluor_pixels=("fluor_pixels","mean"),
               sd_fluor_pixels=("fluor_pixels","std"),
               mean_fluor_frac=("fluor_frac","mean"),
               sd_fluor_frac=("fluor_frac","std"))
          .reset_index()
          .sort_values(["d_id","fp"]))

# --- average within d# combining f+p ---
agg_all = (dfm.groupby(["d_id"])
           .agg(n=("fluor_pixels","size"),
                mean_fluor_pixels=("fluor_pixels","mean"),
                sd_fluor_pixels=("fluor_pixels","std"),
                mean_fluor_frac=("fluor_frac","mean"),
                sd_fluor_frac=("fluor_frac","std"))
           .reset_index()
           .sort_values("d_id"))

print("\n=== Averages by d# (FULL vs PARTIAL) ===")
display(agg_fp)

print("\n=== Averages by d# (FULL+PARTIAL combined) ===")
display(agg_all)

# --- quick plot: combined mean fluorescent pixels per d# with SD ---
plt.figure(figsize=(10,4))
x = np.arange(len(agg_all))
plt.bar(x, agg_all["mean_fluor_pixels"])
plt.xticks(x, agg_all["d_id"])
plt.ylabel(f"Thresholded C2 pixels (thr={C2_THR})")
plt.title("d# precipitation area proxy (C2 thresholded pixel count; no C1 mask)")
plt.errorbar(x, agg_all["mean_fluor_pixels"], yerr=agg_all["sd_fluor_pixels"], fmt="none", capsize=4)
plt.tight_layout()
plt.show()

# --- random montage: raw C2 vs threshold mask overlay ---
showN = min(N_SHOW, len(dfm))
sample = dfm.sample(n=showN, random_state=SEED).reset_index(drop=True)

cols = 2
rowsN = showN
plt.figure(figsize=(cols*6, rowsN*4))

for i in range(showN):
    p = sample.loc[i, "path"]
    c2 = to_u8(cv2.imread(p, cv2.IMREAD_GRAYSCALE))
    bin01, count_on, frac_on = c2_to_binary_count(c2, thr=C2_THR)

    # raw
    ax = plt.subplot(rowsN, cols, i*cols + 1)
    ax.imshow(c2, cmap="gray")
    ax.set_title(f"{os.path.basename(p)}\nC2 raw")
    ax.axis("off")

    # overlay: keep pixels in green
    c2_rgb = cv2.cvtColor(c2, cv2.COLOR_GRAY2RGB)
    overlay = np.zeros_like(c2_rgb)
    overlay[bin01>0] = (0,255,0)
    vis = cv2.addWeighted(c2_rgb, 1.0, overlay, ALPHA, 0)

    ax = plt.subplot(rowsN, cols, i*cols + 2)
    ax.imshow(vis)
    ax.set_title(f"thr={C2_THR} ‚Üí pixels={count_on:,} ({100*frac_on:.2f}%)")
    ax.axis("off")

plt.suptitle("Random examples: C2 thresholded precipitation pixels (no C1 mask)", y=1.01,
             fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# agg_fp columns expected:
# ["d_id","fp","n","mean_fluor_pixels","sd_fluor_pixels","mean_fluor_frac","sd_fluor_frac"]

# ---- prepare wide table for grouped bars ----
wide = agg_fp.pivot(index="d_id", columns="fp", values=["mean_fluor_pixels","sd_fluor_pixels"])
wide = wide.sort_index()

d_ids = wide.index.tolist()

# Some d# might have only f or only p; fill missing with NaN -> plot safely
mean_f = wide["mean_fluor_pixels"].get("f")
mean_p = wide["mean_fluor_pixels"].get("p")
sd_f   = wide["sd_fluor_pixels"].get("f")
sd_p   = wide["sd_fluor_pixels"].get("p")

# Convert to numpy arrays (may contain NaN)
mean_f = mean_f.values if mean_f is not None else np.full(len(d_ids), np.nan)
mean_p = mean_p.values if mean_p is not None else np.full(len(d_ids), np.nan)
sd_f   = sd_f.values   if sd_f   is not None else np.full(len(d_ids), np.nan)
sd_p   = sd_p.values   if sd_p   is not None else np.full(len(d_ids), np.nan)

# ---- plot grouped bars ----
x = np.arange(len(d_ids))
w = 0.38

plt.figure(figsize=(12,4.5))
plt.bar(x - w/2, mean_f, width=w, yerr=sd_f, capsize=4, label="Full (f)")
plt.bar(x + w/2, mean_p, width=w, yerr=sd_p, capsize=4, label="Partial (p)")

plt.xticks(x, d_ids)
plt.ylabel(f"Thresholded C2 pixels (thr={C2_THR})")
plt.title("Precipitation area proxy by device (C2 thresholded pixel count)\nFull vs Partial replicates")
plt.legend()
plt.tight_layout()
plt.show()



In [None]:
import os, glob, re, cv2
import numpy as np
import pandas as pd
import torch
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

# ---------------- CONFIG ----------------
FINAL_ROOT = "data/raw_images"
FOLDERS = ["t1t2", "t3t4", "t5t6", "s2s3"]   # exclude "d"

MODEL_PATH = "data/raw_images"
THR_PILLAR = 0.50
SHORT_SIDE = 768

C2_FIXED_THR = 70   # your final precipitation threshold
MORPH_C2 = True     # open+close to reduce specks

# quick test mode
LIMIT_PER_FOLDER = None   # set to e.g. 6 to test quickly; or None for full run
# ---------------------------------------

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

IMG_EXTS = (".png", ".tif", ".tiff", ".jpg", ".jpeg")
def is_img(p): return os.path.splitext(p)[1].lower() in IMG_EXTS

def resize_short(im, short=SHORT_SIDE):
    h, w = im.shape[:2]
    s = short / min(h, w)
    return cv2.resize(im, (int(round(w*s)), int(round(h*s))), interpolation=cv2.INTER_AREA)

def to_u8(img):
    if img is None:
        return None
    if img.dtype == np.uint8:
        return img
    return cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def infer_pillar_mask_from_c1_rgb(model, c1_path):
    """
    Returns: (H,W) gray C1, (H,W) pillar mask 0/1
    """
    bgr = cv2.imread(c1_path)
    if bgr is None:
        raise FileNotFoundError(f"Unreadable C1: {c1_path}")

    # model expects 3ch
    if len(bgr.shape) == 2:
        rgb = np.stack([bgr,bgr,bgr], axis=-1)
        c1_gray = bgr
    else:
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        c1_gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)

    H, W = c1_gray.shape[:2]
    rgb_rs = resize_short(rgb, SHORT_SIDE)
    ten = ToTensorV2()(image=rgb_rs)["image"].unsqueeze(0).to(DEVICE).float()

    with torch.no_grad():
        prob_rs = torch.sigmoid(model(ten))[0,0].cpu().numpy()

    prob_full = cv2.resize(prob_rs, (W, H), interpolation=cv2.INTER_LINEAR)
    mask01 = (prob_full > THR_PILLAR).astype(np.uint8)

    # tidy mask edges
    k = np.ones((5,5), np.uint8)
    mask01 = cv2.morphologyEx(mask01*255, cv2.MORPH_OPEN,  k, iterations=1)
    mask01 = cv2.morphologyEx(mask01,     cv2.MORPH_CLOSE, k, iterations=1)
    mask01 = (mask01 > 0).astype(np.uint8)

    return to_u8(c1_gray), mask01

def binarize_c2_fixed(c2_gray, nonpillar01, thr=C2_FIXED_THR, morph=MORPH_C2):
    img = to_u8(c2_gray).copy()
    # restrict to non-pillar
    img[nonpillar01 == 0] = 0
    blur = cv2.GaussianBlur(img, (3,3), 0)
    _, bw = cv2.threshold(blur, int(thr), 255, cv2.THRESH_BINARY)
    bw[nonpillar01 == 0] = 0
    if morph:
        k = np.ones((3,3), np.uint8)
        bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN,  k, iterations=1)
        bw = cv2.morphologyEx(bw, cv2.MORPH_CLOSE, k, iterations=1)
    return (bw > 0).astype(np.uint8)

# ---------------- Load model once ----------------
print("Loading model:", MODEL_PATH)
model = smp.Unet(encoder_name="resnet18", encoder_weights=None, in_channels=3, classes=1)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE).eval()

# ---------------- Find and pair images ----------------
# expected name pattern: t1.1.8.a.c1.png / t1.1.8.a.c2.png  (also works for s2/s3)
pat = re.compile(r"^([ts]\d+)\.(\d+)\.(\d+)\.([abc])\.(c1|c2)$", flags=re.IGNORECASE)

def scan_folder(folder_dir):
    files = [p for p in glob.glob(os.path.join(folder_dir, "*")) if is_img(p)]
    recs = []
    for p in files:
        stem = os.path.splitext(os.path.basename(p))[0]
        m = pat.match(stem)
        if not m:
            continue
        ch = m.group(1).lower()        # t1, t2, s2, s3
        section = int(m.group(2))      # 1 or 2
        diam = int(m.group(3))         # 5,8,9...
        trial = m.group(4).lower()     # a,b,c
        cc = m.group(5).lower()        # c1/c2
        base = f"{ch}.{section}.{diam}.{trial}"
        recs.append((base, ch, section, diam, trial, cc, p))
    df = pd.DataFrame(recs, columns=["base","channel","section","diam","trial","cc","path"])
    return df

all_rows = []
for fold in FOLDERS:
    folder_dir = os.path.join(FINAL_ROOT, fold)
    df = scan_folder(folder_dir)
    if df.empty:
        print(f"‚ö†Ô∏è No matching t/s images found in {folder_dir}")
        continue

    # pair by base: need both c1 and c2
    g = df.pivot_table(index=["base","channel","section","diam","trial"],
                       columns="cc", values="path", aggfunc="first").reset_index()
    g.columns.name = None
    if "c1" not in g.columns or "c2" not in g.columns:
        print(f"‚ö†Ô∏è Missing c1 or c2 in {fold}; columns={g.columns.tolist()}")
        continue
    g = g.dropna(subset=["c1","c2"]).copy()
    g["folder"] = fold

    if LIMIT_PER_FOLDER is not None:
        g = g.head(int(LIMIT_PER_FOLDER))

    print(f"{fold}: paired images = {len(g)}")
    all_rows.append(g)

pairs_df = pd.concat(all_rows, ignore_index=True)
print("\nTOTAL paired across folders:", len(pairs_df))
display(pairs_df.head(8))

# ---------------- Compute per-image storyboard stats ----------------
out = []
for i, r in pairs_df.iterrows():
    c1_path = r["c1"]; c2_path = r["c2"]

    # C1 pillars
    c1_gray, pillar01 = infer_pillar_mask_from_c1_rgb(model, c1_path)
    nonpillar01 = (pillar01 == 0).astype(np.uint8)

    # C2 precipitation (thresholded inside nonpillar)
    c2_gray = cv2.imread(c2_path, cv2.IMREAD_GRAYSCALE)
    if c2_gray is None:
        print("‚ö†Ô∏è Skip unreadable C2:", c2_path);
        continue
    if c2_gray.shape != c1_gray.shape:
        c2_gray = cv2.resize(c2_gray, (c1_gray.shape[1], c1_gray.shape[0]), interpolation=cv2.INTER_AREA)

    precip01 = binarize_c2_fixed(c2_gray, nonpillar01, thr=C2_FIXED_THR, morph=MORPH_C2)

    total_px = int(pillar01.size)
    pillar_px = int(pillar01.sum())
    nonpillar_px = total_px - pillar_px
    crystal_px = int(precip01.sum())

    out.append({
        "folder": r["folder"],
        "channel": r["channel"],      # t1,t2,... or s2,s3
        "section": r["section"],      # 1/2
        "diam": r["diam"],            # 5/8/9...
        "trial": r["trial"],          # a/b/c
        "base": r["base"],
        "pillar_pct": 100.0 * pillar_px / total_px,
        "nonpillar_pct": 100.0 * nonpillar_px / total_px,
        "crystal_px": crystal_px,
        "crystal_pct_total": 100.0 * crystal_px / total_px,
        "crystal_pct_nonpillar": 100.0 * crystal_px / max(1, nonpillar_px),
    })

stats = pd.DataFrame(out).sort_values(["channel","section","diam","trial"]).reset_index(drop=True)
print("\nPer-image stats (first 10):")
display(stats.head(10))

# Aggregate across trials a/b/c:
agg = (stats.groupby(["channel","section","diam"])
       .agg(n=("crystal_pct_nonpillar","size"),
            mean_crystal_pct_nonpillar=("crystal_pct_nonpillar","mean"),
            sd_crystal_pct_nonpillar=("crystal_pct_nonpillar","std"),
            mean_crystal_pct_total=("crystal_pct_total","mean"),
            sd_crystal_pct_total=("crystal_pct_total","std"),
            mean_pillar_pct=("pillar_pct","mean"),
            sd_pillar_pct=("pillar_pct","std"))
       .reset_index()
       .sort_values(["channel","section","diam"]))

print("\nAggregated across trials (a/b/c):")
display(agg)


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
import math

# ---------- helpers ----------
def channel_sort_key(ch):
    """
    Natural sort: t1,t2,...,t10 then s2,s3...
    """
    m = re.match(r"^([a-zA-Z]+)(\d+)$", str(ch))
    if not m:
        return (str(ch), 999)
    prefix = m.group(1).lower()
    num = int(m.group(2))
    # put t before s (edit if you want different ordering)
    prefix_order = {"t": 0, "s": 1}
    return (prefix_order.get(prefix, 9), prefix, num)

# ---------- choose channels (include t and s) ----------
channels = sorted(
    [c for c in agg["channel"].unique() if str(c).lower().startswith(("t", "s"))],
    key=channel_sort_key
)

# ---------- layout ----------
n = len(channels)
ncols = 3 if n >= 3 else n
nrows = int(math.ceil(n / ncols))

fig, axes = plt.subplots(nrows, ncols, figsize=(5.2*ncols, 4.3*nrows), sharey=True)
axes = np.array(axes).reshape(-1)  # flatten even if 1 row

# ---------- plot each channel ----------
for i, ch in enumerate(channels):
    ax = axes[i]
    sub = agg[agg["channel"] == ch].copy()
    if sub.empty:
        ax.axis("off")
        continue

    wide_mean = sub.pivot(index="diam", columns="section",
                          values="mean_crystal_pct_nonpillar").sort_index()
    wide_sd   = sub.pivot(index="diam", columns="section",
                          values="sd_crystal_pct_nonpillar").reindex(wide_mean.index)

    diams = wide_mean.index.values
    x = np.arange(len(diams))
    w = 0.38

    # handle missing section(s)
    m1 = wide_mean.get(1, pd.Series(index=diams, data=np.nan)).values
    m2 = wide_mean.get(2, pd.Series(index=diams, data=np.nan)).values
    s1 = wide_sd.get(1, pd.Series(index=diams, data=np.nan)).values
    s2 = wide_sd.get(2, pd.Series(index=diams, data=np.nan)).values

    ax.bar(x - w/2, m1, width=w, yerr=s1, capsize=3, label="Section 1")
    ax.bar(x + w/2, m2, width=w, yerr=s2, capsize=3, label="Section 2")

    ax.set_xticks(x)
    ax.set_xticklabels(diams)
    ax.set_xlabel("Diameter code")
    ax.set_title(str(ch))

    if i % ncols == 0:
        ax.set_ylabel("Crystal area (% of non-pillar region)")

# ---------- legend + title ----------
# Put legend on first subplot only (cleaner)
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", ncol=2, frameon=False)

fig.suptitle(f"C2 precipitation @ thr={C2_FIXED_THR} ‚Äî mean ¬± SD across trials (a/b/c)\nGrouped by channel, diameter, and section",
             y=1.02, fontweight="bold")

# turn off any unused axes
for j in range(i+1, len(axes)):
    axes[j].axis("off")

plt.tight_layout()
plt.show()

# Optional save:
# out_path = "data/raw_images"
# fig.savefig(out_path, dpi=300, bbox_inches="tight")
# print("Saved:", out_path)


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --------------------------
# 0) Map your channels -> conditions
# --------------------------
channel_info = {
    "t1": {"bacteria": "high", "calcium": "high", "height": "tall"},
    "t2": {"bacteria": "low",  "calcium": "low",  "height": "tall"},
    "t3": {"bacteria": "low",  "calcium": "high", "height": "tall"},
    "t4": {"bacteria": "low",  "calcium": "high", "height": "tall"},
    "t5": {"bacteria": "high", "calcium": "low",  "height": "tall"},
    "t6": {"bacteria": "high", "calcium": "low",  "height": "tall"},
    "s2": {"bacteria": "high", "calcium": "low",  "height": "short"},
    "s3": {"bacteria": "high", "calcium": "low",  "height": "short"},
}

# --------------------------
# 1) Add factor columns to per-image table
# --------------------------
df = stats.copy()

df["channel"] = df["channel"].str.lower()
df["bacteria"] = df["channel"].map(lambda c: channel_info.get(c, {}).get("bacteria", "unknown"))
df["calcium"]  = df["channel"].map(lambda c: channel_info.get(c, {}).get("calcium",  "unknown"))
df["height"]   = df["channel"].map(lambda c: channel_info.get(c, {}).get("height",   "unknown"))

# section -> porosity
df["porosity"] = df["section"].map({1: 0.35, 2: 0.45})

# diameter code -> mm (5->0.5 mm, 8->0.8 mm, etc.)
df["diam_mm"] = df["diam"] / 10.0

# convenience condition label
df["condition"] = df["height"] + " | bact=" + df["bacteria"] + " | Ca=" + df["calcium"]

# Response variable (what we‚Äôve been plotting)
YCOL = "crystal_pct_nonpillar"   # % of non-pillar region occupied by precip (C2 thr=70)

print("Rows:", len(df))
print(df[["channel","height","bacteria","calcium","porosity","diam_mm","trial",YCOL]].head())

# --------------------------
# 2) Helper: box + jitter plot for categorical factors
# --------------------------
def box_jitter(ax, data, cat_col, y_col, title):
    cats = sorted(data[cat_col].unique())
    vals = [data.loc[data[cat_col] == c, y_col].dropna().values for c in cats]
    ax.boxplot(vals, positions=np.arange(len(cats)), widths=0.55, showfliers=False)

    # jittered points
    for i, v in enumerate(vals):
        if len(v) == 0:
            continue
        x = np.random.normal(loc=i, scale=0.06, size=len(v))
        ax.plot(x, v, "o", markersize=4, alpha=0.6)

    ax.set_xticks(np.arange(len(cats)))
    ax.set_xticklabels(cats, rotation=0)
    ax.set_ylabel(y_col)
    ax.set_title(title)

# --------------------------
# FIGURE 1: Main effects (porosity, height, bacteria, calcium)
# --------------------------
fig, axes = plt.subplots(1, 4, figsize=(18, 4), sharey=True)

box_jitter(axes[0], df, "porosity", YCOL, "Porosity effect")
box_jitter(axes[1], df, "height",  YCOL, "Height effect")
box_jitter(axes[2], df, "bacteria",YCOL, "Bacteria effect")
box_jitter(axes[3], df, "calcium", YCOL, "Calcium effect")

fig.suptitle("Main effects on precipitation coverage (C2 thresholded; no C1 save-to-disk)", y=1.05, fontweight="bold")
plt.tight_layout()
plt.show()

# --------------------------
# FIGURE 2: Bacteria √ó Calcium interaction (TALL ONLY)
# --------------------------
tall = df[df["height"] == "tall"].copy()
tall["bact_ca"] = "bact=" + tall["bacteria"] + ", Ca=" + tall["calcium"]

summary_bc = (tall.groupby("bact_ca")[YCOL]
              .agg(["count","mean","std"])
              .reset_index()
              .sort_values("bact_ca"))

plt.figure(figsize=(8,4))
x = np.arange(len(summary_bc))
plt.bar(x, summary_bc["mean"], yerr=summary_bc["std"], capsize=4)
plt.xticks(x, summary_bc["bact_ca"], rotation=25, ha="right")
plt.ylabel(YCOL)
plt.title("Interaction: Bacteria √ó Calcium (tall channels only)")
plt.tight_layout()
plt.show()

# --------------------------
# FIGURE 3: Diameter trends (lines) faceted by porosity
#   Each line = condition (height+bacteria+calcium)
# --------------------------
# mean¬±SD across trials for each (condition, porosity, diam_mm)
trend = (df.groupby(["condition","porosity","diam_mm"])[YCOL]
         .agg(["count","mean","std"])
         .reset_index()
         .sort_values(["porosity","condition","diam_mm"]))

poros = sorted(trend["porosity"].unique())
fig, axes = plt.subplots(1, len(poros), figsize=(14,4), sharey=True)

if len(poros) == 1:
    axes = [axes]

for ax, p in zip(axes, poros):
    sub = trend[trend["porosity"] == p]
    for cond in sorted(sub["condition"].unique()):
        s = sub[sub["condition"] == cond]
        ax.plot(s["diam_mm"], s["mean"], marker="o", linewidth=2, label=cond)
        # light error bars
        ax.errorbar(s["diam_mm"], s["mean"], yerr=s["std"], fmt="none", capsize=3)

    ax.set_title(f"Porosity = {p:.2f}")
    ax.set_xlabel("Diameter (mm)")
    ax.set_ylabel(YCOL)

fig.suptitle("Diameter dependence (mean¬±SD across trials a/b/c)", y=1.05, fontweight="bold")
fig.legend(loc="upper center", ncol=2, frameon=False, bbox_to_anchor=(0.5, 1.15))
plt.tight_layout()
plt.show()

# --------------------------
# FIGURE 4: Height effect with chemistry held constant
#   Compare tall vs short when bact=high and Ca=low
# --------------------------
fixed = df[(df["bacteria"]=="high") & (df["calcium"]=="low")].copy()

height_trend = (fixed.groupby(["height","porosity","diam_mm"])[YCOL]
                .agg(["count","mean","std"])
                .reset_index()
                .sort_values(["porosity","height","diam_mm"]))

poros = sorted(height_trend["porosity"].unique())
fig, axes = plt.subplots(1, len(poros), figsize=(12,4), sharey=True)
if len(poros) == 1:
    axes = [axes]

for ax, p in zip(axes, poros):
    sub = height_trend[height_trend["porosity"] == p]
    for h in ["tall","short"]:
        s = sub[sub["height"] == h]
        ax.plot(s["diam_mm"], s["mean"], marker="o", linewidth=2, label=h)
        ax.errorbar(s["diam_mm"], s["mean"], yerr=s["std"], fmt="none", capsize=3)
    ax.set_title(f"Porosity = {p:.2f}")
    ax.set_xlabel("Diameter (mm)")
    ax.set_ylabel(YCOL)

fig.suptitle("Height effect (chemistry fixed: high bacteria + low calcium)", y=1.05, fontweight="bold")
fig.legend(loc="upper center", ncol=2, frameon=False, bbox_to_anchor=(0.5, 1.12))
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --------------------------
# 0) UPDATED channel -> condition mapping
#    (s2/s3 treated as TALL equivalents; same chemistry as t5/t6)
# --------------------------
channel_info = {
    "t1": {"bacteria": "high", "calcium": "high", "height": "tall"},
    "t2": {"bacteria": "low",  "calcium": "low",  "height": "tall"},
    "t3": {"bacteria": "low",  "calcium": "high", "height": "tall"},
    "t4": {"bacteria": "low",  "calcium": "high", "height": "tall"},
    "t5": {"bacteria": "high", "calcium": "low",  "height": "tall"},
    "t6": {"bacteria": "high", "calcium": "low",  "height": "tall"},

    # ‚úÖ corrected: treat these as tall equivalents (mislabeled height)
    "s2": {"bacteria": "high", "calcium": "low",  "height": "tall"},
    "s3": {"bacteria": "high", "calcium": "low",  "height": "tall"},
}

# --------------------------
# 1) Add factor columns to per-image table
# --------------------------
df = stats.copy()
df["channel"] = df["channel"].str.lower()

df["bacteria"] = df["channel"].map(lambda c: channel_info.get(c, {}).get("bacteria", "unknown"))
df["calcium"]  = df["channel"].map(lambda c: channel_info.get(c, {}).get("calcium",  "unknown"))
df["height"]   = df["channel"].map(lambda c: channel_info.get(c, {}).get("height",   "unknown"))

# section -> porosity
df["porosity"] = df["section"].map({1: 0.35, 2: 0.45})

# diameter code -> mm
df["diam_mm"] = df["diam"] / 10.0

# condition label (now s2/s3 will be tall|high|low just like t5/t6)
df["condition"] = df["height"] + " | bact=" + df["bacteria"] + " | Ca=" + df["calcium"]

YCOL = "crystal_pct_nonpillar"  # response variable

print("Rows:", len(df))
print(df[["channel","height","bacteria","calcium","porosity","diam_mm","trial",YCOL]].head())

# --------------------------
# 2) Helper: box + jitter plot for categorical factors
# --------------------------
def box_jitter(ax, data, cat_col, y_col, title):
    cats = sorted(data[cat_col].unique())
    vals = [data.loc[data[cat_col] == c, y_col].dropna().values for c in cats]
    ax.boxplot(vals, positions=np.arange(len(cats)), widths=0.55, showfliers=False)

    for i, v in enumerate(vals):
        if len(v) == 0:
            continue
        x = np.random.normal(loc=i, scale=0.06, size=len(v))
        ax.plot(x, v, "o", markersize=4, alpha=0.6)

    ax.set_xticks(np.arange(len(cats)))
    ax.set_xticklabels(cats, rotation=0)
    ax.set_ylabel(y_col)
    ax.set_title(title)

# --------------------------
# FIGURE 1: Main effects (porosity, height, bacteria, calcium)
#   Note: height effect may now shrink because everything is labeled tall.
# --------------------------
fig, axes = plt.subplots(1, 4, figsize=(18, 4), sharey=True)
box_jitter(axes[0], df, "porosity", YCOL, "Porosity effect")
box_jitter(axes[1], df, "height",   YCOL, "Height effect (after relabel)")
box_jitter(axes[2], df, "bacteria", YCOL, "Bacteria effect")
box_jitter(axes[3], df, "calcium",  YCOL, "Calcium effect")

fig.suptitle("Main effects on precipitation coverage (C2 thresholded)", y=1.05, fontweight="bold")
plt.tight_layout()
plt.show()

# --------------------------
# FIGURE 2: Bacteria √ó Calcium interaction (now includes s2/s3 as tall)
# --------------------------
# since everything is tall now, just use all rows where height == tall
tall = df[df["height"] == "tall"].copy()
tall["bact_ca"] = "bact=" + tall["bacteria"] + ", Ca=" + tall["calcium"]

summary_bc = (tall.groupby("bact_ca")[YCOL]
              .agg(["count","mean","std"])
              .reset_index()
              .sort_values("bact_ca"))

plt.figure(figsize=(8,4))
x = np.arange(len(summary_bc))
plt.bar(x, summary_bc["mean"], yerr=summary_bc["std"], capsize=4)
plt.xticks(x, summary_bc["bact_ca"], rotation=25, ha="right")
plt.ylabel(YCOL)
plt.title("Interaction: Bacteria √ó Calcium (tall-labeled set)")
plt.tight_layout()
plt.show()

# --------------------------
# FIGURE 3: Diameter trends (lines) faceted by porosity
#   Each line = condition (height+bacteria+calcium)
# --------------------------
trend = (df.groupby(["condition","porosity","diam_mm"])[YCOL]
         .agg(["count","mean","std"])
         .reset_index()
         .sort_values(["porosity","condition","diam_mm"]))

poros = sorted(trend["porosity"].unique())
fig, axes = plt.subplots(1, len(poros), figsize=(14,4), sharey=True)
if len(poros) == 1:
    axes = [axes]

for ax, p in zip(axes, poros):
    sub = trend[trend["porosity"] == p]
    for cond in sorted(sub["condition"].unique()):
        s = sub[sub["condition"] == cond]
        ax.plot(s["diam_mm"], s["mean"], marker="o", linewidth=2, label=cond)
        ax.errorbar(s["diam_mm"], s["mean"], yerr=s["std"], fmt="none", capsize=3)
    ax.set_title(f"Porosity = {p:.2f}")
    ax.set_xlabel("Diameter (mm)")
    ax.set_ylabel(YCOL)

fig.suptitle("Diameter dependence (mean¬±SD across trials a/b/c)", y=1.05, fontweight="bold")
fig.legend(loc="upper center", ncol=2, frameon=False, bbox_to_anchor=(0.5, 1.15))
plt.tight_layout()
plt.show()

# --------------------------
# FIGURE 4: Height effect with chemistry held constant
#   ‚ö†Ô∏è With s2/s3 relabeled as tall, there may be NO 'short' left -> plot conditionally.
# --------------------------
fixed = df[(df["bacteria"]=="high") & (df["calcium"]=="low")].copy()

if fixed["height"].nunique() < 2:
    print("Height comparison skipped: only one height level present after relabel (all tall).")
else:
    height_trend = (fixed.groupby(["height","porosity","diam_mm"])[YCOL]
                    .agg(["count","mean","std"])
                    .reset_index()
                    .sort_values(["porosity","height","diam_mm"]))

    poros = sorted(height_trend["porosity"].unique())
    fig, axes = plt.subplots(1, len(poros), figsize=(12,4), sharey=True)
    if len(poros) == 1:
        axes = [axes]

    for ax, p in zip(axes, poros):
        sub = height_trend[height_trend["porosity"] == p]
        for h in sorted(sub["height"].unique()):
            s = sub[sub["height"] == h]
            ax.plot(s["diam_mm"], s["mean"], marker="o", linewidth=2, label=h)
            ax.errorbar(s["diam_mm"], s["mean"], yerr=s["std"], fmt="none", capsize=3)
        ax.set_title(f"Porosity = {p:.2f}")
        ax.set_xlabel("Diameter (mm)")
        ax.set_ylabel(YCOL)

    fig.suptitle("Height effect (chemistry fixed: high bacteria + low calcium)", y=1.05, fontweight="bold")
    fig.legend(loc="upper center", ncol=2, frameon=False, bbox_to_anchor=(0.5, 1.12))
    plt.tight_layout()
    plt.show()
