## Region-Based Mask Experiment Notebook

This notebook starts fresh with a primitive region (patch) comparison approach for mask generation.

Workflow:
- Select two images (original & edited) from `inputs/`.
- Auto-resize edited to match original if needed (toggle).
- Slide a window (default 11x11) over both images.
- For each window pair, apply a user-replaceable function that returns either a WHITE (255) or BLACK (0) patch in the mask.
- Visualize Original | Edited | Patch-based Mask.
- Save outputs with timestamp.

You can easily experiment by editing `evaluate_patch(orig_patch, edit_patch, threshold)`.
Add more sliders by following the pattern used for `patch_size` and `diff_threshold`.

In [10]:
# Imports & paths setup
import os, math
from datetime import datetime
import cv2
import numpy as np
from IPython.display import display
import ipywidgets as widgets
from pathlib import Path

BASE_DIR = Path.cwd()  # assumes notebook opened from its directory
INPUT_DIR = BASE_DIR / 'inputs'
OUTPUT_DIR = BASE_DIR / 'outputs'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

def list_images():
    exts = ('.png','.jpg','.jpeg','.bmp','.tif','.tiff')
    if not INPUT_DIR.exists():
        return []
    return sorted([p.name for p in INPUT_DIR.iterdir() if p.suffix.lower() in exts])

image_files = list_images()
if not image_files:
    print(f'No images found in {INPUT_DIR}. Add image files to proceed.')
else:
    print(f'Found {len(image_files)} image(s): {image_files}')

Found 3 image(s): ['me.png', 'pjs_2.jpg', 'tryon_result.jpg']


In [None]:
# Widgets (re-usable pattern)
if image_files:
    orig_dropdown = widgets.Dropdown(options=image_files, description='Original:')
    edited_dropdown = widgets.Dropdown(options=image_files, description='Edited:')
else:
    orig_dropdown = widgets.Dropdown(options=[''], description='Original:')
    edited_dropdown = widgets.Dropdown(options=[''], description='Edited:')

auto_resize = widgets.Checkbox(value=True, description='Auto-resize edited')
patch_size = widgets.IntSlider(value=11, min=3, max=51, step=2, description='Patch size:', continuous_update=False)
diff_threshold = widgets.IntSlider(value=25, min=0, max=255, step=1, description='Diff thresh:', continuous_update=False)
invert_mask = widgets.Checkbox(value=False, description='Invert mask')
show_grid = widgets.Checkbox(value=False, description='Show patch grid')

controls = widgets.VBox([
    widgets.HBox([orig_dropdown, edited_dropdown, auto_resize]),
    widgets.HBox([patch_size, diff_threshold, invert_mask, show_grid])
])

# controls

In [17]:
# Patch evaluation functions (color-aware)
def evaluate_patch_grayscale(orig_patch: np.ndarray, edit_patch: np.ndarray, threshold: int) -> int:
    """Return 255 (white) or 0 (black) for the mask for this patch.

    Strategy: simple absolute difference on mean intensity (grayscale).
    Args:
      orig_patch: (H,W,3) BGR patch
      edit_patch: (H,W,3) BGR patch (same size)
      threshold: scalar diff threshold (0-255)
    Returns: 255 or 0
    """
    orig_gray = cv2.cvtColor(orig_patch, cv2.COLOR_BGR2GRAY)
    edit_gray = cv2.cvtColor(edit_patch, cv2.COLOR_BGR2GRAY)
    diff = cv2.absdiff(orig_gray, edit_gray)
    mean_diff = float(np.mean(diff))
    return 255 if mean_diff >= threshold else 0

def evaluate_patch_lab_deltaE(orig_patch: np.ndarray, edit_patch: np.ndarray, threshold: float) -> int:
    """Color difference via Lab space (CIE76 ΔE).

    Converts patches to CIELAB and computes mean ΔE (Euclidean distance in Lab).
    This is perceptually better than raw RGB differences and remains color-aware.

    Args:
      orig_patch: (H,W,3) BGR
      edit_patch: (H,W,3) BGR
      threshold: ΔE threshold (typical 5-30+; tune as needed)
    Returns: 255 if mean ΔE >= threshold else 0
    """
    lab1 = cv2.cvtColor(orig_patch, cv2.COLOR_BGR2LAB).astype(np.float32)
    lab2 = cv2.cvtColor(edit_patch, cv2.COLOR_BGR2LAB).astype(np.float32)
    delta = lab1 - lab2
    deltaE = np.sqrt(np.sum(delta * delta, axis=2))
    mean_deltaE = float(np.mean(deltaE))
    return 255 if mean_deltaE >= float(threshold) else 0

def evaluate_patch_ycrcb_weighted_l1(orig_patch: np.ndarray, edit_patch: np.ndarray, threshold: float, y_weight: float = 0.6, cr_weight: float = 0.2, cb_weight: float = 0.2) -> int:
    """Weighted L1 difference in YCrCb (color + luminance aware).

    Separates luminance (Y) from chroma (Cr, Cb). By default weights emphasize
    luminance while still accounting for color changes.

    Args:
      orig_patch: (H,W,3) BGR
      edit_patch: (H,W,3) BGR
      threshold: weighted diff threshold (scale ~0-255; tune with patch_size)
      y_weight, cr_weight, cb_weight: channel weights that sum to ~1
    Returns: 255 or 0
    """
    ycc1 = cv2.cvtColor(orig_patch, cv2.COLOR_BGR2YCrCb).astype(np.float32)
    ycc2 = cv2.cvtColor(edit_patch, cv2.COLOR_BGR2YCrCb).astype(np.float32)
    ad = np.abs(ycc1 - ycc2)
    w = np.array([y_weight, cr_weight, cb_weight], dtype=np.float32)
    # per-pixel weighted L1 across channels
    weighted = ad * w[None, None, :]
    per_pixel = np.sum(weighted, axis=2)
    mean_val = float(np.mean(per_pixel))
    return 255 if mean_val >= float(threshold) else 0

# # Default color-aware evaluate_patch: alias to Lab ΔE
# def evaluate_patch(orig_patch: np.ndarray, edit_patch: np.ndarray, threshold: float) -> int:
#     """Alias to color-aware method (Lab ΔE CIE76)."
#     return evaluate_patch_lab_deltaE(orig_patch, edit_patch, threshold)

In [None]:
# Processing and display
import matplotlib.pyplot as plt
last_mask = None
out_area = widgets.Output()
button_row = widgets.HBox([])

def build_patch_mask(orig: np.ndarray, edit: np.ndarray, psize: int, threshold: int, invert: bool, show_grid_flag: bool) -> np.ndarray:
    h, w = orig.shape[:2]
    mask = np.zeros((h, w), dtype=np.uint8)
    # iterate top-left corners stepping by psize
    for y in range(0, h, psize):
        for x in range(0, w, psize):
            y2 = min(y + psize, h)
            x2 = min(x + psize, w)
            orig_patch = orig[y:y2, x:x2]
            edit_patch = edit[y:y2, x:x2]
            # skip if patch shape mismatch
            if orig_patch.shape != edit_patch.shape:
                continue
            val = evaluate_patch_ycrcb_weighted_l1(orig_patch, edit_patch, threshold)
            if invert:
                val = 255 - val
            mask[y:y2, x:x2] = val
            if show_grid_flag:
                # outline patches lightly (set border to mid-gray if mask white)
                cv2.rectangle(mask, (x, y), (x2-1, y2-1), 128 if val==255 else 64, 1)
    return mask

def run_pipeline(orig_name, edited_name, psize, threshold, invert_mask_val, show_grid_val, auto_resize_val):
    global last_mask
    out_area.clear_output(wait=True)
    with out_area:
        if not orig_name or not edited_name or orig_name == edited_name:
            print('Select two different images.')
            last_mask = None
            return None
        op = str(INPUT_DIR / orig_name)
        ep = str(INPUT_DIR / edited_name)
        orig = cv2.imread(op)
        edit = cv2.imread(ep)
        if orig is None or edit is None:
            print('Failed to read one or both images.')
            last_mask = None
            return None
        if orig.shape[:2] != edit.shape[:2]:
            if auto_resize_val:
                edit = cv2.resize(edit, (orig.shape[1], orig.shape[0]), interpolation=cv2.INTER_LINEAR)
            else:
                print('Size mismatch; enable auto-resize.')
                last_mask = None
                return None
        mask = build_patch_mask(orig, edit, int(psize), int(threshold), invert_mask_val, show_grid_val)
        last_mask = mask
        # display side-by-side
        fig, axes = plt.subplots(1, 3, figsize=(12, 12)) # Stop changing this from 12 to 3!
        axes[0].imshow(cv2.cvtColor(orig, cv2.COLOR_BGR2RGB)); axes[0].set_title('Original'); axes[0].axis('off')
        axes[1].imshow(cv2.cvtColor(edit, cv2.COLOR_BGR2RGB)); axes[1].set_title('Edited'); axes[1].axis('off')
        axes[2].imshow(mask, cmap='gray'); axes[2].set_title('Patch Mask'); axes[2].axis('off')
        fig.tight_layout(); plt.show()
        return mask

ui = widgets.interactive_output(run_pipeline, {
    'orig_name': orig_dropdown,
    'edited_name': edited_dropdown,
    'psize': patch_size,
    'threshold': diff_threshold,
    'invert_mask_val': invert_mask,
    'show_grid_val': show_grid,
    'auto_resize_val': auto_resize
})
display(controls)
display(out_area)
display(button_row)
display(ui)

VBox(children=(HBox(children=(Dropdown(description='Original:', options=('me.png', 'pjs_2.jpg', 'tryon_result.…

Output()

HBox()

Output()

In [9]:
# Save functionality (uses last_mask cache)
status_label = widgets.HTML(value='')
save_btn = widgets.Button(description='Save mask', button_style='success', icon='save')

def save_mask(_=None):
    global last_mask
    if last_mask is None:
        status_label.value = '<em>No mask to save.</em>'
        return
    ts = datetime.now().strftime('%Y%m%d_%H%M%S')
    out_path = OUTPUT_DIR / f'patch_mask_{ts}.png'
    cv2.imwrite(str(out_path), last_mask)
    status_label.value = f'Saved: {out_path.name}'

save_btn.on_click(save_mask)
button_row.children = [save_btn, status_label]
button_row

HBox(children=(Button(button_style='success', description='Save mask', icon='save', style=ButtonStyle()), HTML…

In [None]:
# Programmatic quick test (optional)
if len(image_files) >= 2:
    orig_dropdown.value = image_files[0]
    edited_dropdown.value = image_files[1]