In [1]:
import cv2 as cv
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster, inconsistent
import ipynb.fs.defs.Utils as Utils
import ipynb.fs.defs.ImageStitcher as ImageStitcher
import ipynb.fs.defs.GraphBuilding as GraphBuilding
import ipynb.fs.defs.SimpleGraphStitching as SimpleGraphStitching
import ipynb.fs.defs.MultiGraphStitching as MultiGraphStitching

In [3]:
def get_linkage_matrix(model):
    # Create linkage matrix and then plot the dendrogram
    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)

    return linkage_matrix

In [2]:
def compute_clustering(adj_matrix, weight_matrix, verbose = False):
    distances = weight_matrix.copy()
    maxx = np.max(distances)
    mask = distances != 0
    distances[mask] = maxx / distances[mask]
    
    clustering = AgglomerativeClustering(distance_threshold=0, n_clusters=None,affinity="precomputed",connectivity=adj_matrix,linkage="average").fit(distances)
    linkage_matrix = get_linkage_matrix(clustering)
    den = dendrogram(linkage_matrix, truncate_mode=None, no_plot = not verbose)
    k = len(list(dict.fromkeys(den["leaves_color_list"]))) + max(den["leaves_color_list"].count('C0') - 1, 0 )
    clustering = fcluster(linkage_matrix, k, criterion='maxclust')

    return clustering-1

In [4]:
def build_clusters_matrices(Z, adj_matrix, clusters, idx_ref):
    clusters_matrices = list()
    for c in np.unique(clusters):
        indexes = clusters == c
        expanded_indexes = np.array([], dtype = bool)
        for i in indexes:
            expanded_indexes = np.concatenate([expanded_indexes,[i]*3])
        subgraph_adj_matrix = adj_matrix[np.ix_(indexes, indexes)]
        if idx_ref in np.where(indexes == True)[0]:
            endpoint = np.where(np.where(indexes == True)[0]==idx_ref)[0][0]
        else:
            endpoint = np.argmax(np.sum(subgraph_adj_matrix, axis=1))
        clusters_matrices.append({
            "cluster": c,
            "Z": Z[np.ix_(expanded_indexes, expanded_indexes)],
            "adj_matrix": subgraph_adj_matrix,
            "endpoint": endpoint
        })
    return clusters_matrices  

In [5]:
def intra_patches_synchronization(dataset_name, clusters_matrices, clustering, imgs, T_norm, beautify, warp_shape):
    for cm in clusters_matrices:
        cm["M"] = GraphBuilding.compute_M_matrix(cm["adj_matrix"], cm["Z"])
        if cm["M"].shape != (3,3): # if there are more than 1 elements in the cluster
            indexes = clustering == cm["cluster"]
            imgs_slice = [imgs[i] for i in np.where(indexes)[0]]
            cm["label"], cm["img"], cm["Ht"] = SimpleGraphStitching.simple_graph_stitching(dataset_name,
                        imgs_slice,
                        T_norm,
                        cm["M"],
                        cm["endpoint"],
                        save_output = False,
                        verbose = False,
                        beautify = beautify,
                        warp_shape = warp_shape
                        )
        else:
            cm["img"] = imgs[np.where(clustering == cm["cluster"])[0][0]]
            cm["label"] = [np.eye(3)]
            cm["Ht"] = np.eye(3)

In [6]:
def compute_inter_patches_dict(clusters_matrices, clustering, Z, adj_matrix):
    inter_patches_dict = dict()
    inter_patches_adj_matrix = np.zeros([len(np.unique(clustering)),len(np.unique(clustering))], dtype=int)
    n = clustering.shape[0]
    
    for i in range(n):
        c_i = clustering[i]
        idxs_i = np.where(clustering == c_i)[0]
        x_i = clusters_matrices[c_i]["label"][np.where(idxs_i == i)[0][0]]
        
        for j in range(n):
            c_j = clustering[j]
            if(c_i != c_j): #different cluster
                if(adj_matrix[i,j] != 0):
                    if((c_i, c_j) not in inter_patches_dict):
                        inter_patches_dict[c_i, c_j] = list()
                        
                    inter_patches_adj_matrix[c_i,c_j] += 1
                    idxs_j = np.where(clustering == c_j)[0]
                    x_j = clusters_matrices[c_j]["label"][np.where(idxs_j == j)[0][0]]
                    
                    z_i_j = np.copy(Z[3*j:3*(j+1), 3*i:3*(i+1)])
                    w_i_j = x_j @ z_i_j @ np.linalg.inv(x_i)
                    
                    det = np.linalg.det(w_i_j)
                    w_i_j = w_i_j/np.cbrt(det)
                    
                    inter_patches_dict[c_i,c_j].append(w_i_j)
                    
    return inter_patches_dict, inter_patches_adj_matrix

In [7]:
def multi_patch_stitching(dataset_name,
                        imgs,
                        T_norm,
                        Z, #Matrix containing the homographies between images
                        adj_matrix, #Adjacency matrix of the graph built according to the matches between the images
                        weight_matrix, #Weight matrix of the graph built according to the matches between the images
                        verbose = True,
                        save_output = True,
                        idx_ref = 0,
                        beautify = True,
                        stitching_dir = "stitched",
                        graph_stitching_dir = "patch_graph_stitching",
                        warp_shape = [10000,10000] ):
        
    output_dir = os.path.join(os.path.join(stitching_dir,dataset_name), graph_stitching_dir)
    if save_output:
        if os.path.isdir(output_dir):   
            shutil.rmtree(output_dir)
        if not os.path.isdir(output_dir):   
            os.makedirs(output_dir)
    
    clustering = compute_clustering(adj_matrix, weight_matrix, verbose)
    clusters_matrices = build_clusters_matrices(Z, adj_matrix, clustering, idx_ref)
    
    intra_patches_synchronization(dataset_name,
                                  clusters_matrices, 
                                  clustering,
                                  imgs,
                                  T_norm,
                                  beautify,
                                  warp_shape)
    
    cluster_stitched_imgs = [cm["img"] for cm in clusters_matrices]
    cluster_imgs_transl = [cm["Ht"] for cm in clusters_matrices]
    
    inter_patches_dict, inter_patches_adj_matrix = compute_inter_patches_dict(clusters_matrices,
                                                    clustering,
                                                    Z,
                                                    adj_matrix)
   
    M, _, C, _ = GraphBuilding.build_multi_graph_matrices(dataset_name,
                cluster_stitched_imgs,
                inter_patches_dict,
                output_dir = output_dir,
                verbose = False,
                save_output = save_output
               )
    
    H_inter_clusters, stitched_image, _ = MultiGraphStitching.multi_graph_stitching(dataset_name,
                            cluster_stitched_imgs,
                            T_norm,
                            M, 
                            C,
                            imgs_translations = cluster_imgs_transl,
                            idx_ref = clustering[idx_ref],
                            verbose = verbose,
                            save_output = save_output,
                            beautify = beautify,
                            stitching_dir = stitching_dir,
                            graph_stitching_dir = graph_stitching_dir,
                            warp_shape = warp_shape )
    
    H = list()
    for i in range(len(imgs)):
        c = clustering[i]
        idxs = np.where(clustering == c)[0]
        x_i = clusters_matrices[c]["label"][np.where(idxs == i)[0][0]]
        H.append(H_inter_clusters[c] @ x_i)
    
    return H, stitched_image

In [1]:
def average_patch_stitching(dataset_name,
                        imgs,
                        T_norm,
                        Z, #Matrix containing the homographies between images
                        adj_matrix, #Adjacency matrix of the graph built according to the matches between the images
                        weight_matrix, #Weight matrix of the graph built according to the matches between the images
                        verbose = True,
                        save_output = True,
                        idx_ref = 0,
                        beautify = True,
                        stitching_dir = "stitched",
                        graph_stitching_dir = "patch_graph_stitching",
                        warp_shape = [10000,10000] ):
        
    output_dir = os.path.join(os.path.join(stitching_dir,dataset_name), graph_stitching_dir)
    if save_output:
        if os.path.isdir(output_dir):   
            shutil.rmtree(output_dir)
        if not os.path.isdir(output_dir):   
            os.makedirs(output_dir)
    
    clustering = compute_clustering(adj_matrix, weight_matrix, verbose)
    clusters_matrices = build_clusters_matrices(Z, adj_matrix, clustering, idx_ref)
    
    intra_patches_synchronization(dataset_name,
                                  clusters_matrices, 
                                  clustering,
                                  imgs,
                                  T_norm,
                                  beautify,
                                  warp_shape)
    
    cluster_stitched_imgs = [ cm["img"] for cm in clusters_matrices]
    cluster_imgs_transl = [cm["Ht"] for cm in clusters_matrices]
    
    inter_patches_dict, inter_patches_adj_matrix = compute_inter_patches_dict(clusters_matrices,
                                                    clustering,
                                                    Z,
                                                    adj_matrix)
   
    M, _ , _ = GraphBuilding.build_edge_averaging_matrices(dataset_name,
                cluster_stitched_imgs,
                inter_patches_dict,
                output_dir = output_dir,
                verbose = False,
                save_output = save_output
               )
    
    H_inter_clusters, stitched_image, _ = SimpleGraphStitching.simple_graph_stitching(dataset_name,
                            cluster_stitched_imgs,
                            T_norm,
                            M, 
                            imgs_translations = cluster_imgs_transl,
                            idx_ref = clustering[idx_ref],
                            verbose = verbose,
                            save_output = save_output,
                            beautify = beautify,
                            stitching_dir = stitching_dir,
                            graph_stitching_dir = graph_stitching_dir,
                            warp_shape = warp_shape )
    
    H = list()
    for i in range(len(imgs)):
        c = clustering[i]
        idxs = np.where(clustering == c)[0]
        x_i = clusters_matrices[c]["label"][np.where(idxs == i)[0][0]]
        H.append(H_inter_clusters[c] @ x_i)
    
    return H, stitched_image