In [5]:
# =========================
# Segmentation-only AOI PREVIEW + SAVE (nose/eyes/mouth scaled; no eyebrows)
# =========================
# Requires in your 'aoi' env:
#   pip install face_recognition pillow numpy scipy matplotlib

import os, csv
import numpy as np
from PIL import Image, ImageDraw
from scipy.io import loadmat
from scipy.spatial import ConvexHull
import face_recognition

# ---------- CONFIG ----------
IMG_PATH   = "/PATH/BRAVE_adultF_adultF_1.jpg"
PRED_DIR   = "/PATH"
OUTPUT_DIR = "/PATH/aoi_outputs"
PRED_SUFFIX = "-pred.mat"       # produced by pred_folder.py
SHOW_MODE   = "pil"             # "pil" or "matplotlib"

EXPAND_FACE_BOX = 0.15          # expand face rectangle before clipping AOIs

# Per-AOI scaling (use your current numbers)
AOI_SCALE_NOSE  = 1.15          # nose scale (1.0 = no enlarge)
AOI_SCALE_EYES  = 1.15          # eyes scale
AOI_SCALE_MOUTH = 1.15          # mouth scale

INCLUDE_EYEBROWS_IN_EYES = False  # exclude eyebrows from eye AOI

# segmentation labels (per MLT repo)
LBL_EYE = 4
LBL_BROW = 5
LBL_MOUTH = 8
LBL_NOSE = 10

# colors
COLOR_NOSE  = "purple"
COLOR_EYE   = "red"
COLOR_MOUTH = "orange"
TEXT_COLOR  = "white"

# ---------- HELPERS ----------
def clamp(v, lo, hi): return max(lo, min(hi, v))

def expand_box(box, w, h, ratio=0.15):
    top, right, bottom, left = box
    bw, bh = right - left, bottom - top
    dx, dy = int(round(bw * ratio)), int(round(bh * ratio))
    nt = clamp(top - dy,    0, h-1)
    nb = clamp(bottom + dy, 0, h-1)
    nl = clamp(left - dx,   0, w-1)
    nr = clamp(right + dx,  0, w-1)
    return (nt, nr, nb, nl)

def all_cc_masks(mask_bool, min_pixels=80):
    """Return list of boolean masks for each 4-connected component with >= min_pixels."""
    h, w = mask_bool.shape
    seen = np.zeros_like(mask_bool, dtype=bool)
    comps = []
    nbrs = [(1,0),(-1,0),(0,1),(0,-1)]
    for y in range(h):
        for x in range(w):
            if mask_bool[y,x] and not seen[y,x]:
                stack = [(y,x)]
                seen[y,x] = True
                ys, xs = [], []
                while stack:
                    cy, cx = stack.pop()
                    ys.append(cy); xs.append(cx)
                    for dy,dx in nbrs:
                        ny, nx = cy+dy, cx+dx
                        if 0<=ny<h and 0<=nx<w and mask_bool[ny,nx] and not seen[ny,nx]:
                            seen[ny,nx] = True
                            stack.append((ny,nx))
                if len(xs) >= min_pixels:
                    m = np.zeros_like(mask_bool, dtype=bool)
                    m[np.array(ys), np.array(xs)] = True
                    comps.append(m)
    return comps

def polygon_from_mask(mask_bool, max_points=200):
    ys, xs = np.nonzero(mask_bool)
    if len(xs) < 3:
        return []
    pts = np.column_stack([xs, ys])
    try:
        hull = ConvexHull(pts)
        poly = [(int(pts[i,0]), int(pts[i,1])) for i in hull.vertices]
    except Exception:
        poly = [(int(x), int(y)) for x, y in zip(xs, ys)]
    if len(poly) > max_points:
        step = max(1, len(poly)//max_points)
        poly = poly[::step]
    return poly

def enlarge_polygon(poly, scale):
    if len(poly) < 3: return poly
    cx = sum(x for x,_ in poly) / len(poly)
    cy = sum(y for _,y in poly) / len(poly)
    out = []
    for x,y in poly:
        out.append((int(round((x - cx)*scale + cx)),
                    int(round((y - cy)*scale + cy))))
    return out

# Sutherland–Hodgman clip against axis-aligned rectangle
def clip_polygon_with_rect(poly, x_min, y_min, x_max, y_max):
    def clip_edge(points, inside, intersect):
        if not points: return []
        out = []
        s = points[-1]
        for e in points:
            if inside(e):
                if inside(s): out.append(e)
                else: out.append(intersect(s, e)); out.append(e)
            else:
                if inside(s): out.append(intersect(s, e))
            s = e
        return out
    def inside_left(p):   return p[0] >= x_min
    def inside_right(p):  return p[0] <= x_max
    def inside_top(p):    return p[1] >= y_min
    def inside_bottom(p): return p[1] <= y_max
    def intersect_vert(p1, p2, xk):
        x1,y1=p1; x2,y2=p2
        if x2==x1: return (xk,y1)
        t=(xk-x1)/(x2-x1); y=y1+t*(y2-y1); return (xk,int(round(y)))
    def intersect_horiz(p1, p2, yk):
        x1,y1=p1; x2,y2=p2
        if y2==y1: return (x1,yk)
        t=(yk-y1)/(y2-y1); x=x1+t*(x2-x1); return (int(round(x)),yk)
    pts = [(float(x),float(y)) for x,y in poly]
    pts = clip_edge(pts, inside_left,  lambda a,b: intersect_vert(a,b,x_min))
    pts = clip_edge(pts, inside_right, lambda a,b: intersect_vert(a,b,x_max))
    pts = clip_edge(pts, inside_top,   lambda a,b: intersect_horiz(a,b,y_min))
    pts = clip_edge(pts, inside_bottom,lambda a,b: intersect_horiz(a,b,y_max))
    return [(int(round(x)), int(round(y))) for x,y in pts]

def load_pred_mat_for(img_path, pred_dir, suffix):
    stem = os.path.splitext(os.path.basename(img_path))[0]
    pred_path = os.path.join(pred_dir, stem + suffix)
    if not os.path.exists(pred_path):
        raise FileNotFoundError(f"Missing pred.mat for image: {pred_path}")
    d = loadmat(pred_path)
    arr = None
    for k in ("prediction","pred","label"):
        if k in d:
            arr = d[k]; break
    if arr is None:
        raise KeyError(f"No 'prediction' (or 'pred'/'label') key in {pred_path}")
    return np.array(arr).squeeze(), pred_path

# ---------- LOAD ----------
os.makedirs(OUTPUT_DIR, exist_ok=True)
image = Image.open(IMG_PATH).convert("RGB")
img_w, img_h = image.size
pred, pred_src = load_pred_mat_for(IMG_PATH, PRED_DIR, PRED_SUFFIX)

# handle potential transpose
if pred.shape != (img_h, img_w):
    if pred.T.shape == (img_h, img_w):
        pred = pred.T
    else:
        raise ValueError(f"pred shape {pred.shape} does not match image {(img_h,img_w)} for {pred_src}")
pred = pred.astype(np.int32)

np_img = np.array(image)
face_boxes = face_recognition.face_locations(np_img)
face_boxes = sorted(face_boxes, key=lambda b: b[3])  # left->right
if not face_boxes:
    raise RuntimeError("No faces detected in the image.")

draw = ImageDraw.Draw(image)

# Output file paths
stem = os.path.splitext(os.path.basename(IMG_PATH))[0]
out_png = os.path.join(OUTPUT_DIR, f"{stem}_aoi_preview.png")
out_csv = os.path.join(OUTPUT_DIR, f"{stem}_aoi_coordinates.csv")

rows = []  # CSV rows

# ---------- PER-FACE AOIs ----------
for i, box in enumerate(face_boxes, start=1):
    # expanded face rect to keep AOIs contained
    t, r, b, l = expand_box(box, img_w, img_h, ratio=EXPAND_FACE_BOX)
    face_mask = np.zeros((img_h, img_w), dtype=bool)
    face_mask[t:b, l:r] = True

    # segmentation masks, clipped to face rect
    eyes_mask = (pred == LBL_EYE)
    if INCLUDE_EYEBROWS_IN_EYES:
        eyes_mask |= (pred == LBL_BROW)
    mouth_mask = (pred == LBL_MOUTH)
    nose_mask  = (pred == LBL_NOSE)

    eyes_mask  &= face_mask
    mouth_mask &= face_mask
    nose_mask  &= face_mask

    # draw face rect + label (Face 1, Face 2, ...)
    draw.rectangle([l,t,r,b], outline="orange", width=3)
    draw.text((l+6, max(0, t-18)), f"Face {i}", fill=TEXT_COLOR)

    # ----- NOSE (scale AOI_SCALE_NOSE) -----
    nose_ccs = all_cc_masks(nose_mask, min_pixels=30)
    if nose_ccs:
        nose_ccs.sort(key=lambda m: m.sum(), reverse=True)
        nose_poly = polygon_from_mask(nose_ccs[0])
        if len(nose_poly) >= 3:
            nose_poly = enlarge_polygon(nose_poly, AOI_SCALE_NOSE)
            nose_poly = clip_polygon_with_rect(nose_poly, l, t, r, b)
            if len(nose_poly) >= 3:
                draw.line(nose_poly + [nose_poly[0]], fill=COLOR_NOSE, width=3)
                nx = min(x for x,_ in nose_poly); ny = min(y for _,y in nose_poly)
                draw.text((nx+5, ny+5), f"Nose_F{i}", fill=COLOR_NOSE)
                # save CSV rows
                for pi, (x,y) in enumerate(nose_poly, start=1):
                    rows.append({"AOI":"Nose","AOI_ID":f"Nose_F{i}","Face":i,
                                 "Component":1,"PointIndex":pi,"X":x,"Y":y})
    else:
        draw.text((l+8, t+8), "NOSE: none", fill=COLOR_NOSE)

    # ----- EYES (scale AOI_SCALE_EYES) -----
    eye_ccs = all_cc_masks(eyes_mask, min_pixels=30)
    if eye_ccs:
        # sort biggest first; enumerate components
        for k, m in enumerate(sorted(eye_ccs, key=lambda m: m.sum(), reverse=True), start=1):
            poly = polygon_from_mask(m)
            if len(poly) >= 3:
                poly = enlarge_polygon(poly, AOI_SCALE_EYES)
                poly = clip_polygon_with_rect(poly, l, t, r, b)
                if len(poly) >= 3:
                    draw.line(poly + [poly[0]], fill=COLOR_EYE, width=2)
                    ex = min(x for x,_ in poly); ey = min(y for _,y in poly)
                    draw.text((ex+4, ey+4), f"Eye_F{i}_{k}", fill=COLOR_EYE)
                    # save CSV rows
                    for pi, (x,y) in enumerate(poly, start=1):
                        rows.append({"AOI":"Eye","AOI_ID":f"Eye_F{i}_{k}","Face":i,
                                     "Component":k,"PointIndex":pi,"X":x,"Y":y})
    else:
        draw.text((l+8, t+28), "EYES: none", fill=COLOR_EYE)

    # ----- MOUTH (scale AOI_SCALE_MOUTH) -----
    mouth_ccs = all_cc_masks(mouth_mask, min_pixels=30)
    if mouth_ccs:
        mouth_ccs.sort(key=lambda m: m.sum(), reverse=True)
        poly = polygon_from_mask(mouth_ccs[0])
        if len(poly) >= 3:
            poly = enlarge_polygon(poly, AOI_SCALE_MOUTH)
            poly = clip_polygon_with_rect(poly, l, t, r, b)
            if len(poly) >= 3:
                draw.line(poly + [poly[0]], fill=COLOR_MOUTH, width=2)
                mx = min(x for x,_ in poly); my = min(y for _,y in poly)
                draw.text((mx+4, my+4), f"Mouth_F{i}", fill=COLOR_MOUTH)
                # save CSV rows
                for pi, (x,y) in enumerate(poly, start=1):
                    rows.append({"AOI":"Mouth","AOI_ID":f"Mouth_F{i}","Face":i,
                                 "Component":1,"PointIndex":pi,"X":x,"Y":y})
    else:
        draw.text((l+8, t+48), "MOUTH: none", fill=COLOR_MOUTH)

# ---------- SAVE PREVIEW & CSV ----------
image.save(out_png)

with open(out_csv, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=["AOI","AOI_ID","Face","Component","PointIndex","X","Y"])
    writer.writeheader()
    writer.writerows(rows)

print(f"✅ Preview saved: {out_png}")
print(f"✅ AOI CSV saved: {out_csv}")

# ---------- Also show preview ----------
if SHOW_MODE.lower() == "pil":
    image.show()
else:
    try:
        import matplotlib.pyplot as plt
        plt.figure(); plt.imshow(image); plt.axis("off"); plt.show()
    except Exception as e:
        print("Matplotlib preview failed, falling back to PIL:", e); image.show()
