## Imports

In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os

In [3]:
DATASET_PATH = "dataset"

path_2x2 = os.path.join(DATASET_PATH, "puzzle_2x2")
path_4x4 = os.path.join(DATASET_PATH, "puzzle_4x4")
path_8x8 = os.path.join(DATASET_PATH, "puzzle_8x8")
correct_path = os.path.join(DATASET_PATH, "correct")


images_2x2 = [os.path.join(path_2x2, f) for f in os.listdir(path_2x2)]
images_4x4 = [os.path.join(path_4x4, f) for f in os.listdir(path_4x4)]
images_8x8 = [os.path.join(path_8x8, f) for f in os.listdir(path_8x8)]
correct = [os.path.join(path_8x8, f) for f in os.listdir(correct_path)]

In [4]:
print("2x2 images:", len(images_2x2))
print("4x4 images:", len(images_4x4))
print("8x8 images:", len(images_8x8))


2x2 images: 110
4x4 images: 110
8x8 images: 110


# Matching Algorithms

In [None]:
import os
import cv2
import json
import glob
import numpy as np
import random
import matplotlib.pyplot as plt
import itertools
from tqdm import tqdm

# ==========================================
# 1. SETUP PATHS
# ==========================================
DATASET_PATH = "dataset"
correct_path = os.path.join(DATASET_PATH, "correct")
DRIVE_DESC_PATH = "/content/drive/MyDrive/Jigsaw_Milestone1/descriptors"

valid_exts = ('.jpg', '.jpeg', '.png', '.bmp')

if os.path.exists(correct_path):
    ground_truth_files = [
        os.path.join(correct_path, f) for f in os.listdir(correct_path)
        if f.lower().endswith(valid_exts)
    ]
    ground_truth_files.sort()
else:
    print("‚ùå Error: Dataset 'correct' folder not found.")
    ground_truth_files = []

DRIVE_DESC_PATH = "/content/drive/MyDrive/Jigsaw_Milestone1/descriptors"

def load_descriptors(grid_size, img_index):
    folder = os.path.join(DRIVE_DESC_PATH, f"puzzle_{grid_size}x{grid_size}")

    edge_map = {
        "top": 0,
        "right": 1,
        "bottom": 2,
        "left": 3
    }

    descriptors = {}

    for p in range(grid_size * grid_size):
        json_path = os.path.join(folder, f"{img_index}_p{p}.json")

        if not os.path.exists(json_path):
            return None  # missing descriptors for this image

        with open(json_path, "r") as f:
            data = json.load(f)

        pid = f"piece_{p}"  # must match cut_grid_dynamic

        edges = [None] * 4
        for edge_name, idx in edge_map.items():
            if edge_name not in data:
                return None

            e = data[edge_name]
            edges[idx] = {
                "mean_profile": np.array(e["mean_profile"], dtype=np.float32),
                "color_variance": float(e["color_variance"]),
                "grad_hist": np.array(e.get("grad_hist", []), dtype=np.float32)
            }

        descriptors[pid] = edges

    return descriptors

def calculate_descriptor_match(d1, d2):
    """Compares two descriptors using L2 distance weighted by variance."""
    if d1 is None or d2 is None: return float('inf')
    prof1, var1 = d1['mean_profile'], d1['color_variance']
    prof2, var2 = d2['mean_profile'], d2['color_variance']
    diff = np.mean(np.linalg.norm(prof1 - prof2, axis=1))
    combined_var = (var1 + var2) / 2.0
    return diff - (combined_var * 0.8)

# ==========================================
# 3. SOLVER 1: 2x2 BRUTE FORCE (DESCRIPTOR BASED)
# ==========================================
def solve_2x2_descriptors(piece_images, descriptors):
    ids = list(piece_images.keys())
    best_score = float('inf')
    best_perm = None

    # Try all 24 permutations
    for p in itertools.permutations(ids):
        # p = (TL, TR, BL, BR)
        s = 0

        # Check if descriptors exist for all pieces in this perm
        if any(pid not in descriptors for pid in p): continue

        # Horizontal Seams
        # TL(Right) - TR(Left)
        s += calculate_descriptor_match(descriptors[p[0]][1], descriptors[p[1]][3])
        # BL(Right) - BR(Left)
        s += calculate_descriptor_match(descriptors[p[2]][1], descriptors[p[3]][3])

        # Vertical Seams
        # TL(Bottom) - BL(Top)
        s += calculate_descriptor_match(descriptors[p[0]][2], descriptors[p[2]][0])
        # TR(Bottom) - BR(Top)
        s += calculate_descriptor_match(descriptors[p[1]][2], descriptors[p[3]][0])

        if s < best_score:
            best_score = s
            best_perm = p

    # Render
    if best_perm:
        top = np.hstack([piece_images[best_perm[0]], piece_images[best_perm[1]]])
        btm = np.hstack([piece_images[best_perm[2]], piece_images[best_perm[3]]])
        return np.vstack([top, btm])
    else:
        # Fallback if descriptors missing
        sample = next(iter(piece_images.values()))
        h, w = sample.shape[:2]
        return np.zeros((h*2, w*2, 3), dtype=np.uint8)

# ==========================================
# 4. SOLVER 2: 4x4 / 8x8 ROBUST SOLVER
# ==========================================
class PuzzleSolverDescriptor:
    def __init__(self, piece_images, descriptors_map, max_h, max_w):
        self.pieces = piece_images
        self.descriptors = descriptors_map
        self.ids = list(piece_images.keys())
        self.idx_to_id = {i: pid for i, pid in enumerate(self.ids)}
        self.max_h = max_h
        self.max_w = max_w
        self.n = len(self.ids)
        self.islands = {i: {(0,0): i} for i in range(self.n)}
        self.piece_to_island = {i: i for i in range(self.n)}

    def get_matches_by_margin(self):
        all_matches = []
        pairs = [(0, 2), (1, 3)]
        for i in range(self.n):
            pid1 = self.idx_to_id[i]
            if pid1 not in self.descriptors: continue
            for e1, e2 in pairs:
                candidates = []
                for j in range(self.n):
                    if i == j: continue
                    pid2 = self.idx_to_id[j]
                    if pid2 not in self.descriptors: continue
                    score = calculate_descriptor_match(self.descriptors[pid1][e1], self.descriptors[pid2][e2])
                    candidates.append((score, j))
                if not candidates: continue
                candidates.sort(key=lambda x: x[0])
                best_s, best_j = candidates[0]
                second_s = candidates[1][0] if len(candidates) > 1 else float('inf')
                all_matches.append({'margin': second_s - best_s, 'p1': i, 'e1': e1, 'p2': best_j, 'e2': e2, 'score': best_s})
        all_matches.sort(key=lambda x: x['margin'], reverse=True)
        return all_matches

    def merge_islands(self, p1, e1, p2, e2):
        root1, root2 = self.piece_to_island[p1], self.piece_to_island[p2]
        if root1 == root2: return False
        island1, island2 = self.islands[root1], self.islands[root2]
        r1, c1 = [k for k, v in island1.items() if v == p1][0]
        dr, dc = {0:(-1,0), 1:(0,1), 2:(1,0), 3:(0,-1)}[e1]
        r2, c2 = [k for k, v in island2.items() if v == p2][0]
        shift_r, shift_c = (r1 + dr) - r2, (c1 + dc) - c2

        new_coords = {}
        all_coords = list(island1.keys())
        for (r, c), pid in island2.items():
            nr, nc = r + shift_r, c + shift_c
            if (nr, nc) in island1: return False
            new_coords[(nr, nc)] = pid
            all_coords.append((nr, nc))
        rs, cs = [x[0] for x in all_coords], [x[1] for x in all_coords]
        if (max(rs)-min(rs)+1) > self.max_h or (max(cs)-min(cs)+1) > self.max_w: return False

        for coord, pid in new_coords.items():
            island1[coord] = pid
            self.piece_to_island[pid] = root1
        del self.islands[root2]
        return True

    def force_fill_holes(self):
        if not self.islands: return
        best_island_key = max(self.islands.keys(), key=lambda k: len(self.islands[k]))
        main_island = self.islands[best_island_key]
        grid_matrix = np.full((self.max_h, self.max_w), -1, dtype=int)
        rs, cs = [r for r,c in main_island.keys()], [c for r,c in main_island.keys()]
        min_r, min_c = min(rs), min(cs)

        for (r,c), pid in main_island.items():
            nr, nc = r - min_r, c - min_c
            if 0 <= nr < self.max_h and 0 <= nc < self.max_w: grid_matrix[nr, nc] = pid

        current_placed = set(grid_matrix[grid_matrix != -1])
        orphans = [i for i in range(self.n) if i not in current_placed]

        while orphans:
            best_move = None
            best_val = float('inf')
            empty_spots = list(zip(*np.where(grid_matrix == -1)))
            if not empty_spots: break

            for r, c in empty_spots:
                neighbors = []
                if r>0 and grid_matrix[r-1,c]!=-1: neighbors.append((grid_matrix[r-1,c], 2, 0))
                if r<self.max_h-1 and grid_matrix[r+1,c]!=-1: neighbors.append((grid_matrix[r+1,c], 0, 2))
                if c>0 and grid_matrix[r,c-1]!=-1: neighbors.append((grid_matrix[r,c-1], 1, 3))
                if c<self.max_w-1 and grid_matrix[r,c+1]!=-1: neighbors.append((grid_matrix[r,c+1], 3, 1))
                if not neighbors: continue

                for pid in orphans:
                    pid_real = self.idx_to_id[pid]
                    if pid_real not in self.descriptors: continue
                    err = 0
                    for n_pid, n_edge, p_edge in neighbors:
                        n_real = self.idx_to_id[n_pid]
                        if n_real not in self.descriptors: continue
                        d1, d2 = self.descriptors[n_real][n_edge], self.descriptors[pid_real][p_edge]
                        err += np.mean(np.linalg.norm(d1['mean_profile'] - d2['mean_profile'], axis=1))
                    avg_err = err / len(neighbors)
                    if avg_err < best_val:
                        best_val, best_move = avg_err, (r, c, pid)

            if best_move:
                r, c, pid = best_move
                grid_matrix[r,c] = pid
                orphans.remove(pid)
            else:
                pid = orphans.pop(0)
                r, c = empty_spots[0]
                grid_matrix[r,c] = pid

        final_island = {}
        for r in range(self.max_h):
            for c in range(self.max_w):
                pid = grid_matrix[r,c]
                if pid != -1: final_island[(r,c)] = pid
        self.islands = {0: final_island}

    def solve(self):
        matches = self.get_matches_by_margin()
        for m in matches:
            self.merge_islands(m['p1'], m['e1'], m['p2'], m['e2'])
            if len(self.islands) == 1 and len(list(self.islands.values())[0]) == self.n: break
        self.force_fill_holes()
        if not self.islands: return None
        return self.render(max(self.islands.values(), key=len))

    def render(self, island):
        rs, cs = [k[0] for k in island.keys()], [k[1] for k in island.keys()]
        if not rs: return np.zeros((self.max_h*100, self.max_w*100, 3), dtype=np.uint8)
        min_r, min_c = min(rs), min(cs)
        sample = next(iter(self.pieces.values()))
        ph, pw = sample.shape[:2]
        canvas = np.zeros((self.max_h * ph, self.max_w * pw, 3), dtype=np.uint8)
        for (r, c), idx in island.items():
            norm_r, norm_c = r - min_r, c - min_c
            if 0 <= norm_r < self.max_h and 0 <= norm_c < self.max_w:
                canvas[norm_r*ph:(norm_r+1)*ph, norm_c*pw:(norm_c+1)*pw] = self.pieces[self.idx_to_id[idx]]
        return canvas

# ==========================================
# 5. VISUALIZATION FUNCTION (ALL SIZES)
# ==========================================
def cut_grid_dynamic(img, rows, cols):
    h, w = img.shape[:2]
    ph, pw = h // rows, w // cols
    pieces = {}
    idx = 0
    for r in range(rows):
        for c in range(cols):
            pieces[f"piece_{idx}"] = img[r*ph:(r+1)*ph, c*pw:(c+1)*pw].copy()
            idx += 1
    return pieces

def evaluate_final(grid_size, threshold=95.0):
    print(f"\nüìä FINAL SMART SOLVER: {grid_size}x{grid_size}")
    if not ground_truth_files: return

    success_count = 0
    total = len(ground_truth_files)

    for img_path in tqdm(ground_truth_files):
        original = cv2.imread(img_path)
        if original is None: continue

        pieces = cut_grid_dynamic(original, grid_size, grid_size)
        keys = list(pieces.keys())
        random.shuffle(keys)
        shuffled = {k: pieces[k] for k in keys}

        try:
            solver = PuzzleSolverSmart(shuffled, grid_size, grid_size)
            res = solver.solve()
        except Exception:
            res = np.zeros_like(original)

        if res.shape != original.shape:
             res = cv2.resize(res, (original.shape[1], original.shape[0]))

        d_img = cv2.absdiff(res, original)
        acc = max(0, 100 - np.mean(d_img))

        if acc > threshold: success_count += 1

    print(f"‚úÖ Result: {success_count} / {total} correct")
    print(f"üèÜ Success Rate: {(success_count/total)*100:.2f}%")
# ==========================================
# 4. VISUALIZATION FUNCTION
# ==========================================

def visualize_descriptor_failures(grid_size, num_to_show=3, threshold=95.0):
    print(f"\nüîç HUNTING FAILURES (Using Drive Descriptors): {grid_size}x{grid_size}...")

    if not ground_truth_files: return

    failed_examples = []
    # Use index to match with Drive structure
    indices = list(range(len(ground_truth_files)))
    random.shuffle(indices)

    pbar = tqdm(total=num_to_show)

    for idx in indices:
        if len(failed_examples) >= num_to_show: break

        img_path = ground_truth_files[idx]
        original = cv2.imread(img_path)
        if original is None: continue

        # 1. Load Descriptors from Drive
        descs = load_descriptors(grid_size, idx)
        if descs is None: continue # Skip if missing

        # 2. Cut & Scramble
        pieces = cut_grid_dynamic(original, grid_size, grid_size)
        keys = list(pieces.keys())
        random.shuffle(keys)
        shuffled = {k: pieces[k] for k in keys}

        # 3. Solve
        try:
            if grid_size == 2:
                res = solve_2x2_descriptors(shuffled, descs)
            else:
                solver = PuzzleSolverDescriptor(shuffled, descs, max_h=grid_size, max_w=grid_size)
                res = solver.solve()

            if res is None: res = np.zeros_like(original)
        except Exception:
            res = np.zeros_like(original)

        # 4. Check Accuracy
        if res.shape != original.shape:
             res = cv2.resize(res, (original.shape[1], original.shape[0]))
        diff = cv2.absdiff(res, original)
        acc = max(0, 100 - np.mean(diff))

        if acc < threshold:
            failed_examples.append((original, res, acc, os.path.basename(img_path)))
            pbar.update(1)

    pbar.close()

    if failed_examples:
        print(f"\n‚ö†Ô∏è Found {len(failed_examples)} failed examples:")
        rows = len(failed_examples)
        plt.figure(figsize=(10, 5 * rows))
        for i, (orig, bad_result, acc, fname) in enumerate(failed_examples):
            # Original
            plt.subplot(rows, 2, i*2 + 1)
            plt.imshow(cv2.cvtColor(orig, cv2.COLOR_BGR2RGB))
            plt.title(f"Original: {fname}", fontsize=12)
            plt.axis("off")
            # Failed
            plt.subplot(rows, 2, i*2 + 2)
            plt.imshow(cv2.cvtColor(bad_result, cv2.COLOR_BGR2RGB))
            plt.title(f"FAILED (Acc: {acc:.2f}%)", fontsize=12, color='red')
            plt.axis("off")
        plt.tight_layout()
        plt.show()
    else:
        print("üéâ No failures found in this batch!")


if ground_truth_files:
    # This should be your highest score yet
    evaluate_final(grid_size=2, threshold=95.0)
    evaluate_final(grid_size=4, threshold=95.0)
    evaluate_final(grid_size=8, threshold=90.0)
    visualize_descriptor_failures(grid_size=4, num_to_show=3, threshold=95.0)


