This is a standalone notebook to solve a jigsaw puzzle

# Import dependencies

In [None]:
%pip install matplotlib opencv-python scipy tqdm;

In [None]:
import matplotlib.pyplot as plt
import glob
import numpy as np
import scipy.stats
import cv2
import math
import random
import itertools
import json
import tqdm.notebook as tqdm
from functools import cache
from scipy.signal import savgol_filter
from scipy.signal import find_peaks
import scipy.optimize
from collections import Counter

# Add utilities

In [None]:
def imshow(img, title=None):
    plt.title(title)
    plt.imshow(img[0:1680, 0:1700])
    plt.show()

class Item():
    def __init__(self, **kwargs):
        self.update(**kwargs)

    def update(self, **kwargs):
        self.__dict__.update(kwargs)

class LoopingList(list):
    def __getitem__(self, i):
        if isinstance(i, int):
            return super().__getitem__(i % len(self))
        else:
            return super().__getitem__(i)

def plot_contour(contour, **kwargs):
    plt.plot(contour[:, :, 0], contour[:, :, 1], **kwargs)

def fill_contour(contour, **kwargs):
    plt.fill(contour[:, :, 0], contour[:, :, 1], **kwargs)

def transform_contour(contour, idx, params):
    x, y, degrees = params
    matrix = cv2.getRotationMatrix2D(contour[idx][0], degrees, 1)
    translate = (x, y) - contour[idx][0]
    return cv2.transform(contour, matrix) + translate

def sub_contour(c, idx0, idx1):
    if idx1 > idx0:
        return c[idx0:idx1]
    else:
        return np.concatenate([c[idx0:], c[:idx1]])

# Detect pieces

In [None]:
images = []
pieces = []
for filename in tqdm.tqdm(sorted(glob.glob("scans/kaggle/*.jpg"))):
    img_original = cv2.imread(filename)
    h, w = img_original.shape[:2]
    img_gray = cv2.cvtColor(img_original, cv2.COLOR_BGR2GRAY)
    _, img_binary = cv2.threshold(img_gray, 30, 255, cv2.THRESH_BINARY)
    # no blur, it causes more harm than good for these images
    raw_contours, _ = cv2.findContours(img_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    raw_contours = [c for c in raw_contours if cv2.contourArea(c) > 100e3]
    # draw fully filled contours
    img_contours = np.zeros((h, w), dtype=np.uint8)
    for contour in raw_contours:
        cv2.drawContours(img_contours, [contour], 0, (255, 255, 255), -1)
    # remove small connected dirt
    img_corrected = cv2.morphologyEx(img_contours, cv2.MORPH_OPEN, np.ones((9, 9), dtype=np.uint8))
    # get the clean contour
    contours, _ = cv2.findContours(img_corrected, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    contours = [c for c in contours if cv2.contourArea(c) > 100e3]
    img_masked = img_gray & img_corrected
    img_pieces = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        piece = Item(
            contour=contour - (x, y),
            img=img_masked[y:y+h, x:x+w],
            filename=filename,
            filepos=(x, y)
        )
        img_pieces.append(piece)
    
    pieces.extend(img_pieces)
    
    image = Item(
        filename=filename,
        img_gray=img_gray,
        pieces=img_pieces,
    )
    images.append(image)

for id, piece in enumerate(pieces):
    piece.update(id=id)

print(f"Detected {len(pieces)} pieces")

In [None]:
# show the smallest and biggest pieces by area
pieces.sort(key= lambda piece: cv2.contourArea(piece.contour))

for piece in pieces[:1] + pieces[-1:]:
    img = np.copy(piece.img)
    plt.title(f"Area={int(cv2.contourArea(piece.contour))}")
    plt.imshow(img)
    plot_contour(piece.contour, c="red", ls="--")
    plt.show()

# Detect piece corners

## Find corners via peak distance from center

In [None]:
for piece in tqdm.tqdm(pieces):
    (cx, cy), r = cv2.minEnclosingCircle(piece.contour)
    centered_contour = piece.contour - np.array([cx, cy])
    # ensure peaks are not at start or end of the distances array
    distances = np.sum(centered_contour**2, axis=2)[:, 0]
    distance_offset = np.argmin(distances)
    distances = np.concatenate([distances[distance_offset:], distances[:distance_offset]])

    # find peak distances
    peak_indices = [(distance_idx + distance_offset) % len(distances) for distance_idx in find_peaks(distances, prominence=10000)[0]]    
    peak_indices.sort()
    piece.update(peak_indices=LoopingList(peak_indices))

In [None]:
# Show the pieces having the smallest / highest number of peak indices
pieces.sort(key= lambda piece: len(piece.peak_indices))

for piece in pieces[:1] + pieces[-1:]:
    img = np.copy(piece.img)
    plt.title(f"Number of peaks={len(piece.peak_indices)}")
    for idx in piece.peak_indices:
        cv2.circle(img, piece.contour[idx, 0], 10, (255, 0, 0), 3)
    plt.imshow(img)
    plot_contour(piece.contour, c="red", ls="--")
    plt.show()

## Filter corners by angle geometry

In [None]:
for piece in tqdm.tqdm(pieces):
    def is_right_angle(idx):
        N = len(piece.contour)
        angle_contour = np.concatenate([piece.contour[idx:], piece.contour[:idx]]) - piece.contour[idx]
        contour0 = angle_contour[math.ceil(N*0.005):int(N*0.04)]
        contour1 = angle_contour[-int(N*0.04):-math.ceil(N*0.005)]
        angles0 = np.arctan2(contour0[:, 0, 1], contour0[:, 0, 0]) % (2 * np.pi)
        angles1 = np.arctan2(contour1[:, 0, 1], contour1[:, 0, 0]) % (2 * np.pi)
        mean0 = scipy.stats.circmean(angles0)
        mean1 = scipy.stats.circmean(angles1)
        std0 = scipy.stats.circstd(angles0)
        std1 = scipy.stats.circstd(angles1)
        angle = (mean0 - mean1) % np.pi
        return -0.3 < angle - np.pi/2 < 0.3 and std0 < 0.1 and std1 < 0.1

    # find peaks that look like right angles
    angle_indices = [idx for idx in piece.peak_indices if is_right_angle(idx)]
    piece.update(angle_indices=LoopingList(angle_indices))

In [None]:
# Show the pieces having the smallest / highest number of angle indices
pieces.sort(key= lambda piece: len(piece.angle_indices))

for piece in pieces[:1] + pieces[-1:]:
    img = np.copy(piece.img)
    plt.title(f"Number of peaks={len(piece.angle_indices)}")
    for idx in piece.angle_indices:
        cv2.circle(img, piece.contour[idx, 0], 10, (255, 0, 0), 3)
    plt.imshow(img)
    plot_contour(piece.contour, c="red", ls="--")
    plt.show()

## Filter corners by rectangle geometry

In [None]:
for piece in tqdm.tqdm(pieces):
    def compute_rectangle_error(indices):
            # get coordinates of corners
            corners = LoopingList(np.take(piece.contour, sorted(list(indices)), axis=0)[:, 0, :])
            # compute the side lengths and diagonal lengths
            lengths = [math.sqrt(np.sum((corners[i0] - corners[i1])**2)) for i0, i1 in [(0, 1), (1, 2), (2, 3), (3, 0), (0, 2), (1, 3)]]
            def f_error(a, b):
                return abs(b - a) / (a + b)
            return sum([f_error(lengths[i], lengths[j]) for i, j in [(0, 2), (1, 3), (4, 5)]])

    # form a good rectangle with corner indices and missing indices
    rectangles = []  # list of (score, [indices])
    for indices in itertools.combinations(piece.peak_indices, 4):
        if set(indices).issuperset(piece.angle_indices) or len(piece.angle_indices) > 4:
            error = compute_rectangle_error(indices)
            rectangles.append((error, indices))

    if len(rectangles) >= 1:
        error, indices = sorted(rectangles)[0]
    else:
        error, indices = [100, piece.angle_indices]
    piece.update(rectangle_error=error)
    piece.update(corner_indices=LoopingList(indices))


In [None]:
# Show the pieces having the best / worst rectangle
pieces.sort(key= lambda piece: piece.rectangle_error)

for piece in pieces[:1] + pieces[-1:]:
    img = np.copy(piece.img)
    plt.title(f"Rectangle error={piece.rectangle_error}")
    for idx in piece.corner_indices:
        cv2.circle(img, piece.contour[idx, 0], 10, (255, 0, 0), 3)
    plt.imshow(img)
    plot_contour(piece.contour, c="red", ls="--")
    plt.show()

# Save/Restore snapshot

In [None]:
assert len(pieces) == 1998

np.savez_compressed("kaggle_piece.contour.npz", **dict([(str(idx), piece.contour) for idx, piece in enumerate(pieces)]))
np.savez_compressed("kaggle_piece.img.npz", **dict([(str(idx), piece.img) for idx, piece in enumerate(pieces)]))

jpieces = []
for idx, piece in enumerate(pieces):
    jpieces.append(dict(
        id=piece.id,
        filename=piece.filename,
        filepos=[int(x) for x in piece.filepos],
        corner_indices=[int(x) for x in piece.corner_indices],
    ))

with open("kaggle_piece.json", "w") as f:
    json.dump(jpieces, f)

print(f"Dumped {len(jpieces)} pieces to files")

In [None]:
pieces = []

piece_contours = np.load("kaggle_piece.contour.npz")
piece_imgs = np.load("kaggle_piece.img.npz")

with open("kaggle_piece.json", "r") as f:
    jpieces = json.load(f)
    for idx, jpiece in enumerate(jpieces):
        contour = piece_contours[str(idx)]
        img = piece_imgs[str(idx)]
        piece = Item(
            id=jpiece['id'],
            filename=jpiece['filename'],
            filepos=tuple(jpiece['filepos']),
            corner_indices=LoopingList([np.int64(x) for x in jpiece['corner_indices']]),
            img=img,
            contour=contour,
        )
        pieces.append(piece)

print(f"Loaded {len(pieces)} pieces from files")

# Analyze edges

## Extract edges & detect sign

In [None]:
for piece in pieces:
    edges = LoopingList()
    contour = piece.contour.astype(np.float64)  # convert to float for rotation
    for quarter in range(4):
        idx0 = piece.corner_indices[quarter]
        idx1 = piece.corner_indices[quarter+1]
        p0 = contour[idx0][0]
        p1 = contour[idx1][0]
        # normalize the contour: first point at (0, 0), last point at (X, 0)
        dx, dy = p1 - p0
        length=math.sqrt(dx**2 + dy**2)
        angle_radians = math.atan2(dy, dx)
        matrix = cv2.getRotationMatrix2D(p0, math.degrees(angle_radians), 1)
        normalized_piece_contour = cv2.transform(contour, matrix) - p0
        normalized_edge_contour = sub_contour(normalized_piece_contour, idx0, idx1 + 1)

        # compute the sign of the edge
        heights = normalized_edge_contour[:, 0, 1]
        if np.max(np.abs(heights)) > 20:
            sign = 1 if np.max(heights) > - np.min(heights) else -1
        else:
            sign = 0

        edge = Item(
            idx0=idx0,
            idx1=idx1,
            normalized_piece_contour=normalized_piece_contour,
            sign=sign,
            length=length,
        )
        edges.append(edge)
    
    for idx, edge in enumerate(edges):
        edge.update(
            prev=edges[idx-1],
            next=edges[idx+1]
        )
    
    piece.update(
        edges=edges,
        nb_flats=len([edge for edge in edges if edge.sign == 0])
    )

print("edge sign:", Counter([edge.sign for piece in pieces for edge in piece.edges]))
print("nb of flats:", Counter([piece.nb_flats for piece in pieces]))

In [None]:
# Show the pieces having the smallest / highest number of flats
pieces.sort(key= lambda piece: piece.nb_flats)

sign2color = {-1: "red", 0: "green", 1: "yellow"}

for piece in pieces[:1] + pieces[-1:]:
    img = np.copy(piece.img)
    plt.title(f"Nb of flats={piece.nb_flats}")
    for edge in piece.edges:
        plot_contour(sub_contour(piece.contour, edge.idx0, edge.idx1), c=sign2color[edge.sign])
    plt.imshow(img)
    plt.show()

In [None]:
# Show the pieces having the smallest / longest edge
edge_pieces = [(edge, piece) for piece in pieces for edge in piece.edges]
edge_pieces.sort(key= lambda ep: ep[0].length)

for edge, piece in edge_pieces[:1] + edge_pieces[-1:]:
    img = np.copy(piece.img)
    plt.title(f"Edge length={edge.length}")
    plot_contour(sub_contour(piece.contour, edge.idx0, edge.idx1), c='red')
    plt.imshow(img)
    plt.show()

## Compute puzzle size

In [None]:
def compute_size(area, perimeter):
    # perimeter = 2 * (H+W)
    # area = H*W
    # H**2 - perimeter/2 * H + area = 0
    a = 1
    b = -perimeter/2
    c = area
    delta = b**2 - 4*a*c
    h = int((-b - math.sqrt(delta)) / (2*a))
    w = int((-b + math.sqrt(delta)) / (2*a))
    return (min(h, w), max(h, w))

nb_flats = Counter([piece.nb_flats for piece in pieces])
assert nb_flats[2] == 4
area = len(pieces) + 2   # there are 2 missing pieces in the kaggle dataset
perimeter = nb_flats[1] + 2*nb_flats[2]
w, h = compute_size(area, perimeter)
print(f"Size of puzzle grid: {w} x {h}")
assert w * h == area
assert 2 * (w + h) == perimeter

grid_size = (w, h)

## Sample edges

In [None]:
NB_SAMPLES = 19

for piece in tqdm.tqdm(pieces):
    for edge in piece.edges:
        # compute the distance from the first point, this is not exactly edge.arc_length
        edge_contour = sub_contour(edge.normalized_piece_contour, edge.idx0, edge.idx1)
        deltas = edge_contour[1:] - edge_contour[:-1]
        distances = np.cumsum(np.sqrt(np.sum(deltas**2, axis=2)))
        distance = distances[-1] / (NB_SAMPLES - 1)  # distance between 2 sample points
        # get N equidistant points
        sample_indices = (np.array([np.argmax(distances >= i*distance - 0.0001) for i in range(NB_SAMPLES)]) + edge.idx0) % len(piece.contour)

        edge.update(
            sample_indices=sample_indices,
        )

In [None]:
# show some problematic pieces
pieces.sort(key= lambda piece: piece.id)

for piece in [pieces[514], pieces[818], pieces[998], pieces[946]]:  # pieces[:1]:
    plt.axis('equal')
    plt.title(f"id {piece.id} {piece.filename}")
    plt.imshow(piece.img)
    plot_contour(piece.contour, c="yellow", ls=':')
    for edge, marker in zip(piece.edges, 'x^+v'):
        plot_contour(piece.contour[edge.sample_indices], marker=marker, ls='', c="red")
    plt.axvline(x=0, c="gray", ls=":")
    plt.axhline(y=0, c="gray", ls=":")
    plt.show()

# Start the solution

In [None]:
solution = Item(
    grid = {},  # key=(i, j), value=piece
    grid_size = grid_size,
)

def plot_solution(solution):
    for piece in solution.grid.values():
        plot_contour(piece.placed_contour)

def img_solution(solution):
    width, height = 1000, 1000
    for piece in solution.grid.values():
        x, y, w, h = cv2.boundingRect(piece.placed_contour.astype(int))
        width = max(width, x+w)
        height = max(height, y+h)

    img = np.zeros((height, width), dtype=np.uint8)
    for piece in solution.grid.values():
        cv2.fillPoly(img, [piece.placed_contour.astype(int)], 255)
    return img

# Compute the border

In [None]:
flat_pieces = []

for piece in pieces:
    if piece.nb_flats > 0:
        flat_edges = [edge for edge in piece.edges if edge.sign == 0]
        piece.update(
            first_flat = flat_edges[0],
            last_flat = flat_edges[-1],
            before_flat = flat_edges[0].prev,
            after_flat = flat_edges[-1].next,
        )
        flat_pieces.append(piece)

after_flat_features = {}  # key=piece, value=features
before_flat_features = {}

for piece in flat_pieces:
    before_flat_features[piece] = Item(
        sign=piece.before_flat.sign,
        length=piece.before_flat.length,
        points=sub_contour(piece.first_flat.normalized_piece_contour, piece.before_flat.idx0, piece.before_flat.idx1),
        sample_points=piece.first_flat.normalized_piece_contour[piece.before_flat.sample_indices][::-1],        
    )
    after_flat_features[piece] = Item(
        sign=piece.after_flat.sign,
        length=piece.after_flat.length,
        points=sub_contour(piece.last_flat.normalized_piece_contour, piece.after_flat.idx0, piece.after_flat.idx1),
        sample_points=piece.last_flat.normalized_piece_contour[piece.after_flat.sample_indices],
    )


def plot_border_pieces(ordered_border):
    offset = np.array([0., 0.])
    plt.axis('equal')
    for idx, piece in enumerate(ordered_border):
        contour = (piece.last_flat if idx==0 else piece.first_flat).normalized_piece_contour
        plot_contour(contour + offset)
        offset += contour[piece.last_flat.idx1, 0]
    plt.show()


class BorderMatcher():
    def __init__(self, max_error=np.inf):
        self.max_error = max_error

    def eval_features(self, features0, features1):
        raise NotImplemented()

    @cache
    def eval_pieces(self, piece0, piece1):
        features0 = after_flat_features[piece0]
        features1 = before_flat_features[piece1]
        if features0.sign == -features1.sign and abs(features0.length - features1.length) < 25:
            return self.eval_features(features0, features1)
        else:
            return np.inf

    def next_matches(self, piece0):
        matches = []
        for piece1 in flat_pieces:
            error = self.eval_pieces(piece0, piece1)
            if error < self.max_error:
                matches.append((error, piece1))
        return sorted(matches, key=lambda result: (result[0], result[1].id))

    def next_candidates(self, piece0):
        return [piece1 for error, piece1 in self.next_matches(piece0)]

    def prev_matches(self, piece1):
        matches = []
        for piece0 in flat_pieces:
            error = self.eval_pieces(piece0, piece1)
            if error < self.max_error:
                matches.append((error, piece0))
        return sorted(matches, key=lambda result: (result[0], result[1].id))

    def prev_candidates(self, piece1):
        return [piece0 for error, piece0 in self.prev_matches(piece1)]
    
    def evaluate(self):
        mismatches = []  # list of piece
        best_matches = set()  # list of (error, piece0, piece1)
        for piece in tqdm.tqdm(flat_pieces):
            error, prev_piece = self.prev_matches(piece)[0]
            best_matches.add((error, prev_piece, piece))
            error, next_piece = self.next_matches(piece)[0]
            best_matches.add((error, piece, next_piece))
            error, prev_next_piece = self.prev_matches(next_piece)[0]
            if piece != prev_next_piece:
                mismatches.append(piece)
        print(f"{len(mismatches)} mismatches: {[piece.id for piece in mismatches]}")
        best_matches = list(best_matches)
        best_matches.sort(key=lambda match: (match[0], match[1].id, match[2].id))
        print("Show 1 best match and 3 worst matches")
        for error, piece, next_piece in best_matches[:1] + best_matches[-3:]:
            plt.title(f"Candidate {piece.id}-{next_piece.id}, error={error}")
            plot_border_pieces([piece, next_piece])
            plt.show()

    def plot_candidates(self, piece=None):
        if piece is None:
            piece = random.sample(flat_pieces, 1)
        for error, candidate in self.next_matches(piece):
            plt.title(f"Candidate {piece.id}-{candidate.id}, error={error}")
            plot_border_pieces([piece, candidate])
            plt.show()

    def plot_histograms(self):
        nb_candidates = [len(self.prev_candidates(piece)) for piece in flat_pieces]
        plt.title("Number of candidates before flat")
        plt.hist(nb_candidates, bins=20)
        plt.show()

        nb_candidates = [len(self.next_candidates(piece)) for piece in flat_pieces]
        plt.title("Number of candidates after flat")
        plt.hist(nb_candidates, bins=20)
        plt.show()

        min_error = [self.prev_matches(piece)[0][0] for piece in flat_pieces]
        plt.title("Min error before flat")
        plt.hist(min_error, bins=30)
        plt.show()

        min_error = [self.next_matches(piece)[0][0] for piece in flat_pieces]
        plt.title("Min error after flat")
        plt.hist(min_error, bins=30)
        plt.show()
    
    def get_issues(self):
        issues = Counter()
        for piece in flat_pieces:    
            best_prev_candidates = self.prev_candidates(piece)
            if len(best_prev_candidates) == 0:
                issues['no_prev_candidate'] += 1
            best_next_candidates = self.next_candidates(piece)
            if len(best_next_candidates) == 0:
                issues['no_next_candidate'] += 1
                continue
            best_next_piece = best_next_candidates[0]
            best_prev_next_candidates = self.prev_candidates(best_next_piece) + [None]
            best_prev_next_piece = best_prev_next_candidates[0]
            if piece != best_prev_next_piece:
                issues['mismatch_piece'] += 1
        return(issues)

## Place the first corner

In [None]:
piece0 = [piece for piece in pieces if piece.nb_flats == 2][0]

h = int(piece0.last_flat.length)
w = int(piece0.first_flat.length)

# Binary mask of active edges to match
PAD = 50
img_mask = np.zeros((600, 600), dtype=np.uint8)
cv2.rectangle(img_mask, (0, 0), (2*PAD, h), 255, -1)
cv2.rectangle(img_mask, (0, 0), (w, 2*PAD), 255, -1)

img_edges = np.zeros_like(img_mask)
cv2.rectangle(img_edges, (0, 0), (600, PAD), 255, -1)
cv2.rectangle(img_edges, (0, 0), (PAD, 600), 255, -1)
# plt.title('mask')
# plt.imshow(img_mask)
# plt.show()
# plt.title('edges')
# plt.imshow(img_edges)
# plt.show()

def eval_first_corner(params, debug=False):
    img_piece = np.zeros_like(img_mask)
    contour0 = transform_contour(piece0.first_flat.normalized_piece_contour, piece0.first_flat.idx1, params)
    cv2.fillPoly(img_piece, [contour0.astype(int)], 255)
    img_xor = cv2.bitwise_xor(img_edges, img_piece)
    img_and = cv2.bitwise_and(cv2.bitwise_not(img_xor), img_mask)
    error = np.sum(img_and>0) / np.sum(img_mask>0)
    if debug:
        print(error)
        plt.title('piece')
        plt.imshow(img_piece)
        plt.show()
        plt.title('xor')
        plt.imshow(img_xor)
        plt.show()
        plt.title('and')
        plt.imshow(img_and)
        plt.show()
    return error

res = scipy.optimize.minimize(eval_first_corner, [PAD, PAD, 180], method='Powell', tol=0.1)
print(res.fun, res.x)
piece0.update(
    placed_contour=transform_contour(piece0.first_flat.normalized_piece_contour, piece0.first_flat.idx1, res.x),
    top_edge=piece0.first_flat,
)

solution.grid.clear()
solution.grid[(0, 0)] = piece0

In [None]:
plt.imshow(img_solution(solution))
plt.show()

## Border match via sample points

In [None]:
class SampleBorderMatcher(BorderMatcher):
    def eval_features(self, features0, features1):
        diff = features1.sample_points - features0.sample_points
        offset = np.mean(diff, axis=0)
        error = np.sum((diff - offset)**2)
        return error

sampleBorderMatcher = SampleBorderMatcher()

In [None]:
sampleBorderMatcher.evaluate()

## Border match via points & sample points

In [None]:
class PointBorderMatcher(BorderMatcher):
    def eval_features(self, features0, features1):
        offset = features0.sample_points[0]
        error = 0
        for point1 in features1.sample_points[::-1]:
            distance = np.min(np.sum((features0.points - point1 - offset)**2, axis=2))  # squared distance between sample point0 and closest point1
            error += distance
        return error

pointBorderMatcher = PointBorderMatcher()

In [None]:
pointBorderMatcher.evaluate()

## Border match via optimized offset on points & sample points

In [None]:
class OffsetPointBorderMatcher(BorderMatcher):
    def eval_offset(self, features0, features1, offset):
        error = 0
        for point1 in features1.sample_points[::-1]:
            distance = np.min(np.sum((features0.points - point1 - offset)**2, axis=2))  # squared distance between sample point0 and closest point1
            error += distance
        return error

    def eval_features(self, features0, features1):
        offset = features0.sample_points[0][0]
        res = scipy.optimize.minimize(lambda offset: self.eval_offset(features0, features1, offset), offset, method='Powell', tol=1)
        return res.fun

offsetPointBorderMatcher = OffsetPointBorderMatcher()

In [None]:
offsetPointBorderMatcher.evaluate()

## Composite border matcher

In [None]:
class CompositeBorderMatcher(BorderMatcher):
    def __init__(self):
        self.max_error = 1000
        self.matchers = [PointBorderMatcher(10000), PointBorderMatcher(2000), OffsetPointBorderMatcher(1000)]

    def eval_features(self, features0, features1):
        error = np.inf
        for matcher in self.matchers:
            error = matcher.eval_features(features0, features1)
            if error > matcher.max_error:
                break
        return error
        
compositeBorderMatcher = CompositeBorderMatcher()

## Match border pieces

In [None]:
for piece in flat_pieces:
    next_candidates = point_candidates_after_flat(piece, 1000)
    piece.after_flat.update(
        piece_matches = next_candidates,
        edge_matches = [piece.before_flat for piece in next_candidates],
    )
    prev_candidates = point_candidates_before_flat(piece, 1000)
    piece.before_flat.update(
        piece_matches = prev_candidates,
        edge_matches = [piece.after_flat for piece in prev_candidates],
    )

In [None]:
for piece in random.sample(flat_pieces, 1):
    best_prev_piece = piece.before_flat.piece_matches[-1]
    best_next_piece = piece.after_flat.piece_matches[-1]
    plot_border_pieces([best_prev_piece, piece, best_next_piece])

In [None]:
piece_after = {}  # key=piece0, value=piece1
piece_before = {}  # key=piece1, value=piece0

def add_border_match(piece0, piece1):
    if piece_after.get(piece0, piece1) != piece1:
        raise Exception("Cannot overwrite border match")
    piece_after[piece0] = piece1
    piece_before[piece1] = piece0

for piece in flat_pieces:
    best_next_candidates = point_candidates_after_flat(piece)
    if len(best_next_candidates) == 1:
        next_piece = best_next_candidates[0]
        add_border_match(piece, next_piece)
    best_prev_candidates = point_candidates_before_flat(piece)
    if len(best_prev_candidates) == 1:
        prev_piece = best_prev_candidates[0]
        add_border_match(prev_piece, piece)

def compute_group(piece0):
    group = [piece0]
    piece = piece0
    while piece in piece_after:
        next_piece = piece_after[piece]
        group.append(next_piece)
        piece = next_piece
    piece = piece0
    while piece in piece_before:
        prev_piece = piece_before[piece]
        group.insert(0, prev_piece)
        piece = prev_piece
    return group

def compute_groups():
    used_pieces = set()
    groups = []  # list of pieces
    for piece in flat_pieces:
        if piece not in used_pieces:
            group = compute_group(piece)
            used_pieces.update(group)
            groups.append(group)
    return groups

groups = compute_groups()
groups.sort(key=len, reverse=True)
print(f"Computed {len(groups)} groups, best group has {len(groups[0])} pieces. {sum([len(group)-1 for group in groups])}/{len(flat_pieces)} matched edges")
for group in groups[:3]:
    print([piece.id for piece in group])
    plot_border_pieces(group)

# TODO

In [None]:
piece0 = solution.grid[(0, 0)]
y0, x0 = piece0.placed_contour[piece0.first_flat.idx0][0]

# Binary mask of active edges to match
img_mask = np.zeros((1000, 1000), dtype=np.uint8)
cv2.rectangle(img_mask, (0, int(y0)+PAD), (2*PAD, int(y0)+PAD+200), 255, -1)  # Why y0+PAD ???
cv2.polylines(img_mask, [sub_contour(piece0.placed_contour, piece0.after_flat.idx0, piece0.after_flat.idx1)[PAD:-PAD].astype(int)], False, 255, PAD)
plt.title('mask')
plt.imshow(img_mask)
plt.show()

img_edges = np.zeros_like(img_mask)
cv2.fillPoly(img_edges, [piece0.placed_contour.astype(int)], 255)
cv2.rectangle(img_edges, (0, int(y0)+PAD), (PAD, int(y0)+h), 255, -1)
plt.title('edges')
plt.imshow(img_edges)
plt.show()

def eval_border_piece(piece, params, debug=False):
    x, y, degrees = params
    img_piece = np.zeros_like(img_mask)
    contour = transform_contour(piece.first_flat.normalized_piece_contour, piece.first_flat.idx0, (x, y), degrees)
    cv2.fillPoly(img_piece, [contour.astype(int)], 255)
    img_xor = cv2.bitwise_xor(img_edges, img_piece)
    img_and = cv2.bitwise_and(cv2.bitwise_not(img_xor), img_mask)
    error = np.sum(img_and>0) / np.sum(img_mask>0)
    if debug:
        print(error)
        plt.title('piece')
        plt.imshow(img_piece)
        plt.show()
        plt.title('xor')
        plt.imshow(img_xor)
        plt.show()
        plt.title('and')
        plt.imshow(img_and)
        plt.show()
    return error

# sample matching: after 129 [(572, 140), (1534, 1379), (2014, 1862)]  # (score, piece)
# xor matching with default params: slow and inaccurate
# xor matching with optimized params after 129: slow and very accurate (0.0339, 140), (0.0521, 1379), (0.0667, 1862)  # (score, piece)

matches = []
for piece1 in candidates_after_flat(piece0):
    matches.append(scipy.optimize.minimize(lambda params: eval_border_piece(piece1, params), [PAD, PAD+y0, -90], method='Powell', tol=0.1)
    print(piece)
# piece1 = pieces[998]
# eval_border_piece(piece1, [PAD, PAD+y0, -90], True)
# eval_border_piece(pieces[140], [PAD, PAD+y0, -90])
# eval_border_piece(pieces[1705], [PAD, PAD+y0, -90])
# res = scipy.optimize.minimize(lambda params: eval_border_piece(piece1, params), [PAD, PAD+y0, -90], method='Powell', tol=0.1)
# eval_border_piece(piece1, res.x, True)

# best_pieces = [(eval_border_piece(piece, [PAD, PAD+y0, -90]), piece.id) for piece in tqdm.tqdm(flat_pieces)]
# best_pieces.sort()
# print(best_pieces)
# print(best_pieces[:10], best_pieces[-1:])
# [(0.09865159254807693, 1705), (0.10021033653846154, 874), (0.10109299879807693, 1379), (0.10372220552884616, 295), (0.10691481370192307, 1158)]
# res = scipy.optimize.minimize(place_corner, [PAD, PAD, 180], method='Powell')
# # print(res)
# x, y, degrees = res.x
# piece0.update(
#     placed_contour=transform_contour(piece0.first_flat.normalized_piece_contour, piece0.first_flat.idx1, (x, y), degrees),
# )
# 
# solution.grid.clear()
# solution.grid[(0, 0)] = piece0
# for piece in solution.grid.values():
#     plot_contour(piece.placed_contour)
# plt.show()

In [None]:
# matrix = cv2.getRotationMatrix2D(p0, math.degrees(angle_radians), 1)
# normalized_piece_contour = cv2.transform(contour, matrix) - p0

# img = np.pad(p0.img, PAD)
# plt.imshow(img)
# plot_contour(sub_contour(p0.contour, e0.idx0, e0.idx1) + (PAD, PAD), c='red')
# plt.show()

# Thick binary mask following the edge
# img_mask = np.zeros_like(img)
# cv2.polylines(img_mask, [sub_contour(p0.contour, e0.idx0+10, e0.idx1-10)], False, 255, 20)
# plt.imshow(img_mask)
# plt.show()

# cv2.fillPoly(img, [contour0.astype(int) + PAD], 255)
# plt.imshow(img)
# plt.show()
# 
# plt.axis('equal')
# plot_contour(contour0)
# plot_contour(contour0[[flat0.idx1]], ls='', marker='x', c='red')
# 
# point0 = contour0[flat0.idx1]
# plot_contour(contour1 + point0)
# plt.show()
# 
# # img = np.pad(p1.img, PAD)
# 
# img_xor = cv2.bitwise_xor(img_p1, img_p0)
# plt.imshow(img_xor)
# plt.show()

# cv2.polylines(img_mask, [sub_contour(p0.contour, e0.idx0+10, e0.idx1-10)], False, 255, 20)
# plt.imshow(img_mask)
# plt.show()
# 
# cv2.bitwise_and(img_mask, img)

# 
# plt.imshow(img)
# plot_contour(sub_contour(p1.contour, e1.idx0, e1.idx1) + (PAD, PAD), c='red')
# plt.show()

# def draw_piece(img, piece, idx, xy, degrees):
#     h, w = piece.img.shape()
#     piece_img = cv2.copyMakeBorder(piece.img, PAD, PAD, PAD, PAD, cv2.BORDER_CONSTANT)
#     matrix = cv2.getRotationMatrix2D((PAD + w/2, PAD + h/2), degrees, 1)  # rotate around the center
#     pt = cv2.transform(piece.contour + PAD, matrix)
#     cv2.warpAffine(piece_img, matrix, (w+2*PAD, h+2*PAD), img)
# 
# piece0 = pieces[7]
# flat0 = piece0.first_flat
# contour0 = flat0.normalized_piece_contour
# ref_contour = sub_contour(contour0.astype(int), flat0.next.idx0, flat0.next.idx1)
# piece1 = pieces[1579]
# flat1 = piece1.last_flat
# contour1 = flat1.normalized_piece_contour
# 
# # plt.axis('equal')
# # plot_contour(contourX, c='red')
# # plot_contour(contourY, c='green')
# # plot_contour(sub_contour(contourX, flat0.next.idx0, flat0.next.idx1), marker='*')
# # plt.show()
# 
# # Binary mask of the edge
# pad = 20  # add padding for visualization
# x, y, w, h = cv2.boundingRect(ref_contour)
# img0 = np.zeros((h+2*pad, w+2*pad), dtype=np.uint8)
# cv2.fillPoly(img0, [contour0.astype(int) + (-x + pad, -y + pad)], 255)
# imgMask = np.zeros_like(img0)
# cv2.polylines(imgMask, [ref_contour[10:-10] + (-x + pad, -y + pad)], False, 255, 10)
# img1 = np.zeros_like(img0)
# # x, y, w, h = cv2.boundingRect(sub_contour(contourY.astype(int), flat1.prev.idx0, flat1.prev.idx1))
# # contourY = transform_contour(contour1, flat1.idx0, (0, 0), 0)
# # cv2.fillPoly(img1, [contourY.astype(int) + (-x + 1 + pad, -y + -4 + pad)], 255)
# # imgXor = img0.copy()
# # cv2.bitwise_xor(img0, img1, imgXor)
# # imgAnd = img0.copy()
# # cv2.bitwise_and(cv2.bitwise_not(imgXor), imgMask, imgAnd)
# # print(f"{np.sum(imgAnd>0)}/{np.sum(imgMask>0)} = {np.sum(imgAnd>0)/np.sum(imgMask>0)}")
# 
# plt.imshow(img0)
# plt.show()
# plt.imshow(imgMask)
# plt.show()
# plt.imshow(img1)
# plt.show()
# plt.imshow(imgXor)
# plt.show()
# plt.imshow(imgAnd)
# plt.show()