# Utils
This notebook contains some utility functions that are used to implement the different stitching mechanisms.

## Importing libraries

In [None]:
import os
import cv2 as cv
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np

## Functions definition

In [None]:
#This function loads images from a specific folder
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = cv.imread(os.path.join(folder,filename))
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        if img is not None:
            images.append(img)
    return images

In [None]:
#This function allows to build a weighted graph starting from the adjacency matrix
def build_graph(imgs, #Images associated to nodes
                adj_matrix, #Adjacency matrix
                weight_matrix #Weight matrix
               ):
    G=nx.Graph()
    n=len(imgs)
    
    #Create a node for each image
    for i in range(0,n):
        G.add_node(i,image = imgs[i])
    #Add an edge between two images whenever adj_matrix[i,j] is 1
    for i in range(n): 
         for j in range(n): 
            if adj_matrix[i,j] == 1:
                G.add_edge(i,j, weight=max(weight_matrix[i,j],weight_matrix[j,i])) 
    return G

In [None]:
#This function allows to print a graph together with the images
def advanced_print_graph(imgs, #Imgs (TODO useless)
                         G, #Graph to be printed
                         save_images = True, #If True, saves the graph
                         graph_dir = "output", #Folder where to save the graph
                         graph_name = "graph" #Graph name
                        ):
    pos=nx.circular_layout(G)
    fig=plt.figure(figsize=(15,15))
    ax=plt.subplot(111)
    ax.set_aspect('equal')
    nx.draw(G,pos,ax=ax, width = 3, node_size=900,with_labels = True, 
            edgecolors='red', node_color='lightgray', connectionstyle='arc3, rad = 0.1')
    
    labels = nx.get_edge_attributes(G,'weight')
    nx.draw_networkx_edge_labels(G,pos,edge_labels=labels)
    plt.xlim(-1.5,1.5)
    plt.ylim(-1.5,1.5)
    trans=ax.transData.transform
    trans2=fig.transFigure.inverted().transform
    piesize=0.08 # This is the image size
    p2=piesize/0.9 #This is the image center
    
    for g in G:
        xx,yy=trans(pos[g]) #Figure coordinates
        xa,ya=trans2((xx,yy)) #Axes coordinates
        a = plt.axes([xa-p2,ya-p2, piesize, piesize])
        a.set_aspect('equal')
        a.imshow(G.nodes[g]['image'])
        a.axis('off')
    ax.axis('off')
    
    #If required save the resulting graph
    if save_images:
        plt.savefig(os.path.join(graph_dir,f'{graph_name}.png'))
    return G

In [None]:
#This function allows to build and print a multi-graph
def build_and_print_multi_graph(adj_matrix,
                                save_images = True, #If True, saves the graph
                                graph_dir = "output", #Folder where to save the graph
                                graph_name = "graph" #Graph name
                               ):
    #Create the graph starting from the adjacency matrix
    graph = nx.from_numpy_matrix(adj_matrix, parallel_edges=True, create_using=nx.MultiDiGraph)
    
    #Print the graph
    pos=nx.circular_layout(graph)
    fig = plt.figure(figsize=(10,10))
    nx.draw(graph, pos=pos, with_labels=True, connectionstyle='arc3, rad = 0.1')
    plt.show()
    
    #If required, save the graph
    if save_images:
        fig.savefig(os.path.join(graph_dir,f'{graph_name}.png'))
    return graph

In [None]:
#This function allows to print a graph
def print_graph(graph, #Graph to be printed
                save_images = True, #If True, saves the graph
                graph_dir = "output", #Folder where to save the graph
                graph_name = "graph" #Graph name
               ):
    
    #Print the graph
    pos=nx.spring_layout(graph)
    fig=plt.figure(figsize=(10,10))
    ax=plt.subplot(111)
    ax.set_aspect('equal')
    nx.draw(graph,pos,ax=ax, width = 3, node_size=900,with_labels = True, edgecolors='red', node_color='lightgray')
    plt.show()
    
    #If required save the graph
    if save_images:
        fig.savefig(os.path.join(graph_dir,f'{graph_name}.png'))

In [None]:
#This function allows to retrieve the homographies w.r.t. the reference index starting from the state matrix
def get_homographies_from_states(U, idx_ref=0):
    #Compute the homography of each image w.r.t. the reference one
    H = [np.dot(U[idx_ref],np.linalg.inv(U[i])) for i in range(len(U))]
    return H

In [None]:
#This function allows to compute the reference image/node
def get_reference_node(adj_matrix):
    return np.argmax(np.sum(adj_matrix, axis=0))

In [None]:
#This function splits a 3nx3n matrix into a set of 3x3 matrices
def split_states(x):
    x_small = x.transpose()
    res = [ x_small[:,i*3:(i+1)*3].transpose() for i in range(x_small.shape[1]//3)]
    return res

In [None]:
#This function allows to compute the normalization matrix
def get_normalization_matrix(imgs):
    rescale = 1./np.max(imgs[0].shape)
    T_norm = np.diag([rescale,rescale,1])
    return T_norm