# Base Imports

In [1]:
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
from cv2 import resize
import skimage as sk
from skimage import transform
import skimage.io as skio
import numpy as np


# Utils

In [2]:
def get_image(path):
    im = plt.imread(path)
    im = im.astype(float)
    return im

def show(im, figsize = 10, cmap=None):
    plt.figure(figsize=(figsize,figsize))
    plt.imshow(im, cmap=cmap)
    
def resize_preserve(img, shape):
    """
    Here, this is to crop without warping the input image
    """
    cv2_shape = np.array([img.shape[1], img.shape[0]])
    initial_upscale = max(shape[0]/cv2_shape[0], shape[1]/cv2_shape[1])
    img = resize(img, (cv2_shape * initial_upscale).astype(int))
    cv2_shape = np.array([img.shape[1], img.shape[0]])

    # Crop in x direction if needed
    xdiff = img.shape[1] - shape[0]
    if xdiff:
        img = img[:,xdiff//2:img.shape[1]-xdiff//2]
    
    # Crop in y direction if needed
    ydiff = img.shape[0] - shape[1]
    if ydiff:
        img = img[ydiff//2:img.shape[0]-ydiff//2]
    
    return img


# Correspondences

## Imports

In [3]:
import matplotlib.cm as cm
from scipy.spatial import Delaunay

## Code

In [4]:
def get_correspondences(images, dims = (1, 1), timeout = -1, include_edges = False, cmap = None, mesh = False, rainbow = True):
    """
    Takes in a list of images, returns a list of (x, y) keypoints for each image, in order
    
    Returns a list of keypoints in the shape (images, num_keypoints, 2 (since x and y) )
    
    Optional Args:
    include_edges - boolean for whether or not to count the edges of the image as keypoints automatically
    dims - the number of images, triangulations, and points to display as we decide to keep going or quit
    timeout - the time after which no click closes out the program
    cmap - color map option is nice if you're dealing with grayscale images
    mesh - Whether or not to display a mesh from the keypoints
    rainbow - whether or not to color the points with a rainbow
    """
    
    # Make sure all of our images are properly sized
    for i, image in enumerate(images):
        if len(image.shape) < 3:
            images[i] = np.stack((image, image, image), axis = -1)
    
    # Initialize our list of keypoints as empty of the edges of the images
    images_keypoints = [[] for image in images]
    if include_edges:
        images_keypoints = [[(0,0), (0, image.shape[0]-1), (image.shape[1]-1, 0), 
                             (image.shape[1]-1, image.shape[0]-1)]
                            for image in images]
    
    while True:
        # Display all the images that we have and the keypoints, and traingulations
        plt.clf()
        show_correspondences(images, images_keypoints, dims, cmap, mesh, rainbow, 
                             title = "Click to add another round of points, or press enter")
            
        # See if we want to add more keypoints
        if plt.waitforbuttonpress():
            plt.close()
            return images_keypoints
        plt.clf()
        
        # Start adding new points to a list
        plt.title("Left click for the next point or hit enter to escape")
        next_points = []
        for image, keypoints in zip(images, images_keypoints):
            plt.clf()
            plt.imshow(image, cmap = cmap)
            plt.scatter([point[0] for point in keypoints], [point[1] for point in keypoints])
            plt.draw()
            user_input = plt.ginput(1, timeout=timeout)
            if not len(user_input):
                return images_keypoints
            else:
                next_points.append(user_input[0])  # ginput returns a list of coords... want the first one
        plt.close()
        
        # If we make it through all the images, add all our keypoints to our final list
        for next_point, keypoints in zip(next_points, images_keypoints):
            keypoints.append(next_point)

def get_and_show_correspondences(images, dims, timeout = -1, include_edges = False, cmap = None, mesh = False, rainbow = True, title = None):
    """
    A wrapper for the get_correspondences that also shows images and keypoints after the fact
    Same args
    """
    images_keypoints = get_correspondences(images, dims, timeout, include_edges, cmap = cmap, mesh = mesh)
    show_correspondences(images, images_keypoints, dims, cmap, mesh, rainbow, title)
    return images_keypoints


def show_correspondences(images, images_keypoints, dims, cmap = None, 
                         mesh = False, rainbow = False, title = None, 
                         figsize = (10, 5)):
    """
    Inputs:
    images - list of images in shape (n_images, h, w, c) or (n_images, h, w)
    images_keypoints - a list of keypoints whose i th index is (n_images, n_keypoints, 2)
                        corresponding to images[i]
    dims - a tuple of dimensions for our output graphic (w, h)

    Optional arg: 
    cmap - a cmap for displaying images... useful for if the images are black and white
    mesh - Whether or not to display a mesh from the keypoints
    rainbow - whether or not to color the points with a rainbow
    title - a title for the plot
    figsize - the size of the figure to show
    """
    images_keypoints = np.array(images_keypoints)
    # Draw out the first product(*dims) images, keypoints, and triangulation
    if mesh and len(images_keypoints[0]) >= 3:
        tri = get_avg_triangulation(images_keypoints)
    fig, axs = plt.subplots(*dims, figsize = figsize)
    axs = np.array(axs)
    for ax, image, keypoints in zip(axs.flatten(), images, images_keypoints):
        ax.imshow(image, cmap = cmap)
        color = None
        if rainbow:
            color = cm.rainbow(np.linspace(0, 1, len(keypoints)))
        ax.scatter([point[0] for point in keypoints], [point[1] for point in keypoints], color = color)
        if mesh and len(images_keypoints[0]) >= 3:
            ax.triplot(keypoints[:,0], keypoints[:,1], tri.simplices)
        ax.axis("off")
    if title:
        fig.suptitle(title)
    plt.show()


# Homographies

## Imports

In [5]:
from scipy.ndimage import map_coordinates

## Code

In [6]:
def computeH(im1_pts, im2_pts):
    """
    Computes the homography from im1 to im2
    
    [         ]       [         ]
    [ im2_pts ] = H @ [ im1_pts ]
    [ 1 1 1 1 ]       [ 1 1 1 1 ]
    
    im1_pts = list of shape (pts, 2)
    im2_pts = list of shape (pts, 2)
    
    Here we use the kind of transform I talked about on my website and on the article I linked
    You can see where I define my large A matrix of augmented data, and my b vector of outputs
    """
    b = np.array([[im2_pts[i][0]] for i in range(len(im2_pts))] + 
                  [[im2_pts[i][1]] for i in range(len(im2_pts))])
    
    # Get the first half the rows of A
    # x, y, 1, 0, 0, 0, -x*x_hat, -y*x_hat
    A_1 = [[im1_pts[i][0], im1_pts[i][1], 1, 0, 0, 0,
            -im1_pts[i][0]*im2_pts[i][0], -im1_pts[i][1]*im2_pts[i][0]] 
           for i in range(len(im1_pts))]
    # Get the secong half the rows of A
    # 0, 0, 0, x, y, 1, -x*y_hat, -y*y_hat
    A_2 = [[0, 0, 0, im1_pts[i][0], im1_pts[i][1], 1,
            -im1_pts[i][0]*im2_pts[i][1], -im1_pts[i][1]*im2_pts[i][1]] 
           for i in range(len(im1_pts))]
    
    A = np.array(A_1 + A_2)
        
    H_flat = np.linalg.lstsq(A, b, rcond=None)[0]
    H_flat = H_flat.flatten()
    # Add in h_33 which is set to 1
    H_flat = np.concatenate((H_flat, [1,]))
    
    H = H_flat.reshape(3, 3)
    
    return H

"""
def computeH(im1_pts, im2_pts):
    #Computes the homography from im1 to im2
    
    #[         ]       [         ]
    #[ im2_pts ] = H @ [ im1_pts ]
    #[ 1 1 1 1 ]       [ 1 1 1 1 ]
    
    #im1_pts = list of shape (pts, 2)
    #im2_pts = list of shape (pts, 2)

    #These get im1_pts, im2_pts into the form above
    im1_pts = np.array(im1_pts).T
    im1_pts = np.concatenate((im1_pts, np.ones((1,im1_pts.shape[1]))))
    
    im2_pts = np.array(im2_pts).T
    im2_pts = np.concatenate((im2_pts, np.ones((1,im2_pts.shape[1]))))

    
    # Need to use np.solve to solve for x in b = Ax, 
    #   so need to transpose both sides, giving
    #   im2_pts.T = im1_pts.T @ H.T
    #   Now, im2_pts is b and im1_pts is a
    #   Transposing resulting H.T gives answer
    H = np.linalg.lstsq(im1_pts.T, im2_pts.T, rcond=None)[0].T
    
    # Rescale to make sure that bottom right entry is 1
    
    H = H/H[2,2]
    
    return H
"""

    
def warp_image(im, H, x_offset=0, y_offset=0, outsize = None):
    """
    Warps im by using H
    
    Values that aren't in the frame are padded with a -1 to make masks for each image easier to create
    
    x_offset, y_offset - the amount by which we shift the output up and down to put it in frame
    
    outsize - the size of the ouput (x,y) which defaults to the shape of the input image
    """
    if len(im.shape) == 2:
        im = np.stack([im, im, im], axis = -1)
        
        
    if outsize:
        x_coords, y_coords = np.meshgrid(range(outsize[0]), range(outsize[1]))
        out = np.zeros((outsize[1], outsize[0], 3))

    else:
        y_coords, x_coords = np.meshgrid(range(im.shape[0]), range(im.shape[1]))
        out = np.zeros(im.shape)
    
    x_coords = x_coords.flatten()
    y_coords = y_coords.flatten()
    
    coordinate_stack = np.stack([x_coords+x_offset, y_coords+y_offset, 
                                 np.ones(y_coords.shape)], axis = 0)
        
    transformed_coords = np.linalg.inv(H) @ coordinate_stack
    transformed_coords /= transformed_coords[-1,:]
        
    for channel in range(3):
        out[y_coords, x_coords, channel] = map_coordinates(im[:,:,channel], 
                                                           [transformed_coords[1], transformed_coords[0]],
                                                           mode = 'constant',
                                                           cval=-1)
    
    return out


def mosaic_from_left(images, keypoints, transforms):
    """
    Here, I just go image by image, adding the new transform to a running transform by composing the two
    I then augment each image by the running transform to get a set of images that have all been
    transformed properly and just need to be added and averaged in areas of overlap. 
    """
    x_offsets = np.array([0,0])
    y_offsets = np.array([0,0])
    for i in range(1, len(keypoints)):
        offset = np.mean(np.array(pts_12)[0] - np.array(pts_12)[1], axis = 0).flatten()
        offset = np.round(offsets).astype(int)
        
        y_offsets += np.array([max(offsets[1], 0), -min(offsets[1], 0)])
        x_offsets += np.array([max(offsets[0], 0), -min(offsets[0], 0)])

    im1 = np.maximum(image1, 0)
    im2 = np.maximum(image2, 0)
    
    
    # Get our output images padded
    images = [np.pad(image, [(max(offsets[1], 0), -min(offsets[1], 0)), 
                            (max(offsets[0], 0), -min(offsets[0], 0)),
                            (0, 0)],
                     mode = 'constant', constant_values = -1) for image in images]
    
    # Warp our images
    total_transform = np.identity(3)
    for i in range(1, len(images)):
        total_transform = total_transform @ transforms[i]
        images[i] = warp_image(images[i], total_transform)
    
    output_image = np.sum([np.maximum(image, 0) for image in images], axis = 0)
    
    # Average in areas that have the negative mask still
    for i in range(1, len(images)):
        output_image /= np.ones(images[0].shape) + np.logical_and(images[i-1] >= 0, images[i] >= 0)
        
    return output_image


    
    

## Results

### Basic Homography

In [None]:
images = [get_image(path)/256 for path in ('im1.jpg', 'im2.jpg', 'im3.jpg')] 

In [None]:
%matplotlib
pts_12 = get_and_show_correspondences(images[:2], (1, 2))

In [None]:
H = computeH(pts_12[0], pts_12[1])
show(warp_image(images[0], H, x_offset = 1000))

### Image Rectification

In [None]:
%matplotlib
mona_lisa_keypoints = get_and_show_correspondences([get_image('mona_lisa.jpeg')/256], (1, 1))

In [None]:
rectify_mona_lisa_keypoints = [(0, 0), (600, 0), (0, 900), (600, 900)]

H = computeH(mona_lisa_keypoints[0], rectify_mona_lisa_keypoints)

show(warp_image(get_image('mona_lisa.jpeg')/256, H, outsize = (600, 900)))

### Testing Merging Images (Mosaic)

In [None]:
offsets = np.mean(np.array(pts_12)[0] - np.array(pts_12)[1], axis = 0)
offsets = np.round(offsets).astype(int)
image2 = np.pad(images[0], [(max(offsets[1], 0), -min(offsets[1], 0)), 
                            (max(offsets[0], 0), -min(offsets[0], 0)),
                            (0, 0)],
               mode = 'constant', constant_values = -1)

image2 = warp_image(image2, H)

image1 = np.pad(images[1], [(max(offsets[1], 0), -min(offsets[1], 0)), 
                            (max(offsets[0], 0), -min(offsets[0], 0)),
                            (0, 0)],
               mode = 'constant', constant_values = -1)

im1 = np.maximum(image1, 0)
im2 = np.maximum(image2, 0)


show((im1 + im2) / (np.ones(im1.shape) + np.logical_and(image1 >= 0, image2 >= 0)))


### Bells and Whistles: Nyan Lisa

In [10]:
%matplotlib
mona_lisa_keypoints = get_and_show_correspondences([get_image('mona_lisa.jpeg')/256], (1, 1))

Using matplotlib backend: MacOSX


In [11]:
nyan_cat_keypoints = [(0, 0), (199, 0), (0, 250), (199, 250)]

H = computeH(nyan_cat_keypoints, mona_lisa_keypoints[0])

nyan_cat_warp = warp_image(get_image('nyan_cat.jpg')/256, H, outsize = (1000, 533))

mask = (nyan_cat_warp >=0)

In [12]:
show(nyan_cat_warp * mask + get_image('mona_lisa.jpeg')/256 * (1 - mask))

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


In [None]:
show_correspondences([get_image('mona_lisa.jpeg')/256], mona_lisa_keypoints, (1,1))

# Test

In [None]:
show((np.ones(im1.shape) + np.logical_and(image1 >= 0, image2 >= 0))/2)

In [None]:
show(images[1])

In [None]:
pts1 = np.array(mona_lisa_keypoints[0]).T
pts1 = np.concatenate((pts1, np.ones((1,pts1.shape[1]))))

pts2 = np.array(rectify_mona_lisa_keypoints).T
pts2 = np.concatenate((pts2, np.ones((1,pts2.shape[1]))))



# Code Graveyard

In [None]:
def computeH(im1_pts, im2_pts):
    """
    Computes the homography from im1 to im2
    
    [         ]       [         ]
    [ im2_pts ] = H @ [ im1_pts ]
    [ 1 1 1 1 ]       [ 1 1 1 1 ]
    
    im1_pts = list of shape (pts, 2)
    im2_pts = list of shape (pts, 2)
    """
    b = np.array([[im2_pts[i][0]] for i in range(len(im2_pts))] + 
                  [[im2_pts[i][1]] for i in range(len(im2_pts))])
    
    # Get the first half the rows of A
    # x, y, 1, 0, 0, 0, -x*x_hat, -y*x_hat
    A_1 = [[im1_pts[i][0], im1_pts[i][1], 1, 0, 0, 0,
            -im1_pts[i][0]*im2_pts[i][0], -im1_pts[i][1]*im2_pts[i][0]] 
           for i in range(len(im1_pts))]
    # Get the secong half the rows of A
    # 0, 0, 0, x, y, 1, -x*y_hat, -y*y_hat
    A_2 = [[0, 0, 0, im1_pts[i][0], im1_pts[i][1], 1,
            -im1_pts[i][0]*im2_pts[i][1], -im1_pts[i][1]*im2_pts[i][1]] 
           for i in range(len(im1_pts))]
    
    A = np.array(A_1 + A_2)
        
    H_flat = np.linalg.lstsq(A, b)[0]
    print(np.linalg.lstsq(A, b))
    print(H_flat)
    H_flat = H_flat.flatten()
    # Add in h_33 which is set to 1
    H_flat = np.concatenate((H_flat, [1,]))
    
    H = H_flat.reshape(3, 3)
    
    return H