# Multi-Patch Stitching and Multi-Patch Edge-Averaging Stitching
This notebook provides the code that solves the problem of Image Stitching using Multi-Patch graph synchronization and Multi-Patch Edge-Averaging.
The idea of the first method is to solve a partitioned synchronization problem. The graph is first partitioned in several clusters and simple-graph synchronization is applied to each cluster independently. Then, each cluster is condensed into a node to build a multi-edge graph (patch graph) where the edges are the inter-cluster homographies. Finally, multi-graph synchronization is applied.
The idea of the second method is the same as Multi-Patch synchronization with the only difference that edge-averaging is applied instead of multi-graph synchronization on the patch graph.

## Importing libraries

In [1]:
import cv2 as cv
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster, inconsistent
import ipynb.fs.defs.Utils as Utils
import ipynb.fs.defs.ImageStitcher as ImageStitcher
import ipynb.fs.defs.GraphBuilding as GraphBuilding
import ipynb.fs.defs.SimpleGraphStitching as SimpleGraphStitching
import ipynb.fs.defs.MultiGraphStitching as MultiGraphStitching

## Functions definition

In [None]:
#Compute the linkage matrix that encodes the hierarchical clustering
def get_linkage_matrix(model):
    # Create linkage matrix
    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)

    return linkage_matrix

In [None]:
#This function allows to perform agglomerative clustering on the provided graph
def compute_clustering(adj_matrix, #Adjacency matrix of the graph to be clustered
                       weight_matrix, #Weight matrix of the graph to be clustered
                       verbose = False #If True allows to plot the dendrogram
                      ):
    
    #Compute distance matrix
    distances = weight_matrix.copy()
    maxx = np.max(distances)
    mask = distances != 0
    distances[mask] = maxx / distances[mask]
    
    #Fit the hierarchical clustering from distance matrix (computed instead of similarity matrix)
    clustering = AgglomerativeClustering(distance_threshold=0, n_clusters=None,affinity="precomputed",connectivity=adj_matrix,linkage="average").fit(distances)
    #Compute linkage matrix
    linkage_matrix = get_linkage_matrix(clustering)
    #Compute dendrogram
    den = dendrogram(linkage_matrix, truncate_mode=None, no_plot = not verbose)
    
    #Compute number of clusters
    k = len(list(dict.fromkeys(den["leaves_color_list"]))) + max(den["leaves_color_list"].count('C0') - 1, 0 )
    
    #Compute clustering
    clustering = fcluster(linkage_matrix, k, criterion='maxclust')

    return clustering-1

In [None]:
#This function allows to compute the matrices needed for graph synchronization for each cluster
def build_clusters_matrices(Z, #Matrix containing the homographies between the images
                            adj_matrix, #Adjacency matrix of the graph
                            clusters, #Clustering
                            idx_ref #Index of the reference image
                           ):
    clusters_matrices = list() #This variable will contain the list of matrices needed for each cluster to perform synchronization
    #For every cluster
    for c in np.unique(clusters):
        #Retrieve the indexes of the images belonging to the current cluster
        indexes = clusters == c 
        #Expand indexes in such a way that they can be used to retrieve the portion of the Z matrix related to the current cluster
        expanded_indexes = np.array([], dtype = bool)
        for i in indexes:
            expanded_indexes = np.concatenate([expanded_indexes,[i]*3])
        
        #Retrieve adjacency matrix of the sub-graph containing the images in the current cluster
        subgraph_adj_matrix = adj_matrix[np.ix_(indexes, indexes)]
        
        #Choose a reference image for each cluster. Ensure that the initial reference index is also the reference index of the cluster it belongs to
        if idx_ref in np.where(indexes == True)[0]:
            endpoint = np.where(np.where(indexes == True)[0]==idx_ref)[0][0]
        else:
            endpoint = np.argmax(np.sum(subgraph_adj_matrix, axis=1))
        
        #Save obtained matrices and reference index of the current cluster
        clusters_matrices.append({
            "cluster": c,
            "Z": Z[np.ix_(expanded_indexes, expanded_indexes)],
            "adj_matrix": subgraph_adj_matrix,
            "endpoint": endpoint
        })
    return clusters_matrices  

In [None]:
#This function allows to perform simple-graph synchronization for each cluster
def intra_patches_synchronization(dataset_name, #Name of the dataset to be used
                                  clusters_matrices, #Matrices needed for synchronization of each cluster
                                  clustering, #Clustering of the graph
                                  imgs, #Images to be stitched
                                  T_norm, #Normalization matrix
                                  beautify, #If True allows to print the stitched image in a better way
                                  warp_shape #Shape of the stitched image
                                 ):
    #For each cluster
    for cm in clusters_matrices:
        #Compute M matrix
        cm["M"] = GraphBuilding.compute_M_matrix(cm["adj_matrix"], cm["Z"])
        #If there is more than 1 elements in the cluster
        if cm["M"].shape != (3,3): 
            #Compute indexes of the images belonging to the current cluster
            indexes = clustering == cm["cluster"]
            #Compute set of images belonging to the current cluster
            imgs_slice = [imgs[i] for i in np.where(indexes)[0]]
            #Apply simple-graph synchronization to the current sub-set of images
            cm["label"], cm["img"], cm["Ht"] = SimpleGraphStitching.simple_graph_stitching(dataset_name,
                        imgs_slice,
                        T_norm,
                        cm["M"],
                        cm["endpoint"],
                        save_output = False,
                        verbose = False,
                        beautify = beautify,
                        warp_shape = warp_shape
                        )
        #Otherwise the stitched image is the image itself and the state is the identity matrix
        else:
            cm["img"] = imgs[np.where(clustering == cm["cluster"])[0][0]]
            cm["label"] = [np.eye(3)]
            cm["Ht"] = np.eye(3)

In [None]:
#This function allows to compute the inter-cluster homographies 
def compute_inter_patches_dict(clusters_matrices, #Matrices related to each cluster
                               clustering, #Clustering of the graph
                               Z, #Matrix containing the homographies between the images
                               adj_matrix #Adjacency matrix of the graph
                              ):
    inter_patches_dict = dict()
    #Adjacency matrix of the patch-graph
    inter_patches_adj_matrix = np.zeros([len(np.unique(clustering)),len(np.unique(clustering))], dtype=int)
    n = clustering.shape[0] 
    
    #For every image
    for i in range(n):
        c_i = clustering[i] #Retrieve the cluster to which the image belongs
        idxs_i = np.where(clustering == c_i)[0] #Compute the indexes of the images belonging to the same cluster
        x_i = clusters_matrices[c_i]["label"][np.where(idxs_i == i)[0][0]] #Retrieve the state of the current source node
        
        #For every image
        for j in range(n):
            c_j = clustering[j] #Retrieve the cluster to which the image belongs
            #If different cluster
            if(c_i != c_j):
                #If the two images are connected in the graph
                if(adj_matrix[i,j] != 0):
                    #If the relation has not been yet examined
                    if((c_i, c_j) not in inter_patches_dict):
                        inter_patches_dict[c_i, c_j] = list() #Create the list that will contain the set of homographies relating the two nodes c_i and c_j of the new graph
                    
                    #Increase the degree of the current multi-edge
                    inter_patches_adj_matrix[c_i,c_j] += 1
                    #Retrieve the indexes of the images belonging to the same cluster
                    idxs_j = np.where(clustering == c_j)[0]
                    #Retrieve the state of the current destination node
                    x_j = clusters_matrices[c_j]["label"][np.where(idxs_j == j)[0][0]]
                    
                    #Retrieve the homography relating image i and image j
                    z_i_j = np.copy(Z[3*j:3*(j+1), 3*i:3*(i+1)])
                    #Compute the new homography relating nodes i and j
                    w_i_j = x_j @ z_i_j @ np.linalg.inv(x_i)
                    
                    #Normalize so that the determinant is 1
                    det = np.linalg.det(w_i_j)
                    w_i_j = w_i_j/np.cbrt(det)
                    
                    inter_patches_dict[c_i,c_j].append(w_i_j)
                    
    return inter_patches_dict, inter_patches_adj_matrix

In [None]:
#This function allows to apply the Multi-Patch graph synchronization method
def multi_patch_stitching(dataset_name, #Name of the dataset to be used
                        imgs, #Images to be stitched
                        T_norm, #Normalization matrix
                        Z, #Matrix containing the homographies between images
                        adj_matrix, #Adjacency matrix of the graph built according to the matches between the images
                        weight_matrix, #Weight matrix of the graph built according to the matches between the images
                        verbose = True, #If True allows to print intermediate results
                        save_output = True, #If True allows to save the output
                        idx_ref = 0, #Index of the reference image
                        beautify = True, #If True allows to print the stitched image in a better way
                        stitching_dir = "stitched", #Directory where to save the stitched image
                        graph_stitching_dir = "patch_graph_stitching", #Directory where to save the patch graph
                        warp_shape = [10000,10000] #Shape of the stitched image
                        ):
    
    #Create directories if not exist
    output_dir = os.path.join(os.path.join(stitching_dir,dataset_name), graph_stitching_dir)
    if save_output:
        if os.path.isdir(output_dir):   
            shutil.rmtree(output_dir)
        if not os.path.isdir(output_dir):   
            os.makedirs(output_dir)
    
    #Compute clustering of the graph
    clustering = compute_clustering(adj_matrix, weight_matrix, verbose)
    #Compute matrices related to each cluster
    clusters_matrices = build_clusters_matrices(Z, adj_matrix, clustering, idx_ref)
    #Perform simple-graph synchronization on each cluster
    intra_patches_synchronization(dataset_name,
                                  clusters_matrices, 
                                  clustering,
                                  imgs,
                                  T_norm,
                                  beautify,
                                  warp_shape)
    
    #Retrieve stitched image for each cluster
    cluster_stitched_imgs = [cm["img"] for cm in clusters_matrices]
    #Retrieve translation matrix for each cluster
    cluster_imgs_transl = [cm["Ht"] for cm in clusters_matrices]
    
    #Compute inter-cluster homographies
    inter_patches_dict, inter_patches_adj_matrix = compute_inter_patches_dict(clusters_matrices,
                                                    clustering,
                                                    Z,
                                                    adj_matrix)
    
    #Compute matrices needed to apply graph synchronization on the patch graph
    M, _, C, _ = GraphBuilding.build_multi_graph_matrices(dataset_name,
                cluster_stitched_imgs,
                inter_patches_dict,
                output_dir = output_dir,
                verbose = False,
                save_output = save_output
               )
    
    #Apply multi-graph synchronization to the patch-graph
    H_inter_clusters, stitched_image, _ = MultiGraphStitching.multi_graph_stitching(dataset_name,
                            cluster_stitched_imgs,
                            T_norm,
                            M, 
                            C,
                            imgs_translations = cluster_imgs_transl,
                            idx_ref = clustering[idx_ref],
                            verbose = verbose,
                            save_output = save_output,
                            beautify = beautify,
                            stitching_dir = stitching_dir,
                            graph_stitching_dir = graph_stitching_dir,
                            warp_shape = warp_shape )
    
    #Compute the state of each node
    H = list()
    #For each node
    for i in range(len(imgs)):
        c = clustering[i] #Retrieve the cluster to which the node belongs
        idxs = np.where(clustering == c)[0] #Retrieve indexes of the nodes belonging to the current cluster
        x_i = clusters_matrices[c]["label"][np.where(idxs == i)[0][0]] #Compute the label of the current node
        H.append(H_inter_clusters[c] @ x_i)
    
    return H, stitched_image

In [None]:
#This function allows to apply the Multi-Patch Edge-Averaging method
def average_patch_stitching(dataset_name, #Name of the dataset to be used
                        imgs, #Images to be stitched
                        T_norm, #Normalization matrix
                        Z, #Matrix containing the homographies between images
                        adj_matrix, #Adjacency matrix of the graph built according to the matches between the images
                        weight_matrix, #Weight matrix of the graph built according to the matches between the images
                        verbose = True, #If True prints intermediate results
                        save_output = True, #If True saves output
                        idx_ref = 0, #Index of the reference image
                        beautify = True, #If True allows to print the stitched image in a better way
                        stitching_dir = "stitched",#Directory where to save the stitched image
                        graph_stitching_dir = "edge_avg_patch_graph_stitching", #Directory where to save the patch graph
                        warp_shape = [10000,10000] #Shape of the stitched image
                        ):
        
    #Create directories if not exist
    output_dir = os.path.join(os.path.join(stitching_dir,dataset_name), graph_stitching_dir)
    if save_output:
        if os.path.isdir(output_dir):   
            shutil.rmtree(output_dir)
        if not os.path.isdir(output_dir):   
            os.makedirs(output_dir)
    
    #Compute clustering of the graph
    clustering = compute_clustering(adj_matrix, weight_matrix, verbose)
    #Compute matrices related to each cluster
    clusters_matrices = build_clusters_matrices(Z, adj_matrix, clustering, idx_ref)
    #Perform simple-graph synchronization on each cluster
    intra_patches_synchronization(dataset_name,
                                  clusters_matrices, 
                                  clustering,
                                  imgs,
                                  T_norm,
                                  beautify,
                                  warp_shape)
    
    #Retrieve stitched image for each cluster
    cluster_stitched_imgs = [ cm["img"] for cm in clusters_matrices]
    #Retrieve translation matrix for each cluster
    cluster_imgs_transl = [cm["Ht"] for cm in clusters_matrices]
    
    #Compute inter-cluster homographies
    inter_patches_dict, inter_patches_adj_matrix = compute_inter_patches_dict(clusters_matrices,
                                                    clustering,
                                                    Z,
                                                    adj_matrix)
    
    #Compute matrices needed to apply edge averaging on the patch graph
    M, _ , _ = GraphBuilding.build_edge_averaging_matrices(dataset_name,
                cluster_stitched_imgs,
                inter_patches_dict,
                output_dir = output_dir,
                verbose = False,
                save_output = save_output
               )
    
    #Apply simple-graph synchronization to the patch-graph 
    H_inter_clusters, stitched_image, _ = SimpleGraphStitching.simple_graph_stitching(dataset_name,
                            cluster_stitched_imgs,
                            T_norm,
                            M, 
                            imgs_translations = cluster_imgs_transl,
                            idx_ref = clustering[idx_ref],
                            verbose = verbose,
                            save_output = save_output,
                            beautify = beautify,
                            stitching_dir = stitching_dir,
                            graph_stitching_dir = graph_stitching_dir,
                            warp_shape = warp_shape )
    
    #Compute the state of each node
    H = list()
    #For each node
    for i in range(len(imgs)):
        c = clustering[i] #Retrieve the cluster to which the node belongs
        idxs = np.where(clustering == c)[0] #Retrieve indexes of the nodes belonging to the current cluster
        x_i = clusters_matrices[c]["label"][np.where(idxs == i)[0][0]] #Compute the label of the current node
        H.append(H_inter_clusters[c] @ x_i)
    
    return H, stitched_image