 CSE483 Computer Vision - Jigsaw Puzzle Assembly Project : Milestone 1 + Milestone 2 Implementation

 This project implements a classical computer vision pipeline to:
1.   Preprocess puzzle images and extract individual pieces
2.   Extract edge descriptors using gradient analysis and color information
3.   Match puzzle pieces based on edge compatibility
4. Assemble puzzles using frontier-based optimization

 NO machine learning or deep learning techniques are used.
 Only classical CV techniques: CLAHE, Gaussian blur, Sobel gradients,
 color analysis, and optimization algorithms.


In [1]:

from google.colab import files
import zipfile
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time
import re
import heapq
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
from skimage.metrics import structural_similarity as ssim


# ==============================================================================
# SECTION 1: DATA LOADING AND SETUP
# ==============================================================================

print("Upload your puzzle dataset zip files...")
uploaded = files.upload()

# Extract all uploaded zip files
os.makedirs("dataset", exist_ok=True)

for zip_name in uploaded.keys():
    print(f"Extracting: {zip_name}")
    folder_name = zip_name.replace(".zip", "")
    extract_path = os.path.join("dataset", folder_name)
    os.makedirs(extract_path, exist_ok=True)

    with zipfile.ZipFile(zip_name, 'r') as zip_ref:
        zip_ref.extractall(extract_path)

print("\nExtraction complete! Folders in dataset/:")
print(os.listdir("dataset"))


Upload your puzzle dataset zip files...


Saving correct.zip to correct.zip
Saving puzzle_2x2.zip to puzzle_2x2.zip
Saving puzzle_4x4.zip to puzzle_4x4.zip
Saving puzzle_8x8.zip to puzzle_8x8.zip
Extracting: correct.zip
Extracting: puzzle_2x2.zip
Extracting: puzzle_4x4.zip
Extracting: puzzle_8x8.zip

Extraction complete! Folders in dataset/:
['puzzle_8x8', 'puzzle_2x2', 'correct', 'puzzle_4x4']


In [2]:

def load_and_verify_image(path: str) -> np.ndarray:
    """
    Load an image and verify it exists.

    Args: Path to the image file

    Returns: Loaded image as numpy array, or None if failed
    """
    img = cv2.imread(path)
    if img is None:
        print(f"Error: Image not found at {path}")
        return None
    print(f"Loaded image: {path} | Shape: {img.shape}")
    return img
#======================================================================================================================================
def split_into_grid(img: np.ndarray, rows: int, cols: int) -> List[Tuple[np.ndarray, int, int]]:
    """
    Split a complete puzzle image into individual pieces based on grid layout.

    Args:
        img: Complete puzzle image
        rows: Number of rows in the grid
        cols: Number of columns in the grid

    Returns:
        List of tuples (piece_image, row_index, col_index)
    """
    h, w = img.shape[:2]
    cell_h = h // rows
    cell_w = w // cols
    pieces = []

    for r in range(rows):
        for c in range(cols):
            y1, y2 = r * cell_h, (r + 1) * cell_h
            x1, x2 = c * cell_w, (c + 1) * cell_w
            piece = img[y1:y2, x1:x2]
            pieces.append((piece, r, c))

    return pieces
#======================================================================================================================================
def enhance_piece(piece: np.ndarray) -> np.ndarray:
    """
    Enhance piece contrast using CLAHE (Contrast Limited Adaptive Histogram Equalization).
    This improves edge detection quality by normalizing local contrast.

    Args:
        piece: Input BGR image

    Returns:
        Enhanced grayscale image
    """
    gray = cv2.cvtColor(piece, cv2.COLOR_BGR2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    enhanced = clahe.apply(gray)
    return enhanced
#======================================================================================================================================
def remove_background(piece: np.ndarray) -> np.ndarray:
    """
    Generate binary mask to separate foreground (puzzle piece) from background.
    Uses adaptive thresholding to handle varying lighting conditions.

    Args:
        piece: Input BGR image

    Returns:
        Binary mask (255=foreground, 0=background)
    """
    gray = cv2.cvtColor(piece, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)

    # Adaptive thresholding accounts for local intensity variations
    binary = cv2.adaptiveThreshold(
        blurred,
        255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY_INV,
        35,
        5
    )

    # Invert to get white foreground
    mask = cv2.bitwise_not(binary)
    return mask

#======================================================================================================================================

def detect_edges(img: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
    """
    Detect edges using Canny edge detector.
    Applies Gaussian blur first to reduce noise sensitivity.

    Args:
        img: Input grayscale image
        mask: Optional binary mask to limit edge detection region

    Returns:
        Binary edge map
    """
    if mask is not None:
        img = cv2.bitwise_and(img, img, mask=mask)

    blurred = cv2.GaussianBlur(img, (5, 5), 1.4)
    edges = cv2.Canny(blurred, threshold1=50, threshold2=150)

    return edges

#======================================================================================================================================

def extract_edge_profile(piece: np.ndarray, side: str, depth: int = 5) -> np.ndarray:
    """
    Extract gradient magnitude profile along one edge of the piece.
    This captures the shape characteristics of the edge for matching.

    Args:
        piece: Input BGR image
        side: Which edge to extract ('top', 'bottom', 'left', 'right')
        depth: How many pixels deep to sample from edge

    Returns:
        1D array of gradient magnitudes along the edge
    """
    gray = cv2.cvtColor(piece, cv2.COLOR_BGR2GRAY) if len(piece.shape) == 3 else piece
    h, w = gray.shape

    # Extract edge region
    if side == 'top':
        edge_region = gray[:depth, :]
    elif side == 'bottom':
        edge_region = gray[-depth:, :]
    elif side == 'left':
        edge_region = gray[:, :depth]
    elif side == 'right':
        edge_region = gray[:, -depth:]

    # Compute gradient magnitude using Sobel operators
    gradient_x = cv2.Sobel(edge_region, cv2.CV_64F, 1, 0, ksize=3)
    gradient_y = cv2.Sobel(edge_region, cv2.CV_64F, 0, 1, ksize=3)
    gradient_mag = np.sqrt(gradient_x**2 + gradient_y**2)

    # Average across depth to get 1D profile
    if side in ['top', 'bottom']:
        profile = np.mean(gradient_mag, axis=0)
    else:
        profile = np.mean(gradient_mag, axis=1)

    return profile

#======================================================================================================================================

def process_single_puzzle(img_path: str, rows: int, cols: int, output_dir: str):
    """
    Complete MS1 pipeline for a single puzzle image.

    Steps:
    1. Split into grid pieces
    2. Enhance each piece with CLAHE
    3. Generate background mask
    4. Detect edges
    5. Extract edge profiles for MS2

    Args:
        img_path: Path to complete puzzle image
        rows: Number of rows in grid
        cols: Number of columns in grid
        output_dir: Directory to save processed outputs
    """
    start_time = time.time()

    img = cv2.imread(img_path)
    if img is None:
        print(f"ERROR: Could not load {img_path}")
        return

    img_name = os.path.splitext(os.path.basename(img_path))[0]

    # Create output folder structure
    folders = ['enhanced', 'binary_masks', 'edges', 'pieces', 'edge_profiles']
    for folder in folders:
        os.makedirs(os.path.join(output_dir, folder), exist_ok=True)

    print(f"\nProcessing: {img_name} ({rows}x{cols} grid)")

    # Step 1: Split into pieces
    pieces = split_into_grid(img, rows, cols)
    print(f"  Step 1: Split into {len(pieces)} pieces")

    # Step 2-5: Process each piece
    for piece_img, r, c in pieces:
        piece_name = f"{img_name}_r{r}c{c}"

        # Save original piece
        cv2.imwrite(os.path.join(output_dir, 'pieces', f'{piece_name}.png'), piece_img)

        # Enhancement
        enhanced = enhance_piece(piece_img)
        cv2.imwrite(os.path.join(output_dir, 'enhanced', f'{piece_name}_enhanced.png'), enhanced)

        # Background removal
        binary_mask = remove_background(piece_img)
        cv2.imwrite(os.path.join(output_dir, 'binary_masks', f'{piece_name}_mask.png'), binary_mask)

        # Edge detection
        edges = detect_edges(enhanced, mask=binary_mask)
        cv2.imwrite(os.path.join(output_dir, 'edges', f'{piece_name}_edges.png'), edges)

        # Extract edge profiles
        profiles = {}
        for side in ['top', 'bottom', 'left', 'right']:
            profiles[side] = extract_edge_profile(piece_img, side, depth=5)

        # Save profiles as numpy compressed file
        np.savez(os.path.join(output_dir, 'edge_profiles', f'{piece_name}_profiles.npz'),
                 top=profiles['top'],
                 bottom=profiles['bottom'],
                 left=profiles['left'],
                 right=profiles['right'])

    elapsed = time.time() - start_time
    print(f"Completed in {elapsed:.2f}s - Processed {len(pieces)} pieces")

**MS1 EXECUTION**

In [3]:

print("\n" + "="*70)
print("RUNNING MILESTONE 1 - IMAGE PREPROCESSING PIPELINE")
print("="*70)

OUTPUT_BASE = "processed_output"
os.makedirs(OUTPUT_BASE, exist_ok=True)

# Process all puzzle types
puzzle_configs = [
    ('dataset/puzzle_2x2/puzzle_2x2', 2, 2, 'puzzle_2x2'),
    ('dataset/puzzle_4x4/puzzle_4x4', 4, 4, 'puzzle_4x4'),
    ('dataset/puzzle_8x8/puzzle_8x8', 8, 8, 'puzzle_8x8'),
]

for dataset_path, rows, cols, puzzle_type in puzzle_configs:
    if not os.path.exists(dataset_path):
        print(f"\nSkipping {puzzle_type}: folder not found")
        continue

    print(f"\n{'='*70}")
    print(f"Processing {puzzle_type} puzzles...")
    print(f"{'='*70}")

    output_dir = os.path.join(OUTPUT_BASE, puzzle_type)

    image_files = [f for f in os.listdir(dataset_path)
                   if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    for img_file in sorted(image_files):
        img_path = os.path.join(dataset_path, img_file)
        process_single_puzzle(img_path, rows, cols, output_dir)

print("\n" + "="*70)
print("MILESTONE 1 COMPLETE")
print("="*70)
print(f"Output saved to: {OUTPUT_BASE}/")
print("\nFolder structure:")
print("  pieces/         : Original puzzle pieces")
print("  enhanced/       : CLAHE enhanced pieces")
print("  binary_masks/   : Background removal masks")
print("  edges/          : Edge detection results")
print("  edge_profiles/  : Gradient data for MS2")


RUNNING MILESTONE 1 - IMAGE PREPROCESSING PIPELINE

Processing puzzle_2x2 puzzles...

Processing: 0 (2x2 grid)
  Step 1: Split into 4 pieces
Completed in 0.16s - Processed 4 pieces

Processing: 1 (2x2 grid)
  Step 1: Split into 4 pieces
Completed in 0.03s - Processed 4 pieces

Processing: 10 (2x2 grid)
  Step 1: Split into 4 pieces
Completed in 0.03s - Processed 4 pieces

Processing: 100 (2x2 grid)
  Step 1: Split into 4 pieces
Completed in 0.02s - Processed 4 pieces

Processing: 101 (2x2 grid)
  Step 1: Split into 4 pieces
Completed in 0.02s - Processed 4 pieces

Processing: 102 (2x2 grid)
  Step 1: Split into 4 pieces
Completed in 0.02s - Processed 4 pieces

Processing: 103 (2x2 grid)
  Step 1: Split into 4 pieces
Completed in 0.02s - Processed 4 pieces

Processing: 104 (2x2 grid)
  Step 1: Split into 4 pieces
Completed in 0.02s - Processed 4 pieces

Processing: 105 (2x2 grid)
  Step 1: Split into 4 pieces
Completed in 0.02s - Processed 4 pieces

Processing: 106 (2x2 grid)
  Step 1:

**MILESTONE 2 - PUZZLE ASSEMBLY**

**Firstly : the 2x2 solver path**

In [5]:
"""
==================== MS2 â€“ Enhanced Puzzle Solver ====================

Project alignment notes:
- Classical computer vision only (no machine learning / no deep learning).
- Uses edge descriptors (color profiles, Sobel magnitude/direction, small edge regions).
- Uses a beam-search style assembly using pairwise compatibility scores.
- Visualization compares assembled grid with the provided ground-truth image (if available).

Libraries:
- OpenCV (cv2), NumPy, Matplotlib are typical course-lab libraries in CV courses.
- Removed skimage dependency; SSIM is implemented using NumPy only.
"""

import os
import re
import cv2
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional


# ======================================================================
# NUMPY-ONLY SSIM (classical metric; not ML)
# ======================================================================

def ssim_gray(a: np.ndarray, b: np.ndarray) -> float:
    """
    Simplified SSIM for grayscale images using global statistics.
    This avoids external libraries while keeping an SSIM-like similarity measure.

    a, b: grayscale arrays with same shape (float32 recommended).
    Returns: SSIM value roughly in [-1, 1], typically [0, 1] for similar patches.
    """
    a = a.astype(np.float32)
    b = b.astype(np.float32)

    if a.shape != b.shape:
        b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_AREA)

    mu_a = a.mean()
    mu_b = b.mean()

    var_a = ((a - mu_a) ** 2).mean()
    var_b = ((b - mu_b) ** 2).mean()
    cov_ab = ((a - mu_a) * (b - mu_b)).mean()

    # Stabilizers (for 8-bit images; constants follow common SSIM choices)
    C1 = (0.01 * 255) ** 2
    C2 = (0.03 * 255) ** 2

    num = (2 * mu_a * mu_b + C1) * (2 * cov_ab + C2)
    den = (mu_a ** 2 + mu_b ** 2 + C1) * (var_a + var_b + C2)

    if den < 1e-12:
        return 0.0
    return float(num / den)


# ======================================================================
# DATA STRUCTURE
# ======================================================================

@dataclass
class PuzzlePiece:
    name: str
    puzzle_id: str
    true_row: int
    true_col: int
    image: np.ndarray

    # Edge descriptors (computed from a normalized version of the piece)
    edge_colors: Dict[str, np.ndarray]
    edge_gradients: Dict[str, np.ndarray]
    edge_directions: Dict[str, np.ndarray]
    edge_regions: Dict[str, np.ndarray]

    placed_row: Optional[int] = None
    placed_col: Optional[int] = None


# ======================================================================
# EDGE DESCRIPTORS
# ======================================================================

def extract_edge_color(img: np.ndarray, side: str, depth: int = 3):
    """
    Returns a 1D edge color profile (mean across thickness).
    Works on 3-channel images.
    """
    if side == "top":
        region = img[:depth, :]
    elif side == "bottom":
        region = img[-depth:, :]
    elif side == "left":
        region = img[:, :depth]
    else:
        region = img[:, -depth:]

    # For horizontal edges, average over rows -> vector across columns.
    if side in ("top", "bottom"):
        return np.mean(region, axis=0)
    # For vertical edges, average over cols -> vector across rows.
    return np.mean(region, axis=1)


def extract_sobel_descriptors(img: np.ndarray, side: str, depth: int = 5):
    """
    Returns:
      - Sobel magnitude 1D profile along the edge
      - Sobel direction 1D profile along the edge
    """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img

    if side == "top":
        region = gray[:depth, :]
    elif side == "bottom":
        region = gray[-depth:, :]
    elif side == "left":
        region = gray[:, :depth]
    else:
        region = gray[:, -depth:]

    sobel_x = cv2.Sobel(region, cv2.CV_64F, 1, 0, ksize=3)
    sobel_y = cv2.Sobel(region, cv2.CV_64F, 0, 1, ksize=3)

    mag = np.sqrt(sobel_x**2 + sobel_y**2)
    direction = np.arctan2(sobel_y, sobel_x)

    if side in ("top", "bottom"):
        return np.mean(mag, axis=0), np.mean(direction, axis=0)
    return np.mean(mag, axis=1), np.mean(direction, axis=1)


def extract_edge_region(img: np.ndarray, side: str, depth: int = 10):
    """
    Extract a thin strip region along the edge for local intensity/texture checks.
    """
    if side == "top":
        region = img[:depth, :]
    elif side == "bottom":
        region = img[-depth:, :]
    elif side == "left":
        region = img[:, :depth]
    else:
        region = img[:, -depth:]

    # Pad if extremely small (robustness)
    region = cv2.copyMakeBorder(
        region,
        0, max(0, 3 - region.shape[0]),
        0, max(0, 3 - region.shape[1]),
        cv2.BORDER_CONSTANT, value=0
    )
    return region


# ======================================================================
# LOADING PIECES FROM MS1 OUTPUT
# ======================================================================

# Robust filename parsing: find "<pid>_r<row>c<col>" anywhere in the filename.
# This aligns with common MS1 outputs like: "11_r0c0.png"
NAME_RE = re.compile(r"(\d+)_r(\d+)c(\d+)", re.IGNORECASE)

def load_pieces(ms1_folder: str) -> Dict[str, List[PuzzlePiece]]:
    """
    Load pieces from MS1 output folder and compute edge descriptors.

    Logic:
    - Use the original piece image for reconstruction/visualization.
    - Use a normalized grayscale->BGR version for descriptor extraction:
      Gaussian blur + CLAHE helps stabilize gradients under illumination changes.
    """
    pieces_path = os.path.join(ms1_folder, "pieces")
    if not os.path.exists(pieces_path):
        print(f"MS1 pieces folder not found: {pieces_path}")
        return {}

    puzzles: Dict[str, List[PuzzlePiece]] = {}
    files = sorted(os.listdir(pieces_path))

    for f in files:
        if not f.lower().endswith((".png", ".jpg", ".jpeg")):
            continue

        m = NAME_RE.search(f)
        if not m:
            print(f"Skipping file with unknown format: {f}")
            continue

        pid = m.group(1)
        row = int(m.group(2))
        col = int(m.group(3))

        img_path = os.path.join(pieces_path, f)
        img = cv2.imread(img_path)
        if img is None:
            print(f"Failed to read image: {img_path}")
            continue

        # Normalize for descriptor extraction (classical preprocessing)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        gray = cv2.GaussianBlur(gray, (3, 3), 0)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        norm_gray = clahe.apply(gray)
        norm_bgr = cv2.cvtColor(norm_gray, cv2.COLOR_GRAY2BGR)

        edges_c: Dict[str, np.ndarray] = {}
        edges_m: Dict[str, np.ndarray] = {}
        edges_d: Dict[str, np.ndarray] = {}
        edges_r: Dict[str, np.ndarray] = {}

        for side in ("top", "bottom", "left", "right"):
            edges_c[side] = extract_edge_color(norm_bgr, side)
            mag, direc = extract_sobel_descriptors(norm_bgr, side)
            edges_m[side] = mag
            edges_d[side] = direc
            edges_r[side] = extract_edge_region(norm_bgr, side)

        core = f.rsplit(".", 1)[0]
        piece_obj = PuzzlePiece(
            name=core,
            puzzle_id=pid,
            true_row=row,
            true_col=col,
            image=img,
            edge_colors=edges_c,
            edge_gradients=edges_m,
            edge_directions=edges_d,
            edge_regions=edges_r
        )

        puzzles.setdefault(pid, []).append(piece_obj)

    print(f"Loaded {len(puzzles)} puzzles from {pieces_path}")
    return puzzles


# ======================================================================
# MATCH SCORE (pairwise edge compatibility)
# ======================================================================

def compute_match_score(A: PuzzlePiece, B: PuzzlePiece, eA: str, eB: str) -> float:
    """
    Higher score = better match.

    Components:
    1) Sobel magnitude complement similarity (dominant term).
    2) Color normalized cross-correlation along edge profile (helps when texture is weak).
    3) Local region intensity MSE (penalize mismatched intensity).
    4) SSIM-like similarity on local region (NumPy-only implementation).

    Note:
    - This is classical scoring; no learned models.
    """
    def smooth1d(v, k=5):
        v = np.array(v, dtype=np.float32)
        if len(v) < 3:
            return v
        kernel = np.ones(k, dtype=np.float32) / k
        return np.convolve(v, kernel, mode="same")

    def normalize1(v):
        v = np.array(v, dtype=np.float32)
        mn, mx = v.min(), v.max()
        if mx == mn:
            return np.zeros_like(v)
        return (v - mn) / (mx - mn)

    # Sobel magnitude profiles
    magA = np.array(A.edge_gradients[eA], dtype=np.float32)
    magB = np.array(B.edge_gradients[eB], dtype=np.float32)

    magA = smooth1d(magA, k=5)
    magB = smooth1d(magB, k=5)
    magA_n = normalize1(magA)
    magB_n = normalize1(magB)

    L = min(len(magA_n), len(magB_n))
    if L < 3:
        return -1e9

    # Complement matching: A should align with inverted B profile
    sobel_err = np.mean((magA_n[:L] - (1.0 - magB_n[:L])) ** 2)
    sobel_score = -sobel_err

    # Color NCC on edge profile
    color_score = 0.0
    try:
        colA = np.array(A.edge_colors[eA], dtype=np.float32)
        colB = np.array(B.edge_colors[eB], dtype=np.float32)

        if colA.ndim == 2:
            vA = np.mean(colA, axis=1) if colA.shape[0] > colA.shape[1] else np.mean(colA, axis=0)
        else:
            vA = np.ravel(colA)

        if colB.ndim == 2:
            vB = np.mean(colB, axis=1) if colB.shape[0] > colB.shape[1] else np.mean(colB, axis=0)
        else:
            vB = np.ravel(colB)

        Lc = min(len(vA), len(vB))
        if Lc >= 3:
            vA = vA[:Lc] - vA[:Lc].mean()
            vB = vB[:Lc] - vB[:Lc].mean()
            denom = (np.sqrt((vA**2).sum()) * np.sqrt((vB**2).sum()))
            color_score = float((vA * vB).sum() / denom) if denom > 1e-6 else 0.0
    except Exception:
        color_score = 0.0

    # Intensity MSE + SSIM-like similarity on edge regions
    rA = A.edge_regions[eA]
    rB = B.edge_regions[eB]

    gA = cv2.cvtColor(rA, cv2.COLOR_BGR2GRAY) if rA.ndim == 3 else rA
    gB = cv2.cvtColor(rB, cv2.COLOR_BGR2GRAY) if rB.ndim == 3 else rB

    H = min(gA.shape[0], gB.shape[0])
    W = min(gA.shape[1], gB.shape[1])
    if H < 3 or W < 3:
        intensity_score = -1e9
        ssim_score = -1.0
    else:
        a = gA[:H, :W].astype(np.float32)
        b = gB[:H, :W].astype(np.float32)
        mse = np.mean((a - b) ** 2)
        intensity_score = -mse
        ssim_score = ssim_gray(a, b)

    # Weighted sum (tuned empirically)
    total = (sobel_score * 3.5) + (color_score * 0.9) + (ssim_score * 1.6) + (intensity_score * 0.8)
    return float(total)


# ======================================================================
# BUILD SCORE TABLES
# ======================================================================

def build_match_matrix(pieces: List[PuzzlePiece]):
    """
    Build:
    - scores[(i, j, edgeA, edgeB)] = compatibility score (directional)
    - mutual[(i, j)] = best symmetric agreement across edge pairs
    """
    scores = {}
    opposite = {"top": "bottom", "bottom": "top", "left": "right", "right": "left"}
    N = len(pieces)
    total = N * (N - 1) * 4
    count = 0
    print(f"Computing {total} edge comparisons...")

    for i in range(N):
        for j in range(N):
            if i == j:
                continue
            for edgeA in opposite:
                edgeB = opposite[edgeA]
                scores[(i, j, edgeA, edgeB)] = compute_match_score(pieces[i], pieces[j], edgeA, edgeB)
                count += 1
                if count % 500 == 0:
                    print(f"Progress: {count}/{total}", end="\r")

    # Mutual agreement table
    mutual = {}
    for i in range(N):
        for j in range(N):
            if i == j:
                continue
            best = -1e9
            for edgeA in opposite:
                edgeB = opposite[edgeA]
                s1 = scores.get((i, j, edgeA, edgeB), -1e9)
                s2 = scores.get((j, i, edgeB, edgeA), -1e9)
                best = max(best, s1 + s2)
            mutual[(i, j)] = best

    print("\nEdge comparisons complete, mutual table built.")
    return scores, mutual


# ======================================================================
# BEAM SEARCH ASSEMBLY
# ======================================================================

def assemble_beam(pieces, size, beam_width=8):
    """
    Beam search assembly:
    - Places pieces left-to-right, top-to-bottom.
    - Keeps top-N partial grids by score.
    - Uses neighbor scores + small mutual-pair bonus.
    """
    R, C = size
    N = len(pieces)
    scores, mutual = build_match_matrix(pieces)

    directions = {
        (-1, 0): ("bottom", "top"),    # neighbor above -> candidate top
        (1, 0): ("top", "bottom"),     # neighbor below -> candidate bottom
        (0, -1): ("right", "left"),    # neighbor left -> candidate left
        (0, 1): ("left", "right")      # neighbor right -> candidate right
    }

    positions = [(r, c) for r in range(R) for c in range(C)]

    # Seed with pieces that have strong mutual best partner (heuristic)
    avg_mutual = []
    for i in range(N):
        vals = [mutual.get((i, j), -1e9) for j in range(N) if j != i]
        best_partner_strength = np.max(vals) if vals else -1e9
        avg_mutual.append((best_partner_strength, i))
    avg_mutual.sort(reverse=True)

    seed_ids = [t[1] for t in avg_mutual[:min(len(avg_mutual), max(3, beam_width))]]

    beams = []
    for seed in seed_ids:
        grid = [[None] * C for _ in range(R)]
        grid[0][0] = seed
        beams.append((grid, {seed}, 0.0))

    if not beams:
        for idx in range(min(N, beam_width)):
            grid = [[None] * C for _ in range(R)]
            grid[0][0] = idx
            beams.append((grid, {idx}, 0.0))

    # Expand remaining positions
    for pos_i in range(1, len(positions)):
        r, c = positions[pos_i]
        new_beams = []

        for grid, used, base_score in beams:
            neighbors = []
            for (dr, dc), (edgeN, edgeC) in directions.items():
                nr, nc = r + dr, c + dc
                if 0 <= nr < R and 0 <= nc < C and grid[nr][nc] is not None:
                    neighbors.append((nr, nc, dr, dc))

            for cand in range(N):
                if cand in used:
                    continue

                inc = 0.0
                mutual_boost = 0.0

                for nr, nc, dr, dc in neighbors:
                    neigh = grid[nr][nc]
                    eN, eC = directions[(dr, dc)]
                    inc += scores.get((neigh, cand, eN, eC), -1e9)
                    mutual_boost += mutual.get((neigh, cand), -1e9) * 0.12

                ng = [row[:] for row in grid]
                ng[r][c] = cand
                total_score = base_score + inc + mutual_boost
                new_beams.append((ng, used | {cand}, total_score))

        if not new_beams:
            break

        new_beams.sort(key=lambda x: x[2], reverse=True)
        beams = new_beams[:beam_width]

    if not beams:
        raise RuntimeError("Beam search failed to produce any solution.")
    return beams[0][0]


# ======================================================================
# ACCURACY (requires correct true_row/true_col parsing from filenames)
# ======================================================================

def calc_accuracy(pieces, grid):
    R, C = len(grid), len(grid[0])
    total = R * C
    correct = 0

    for r in range(R):
        for c in range(C):
            p = pieces[grid[r][c]]
            if p.true_row == r and p.true_col == c:
                correct += 1

    return correct / total * 100


# ======================================================================
# RECONSTRUCT FULL IMAGE FROM GRID
# ======================================================================

def reconstruct(pieces, grid):
    R, C = len(grid), len(grid[0])
    h, w = pieces[0].image.shape[:2]
    canvas = np.zeros((R * h, C * w, 3), dtype=np.uint8)

    for r in range(R):
        for c in range(C):
            canvas[r*h:(r+1)*h, c*w:(c+1)*w] = pieces[grid[r][c]].image

    return canvas


# ======================================================================
# VISUALIZATION (ASSEMBLED VS GROUND TRUTH)
# ======================================================================

def visualize(pieces, grid, pid, out_dir, size):
    """
    Visualization:
    - left: assembled image
    - middle: ground-truth (full)
    - right: ground-truth cells with colored border based on SSIM similarity to assembled cells

    This uses NumPy-only SSIM (ssim_gray).
    """
    R, C = size
    assembled = reconstruct(pieces, grid)
    H, W = assembled.shape[:2]
    cell_h = H // R
    cell_w = W // C

    correct_root = "dataset/correct/correct"
    possible = [f"{pid}.png", f"{pid}.jpg", f"{pid}.jpeg", f"{pid}.PNG", f"{pid}.JPG", f"{pid}.JPEG"]

    gt = None
    for filename in possible:
        full_path = os.path.join(correct_root, filename)
        if os.path.exists(full_path):
            gt = cv2.imread(full_path)
            break

    if gt is None:
        print(f"Ground truth not found for {pid} in {correct_root}")
        gt = np.zeros_like(assembled)

    if gt.shape[:2] != assembled.shape[:2]:
        gt = cv2.resize(gt, (assembled.shape[1], assembled.shape[0]), interpolation=cv2.INTER_AREA)

    comparison_canvas = np.zeros_like(assembled)
    SSIM_THR = 0.60
    correct_cells = 0
    total_cells = R * C

    for r in range(R):
        for c in range(C):
            y1, y2 = r * cell_h, (r + 1) * cell_h
            x1, x2 = c * cell_w, (c + 1) * cell_w

            assembled_patch = assembled[y1:y2, x1:x2]
            gt_patch = gt[y1:y2, x1:x2]

            a_gray = cv2.cvtColor(assembled_patch, cv2.COLOR_BGR2GRAY)
            g_gray = cv2.cvtColor(gt_patch, cv2.COLOR_BGR2GRAY)

            # Inner crop reduces border artifacts
            ih = max(1, a_gray.shape[0] // 10)
            iw = max(1, a_gray.shape[1] // 10)
            a_in = a_gray[ih:-ih or None, iw:-iw or None]
            g_in = g_gray[ih:-ih or None, iw:-iw or None]
            if a_in.size == 0 or g_in.size == 0:
                a_in, g_in = a_gray, g_gray

            s = ssim_gray(a_in, g_in)

            is_correct = (s >= SSIM_THR)
            if is_correct:
                correct_cells += 1

            color = (0, 255, 0) if is_correct else (0, 0, 255)
            patch_vis = gt_patch.copy()
            cv2.rectangle(patch_vis, (0, 0), (patch_vis.shape[1]-1, patch_vis.shape[0]-1), color, thickness=8)

            txt = f"{s:.2f}"
            cv2.putText(patch_vis, txt, (8, patch_vis.shape[0]-12), cv2.FONT_HERSHEY_SIMPLEX,
                        0.7, (255, 255, 255), 2, cv2.LINE_AA)

            comparison_canvas[y1:y2, x1:x2] = patch_vis

    acc_pct = correct_cells / total_cells * 100

    fig, ax = plt.subplots(1, 3, figsize=(22, 7))

    ax[0].imshow(cv2.cvtColor(assembled, cv2.COLOR_BGR2RGB))
    ax[0].set_title("Our Assembly")
    ax[0].axis("off")

    ax[1].imshow(cv2.cvtColor(gt, cv2.COLOR_BGR2RGB))
    ax[1].set_title("Ground Truth (full image)")
    ax[1].axis("off")

    ax[2].imshow(cv2.cvtColor(comparison_canvas, cv2.COLOR_BGR2RGB))
    ax[2].set_title(f"Correctness = {acc_pct:.1f}% (visual SSIM)")
    ax[2].axis("off")

    os.makedirs(out_dir, exist_ok=True)
    save_path = os.path.join(out_dir, f"{pid}_result.png")
    plt.savefig(save_path, dpi=160, bbox_inches="tight")
    plt.close()

    print(f"Saved: {save_path}")


# ======================================================================
# FULL PIPELINE (MS2 ONLY)
# ======================================================================

def run_ms2(ms1_folder, out_folder, size, beam_width=10):
    puzzles = load_pieces(ms1_folder)

    for pid, pieces in puzzles.items():
        print(f"\nSolving puzzle {pid} ...")

        grid = assemble_beam(pieces, size, beam_width)

        visualize(pieces, grid, pid, out_folder, size)

    print("\nMS2 Complete.")


# ======================================================================
# RUN MS2 ON ALL PUZZLE SIZES (you can add 4x4 and 8x8 similarly)
# ======================================================================

if __name__ == "__main__":
    print("\n" + "=" * 80)
    print("RUNNING MS2 ON ALL PUZZLE SIZES")
    print("=" * 80)

    run_ms2(
        ms1_folder="processed_output/puzzle_2x2",
        out_folder="ms2_results_puzzle_2x2",
        size=(2, 2),
        beam_width=12
    )



RUNNING MS2 ON ALL PUZZLE SIZES
Loaded 110 puzzles from processed_output/puzzle_2x2/pieces

Solving puzzle 0 ...
Computing 48 edge comparisons...

Edge comparisons complete, mutual table built.
Saved: ms2_results_puzzle_2x2/0_result.png

Solving puzzle 100 ...
Computing 48 edge comparisons...

Edge comparisons complete, mutual table built.
Saved: ms2_results_puzzle_2x2/100_result.png

Solving puzzle 101 ...
Computing 48 edge comparisons...

Edge comparisons complete, mutual table built.
Saved: ms2_results_puzzle_2x2/101_result.png

Solving puzzle 102 ...
Computing 48 edge comparisons...

Edge comparisons complete, mutual table built.
Saved: ms2_results_puzzle_2x2/102_result.png

Solving puzzle 103 ...
Computing 48 edge comparisons...

Edge comparisons complete, mutual table built.
Saved: ms2_results_puzzle_2x2/103_result.png

Solving puzzle 104 ...
Computing 48 edge comparisons...

Edge comparisons complete, mutual table built.
Saved: ms2_results_puzzle_2x2/104_result.png

Solving puz

**2nd : 4x4 Solver**

In [7]:
import os
import re
import time
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import cv2
import numpy as np
import matplotlib.pyplot as plt


# ============================================================
# NUMPY-ONLY SSIM (CLASSICAL METRIC, NOT ML)
# ============================================================

def ssim_gray(a: np.ndarray, b: np.ndarray) -> float:
    """
    Simplified SSIM for grayscale images using global statistics.
    This avoids external libraries (e.g., skimage) while providing a robust similarity metric.

    a, b: grayscale arrays (any numeric type). If shapes differ, b is resized to a.
    Returns: SSIM-like score, typically in [0, 1] for similar patches.
    """
    a = a.astype(np.float32)
    b = b.astype(np.float32)

    if a.shape != b.shape:
        b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_AREA)

    mu_a = a.mean()
    mu_b = b.mean()
    var_a = ((a - mu_a) ** 2).mean()
    var_b = ((b - mu_b) ** 2).mean()
    cov_ab = ((a - mu_a) * (b - mu_b)).mean()

    # SSIM stabilizers for 8-bit images
    C1 = (0.01 * 255) ** 2
    C2 = (0.03 * 255) ** 2

    num = (2 * mu_a * mu_b + C1) * (2 * cov_ab + C2)
    den = (mu_a ** 2 + mu_b ** 2 + C1) * (var_a + var_b + C2)

    if den < 1e-12:
        return 0.0
    return float(num / den)


# ============================================================
# DATA STRUCTURE
# ============================================================

@dataclass
class PuzzlePiece:
    name: str
    puzzle_id: str
    true_row: int
    true_col: int
    image: np.ndarray

    # Edge descriptors (computed from normalized version of the piece)
    edge_colors: Dict[str, np.ndarray]
    edge_gradients: Dict[str, np.ndarray]
    edge_directions: Dict[str, np.ndarray]
    edge_regions: Dict[str, np.ndarray]

    placed_row: Optional[int] = None
    placed_col: Optional[int] = None


# ============================================================
# EDGE DESCRIPTORS
# ============================================================

def extract_edge_color(img: np.ndarray, side: str, depth: int = 4):
    """
    Returns a 1D color profile along the requested edge.
    For top/bottom: mean over thickness -> vector across width.
    For left/right: mean over thickness -> vector across height.
    """
    if side == "top":
        region = img[:depth, :]
    elif side == "bottom":
        region = img[-depth:, :]
    elif side == "left":
        region = img[:, :depth]
    else:
        region = img[:, -depth:]

    if side in ("top", "bottom"):
        return np.mean(region, axis=0)   # (w,3)
    return np.mean(region, axis=1)       # (h,3)


def extract_sobel_descriptors(img: np.ndarray, side: str, depth: int = 6):
    """
    Computes Sobel magnitude and direction profiles near an edge.
    Outputs 1D profiles aligned with the edge length.
    """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img

    if side == "top":
        region = gray[:depth, :]
    elif side == "bottom":
        region = gray[-depth:, :]
    elif side == "left":
        region = gray[:, :depth]
    else:
        region = gray[:, -depth:]

    sobel_x = cv2.Sobel(region, cv2.CV_64F, 1, 0, ksize=3)
    sobel_y = cv2.Sobel(region, cv2.CV_64F, 0, 1, ksize=3)

    mag = np.sqrt(sobel_x**2 + sobel_y**2)
    direction = np.arctan2(sobel_y, sobel_x)

    if side in ("top", "bottom"):
        return np.mean(mag, axis=0), np.mean(direction, axis=0)
    return np.mean(mag, axis=1), np.mean(direction, axis=1)


def extract_edge_region(img: np.ndarray, side: str, depth: int = 14):
    """
    Extracts a narrow strip along the edge for local similarity checks.
    Uses reflect padding to avoid boundary issues for very small pieces.
    """
    if side == "top":
        region = img[:depth, :]
    elif side == "bottom":
        region = img[-depth:, :]
    elif side == "left":
        region = img[:, :depth]
    else:
        region = img[:, -depth:]

    rh, rw = region.shape[:2]
    region = cv2.copyMakeBorder(
        region,
        0, max(0, 4 - rh),
        0, max(0, 4 - rw),
        cv2.BORDER_REFLECT_101,
    )
    return region


# ============================================================
# LOAD PIECES FROM MS1 (4x4 ONLY)
# ============================================================

NAME_RE = re.compile(r"(\d+)_r(\d+)c(\d+)", re.IGNORECASE)

def load_pieces(ms1_folder: str) -> Dict[str, List[PuzzlePiece]]:
    """
    Expects filenames like:  11_r0c0.png  in  ms1_folder/pieces
    Uses normalized image (blur + CLAHE) for descriptor extraction,
    but stores original BGR image for reconstruction.
    """
    pieces_path = os.path.join(ms1_folder, "pieces")
    if not os.path.exists(pieces_path):
        print(f"MS1 pieces folder not found: {pieces_path}")
        return {}

    puzzles: Dict[str, List[PuzzlePiece]] = {}
    files = sorted(
        f for f in os.listdir(pieces_path)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    )

    print(f"Found {len(files)} pieces in {pieces_path}")

    for f in files:
        m = NAME_RE.search(f)
        if not m:
            print(f"Skipping (bad name format): {f}")
            continue

        puzzle_id = m.group(1)
        row = int(m.group(2))
        col = int(m.group(3))

        img_path = os.path.join(pieces_path, f)
        img = cv2.imread(img_path)
        if img is None:
            print(f"Failed to read: {img_path}")
            continue

        # Normalization for descriptor stability (classical preprocessing)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        gray = cv2.GaussianBlur(gray, (3, 3), 0)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        norm_gray = clahe.apply(gray)
        norm_bgr = cv2.cvtColor(norm_gray, cv2.COLOR_GRAY2BGR)

        edge_colors = {}
        edge_gradients = {}
        edge_directions = {}
        edge_regions = {}

        for side in ("top", "bottom", "left", "right"):
            edge_colors[side] = extract_edge_color(norm_bgr, side)
            mag, direc = extract_sobel_descriptors(norm_bgr, side)
            edge_gradients[side] = mag
            edge_directions[side] = direc
            edge_regions[side] = extract_edge_region(norm_bgr, side)

        stem = f.rsplit(".", 1)[0]  # base name without extension

        piece = PuzzlePiece(
            name=stem,
            puzzle_id=puzzle_id,
            true_row=row,
            true_col=col,
            image=img,  # keep original for reconstruction
            edge_colors=edge_colors,
            edge_gradients=edge_gradients,
            edge_directions=edge_directions,
            edge_regions=edge_regions,
        )

        puzzles.setdefault(puzzle_id, []).append(piece)

    print(f"Loaded {len(puzzles)} puzzles (4x4)")
    return puzzles


# ============================================================
# LAB SEAM COST (CLASSICAL)
# ============================================================

def calc_pair_cost_lab(tile1_lab: np.ndarray, tile2_lab: np.ndarray, direction: int) -> float:
    """
    direction=0: tile2 is to the RIGHT of tile1 (horizontal seam)
    direction=1: tile2 is BELOW tile1 (vertical seam)

    Uses:
    - Border difference cost (pixel seam error)
    - Gradient consistency cost (penalizes seam discontinuities)
    """
    if direction == 0:
        b1 = tile1_lab[:, -1, :]
        b2 = tile2_lab[:, 0, :]
        d1 = tile1_lab[:, -2, :]
        d2 = tile2_lab[:, 1, :]
    else:
        b1 = tile1_lab[-1, :, :]
        b2 = tile2_lab[0, :, :]
        d1 = tile1_lab[-2, :, :]
        d2 = tile2_lab[1, :, :]

    diff = b1 - b2
    cost_p = float(np.sum(np.sqrt(np.sum(diff**2, axis=1))))

    grad_t1 = b1 - d1
    grad_t2 = d2 - b2
    grad_seam = b2 - b1

    cost_g = float(
        np.sum(np.sqrt(np.sum((grad_seam - grad_t1) ** 2, axis=1))) +
        np.sum(np.sqrt(np.sum((grad_seam - grad_t2) ** 2, axis=1)))
    )

    return cost_p + 1.5 * cost_g


def build_pair_costs_lab(pieces: List[PuzzlePiece]) -> np.ndarray:
    """
    Precomputes directed adjacency costs:
    costs[i,j,0] = cost if j is RIGHT of i
    costs[i,j,1] = cost if j is BELOW i
    """
    nt = len(pieces)

    labs = []
    for p in pieces:
        try:
            lab = cv2.cvtColor(p.image, cv2.COLOR_BGR2LAB).astype(np.float32)
        except cv2.error:
            lab = p.image.astype(np.float32)
        labs.append(lab)

    costs = np.zeros((nt, nt, 2), dtype=np.float32)
    for i in range(nt):
        for j in range(nt):
            if i == j:
                continue
            costs[i, j, 0] = calc_pair_cost_lab(labs[i], labs[j], 0)
            costs[i, j, 1] = calc_pair_cost_lab(labs[i], labs[j], 1)
    return costs


# ============================================================
# GLOBAL FRONTIER GROWING SOLVER (TRY ALL SEEDS)
# ============================================================

def solve_global_frontier_from_costs(costs: np.ndarray, n: int) -> Optional[np.ndarray]:
    """
    Frontier growing:
    - Start with a seed in (0,0)
    - For each frontier empty cell, compute best candidate (and ratio test)
    - Place the most confident candidate, expand frontier
    - Repeat until grid full
    - Try all seeds; keep the grid with minimum total energy
    """
    import heapq

    nt = costs.shape[0]
    if nt != n * n:
        return None

    global_best_grid = None
    global_min_energy = float("inf")

    for seed in range(nt):
        grid = np.full((n, n), -1, dtype=int)
        grid[0, 0] = seed
        used = {seed}
        pq = []  # heap: (ratio, best_cost, r, c, best_cand)

        def push_best_candidate(r: int, c: int):
            # Placed neighbors around (r,c)
            tL = grid[r, c - 1] if c > 0 else -1
            tR = grid[r, c + 1] if c < n - 1 else -1
            tT = grid[r - 1, c] if r > 0 else -1
            tB = grid[r + 1, c] if r < n - 1 else -1

            candidates = []
            for cand in range(nt):
                if cand in used:
                    continue

                cost = 0.0
                if tL != -1:
                    cost += float(costs[tL, cand, 0])
                if tR != -1:
                    cost += float(costs[cand, tR, 0])
                if tT != -1:
                    cost += float(costs[tT, cand, 1])
                if tB != -1:
                    cost += float(costs[cand, tB, 1])

                candidates.append((cost, cand))

            if not candidates:
                return

            candidates.sort(key=lambda x: x[0])
            best_cost, best_cand = candidates[0]

            # Ratio test: best / second-best; smaller means more confident
            ratio = 1.0
            if len(candidates) > 1 and candidates[1][0] > 1e-5:
                ratio = best_cost / candidates[1][0]

            heapq.heappush(pq, (ratio, best_cost, r, c, best_cand))

        def add_neighbors(r: int, c: int):
            for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
                nr, nc = r + dr, c + dc
                if 0 <= nr < n and 0 <= nc < n and grid[nr, nc] == -1:
                    push_best_candidate(nr, nc)

        add_neighbors(0, 0)

        while len(used) < nt and pq:
            ratio, best_cost, r, c, cand = heapq.heappop(pq)

            if grid[r, c] != -1:
                continue

            if cand in used:
                # candidate already taken; recompute best for this cell
                push_best_candidate(r, c)
                continue

            grid[r, c] = cand
            used.add(cand)
            add_neighbors(r, c)

        if len(used) == nt:
            # Total energy over all seams
            energy = 0.0
            for rr in range(n):
                for cc in range(n):
                    idx = grid[rr, cc]
                    if cc < n - 1:
                        energy += float(costs[idx, grid[rr, cc + 1], 0])
                    if rr < n - 1:
                        energy += float(costs[idx, grid[rr + 1, cc], 1])

            if energy < global_min_energy:
                global_min_energy = energy
                global_best_grid = grid.copy()

    return global_best_grid


# ============================================================
# LOCAL REFINEMENT (ADJACENT SWAPS)
# ============================================================

def refine_grid_with_adjacent_swaps(
    grid: np.ndarray,
    costs: np.ndarray,
    n: int,
    max_iterations: int = 10,
    max_no_improvement: int = 3
) -> np.ndarray:
    """
    Adjacent-swap refinement:
    - Tries swapping neighboring tiles (right and bottom).
    - Keeps a swap if it reduces total seam energy.
    - Stops after a few iterations with no improvement.
    """
    current = grid.copy()

    def total_energy(g: np.ndarray) -> float:
        e = 0.0
        for r in range(n):
            for c in range(n):
                idx = g[r, c]
                if c < n - 1:
                    e += float(costs[idx, g[r, c + 1], 0])
                if r < n - 1:
                    e += float(costs[idx, g[r + 1, c], 1])
        return e

    best = current.copy()
    best_e = total_energy(best)

    no_improve = 0
    for _ in range(max_iterations):
        improved_any = False

        for r in range(n):
            for c in range(n):
                for dr, dc in [(0, 1), (1, 0)]:
                    r2, c2 = r + dr, c + dc
                    if r2 >= n or c2 >= n:
                        continue

                    g2 = current.copy()
                    g2[r, c], g2[r2, c2] = g2[r2, c2], g2[r, c]

                    e2 = total_energy(g2)
                    if e2 < best_e:
                        best_e = e2
                        best = g2
                        current = g2
                        improved_any = True

        if improved_any:
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= max_no_improvement:
                break

    return best


def assemble_frontier_global(
    pieces: List[PuzzlePiece],
    size: Tuple[int, int] = (4, 4),
    do_refine: bool = True
) -> List[List[int]]:
    """
    Main 4x4 assembly:
    - Build LAB seam costs
    - Solve via global frontier growing across all seeds
    - Optionally refine with local adjacent swaps
    Returns: grid as list-of-lists of piece indices.
    """
    R, C = size
    if R != C:
        raise ValueError("Frontier solver assumes square puzzles.")
    n = R

    nt = len(pieces)
    if nt != n * n:
        raise ValueError(f"Expected {n*n} pieces, got {nt}")

    costs = build_pair_costs_lab(pieces)
    grid = solve_global_frontier_from_costs(costs, n)
    if grid is None:
        raise RuntimeError("Frontier solver failed to produce a grid.")

    if do_refine:
        grid = refine_grid_with_adjacent_swaps(grid, costs, n, max_iterations=10)

    return [[int(grid[r, c]) for c in range(n)] for r in range(n)]


# ============================================================
# ACCURACY + RECONSTRUCTION + VISUALIZATION
# ============================================================

def calc_accuracy(pieces: List[PuzzlePiece], grid):
    """
    Tile accuracy depends on correctly parsed (true_row,true_col) from filenames.
    """
    R, C = len(grid), len(grid[0])
    total = R * C
    correct = 0
    for r in range(R):
        for c in range(C):
            idx = grid[r][c]
            if idx is None or idx < 0 or idx >= len(pieces):
                continue
            p = pieces[idx]
            if p.true_row == r and p.true_col == c:
                correct += 1
    return correct / total * 100.0


def reconstruct(pieces: List[PuzzlePiece], grid):
    """
    Stitches the selected tiles into one image.
    """
    R, C = len(grid), len(grid[0])
    h, w = pieces[0].image.shape[:2]
    canvas = np.zeros((R * h, C * w, 3), dtype=np.uint8)

    for r in range(R):
        for c in range(C):
            idx = grid[r][c]
            if idx is None or idx < 0 or idx >= len(pieces):
                continue
            canvas[r * h:(r + 1) * h, c * w:(c + 1) * w] = pieces[idx].image

    return canvas


def visualize(
    pieces: List[PuzzlePiece],
    grid,
    pid: str,
    out_dir: str,
    size=(4, 4)
):
    """
    Visualization:
    - left: assembled image
    - middle: ground truth full image
    - right: ground truth patches with border colored by SSIM-like similarity to assembled patches

    Uses NumPy-only SSIM (ssim_gray) to avoid external dependencies.
    """
    R, C = size
    assembled = reconstruct(pieces, grid)
    H, W = assembled.shape[:2]
    cell_h = H // R
    cell_w = W // C

    # Load ground truth
    correct_root = "dataset/correct/correct"
    gt = None
    for ext in (".png", ".jpg", ".jpeg", ".JPG", ".PNG", ".JPEG"):
        path = os.path.join(correct_root, f"{pid}{ext}")
        if os.path.exists(path):
            gt = cv2.imread(path)
            break

    if gt is None:
        print(f"GT not found for puzzle {pid}, using black image.")
        gt = np.zeros_like(assembled)

    if gt.shape[:2] != (H, W):
        gt = cv2.resize(gt, (W, H), interpolation=cv2.INTER_AREA)

    comparison = np.zeros_like(assembled)
    SSIM_THR = 0.60

    # Compare each cell using SSIM-like metric
    for r in range(R):
        for c in range(C):
            y1, y2 = r * cell_h, (r + 1) * cell_h
            x1, x2 = c * cell_w, (c + 1) * cell_w

            asm_patch = assembled[y1:y2, x1:x2]
            gt_patch = gt[y1:y2, x1:x2]

            a_gray = cv2.cvtColor(asm_patch, cv2.COLOR_BGR2GRAY)
            g_gray = cv2.cvtColor(gt_patch, cv2.COLOR_BGR2GRAY)

            try:
                s = ssim_gray(a_gray, g_gray)
            except Exception:
                # fallback: pseudo similarity from MSE
                mse = np.mean((a_gray.astype(np.float32) - g_gray.astype(np.float32)) ** 2)
                s = 1.0 / (1.0 + mse)

            good = (s >= SSIM_THR)
            color = (0, 255, 0) if good else (0, 0, 255)

            patch_vis = gt_patch.copy()
            cv2.rectangle(
                patch_vis,
                (0, 0),
                (patch_vis.shape[1] - 1, patch_vis.shape[0] - 1),
                color,
                8,
            )
            cv2.putText(
                patch_vis,
                f"{s:.2f}",
                (8, patch_vis.shape[0] - 12),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.7,
                (255, 255, 255),
                2,
                cv2.LINE_AA,
            )

            comparison[y1:y2, x1:x2] = patch_vis

    # Estimate correctness by sampling the colored border area
    correct_cells = 0
    total_cells = R * C
    for r in range(R):
        for c in range(C):
            y1 = r * cell_h
            x1 = c * cell_w
            sample = comparison[y1 + 4:y1 + 10, x1 + 4:x1 + 10]
            avg = np.mean(sample.reshape(-1, 3), axis=0)
            if avg[1] > avg[2] + 5:
                correct_cells += 1
    vis_acc = correct_cells / total_cells * 100

    fig, ax = plt.subplots(1, 3, figsize=(22, 7))
    ax[0].imshow(cv2.cvtColor(assembled, cv2.COLOR_BGR2RGB))
    ax[0].set_title("Our Assembly")
    ax[0].axis("off")

    ax[1].imshow(cv2.cvtColor(gt, cv2.COLOR_BGR2RGB))
    ax[1].set_title("Ground Truth")
    ax[1].axis("off")

    ax[2].imshow(cv2.cvtColor(comparison, cv2.COLOR_BGR2RGB))
    ax[2].set_title(f"Visual correctness approx = {vis_acc:.1f}%")
    ax[2].axis("off")

    os.makedirs(out_dir, exist_ok=True)
    save_path = os.path.join(out_dir, f"{pid}_4x4_result.png")
    plt.savefig(save_path, dpi=160, bbox_inches="tight")
    plt.close()
    print(f"Saved visualization: {save_path}")


# ============================================================
# MAIN: 4x4 ONLY
# ============================================================

def run_ms2_4x4(
    ms1_folder="processed_output/puzzle_4x4",
    out_folder="ms2_results_puzzle_4x4"
):
    puzzles = load_pieces(ms1_folder)
    if not puzzles:
        print("No puzzles found in MS1 folder.")
        return

    os.makedirs(out_folder, exist_ok=True)

    for pid, pieces in sorted(puzzles.items(), key=lambda x: int(x[0]) if x[0].isdigit() else x[0]):
        if len(pieces) != 16:
            print(f"Puzzle {pid}: expected 16 pieces, got {len(pieces)} - skipping.")
            continue

        print(f"\nSolving 4x4 puzzle {pid}")
        t0 = time.time()

        grid = assemble_frontier_global(pieces, size=(4, 4), do_refine=True)

        dt = time.time() - t0
        print(f"Tile time = {dt:.2f}s")

        visualize(pieces, grid, pid, out_folder, size=(4, 4))

    print("\nMS2 4x4 run complete.")


if __name__ == "__main__":
    print("\n" + "=" * 72)
    print("RUNNING MS2 - 4x4 PUZZLES (Global Frontier Growing + refinement)")
    print("=" * 72)
    run_ms2_4x4()



RUNNING MS2 - 4x4 PUZZLES (Global Frontier Growing + refinement)
Found 1760 pieces in processed_output/puzzle_4x4/pieces
Loaded 110 puzzles (4x4)

Solving 4x4 puzzle 0
Tile time = 0.05s
Saved visualization: ms2_results_puzzle_4x4/0_4x4_result.png

Solving 4x4 puzzle 1
Tile time = 0.05s
Saved visualization: ms2_results_puzzle_4x4/1_4x4_result.png

Solving 4x4 puzzle 2
Tile time = 0.05s
Saved visualization: ms2_results_puzzle_4x4/2_4x4_result.png

Solving 4x4 puzzle 3
Tile time = 0.03s
Saved visualization: ms2_results_puzzle_4x4/3_4x4_result.png

Solving 4x4 puzzle 4
Tile time = 0.03s
Saved visualization: ms2_results_puzzle_4x4/4_4x4_result.png

Solving 4x4 puzzle 5
Tile time = 0.03s
Saved visualization: ms2_results_puzzle_4x4/5_4x4_result.png

Solving 4x4 puzzle 6
Tile time = 0.03s
Saved visualization: ms2_results_puzzle_4x4/6_4x4_result.png

Solving 4x4 puzzle 7
Tile time = 0.03s
Saved visualization: ms2_results_puzzle_4x4/7_4x4_result.png

Solving 4x4 puzzle 8
Tile time = 0.03s
Saved

**3rd : 8x8 Solver**

In [8]:
# 1---Refined ////////FULL OUTPUT (cleaned: no icons/emojis, no skimage)

import os
import re
import heapq
import time
from dataclasses import dataclass
from typing import Dict, List, Optional

import cv2
import numpy as np
import matplotlib.pyplot as plt


# ============================================================
# NUMPY-ONLY SSIM (CLASSICAL METRIC, NOT ML)
# ============================================================

def ssim_gray(a: np.ndarray, b: np.ndarray) -> float:
    """
    Simplified SSIM for grayscale images using global statistics.
    Avoids external libraries (e.g., skimage).

    a, b: grayscale arrays. If shapes differ, b is resized to a.
    Returns: SSIM-like score.
    """
    a = a.astype(np.float32)
    b = b.astype(np.float32)

    if a.shape != b.shape:
        b = cv2.resize(b, (a.shape[1], a.shape[0]), interpolation=cv2.INTER_AREA)

    mu_a = a.mean()
    mu_b = b.mean()
    var_a = ((a - mu_a) ** 2).mean()
    var_b = ((b - mu_b) ** 2).mean()
    cov_ab = ((a - mu_a) * (b - mu_b)).mean()

    C1 = (0.01 * 255) ** 2
    C2 = (0.03 * 255) ** 2

    num = (2 * mu_a * mu_b + C1) * (2 * cov_ab + C2)
    den = (mu_a ** 2 + mu_b ** 2 + C1) * (var_a + var_b + C2)

    if den < 1e-12:
        return 0.0
    return float(num / den)


# ============================================================
# DATA STRUCTURE
# ============================================================

@dataclass
class PuzzlePiece:
    name: str
    puzzle_id: str
    true_row: int
    true_col: int
    image: np.ndarray
    edge_colors: Dict[str, np.ndarray] = None
    edge_gradients: Dict[str, np.ndarray] = None
    edge_regions: Dict[str, np.ndarray] = None


# ============================================================
# EDGE DESCRIPTORS (CLASSICAL IMAGE PROCESSING)
# ============================================================

def extract_edge_color(img: np.ndarray, side: str, depth: int = 4):
    """
    Returns a 1D color profile along an edge by averaging a thin strip.
    """
    if side == "top":
        region = img[:depth, :]
    elif side == "bottom":
        region = img[-depth:, :]
    elif side == "left":
        region = img[:, :depth]
    else:
        region = img[:, -depth:]
    return np.mean(region, axis=0) if side in ("top", "bottom") else np.mean(region, axis=1)


def extract_sobel_descriptors(img: np.ndarray, side: str, depth: int = 6):
    """
    Returns a 1D Sobel magnitude profile along the requested edge.
    """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img

    if side == "top":
        region = gray[:depth, :]
    elif side == "bottom":
        region = gray[-depth:, :]
    elif side == "left":
        region = gray[:, :depth]
    else:
        region = gray[:, -depth:]

    sobel_x = cv2.Sobel(region, cv2.CV_64F, 1, 0, ksize=3)
    sobel_y = cv2.Sobel(region, cv2.CV_64F, 0, 1, ksize=3)
    mag = np.sqrt(sobel_x**2 + sobel_y**2)

    return np.mean(mag, axis=0) if side in ("top", "bottom") else np.mean(mag, axis=1)


def extract_edge_region(img: np.ndarray, side: str, depth: int = 12):
    """
    Extracts a small strip along an edge for similarity checks (visual evaluation).
    Adds padding if the strip is too small.
    """
    if side == "top":
        region = img[:depth, :]
    elif side == "bottom":
        region = img[-depth:, :]
    elif side == "left":
        region = img[:, :depth]
    else:
        region = img[:, -depth:]

    rh, rw = region.shape[:2]
    if rh < 4 or rw < 4:
        region = cv2.copyMakeBorder(
            region, 0, max(0, 4 - rh), 0, max(0, 4 - rw),
            cv2.BORDER_REFLECT_101
        )
    return region


# ============================================================
# LOAD PIECES (MS1 OUTPUT)
# ============================================================

NAME_RE = re.compile(r"^(\d+)_r(\d+)c(\d+)\.(png|jpg|jpeg)$", re.IGNORECASE)

def load_pieces(ms1_folder: str) -> Dict[str, List[PuzzlePiece]]:
    """
    Loads pieces from MS1 folder: ms1_folder/pieces.
    Expected filename format: <pid>_r<row>c<col>.(png/jpg/jpeg)
    """
    pieces_path = os.path.join(ms1_folder, "pieces")
    if not os.path.exists(pieces_path):
        print(f"MS1 pieces folder not found: {pieces_path}")
        return {}

    puzzles: Dict[str, List[PuzzlePiece]] = {}
    files = sorted(
        f for f in os.listdir(pieces_path)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    )

    print(f"Found {len(files)} pieces in {pieces_path}")

    for f in files:
        m = NAME_RE.match(f)
        if not m:
            continue

        pid = m.group(1)
        r = int(m.group(2))
        c = int(m.group(3))

        img_path = os.path.join(pieces_path, f)
        img = cv2.imread(img_path)
        if img is None:
            continue

        # Normalization for descriptor stability (classical preprocessing)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        gray = cv2.GaussianBlur(gray, (3, 3), 0)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        norm_gray = clahe.apply(gray)
        norm_bgr = cv2.cvtColor(norm_gray, cv2.COLOR_GRAY2BGR)

        edge_colors, edge_gradients, edge_regions = {}, {}, {}
        for side in ("top", "bottom", "left", "right"):
            edge_colors[side] = extract_edge_color(norm_bgr, side)
            edge_gradients[side] = extract_sobel_descriptors(norm_bgr, side)
            edge_regions[side] = extract_edge_region(norm_bgr, side)

        stem = f.rsplit(".", 1)[0]
        puzzles.setdefault(pid, []).append(
            PuzzlePiece(stem, pid, r, c, img, edge_colors, edge_gradients, edge_regions)
        )

    print(f"Loaded {len(puzzles)} puzzles")
    return puzzles


# ============================================================
# COST FUNCTION (8x8) - CLASSICAL SEAM + GRADIENT + SOBEL PROFILE
# ============================================================

def calc_hybrid_cost_8x8(pieceA: PuzzlePiece, pieceB: PuzzlePiece, direction: int) -> float:
    """
    direction=0: pieceB to the RIGHT of pieceA
    direction=1: pieceB BELOW pieceA

    Combines:
    - Pixel seam difference in LAB (border mismatch)
    - Gradient consistency in LAB (second-order seam consistency)
    - Sobel profile error (edge structure complement)
    """
    try:
        labA = cv2.cvtColor(pieceA.image, cv2.COLOR_BGR2LAB).astype(np.float32)
        labB = cv2.cvtColor(pieceB.image, cv2.COLOR_BGR2LAB).astype(np.float32)
    except:
        labA = pieceA.image.astype(np.float32)
        labB = pieceB.image.astype(np.float32)

    if direction == 0:
        b1, b2 = labA[:, -1, :], labB[:, 0, :]
        d1, d2 = labA[:, -2, :], labB[:, 1, :]
    else:
        b1, b2 = labA[-1, :, :], labB[0, :, :]
        d1, d2 = labA[-2, :, :], labB[1, :, :]

    diff = b1 - b2
    cost_pixel = float(np.sum(np.sqrt(np.sum(diff**2, axis=1))))

    grad_t1 = b1 - d1
    grad_t2 = d2 - b2
    grad_seam = b2 - b1

    cost_grad_lab = float(
        np.sum(np.sqrt(np.sum((grad_seam - grad_t1)**2, axis=1))) +
        np.sum(np.sqrt(np.sum((grad_seam - grad_t2)**2, axis=1)))
    )

    magA = pieceA.edge_gradients['right'] if direction == 0 else pieceA.edge_gradients['bottom']
    magB = pieceB.edge_gradients['left']  if direction == 0 else pieceB.edge_gradients['top']

    def normalize(v):
        v = np.array(v, dtype=np.float32)
        mn, mx = v.min(), v.max()
        return np.zeros_like(v) if mx == mn else (v - mn) / (mx - mn)

    magA_n, magB_n = normalize(magA), normalize(magB)
    L = min(len(magA_n), len(magB_n))
    if L > 0:
        sobel_err = np.mean((magA_n[:L] - (1.0 - magB_n[:L]))**2)
        cost_sobel = sobel_err * 30.0
    else:
        cost_sobel = 100.0

    # Your tuned weights (unchanged)
    total = cost_pixel * 1.5 + cost_grad_lab * 1.5 + cost_sobel * 0.5
    return total


def build_costs_8x8(pieces: List[PuzzlePiece]) -> np.ndarray:
    """
    costs[i,j,0] = cost if j is RIGHT of i
    costs[i,j,1] = cost if j is BELOW i
    """
    nt = len(pieces)
    costs = np.zeros((nt, nt, 2), dtype=np.float32)
    print(f"Computing {nt*nt*2} edge costs...")

    total_cmp = nt * nt * 2
    count = 0
    for i in range(nt):
        for j in range(nt):
            if i == j:
                costs[i, j, 0] = 1e9
                costs[i, j, 1] = 1e9
                count += 2
            else:
                costs[i, j, 0] = calc_hybrid_cost_8x8(pieces[i], pieces[j], 0)
                costs[i, j, 1] = calc_hybrid_cost_8x8(pieces[i], pieces[j], 1)
                count += 2

            if count % 1000 == 0:
                print(f"Progress: {count/total_cmp*100:.1f}%", end="\r")

    print("\nCost computation complete")
    return costs


# ============================================================
# ENERGY (TOTAL SEAM COST OF A GRID)
# ============================================================

def total_energy(grid: np.ndarray, costs: np.ndarray) -> float:
    n = grid.shape[0]
    e = 0.0
    for r in range(n):
        for c in range(n):
            idx = grid[r, c]
            if c < n - 1:
                e += float(costs[idx, grid[r, c+1], 0])
            if r < n - 1:
                e += float(costs[idx, grid[r+1, c], 1])
    return e


# ============================================================
# IMPROVED FRONTIER (MULTI-SEED + CONFIDENCE HEAP)
# ============================================================

def solve_frontier_improved(costs: np.ndarray, n: int, num_seeds: int = 24) -> Optional[np.ndarray]:
    """
    Frontier growing:
    - Select promising seeds (low average best neighbor cost)
    - For each seed, grow grid by placing most confident frontier candidate
    - Return the best (minimum energy) grid among seeds
    """
    nt = costs.shape[0]
    if nt != n * n:
        return None

    # Score seeds by mean of their 4 best outgoing edges
    seed_scores = []
    for i in range(nt):
        best_costs = []
        for j in range(nt):
            if i != j:
                best_costs.append(min(costs[i, j, 0], costs[i, j, 1]))
        best_costs.sort()
        seed_scores.append((float(np.mean(best_costs[:4])), i))

    seed_scores.sort()
    best_seeds = [idx for _, idx in seed_scores[:num_seeds]]
    print(f"Testing {num_seeds} promising seeds...")

    global_best_grid = None
    global_min_energy = float("inf")

    for seed_num, seed in enumerate(best_seeds):
        grid = np.full((n, n), -1, dtype=int)
        grid[0, 0] = seed
        used = {seed}
        pq = []  # heap: (confidence_ratio, best_cost, r, c, cand)

        def push_best_candidate(r: int, c: int):
            neighbors = []
            for dr, dc in [(-1, 0), (0, -1), (0, 1), (1, 0)]:
                nr, nc = r + dr, c + dc
                if 0 <= nr < n and 0 <= nc < n and grid[nr, nc] != -1:
                    neighbors.append((nr, nc, dr, dc))

            if not neighbors:
                return

            candidates = []
            for cand in range(nt):
                if cand in used:
                    continue
                cost = 0.0
                for nr, nc, dr, dc in neighbors:
                    neigh_idx = grid[nr, nc]
                    if dr == -1:      # neighbor above -> candidate is below neighbor
                        cost += float(costs[neigh_idx, cand, 1])
                    elif dr == 1:     # neighbor below -> candidate is above neighbor
                        cost += float(costs[cand, neigh_idx, 1])
                    elif dc == -1:    # neighbor left -> candidate is right of neighbor
                        cost += float(costs[neigh_idx, cand, 0])
                    else:             # neighbor right -> candidate is left of neighbor
                        cost += float(costs[cand, neigh_idx, 0])

                avg_cost = cost / len(neighbors)
                candidates.append((avg_cost, cand))

            candidates.sort(key=lambda x: x[0])
            best_cost, best_cand = candidates[0]

            # Confidence ratio: best / second-best (smaller is more confident)
            confidence = 1.0
            if len(candidates) > 1 and candidates[1][0] > 1e-5:
                confidence = best_cost / candidates[1][0]

            heapq.heappush(pq, (confidence, best_cost, r, c, best_cand))

        def add_neighbors(r: int, c: int):
            for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
                nr, nc = r + dr, c + dc
                if 0 <= nr < n and 0 <= nc < n and grid[nr, nc] == -1:
                    push_best_candidate(nr, nc)

        add_neighbors(0, 0)
        placed = 1

        while placed < nt and pq:
            conf, cost, r, c, cand = heapq.heappop(pq)

            if grid[r, c] != -1:
                continue
            if cand in used:
                push_best_candidate(r, c)
                continue

            grid[r, c] = cand
            used.add(cand)
            placed += 1
            add_neighbors(r, c)

        if placed != nt:
            continue

        energy = total_energy(grid, costs)
        if energy < global_min_energy:
            global_min_energy = energy
            global_best_grid = grid.copy()
            print(f"Seed {seed_num+1}/{num_seeds}: energy={energy:.0f}")

    return global_best_grid


# ============================================================
# SMART REFINEMENT (ADJACENT SWAPS + 2x2 BLOCK SWAPS)
# ============================================================

def refine_smart(grid: np.ndarray, costs: np.ndarray, max_iterations: int = 10) -> np.ndarray:
    """
    Local improvement:
    - Try improving via adjacent swaps
    - If stuck, try 2x2 block swaps
    """
    n = grid.shape[0]
    g = grid.copy()
    best_e = total_energy(g, costs)
    print(f"Refinement starting energy: {best_e:.0f}")

    for iteration in range(max_iterations):
        improved_adjacent = False

        # Adjacent swaps
        for r in range(n):
            for c in range(n):
                for dr, dc in [(0, 1), (1, 0)]:
                    r2, c2 = r + dr, c + dc
                    if r2 >= n or c2 >= n:
                        continue
                    g2 = g.copy()
                    g2[r, c], g2[r2, c2] = g2[r2, c2], g2[r, c]
                    e2 = total_energy(g2, costs)
                    if e2 < best_e - 1e-3:
                        g, best_e = g2, e2
                        improved_adjacent = True

        if improved_adjacent:
            print(f"Iteration {iteration+1}: energy={best_e:.0f} (adjacent swaps)")
            continue

        # 2x2 block swaps
        improved_blocks = False
        for r1 in range(0, n-1, 2):
            for c1 in range(0, n-1, 2):
                for dr in [2, -2, 0]:
                    for dc in [2, -2, 0]:
                        if dr == 0 and dc == 0:
                            continue
                        r2, c2 = r1 + dr, c1 + dc
                        if not (0 <= r2 < n-1 and 0 <= c2 < n-1):
                            continue

                        g2 = g.copy()
                        tmp = g2[r1:r1+2, c1:c1+2].copy()
                        g2[r1:r1+2, c1:c1+2] = g2[r2:r2+2, c2:c2+2]
                        g2[r2:r2+2, c2:c2+2] = tmp

                        e2 = total_energy(g2, costs)
                        if e2 < best_e - 1e-3:
                            g, best_e = g2, e2
                            improved_blocks = True

        if improved_blocks:
            print(f"Iteration {iteration+1}: energy={best_e:.0f} (block swaps)")
        else:
            print(f"Converged at iteration {iteration+1}")
            break

    print(f"Final energy: {best_e:.0f}")
    return g


# ============================================================
# RECONSTRUCT + VISUALIZATION
# ============================================================

def reconstruct_image(pieces: List[PuzzlePiece], grid: np.ndarray) -> np.ndarray:
    """
    Stitches the placed pieces into a full puzzle image.
    """
    n = grid.shape[0]
    tile_h, tile_w = pieces[0].image.shape[:2]
    out = np.zeros((n * tile_h, n * tile_w, 3), dtype=np.uint8)

    for r in range(n):
        for c in range(n):
            idx = int(grid[r, c])
            out[r*tile_h:(r+1)*tile_h, c*tile_w:(c+1)*tile_w] = pieces[idx].image

    return out


def visualize_result(pieces: List[PuzzlePiece], grid: np.ndarray, pid: str, out_dir: str):
    """
    Visualizes:
    - Our assembled image
    - Ground truth full image
    - Ground truth patches with border colored by SSIM-like score vs our patch
    """
    n = grid.shape[0]
    assembled = reconstruct_image(pieces, grid)
    H, W = assembled.shape[:2]
    cell_h, cell_w = H // n, W // n

    # Use the puzzle_id from loaded pieces (stable)
    actual_puzzle_id = pieces[0].puzzle_id

    # Load ground truth image
    correct_root = "dataset/correct/correct"
    gt = None
    for ext in [".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG"]:
        path = os.path.join(correct_root, f"{actual_puzzle_id}{ext}")
        if os.path.exists(path):
            gt = cv2.imread(path)
            print(f"Loading GT: {actual_puzzle_id}{ext}")
            break

    if gt is None:
        print(f"GT not found for puzzle {actual_puzzle_id}")
        gt = np.zeros_like(assembled)

    if (gt.shape[0], gt.shape[1]) != (H, W):
        gt = cv2.resize(gt, (W, H), interpolation=cv2.INTER_AREA)

    comparison = np.zeros_like(assembled)
    SSIM_THR = 0.60
    correct_count = 0

    for r in range(n):
        for c in range(n):
            y1, y2 = r * cell_h, (r + 1) * cell_h
            x1, x2 = c * cell_w, (c + 1) * cell_w

            asm_patch = assembled[y1:y2, x1:x2]
            gt_patch = gt[y1:y2, x1:x2]

            a_gray = cv2.cvtColor(asm_patch, cv2.COLOR_BGR2GRAY)
            g_gray = cv2.cvtColor(gt_patch, cv2.COLOR_BGR2GRAY)

            try:
                s = ssim_gray(a_gray, g_gray)
            except:
                s = 0.0

            good = (s >= SSIM_THR)
            if good:
                correct_count += 1

            color = (0, 255, 0) if good else (0, 0, 255)
            patch_vis = gt_patch.copy()
            cv2.rectangle(
                patch_vis, (0, 0),
                (patch_vis.shape[1]-1, patch_vis.shape[0]-1),
                color, thickness=4
            )
            comparison[y1:y2, x1:x2] = patch_vis

    vis_acc = correct_count / (n * n) * 100.0

    fig, ax = plt.subplots(1, 3, figsize=(22, 7))

    ax[0].imshow(cv2.cvtColor(assembled, cv2.COLOR_BGR2RGB))
    ax[0].set_title(f"Our Assembly - Puzzle {actual_puzzle_id}", fontsize=14)
    ax[0].axis("off")

    ax[1].imshow(cv2.cvtColor(gt, cv2.COLOR_BGR2RGB))
    ax[1].set_title(f"Ground Truth - Puzzle {actual_puzzle_id}", fontsize=14)
    ax[1].axis("off")

    ax[2].imshow(cv2.cvtColor(comparison, cv2.COLOR_BGR2RGB))
    ax[2].set_title(f"Visual Accuracy = {vis_acc:.1f}%", fontsize=14)
    ax[2].axis("off")

    os.makedirs(out_dir, exist_ok=True)
    save_path = os.path.join(out_dir, f"{actual_puzzle_id}_8x8_result.png")
    plt.savefig(save_path, dpi=160, bbox_inches="tight")
    plt.close()

    print(f"Saved: {save_path} | Puzzle ID: {actual_puzzle_id} | Visual accuracy: {vis_acc:.1f}%")


# ============================================================
# MAIN
# ============================================================

def run_ms2_8x8_improved(
    ms1_folder="processed_output/puzzle_8x8",
    out_folder="ms2_results_puzzle_8x8"
):
    puzzles = load_pieces(ms1_folder)
    if not puzzles:
        print("No puzzles found.")
        return

    os.makedirs(out_folder, exist_ok=True)

    for pid, pieces in sorted(
        puzzles.items(),
        key=lambda kv: int(kv[0]) if kv[0].isdigit() else kv[0]
    ):
        if len(pieces) != 64:
            print(f"Puzzle {pid}: expected 64 pieces, got {len(pieces)} - skipping.")
            continue

        print(f"\n{'='*72}\nSolving 8x8 puzzle {pid}\n{'='*72}")
        t0 = time.time()

        costs = build_costs_8x8(pieces)
        grid = solve_frontier_improved(costs, n=8, num_seeds=24)
        if grid is None:
            print("Failed to solve.")
            continue

        grid = refine_smart(grid, costs, max_iterations=10)

        print(f"Total time: {time.time() - t0:.1f}s")
        visualize_result(pieces, grid, pid, out_folder)

    print(f"\n{'='*72}\nMS2 8x8 improved run complete.\n{'='*72}")


if __name__ == "__main__":
    run_ms2_8x8_improved()


Found 7040 pieces in processed_output/puzzle_8x8/pieces
Loaded 110 puzzles

Solving 8x8 puzzle 0
Computing 8192 edge costs...
Progress: 97.7%
Cost computation complete
Testing 24 promising seeds...
Seed 1/24: energy=540492
Seed 8/24: energy=493135
Refinement starting energy: 493135
Iteration 1: energy=480993 (adjacent swaps)
Converged at iteration 2
Final energy: 480993
Total time: 2.3s
Loading GT: 0.png
Saved: ms2_results_puzzle_8x8/0_8x8_result.png | Puzzle ID: 0 | Visual accuracy: 100.0%

Solving 8x8 puzzle 1
Computing 8192 edge costs...
Progress: 97.7%
Cost computation complete
Testing 24 promising seeds...
Seed 1/24: energy=286280
Seed 2/24: energy=264303
Seed 5/24: energy=168100
Seed 6/24: energy=167000
Refinement starting energy: 167000
Converged at iteration 1
Final energy: 167000
Total time: 1.9s
Loading GT: 1.png
Saved: ms2_results_puzzle_8x8/1_8x8_result.png | Puzzle ID: 1 | Visual accuracy: 7.8%

Solving 8x8 puzzle 2
Computing 8192 edge costs...
Progress: 97.7%
Cost computa

In [None]:
import shutil

folder = "ms2_results_8x8_assembly_fig_onlyy"   # or ms2_results_8x8
zip_path = shutil.make_archive(folder, 'zip', folder)
print("Created:", zip_path)


Created: /content/ms2_results_8x8_assembly_fig_onlyy.zip
