# Phase 2 Batch Processing: Enhanced Solver

This notebook implements the full Phase 2 pipeline for reassembling the puzzle pieces.

## key Features
1.  **Constrained Backtracking Solver**: Uses a "Best-First Grid Search" to fill the grid (0,0 to N,N). This guarantees that the output dimensions strictly match the expected $N \times N$ size, solving the "Size Mismatch" issue.
2.  **Rotation Bias**: Adds a massive penalty (`ROTATION_PENALTY`) to any piece that is rotated (90, 180, 270). This forces the solver to prefer the 0-degree "Upright" orientation, solving the "Flipped Image" issue.
3.  **Simulated Annealing (SA)**: A post-processing refinement step that randomly swaps and rotates pieces to minimize global edge error, improving accuracy for large puzzles (8x8).
4.  **Dual Accuracy Metrics**:
    *   **Piece Accuracy**: Checks strict piece ID placement (Perfect for debugging).
    *   **Visual Accuracy**: Compares the reconstructed image pixel-by-pixel against the Ground Truth image from the dataset (Perfect for user verification).
5.  **Batch Architecture**: Iterates through all 330 processed puzzles from Phase 1.

In [1]:
import cv2
import numpy as np
import os
import glob
import sys
import math
import random
import copy

# High penalty to force pieces to be upright (0 degrees) if possible
ROTATION_PENALTY = 10000.0 

In [2]:
class PuzzlePiece:
    def __init__(self, piece_id, image):
        self.id = piece_id
        # Keep original image for reconstruction
        self.original_image = image
        self.h, self.w = image.shape[:2]
        
        # Pre-compute rotations and features
        # 0: 0 deg, 1: 90 deg, 2: 180 deg, 3: 270 deg (Clockwise)
        self.rotations = []
        
        img = image
        for r in range(4):
            features = self._extract_features(img)
            self.rotations.append({
                "image": img,
                "features": features
            })
            img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
            
    def _extract_features(self, img):
        # Convert to Lab for better color matching
        lab = cv2.cvtColor(img, cv2.COLOR_BGR2Lab).astype(np.float32)
        return {
            'TOP': lab[0, :, :],
            'BOTTOM': lab[-1, :, :],
            'LEFT': lab[:, 0, :],
            'RIGHT': lab[:, -1, :]
        }
    
    def get_features(self, rotation_idx):
        return self.rotations[rotation_idx]['features']
    
    def get_image(self, rotation_idx):
        return self.rotations[rotation_idx]['image']

class PlacedPiece:
    def __init__(self, piece, x, y, rotation):
        self.piece = piece
        self.x = x
        self.y = y
        self.rotation = rotation

In [3]:
def calculate_ssd(edge1, edge2):
    """Sum of Squared Differences between two edges."""
    len1 = len(edge1)
    len2 = len(edge2)
    min_len = min(len1, len2)
    
    e1 = edge1[:min_len]
    e2 = edge2[:min_len]
    
    diff = e1 - e2
    ssd = np.sum(diff ** 2)
    return ssd / min_len

## Solvers

In [4]:
def solve_puzzle_constrained(pieces, N):
    """
    Solves the puzzle by iterating through all possible valid placements ensuring an NxN grid.
    Uses a 'Best-First Grid Search' strategy:
    1. Try every piece/rotation as the top-left seed.
    2. Greedily fill the rest of the grid based on neighbor constraints.
    3. Keep the solution with the minimum total cost.
    """
    if not pieces: return []
    
    best_solution = None
    min_total_cost = float('inf')
    
    # Pre-compute piece indices for fast access
    piece_indices = list(range(len(pieces)))
    
    # Try every piece at every rotation as the start (0,0)
    # This prevents getting stuck if the "best looking corner" is actually a middle piece.
    for start_idx in piece_indices:
        for start_rot in range(4):
            current_grid = [[None] * N for _ in range(N)]
            used = {start_idx}
            
            # Place (0,0)
            seed_piece = pieces[start_idx]
            current_grid[0][0] = PlacedPiece(seed_piece, 0, 0, start_rot)
            
            current_cost = 0
            valid_fill = True
            
            # Fill the rest of the grid linearly
            # (0,1), (0,2)... (1,0), (1,1)...
            for row in range(N):
                for col in range(N):
                    if row == 0 and col == 0: continue
                    
                    # Find best match for (row, col)
                    best_match_cost = float('inf')
                    best_match_idx = -1
                    best_match_rot = -1
                    
                    # Constraints
                    left_placed = current_grid[row][col-1] if col > 0 else None
                    top_placed = current_grid[row-1][col] if row > 0 else None
                    
                    # Iterate all unused pieces
                    for cand_idx in piece_indices:
                        if cand_idx in used: continue
                        
                        cand_piece = pieces[cand_idx]
                        
                        for rot in range(4):
                            cand_feats = cand_piece.get_features(rot)
                            local_cost = 0
                            
                            # Check Left Neighbor (Match Left's Right to My Left)
                            if left_placed:
                                left_feats = left_placed.piece.get_features(left_placed.rotation)
                                local_cost += calculate_ssd(left_feats['RIGHT'], cand_feats['LEFT'])
                                
                            # Check Top Neighbor (Match Top's Bottom to My Top)
                            if top_placed:
                                top_feats = top_placed.piece.get_features(top_placed.rotation)
                                local_cost += calculate_ssd(top_feats['BOTTOM'], cand_feats['TOP'])
                            
                            # Apply Rotation Penalty
                            if rot != 0:
                                local_cost += ROTATION_PENALTY

                            if local_cost < best_match_cost:
                                best_match_cost = local_cost
                                best_match_idx = cand_idx
                                best_match_rot = rot
                    
                    # Place the best candidate
                    if best_match_idx != -1:
                        current_grid[row][col] = PlacedPiece(pieces[best_match_idx], col, row, best_match_rot)
                        used.add(best_match_idx)
                        current_cost += best_match_cost
                        
                        # Optimization: Prune if already worse than best global
                        if current_cost >= min_total_cost:
                            valid_fill = False
                            break
                    else:
                        valid_fill = False
                        break
                
                if not valid_fill: break
            
            if valid_fill:
                if current_cost < min_total_cost:
                    min_total_cost = current_cost
                    # Flatten to list
                    best_solution = []
                    for r in range(N):
                        for c in range(N):
                            best_solution.append(current_grid[r][c])
                            
    return best_solution if best_solution else []

In [5]:
def calculate_grid_energy(grid, N):
    """
    Calculates total energy (sum of SSDs) of the grid.
    Lower is better.
    """
    energy = 0
    # Horizontal edges
    for r in range(N):
        for c in range(N-1):
            p1 = grid[r][c]
            p2 = grid[r][c+1]
            if p1 and p2:
                feats1 = p1.piece.get_features(p1.rotation)
                feats2 = p2.piece.get_features(p2.rotation)
                energy += calculate_ssd(feats1['RIGHT'], feats2['LEFT'])
                
    # Vertical edges
    for r in range(N-1):
        for c in range(N):
            p1 = grid[r][c]
            p2 = grid[r+1][c]
            if p1 and p2:
                feats1 = p1.piece.get_features(p1.rotation)
                feats2 = p2.piece.get_features(p2.rotation)
                energy += calculate_ssd(feats1['BOTTOM'], feats2['TOP'])
    
    # Rotation Penalty (Bias towards 0)
    for r in range(N):
        for c in range(N):
            p = grid[r][c]
            if p and p.rotation != 0:
                energy += ROTATION_PENALTY

    return energy

def refine_solution_sa(initial_solution, N, iterations=10000, initial_temp=1.0, cooling_rate=0.9995):
    """
    Refines a grid solution using Simulated Annealing.
    Moves: Swap two pieces, Rotate a piece.
    Enhanced with Reheating/Restart logic for 8x8.
    """
    
    # Reconstruct grid
    grid = [[None]*N for _ in range(N)]
    for p in initial_solution:
        if 0 <= p.y < N and 0 <= p.x < N:
            grid[p.y][p.x] = p
            
    # Validation check
    for r in range(N):
        for c in range(N):
            if grid[r][c] is None:
                return initial_solution

    current_energy = calculate_grid_energy(grid, N)
    best_grid = copy.deepcopy(grid)
    best_energy = current_energy
    
    temp = initial_temp
    
    # Reheat parameters
    # Reheat every 10% of iterations if no improvement?
    # Simple schedule: Reheat twice.
    reheat_points = {int(iterations*0.3), int(iterations*0.6)}
    
    for i in range(iterations):
        # Reheating
        if i in reheat_points:
            temp = max(temp, initial_temp * 0.5)
            # print(f"Reheating at step {i} to temp {temp}")

        # Create a candidate neighbor
        move_type = random.choice(['swap', 'rotate'])
        revert_move = None
        
        if move_type == 'swap':
            r1, c1 = random.randint(0, N-1), random.randint(0, N-1)
            r2, c2 = random.randint(0, N-1), random.randint(0, N-1)
            while r1 == r2 and c1 == c2:
                r2, c2 = random.randint(0, N-1), random.randint(0, N-1)
                
            # Perform swap
            temp_p1 = grid[r1][c1]
            temp_p2 = grid[r2][c2]
            
            # Update internal coordinates
            temp_p1.x, temp_p1.y = c2, r2
            temp_p2.x, temp_p2.y = c1, r1
            
            grid[r1][c1] = temp_p2
            grid[r2][c2] = temp_p1
            
            # Define revert
            def revert_swap(g, _r1, _c1, _r2, _c2, _p1, _p2):
                 g[_r1][_c1] = _p1
                 g[_r2][_c2] = _p2
                 _p1.x, _p1.y = _c1, _r1
                 _p2.x, _p2.y = _c2, _r2
            
            revert_move = lambda: revert_swap(grid, r1, c1, r2, c2, temp_p1, temp_p2)
            
        elif move_type == 'rotate':
            r, c = random.randint(0, N-1), random.randint(0, N-1)
            p = grid[r][c]
            old_rot = p.rotation
            new_rot = (old_rot + random.randint(1, 3)) % 4
            p.rotation = new_rot
            
            def revert_rot(g, _r, _c, _old):
                g[_r][_c].rotation = _old
            
            revert_move = lambda: revert_rot(grid, r, c, old_rot)

        # Calculate new energy
        new_energy = calculate_grid_energy(grid, N)
        
        # Acceptance logic
        delta = new_energy - current_energy
        accept = False
        if delta < 0:
            accept = True
        else:
            try:
                prob = math.exp(-delta / temp)
            except OverflowError:
                prob = 0
            if random.random() < prob:
                accept = True
                
        if accept:
            current_energy = new_energy
            if current_energy < best_energy:
                best_energy = current_energy
                best_grid = copy.deepcopy(grid)
        else:
            # Revert move
            revert_move()
            
        # Cool down
        temp *= cooling_rate
        
    # Flatten best_grid back to solution list
    final_solution = []
    for r in range(N):
        for c in range(N):
            p = best_grid[r][c]
            # Ensure coordinates are synced
            p.x, p.y = c, r
            final_solution.append(p)
            
    return final_solution

In [6]:
def reconstruct_image(solution):
    """
    Returns (image, (grid_w, grid_h))
    """
    if not solution: return None, (0, 0)
    
    xs = [p.x for p in solution]
    ys = [p.y for p in solution]
    min_x, max_x = min(xs), max(xs)
    min_y, max_y = min(ys), max(ys)
    
    piece_h, piece_w = solution[0].piece.h, solution[0].piece.w
    
    grid_w = max_x - min_x + 1
    grid_h = max_y - min_y + 1
    
    canvas_h = grid_h * piece_h
    canvas_w = grid_w * piece_w
    
    canvas = np.zeros((canvas_h, canvas_w, 3), dtype=np.uint8)
    
    for p in solution:
        img = p.piece.get_image(p.rotation)
        ph, pw = img.shape[:2]
        
        col = p.x - min_x
        row = p.y - min_y
        
        y_start = row * piece_h
        x_start = col * piece_w
        
        # Clip if necessary
        h_place = min(ph, canvas_h - y_start)
        w_place = min(pw, canvas_w - x_start)
        
        canvas[y_start:y_start+h_place, x_start:x_start+w_place] = img[:h_place, :w_place]
        
    return canvas, (grid_w, grid_h)

In [7]:
def calculate_accuracy(solution, N):
    """
    Calculates accuracy based on piece ID matching grid position.
    Assumes piece_ID for (row, col) is "piece_{row*N + col}"
    Returns (correct_count, total_count, accuracy_percentage)
    """
    correct = 0
    total = N * N
    
    for p in solution:
        # Expected ID based on current grid position (p.y, p.x)
        expected_id_suffix = p.y * N + p.x
        expected_id = f"piece_{expected_id_suffix}"
        
        # Check ID match and Rotation match (must be 0)
        if p.piece.id == expected_id and p.rotation == 0:
            correct += 1
            
    return (correct / total) * 100.0

def calculate_visual_accuracy(reconstructed, original_path):
    """
    Compares the reconstructed image against the Ground Truth image.
    Returns percentage of pixels that match (within tolerance).
    """
    if reconstructed is None:
        return 0.0
        
    original = cv2.imread(original_path)
    if original is None:
        return 0.0
        
    # Resize original to match reconstruction if slight mismatch
    # (Phase 1 slicing might skip last pixel row/col if division not exact)
    if original.shape != reconstructed.shape:
        original = cv2.resize(original, (reconstructed.shape[1], reconstructed.shape[0]))
    
    diff = cv2.absdiff(reconstructed, original)
    
    # Tolerance: 10 intensity limits (out of 255)
    match_mask = np.all(diff < 10, axis=2)
    
    accuracy = np.count_nonzero(match_mask) / match_mask.size
    return accuracy * 100.0

def process_single_puzzle(puzzle_dir, output_root, category, dataset_root):  # Added dataset_root arg
    """
    Solves one puzzle folder containing piece_*.jpg files.
    """
    try:
        parts = category.split('_')
        dim_part = parts[1] # '4x4'
        expected_N = int(dim_part.split('x')[0])
    except:
        expected_N = 0 

    image_files = glob.glob(os.path.join(puzzle_dir, "piece_*.jpg"))
    if not image_files:
        return False, "No images", 0.0, 0.0
        
    # Sort files
    try:
        image_files.sort(key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split('_')[1]))
    except:
        pass 
    
    pieces = []
    for f in image_files:
        img = cv2.imread(f)
        if img is not None:
            pid = os.path.splitext(os.path.basename(f))[0]
            pieces.append(PuzzlePiece(pid, img))
            
    if not pieces:
        return False, "No valid pieces loaded", 0.0, 0.0

    # Solve
    solution = None
    if expected_N > 0 and len(pieces) == expected_N * expected_N:
        # 1. Constrained Greedy
        solution = solve_puzzle_constrained(pieces, expected_N)
        
        # 2. SA Refinement
        if solution:
            # Enhanced parameters for 8x8
            iter_count = 20000
            if expected_N >= 8: iter_count = 100000 # Boosted for better accuracy
            
            solution = refine_solution_sa(solution, expected_N, iterations=iter_count, initial_temp=1000.0)
            
    else:
        print(f"Warning: {category} expected {expected_N}x{expected_N} but found {len(pieces)}.")

    piece_acc = 0.0
    visual_acc = 0.0
    
    if solution:
        final_img, (grid_w, grid_h) = reconstruct_image(solution)
        
        # Calculate Piece Accuracy
        piece_acc = calculate_accuracy(solution, expected_N)
        
        # Calculate Visual Accuracy (vs Ground Truth)
        # input puzzle_dir is like "phase1_batch_output/puzzle_2x2/0"
        # original image is "Jigsaw Puzzle Dataset/Gravity Falls/puzzle_2x2/0.jpg"
        puzzle_name = os.path.basename(puzzle_dir)
        original_path = os.path.join(dataset_root, category, f"{puzzle_name}.jpg")
        
        visual_acc = calculate_visual_accuracy(final_img, original_path)
    
        # Output
        rel_path = os.path.relpath(puzzle_dir, "phase1_batch_output")
        out_dir = os.path.join(output_root, rel_path)
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        is_valid = (grid_w == expected_N and grid_h == expected_N)
        
        if is_valid:
            cv2.imwrite(os.path.join(out_dir, "solved.jpg"), final_img)
            return True, f"{rel_path}: OK | Piece Acc: {piece_acc:.1f}% | Visual Acc: {visual_acc:.1f}%", piece_acc, visual_acc
        else:
            cv2.imwrite(os.path.join(out_dir, f"solved_mismatch_{grid_w}x{grid_h}.jpg"), final_img)
            return False, f"{rel_path}: SIZE MISMATCH", piece_acc, visual_acc
            
    return False, "No solution found", 0.0, 0.0

In [8]:
def process_dataset(input_root, output_root, dataset_root):
    print(f"Starting Phase 2 Batch Processing from {input_root}...")
    
    categories = ["puzzle_2x2", "puzzle_4x4", "puzzle_8x8"]
    
    for cat in categories:
        cat_path = os.path.join(input_root, cat)
        if not os.path.exists(cat_path):
            continue
            
        print(f"\nProcessing category: {cat}")
        puzzle_folders = [f.path for f in os.scandir(cat_path) if f.is_dir()]
        
        total_piece_acc = 0.0
        total_visual_acc = 0.0
        count = 0
        failures = 0
        
        for p_dir in puzzle_folders:
            try:
                success, msg, p_acc, v_acc = process_single_puzzle(p_dir, output_root, cat, dataset_root)
                count += 1
                total_piece_acc += p_acc
                total_visual_acc += v_acc
                
                if not success:
                    failures += 1
                    print(f"  [FAIL] {msg}")
                
                if count % 10 == 0:
                    print(f"  Progress: {count}/{len(puzzle_folders)} - Avg Piece Acc: {total_piece_acc/count:.1f}% | Avg Visual Acc: {total_visual_acc/count:.1f}%")
            except Exception as e:
                print(f"Failed to solve {p_dir}: {e}")
                
        print(f"Category {cat} Complete.")
        print(f"Average Piece Accuracy: {total_piece_acc/count:.2f}%")
        print(f"Average Visual Accuracy: {total_visual_acc/count:.2f}%")
        print(f"Total Failures: {failures}")

In [9]:
if __name__ == "__main__":
    input_root = "phase1_batch_output"
    output_root = "phase2_batch_output"
    dataset_root = r"Jigsaw Puzzle Dataset/Gravity Falls" # Path to original images
    
    if os.path.exists(input_root):
        process_dataset(input_root, output_root, dataset_root)
    else:
        print(f"Input directory not found: {input_root}")

Starting Phase 2 Batch Processing from phase1_batch_output...

Processing category: puzzle_2x2
  Progress: 10/110 - Avg Piece Acc: 27.5% | Avg Visual Acc: 28.2%
  Progress: 20/110 - Avg Piece Acc: 32.5% | Avg Visual Acc: 33.3%
  Progress: 30/110 - Avg Piece Acc: 27.5% | Avg Visual Acc: 30.7%
  Progress: 40/110 - Avg Piece Acc: 28.8% | Avg Visual Acc: 33.0%
  Progress: 50/110 - Avg Piece Acc: 26.0% | Avg Visual Acc: 30.4%
  Progress: 60/110 - Avg Piece Acc: 26.2% | Avg Visual Acc: 30.3%
  Progress: 70/110 - Avg Piece Acc: 25.4% | Avg Visual Acc: 29.3%
  Progress: 80/110 - Avg Piece Acc: 25.3% | Avg Visual Acc: 28.9%
  Progress: 90/110 - Avg Piece Acc: 24.7% | Avg Visual Acc: 28.4%
  Progress: 100/110 - Avg Piece Acc: 24.2% | Avg Visual Acc: 27.7%
  Progress: 110/110 - Avg Piece Acc: 24.8% | Avg Visual Acc: 28.1%
Category puzzle_2x2 Complete.
Average Piece Accuracy: 24.77%
Average Visual Accuracy: 28.09%
Total Failures: 0

Processing category: puzzle_4x4
  Progress: 10/110 - Avg Piece Ac

## Validation Cell
Run this cell to explicitly validate accuracy on specific subsets or debug low-accuracy cases.

In [10]:
# Validaton Sandbox
def validate_subsection(category, limit=5):
    input_root = "phase1_batch_output"
    output_root = "phase2_validation_output"
    dataset_root = r"Jigsaw Puzzle Dataset/Gravity Falls"
    cat_path = os.path.join(input_root, category)
    
    puzzle_folders = [f.path for f in os.scandir(cat_path) if f.is_dir()]
    puzzle_folders = puzzle_folders[:limit]
    
    print(f"Validating {limit} puzzles from {category}...")
    for p_dir in puzzle_folders:
        success, msg, p_acc, v_acc = process_single_puzzle(p_dir, output_root, category, dataset_root)
        print(f"{os.path.basename(p_dir)}: {msg}")

# Uncomment to run validation test:
# validate_subsection("puzzle_8x8", limit=2)