In [1]:
import cv2 as cv
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import ipynb.fs.defs.Utils as Utils
import ipynb.fs.defs.ImageStitcher as ImageStitcher

In [None]:
def get_clusters(adj_matrix, weight_matrix):
    distances = weight_matrix.copy()
    maxx = np.max(distances)
    mask = distances != 0
    distances[mask] = maxx / distances[mask]
    
    clustering = AgglomerativeClustering(distance_threshold=0, n_clusters=None,affinity="precomputed",connectivity=adj_matrix,linkage="average").fit(distances)
    linkage_matrix = get_linkage_matrix(clustering)
    den = dendrogram(linkage_matrix, truncate_mode=None) # truncate_mode="level", p=3)
    k = len(list(dict.fromkeys(den["leaves_color_list"]))) + max(den["leaves_color_list"].count('C0') - 1, 0 )
    cluster = fcluster(linkage_matrix, k, criterion='maxclust')

    return cluster-1

In [None]:
def get_linkage_matrix(model):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)

    return linkage_matrix

In [None]:
def split_Matrices(Z,adj_matrix, clusters):
    new_matrices = list()
    for c in np.unique(clusters):
        indexes = clusters==c
        expanded_indexes =np.array([], dtype = bool)
        for i in indexes:
            expanded_indexes = np.concatenate([expanded_indexes,[i]*3])
        subgraph_adj_matrix = adj_matrix[np.ix_(indexes, indexes)]
        new_matrices.append({"c":c,
                             "H":Z[np.ix_(expanded_indexes, expanded_indexes)],
                             "adj_matrix": subgraph_adj_matrix,
                            "endpoint":np.argmax(np.sum(subgraph_adj_matrix, axis=1))})
    return new_matrices  

In [None]:
def multi_graph_stitching(dataset_name,
                        imgs,
                        T_norm,
                        Z, #Matrix containing the homographies between images
                        adj_matrix, #Adjacency matrix of the graph built according to the matches between the images
                        weight_matrix, #Weight matrix of the graph built according to the matches between the images
                        verbose = True,
                        save_output = True,
                        beautify = True,
                        stitching_dir = "stitched",
                        multi_patch_stitching_dir = "multi_graph_stitching",
                        warp_shape = [10000,10000] ):
        
    output_dir = os.path.join(os.path.join(stitching_dir,dataset_name),graph_stitching_dir)
    if save_output:
        if os.path.isdir(output_dir):   
            shutil.rmtree(output_dir)
        if not os.path.isdir(output_dir):   
            os.makedirs(output_dir)
            
    clusters = get_clusters(adj_matrix,weight_matrix)
    new_matrices = split_Matrices(Z,adj_matrix,clusters)
    
    for matrices in new_matrices:
        zeta = np.sum(matrices["adj_matrix"], axis=0)
        M = np.copy(matrices["H"])
        for i in range(0,zeta.shape[0]):
            M[3*i:3*(i+1),3*i:3*(i+1)] = np.eye(3,3)*(1-zeta[i])
        matrices["M"] = M
        if(matrices["G"].shape != (3,3)):
            u,d,vh = np.linalg.svd(matrices["G"]) #Application of SVD to the G matrix
            v = vh.transpose() #Transpose matrix of right eigenvectors
            # Find the three right singular vectors associated
            # with the 3 smallest singular value (so the last 3 columns)
            u_hat = v[:,[-1,-2,-3]]
            U = get_states(u_hat) #Get the state of each node
            indexes = cluster==matrices["c"]
            imgs_slice = [imgs[i] for i in np.where(indexes)[0]]
            matrices["label"],matrices["img"] = stitch_images(U, imgs_slice, matrices["endpoint"], len(imgs_slice), beautify=beautify) #Stitch all the images
        
            simple_graph_stitching(dataset_name,
                        imgs_slice,
                        T_norm,
                        M,
                        matrices["endpoint"],
                        save_output=False
                        )
        
        else:
            matrices["img"] = imgs[np.where(cluster==matrices["c"])[0][0]]
            matrices["label"] = [np.eye(3)]
    

In [None]:
simple_graph_stitching(dataset_name,
                            imgs,
                            T_norm,
                            M, 
                            idx_ref = 0,
                            idxs = None,
                            verbose = True,
                            save_output = True,
                            beautify = True,
                            stitching_dir = "stitched",
                            graph_stitching_dir = "simple_graph_stitching",
                            warp_shape = [10000,10000] )