In [192]:
import cv2
import numpy as np
import os
import shutil
import sys

In [193]:
#matches is of (3|4 X 2 X 2) size. Each row is a match - pair of (kp1,kp2) where kpi = (x,y)
def get_transform(matches, is_affine, img_index):
	# Flatten matches to extract src_points and dst_points
    specific_matches = matches[img_index - 2]

    src_points = specific_matches[:, 0]  # All (x, y) for img1
    dst_points = specific_matches[:, 1]  # All (x, y) for img2

    if is_affine:
        # Compute the affine transformation
        T, _ = cv2.estimateAffinePartial2D(src_points, dst_points)
    else:
        # Compute the homography
        T, _ = cv2.findHomography(src_points, dst_points, method=cv2.RANSAC)
    
    return T

def stitch(img1, img2, transform):
    height, width = img1.shape[:2]
    output_size = (width, height)

    warped_img2 =  inverse_transform_target_image(img2, transform, output_size)
    
    mask = cv2.cvtColor(warped_img2, cv2.COLOR_BGR2GRAY) > 0  # Non-black pixels are True
    mask = mask.astype(np.uint8)  # Convert to binary mask (0 or 1)

    stitched_img = img1.copy()
    for c in range(3):  # Iterate over color channels
        stitched_img[:, :, c] = stitched_img[:, :, c] * (1 - mask) + warped_img2[:, :, c] * mask

    return stitched_img


# Output size is (w,h)
def inverse_transform_target_image(target_img, original_transform, output_size):	
    if original_transform.shape == (2, 3):  # Affine transformation
        inverse_transform = cv2.invertAffineTransform(original_transform)
        warped_img = cv2.warpAffine(target_img, inverse_transform, output_size)
    elif original_transform.shape == (3, 3):  # Homography
        inverse_transform = np.linalg.inv(original_transform)
        warped_img = cv2.warpPerspective(target_img, inverse_transform, output_size)
    return warped_img

# returns list of pieces file names
def prepare_puzzle(puzzle_dir):
	edited = os.path.join(puzzle_dir, 'abs_pieces')
	if os.path.exists(edited):
		shutil.rmtree(edited)
	os.mkdir(edited)
	
	affine = 4 - int("affine" in puzzle_dir)
	
	matches_data = os.path.join(puzzle_dir, 'matches.txt')
	n_images = len(os.listdir(os.path.join(puzzle_dir, 'pieces')))

	matches = np.loadtxt(matches_data, dtype=np.int64).reshape(n_images-1,affine,2,2)
	
	return matches, affine == 3, n_images

In [194]:
# puzzle_dir = 'puzzles/puzzle_affine_2'

# img1 = cv2.imread(os.path.join(puzzle_dir, 'pieces/piece_1.jpg'))  # Image with black background and real picture window
# img2 = cv2.imread(os.path.join(puzzle_dir, 'pieces/piece_2.jpg'))  # Image to be added

# matches, is_affine, n = prepare_puzzle(puzzle_dir)
# T = get_transform(matches, is_affine)

# src_points = matches[:, :, 0, :].reshape(-1, 2)  # All (x, y) for img1
# dst_points = matches[:, :, 1, :].reshape(-1, 2)  # All (x, y) for img2

# for point in src_points:
#     cv2.circle(img1, tuple(map(int, point)), 5, (255, 0, 0), -1)
# for point in dst_points:
#     cv2.circle(img2, tuple(map(int, point)), 5, (0, 255, 0), -1)
# cv2.imshow('Image 1 Keypoints', img1)
# cv2.imshow('Image 2 Keypoints', img2)
# cv2.waitKey(0)
# cv2.destroyAllWindows()



In [195]:
# matches, is_affine, n = prepare_puzzle('puzzles/puzzle_affine_1')
# transform = get_transform(matches, is_affine)

# print(transform)

In [196]:
puzzle_dir = 'puzzles/puzzle_affine_2'

img1 = cv2.imread(os.path.join(puzzle_dir, 'pieces/piece_1.jpg'))  # Image with black background and real picture window

matches, is_affine, n = prepare_puzzle(puzzle_dir)
stitched_img = img1


transform = get_transform(matches, is_affine, img_index=2)

for i in range(2, n + 1):  # Images from piece_2.jpg to piece_n.jpg
        img2 = cv2.imread(os.path.join(puzzle_dir, 'pieces', f'piece_{i}.jpg'))
        
        # Get the transformation matrix for the current image pair (piece_1 vs piece_2, piece_1 vs piece_3, etc.)
        transform = get_transform(matches, is_affine, img_index=i)

        # Stitch the current img2 onto the stitched_img
        stitched_img = stitch(stitched_img, img2, transform)

# Save or display the stitched result
cv2.imwrite('stitched_image.jpg', stitched_img)
cv2.imshow('Stitched Image', stitched_img)
cv2.waitKey(0)
cv2.destroyAllWindows()