In [3]:
import os
import json
import cv2
import numpy as np
from typing import List, Tuple

# SAM (segment-anything)
# from segment_anything import SamPredictor, sam_model_registry

# Diffusers / Hugging Face
import torch
# from diffusers import (
#     StableDiffusionControlNetInpaintPipeline,
#     ControlNetModel,
#     UniPCMultistepScheduler,
# )
from PIL import Image

In [4]:
def load_image(path: str) -> np.ndarray:
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img


def save_image(img: np.ndarray, path: str):
    img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(path, img_bgr)

In [5]:
def init_sam(model_type: str = "vit_h", checkpoint: str = None, device="cuda"):
    from segment_anything import SamPredictor, sam_model_registry
    if checkpoint is None:
        raise ValueError("Please provide a SAM checkpoint path or set checkpoint argument.")
    sam = sam_model_registry[model_type](checkpoint=checkpoint)
    sam.to(device=device)
    predictor = SamPredictor(sam)
    return predictor


def get_sam_masks(predictor, image: np.ndarray, grid_size: int = 32) -> List[np.ndarray]:
    image_torch = image.astype(np.uint8)
    predictor.set_image(image_torch)

    h, w = image.shape[:2]
    masks = []

    ys = np.linspace(0, h - 1, grid_size, dtype=int)
    xs = np.linspace(0, w - 1, grid_size, dtype=int)

    points = [[x, y] for y in ys for x in xs]
    points = np.array(points)

    chunk = 512
    for i in range(0, len(points), chunk):
        pts = points[i:i + chunk]
        input_points = pts
        input_labels = np.ones(len(pts), dtype=int)
        masks_out, scores, logits = predictor.predict(
            point_coords=input_points,
            point_labels=input_labels,
            multimask_output=True,
        )
        for mset in masks_out:
            for m in mset:
                mask = m.astype(bool)
                if mask.sum() < 0.001 * h * w:
                    continue
                masks.append(mask)

    unique = []
    seen = set()
    for m in masks:
        ys, xs = np.where(m)
        if len(xs) == 0:
            continue
        bb = (min(xs), min(ys), max(xs), max(ys))
        if bb in seen:
            continue
        seen.add(bb)
        unique.append(m)
    return unique


def select_mask_by_pose(masks: List[np.ndarray], bbox: Tuple[int, int, int, int]) -> np.ndarray:
    x0, y0, x1, y1 = bbox
    h = masks[0].shape[0]
    w = masks[0].shape[1]
    pose_mask = np.zeros((h, w), dtype=bool)
    pose_mask[y0:y1, x0:x1] = True

    best_mask, best_iou = None, 0.0
    for m in masks:
        inter = np.logical_and(m, pose_mask).sum()
        union = np.logical_or(m, pose_mask).sum()
        if union == 0:
            continue
        iou = inter / union
        if iou > best_iou:
            best_iou, best_mask = iou, m
    if best_mask is None and len(masks) > 0:
        best_mask = max(masks, key=lambda x: x.sum())
    return best_mask


In [21]:
# --------------------------- Pose JSON utilities & mask generation ---------------------------

def load_pose_json(path: str) -> dict:
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)


def extract_keypoints_from_dict(pose_dict: dict, image_size: Tuple[int,int]) -> list:
    """Unified extractor for COCO-like, DWpose-like, or aligned (pixel) coords."""
    H, W = image_size

    # DWpose style: {"bodies": [[[x,y],...]]} OR {"bodies": [[x,y], ...]}
    if 'bodies' in pose_dict and len(pose_dict['bodies']) > 0:
        arr = pose_dict['bodies'][0]
        # case 1: each element is [x,y]
        if all(isinstance(pt, list) and len(pt) == 2 for pt in arr):
            kps = [(pt[0], pt[1], 1.0) for pt in arr]
            return kps
        # case 2: flat list [x1,y1, x2,y2, ...]
        if all(isinstance(val, (int,float)) for val in arr):
            kps = [(arr[i], arr[i+1], 1.0) for i in range(0, len(arr), 2)]
            return kps

    # Aligned pose: [[x,y], [x,y], ...]
    if isinstance(pose_dict, list) and all(isinstance(pt, list) and len(pt)==2 for pt in pose_dict):
        return [(x/W, y/H, 1.0) for x,y in pose_dict]

    # direct x,y arrays
    if 'x' in pose_dict and 'y' in pose_dict:
        xs = pose_dict['x']; ys = pose_dict['y']; vs = pose_dict.get('v', [1.0]*len(xs))
        return list(zip(xs, ys, vs))

    # fallback
    for v in pose_dict.values():
        if isinstance(v, list) and len(v) % 3 == 0 and len(v) >= 3:
            arr = v
            return [(arr[i], arr[i+1], arr[i+2]) for i in range(0, len(arr), 3)]
    return []
    
def pose_jsons_to_mask(orig_pose: dict, aligned_pose: dict, image_size: Tuple[int,int], thresh: float=0.05) -> np.ndarray:
    """Generate mask from differences between two pose keypoint sets."""
    H, W = image_size
    mask = np.zeros((H, W), dtype=bool)

    kps0, kps1 = orig_pose['kps'], aligned_pose['kps']
    for (x0,y0,v0), (x1,y1,v1) in zip(kps0, kps1):
        if v0 < 0.3 or v1 < 0.3:
            continue
        dx, dy = abs(x0-x1), abs(y0-y1)
        if dx > thresh or dy > thresh:
            cx, cy = int(x1*W), int(y1*H)
            cv2.circle(mask, (cx,cy), 15, 1, -1)
    return mask

In [17]:
# --------------------------- Example main flow ---------------------------

def main_example(image_path: str, orig_pose_json: str, aligned_pose_json: str, sam_checkpoint: str = None, hf_token: str = None):
    if hf_token:
        os.environ['HF_TOKEN'] = hf_token

    img_np = load_image(image_path)
    H, W = img_np.shape[:2]

    sam_predictor = None
    if sam_checkpoint is not None:
        sam_predictor = init_sam(model_type='vit_h', checkpoint=sam_checkpoint, device='cuda')

    orig = load_pose_json(orig_pose_json)
    aligned = load_pose_json(aligned_pose_json)

    kps_orig = extract_keypoints_from_dict(orig, (H,W))
    kps_aligned = extract_keypoints_from_dict(aligned, (H,W))

    mask_bool = pose_jsons_to_mask({'kps': kps_orig}, {'kps': kps_aligned}, (H,W))
    mask_np = (mask_bool.astype(np.uint8) * 255)

    pose_map = render_pose_map(kps_aligned, (W, H))

    prompt = "photo of a person, realistic clothing and anatomical correctness"
    inpaint_result = run_controlnet_inpaint(Image.fromarray(img_np), mask_np, pose_map, prompt)

    out_path = os.path.splitext(image_path)[0] + '_inpainted.png'
    inpaint_result.save(out_path)
    print('Saved inpainted to', out_path)


In [22]:
image_path = "../flask_app/static/uploads/groundimg.jpeg"
orig_pose_json = "keypoints.json"
aligned_pose_json = "keypoints_aligned.json"
sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth"

main_example(image_path, orig_pose_json, aligned_pose_json, sam_checkpoint)

NameError: name 'render_pose_map' is not defined