# Utils
This notebook contains some utility functions that are used to implement the different stitching mechanisms.

## Importing libraries

In [4]:
import os
import cv2 as cv
from math import ceil
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np
import matplotlib
import plotly.graph_objects as go
from IPython.display import display
import pandas as pd
from enum import Enum

## Functions definition

In [5]:
#This function loads images from a specific folder
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = cv.imread(os.path.join(folder,filename))
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        if img is not None:
            images.append(img)
    return images

In [6]:
#This function allows to build a weighted graph starting from the adjacency matrix
def build_graph(imgs, #Images associated to nodes
                adj_matrix, #Adjacency matrix
                weight_matrix #Weight matrix
               ):
    G=nx.Graph()
    n=len(imgs)
    
    #Create a node for each image
    for i in range(0,n):
        G.add_node(i,image = imgs[i])
    #Add an edge between two images whenever adj_matrix[i,j] is 1
    for i in range(n): 
         for j in range(n): 
            if adj_matrix[i,j] == 1:
                G.add_edge(i,j, weight=max(weight_matrix[i,j],weight_matrix[j,i])) 
    return G

In [7]:
#This function allows to print a graph together with the images
def advanced_print_graph(imgs, #Imgs (TODO useless)
                         G, #Graph to be printed
                         save_images = True, #If True, saves the graph
                         graph_dir = "output", #Folder where to save the graph
                         graph_name = "graph" #Graph name
                        ):
    pos=nx.circular_layout(G)
    fig=plt.figure(figsize=(15,15))
    ax=plt.subplot(111)
    ax.set_aspect('equal')
    nx.draw(G,pos,ax=ax, width = 3, node_size=900,with_labels = True, 
            edgecolors='red', node_color='lightgray', connectionstyle='arc3, rad = 0.1')
    
    labels = nx.get_edge_attributes(G,'weight')
    nx.draw_networkx_edge_labels(G,pos,edge_labels=labels)
    plt.xlim(-1.5,1.5)
    plt.ylim(-1.5,1.5)
    trans=ax.transData.transform
    trans2=fig.transFigure.inverted().transform
    piesize=0.08 # This is the image size
    p2=piesize/0.9 #This is the image center
    
    for g in G:
        xx,yy=trans(pos[g]) #Figure coordinates
        xa,ya=trans2((xx,yy)) #Axes coordinates
        a = plt.axes([xa-p2,ya-p2, piesize, piesize])
        a.set_aspect('equal')
        a.imshow(G.nodes[g]['image'])
        a.axis('off')
    ax.axis('off')
    
    #If required save the resulting graph
    if save_images:
        plt.savefig(os.path.join(graph_dir,f'{graph_name}.png'))
    return G

In [8]:
#This function allows to build and print a multi-graph
def build_and_print_multi_graph(adj_matrix,
                                save_images = True, #If True, saves the graph
                                graph_dir = "output", #Folder where to save the graph
                                graph_name = "graph" #Graph name
                               ):
    #Create the graph starting from the adjacency matrix
    graph = nx.from_numpy_matrix(adj_matrix, parallel_edges=True, create_using=nx.MultiDiGraph)
    
    #Print the graph
    pos=nx.circular_layout(graph)
    fig = plt.figure(figsize=(10,10))
    nx.draw(graph, pos=pos, with_labels=True, connectionstyle='arc3, rad = 0.1')
    plt.show()
    
    #If required, save the graph
    if save_images:
        fig.savefig(os.path.join(graph_dir,f'{graph_name}.png'))
    return graph

In [9]:
#This function allows to print a graph
def print_graph(graph, #Graph to be printed
                save_images = True, #If True, saves the graph
                graph_dir = "output", #Folder where to save the graph
                graph_name = "graph" #Graph name
               ):
    
    #Print the graph
    pos=nx.spring_layout(graph)
    fig=plt.figure(figsize=(10,10))
    ax=plt.subplot(111)
    ax.set_aspect('equal')
    nx.draw(graph,pos,ax=ax, width = 3, node_size=900,with_labels = True, edgecolors='red', node_color='lightgray')
    plt.show()
    
    #If required save the graph
    if save_images:
        fig.savefig(os.path.join(graph_dir,f'{graph_name}.png'))

In [10]:
#This function allows to retrieve the homographies w.r.t. the reference index starting from the state matrix
def get_homographies_from_states(U, idx_ref=0):
    #Compute the homography of each image w.r.t. the reference one
    H = [np.dot(U[idx_ref],np.linalg.inv(U[i])) for i in range(len(U))]
    return H

In [11]:
#This function allows to compute the reference image/node
def get_reference_node(adj_matrix):
    return np.argmax(np.sum(adj_matrix, axis=0))

In [12]:
#This function splits a 3nx3n matrix into a set of 3x3 matrices
def split_states(x):
    x_small = x.transpose()
    res = [ x_small[:,i*3:(i+1)*3].transpose() for i in range(x_small.shape[1]//3)]
    return res

In [13]:
#This function allows to compute the normalization matrix
def get_normalization_matrix(imgs):
    rescale = 1./np.max(imgs[0].shape)
    T_norm = np.diag([rescale,rescale,1])
    return T_norm

In [14]:
def mesh_plot(df):
    colors = ['#636EFA', 
         '#EF553B',
         '#00CC96',
         '#AB63FA',
         '#FFA15A',
         '#19D3F3',
         '#FF6692',
         '#B6E880',
         '#FF97FF',
         '#FECB52']

    fig = go.Figure()
    for i,c in enumerate(df.columns):
        x=df.loc[:,c].index.get_level_values(1).values
        y=df.loc[:,c].index.get_level_values(0).values
        z=np.minimum(df.loc[:,c].values, 25)
        #z=np.minimum(df.loc[:,c].values, 25)
        fig.add_trace(go.Mesh3d(name = c, 
                                 visible = True,
                                 x=x,
                                 y=y,
                                 z=z, 
                                 color = colors[i], 
                                 opacity=0.9))

    um = [ {} for _ in range(len(df.columns)) ]
    menuadjustment = 0.15

    buttonX = -0.1
    buttonY = 1 + menuadjustment
    for i, col in enumerate(df.columns):
        button = dict(method='restyle',
                      label=col,
                      visible=True,
                      args = [{'visible':True}, [i]],
                      args2 = [{'visible': False}, [i]],
                     )

        # adjust some button features
        buttonY = buttonY-menuadjustment
        um[i]['buttons'] = [button]
        um[i]['showactive'] = False
        um[i]['y'] = buttonY
        um[i]['x'] = buttonX

    # add a button to toggle all traces on and off
    button2 = dict(method='restyle',
                   label='All',
                   visible=True,
                   args=[{'visible':True}],
                   args2 = [{'visible': False}],
                   )

    # assign button2 to an updatemenu and make some adjustments
    um.append(dict())
    um[i+1]['buttons'] = [button2]
    um[i+1]['showactive'] = True
    um[i+1]['y'] = buttonY - menuadjustment
    um[i+1]['x'] = buttonX

    scene = dict(
                xaxis = dict(
                    title='noise std'),
                yaxis = dict(
                    title='number of matches'),
                zaxis = dict(
                    title='error'),
                )

    fig.update_layout(title='Experiments',
                      scene = scene,
                      autosize=True,
                      width=1000,
                      height=700,
                      showlegend=True, 
                      updatemenus=um,
                      margin=dict(l=65, r=50, b=65, t=90))

    for m in fig.layout.updatemenus:
        m['type'] = 'buttons'

    fig.show()

In [15]:
def scatter_with_slider(slider, df, x_title, slide_title):
        # Create figure#
    layout = go.Layout(
        title=f"Errors sliding {slide_title}",
        xaxis=dict(
            title=x_title
        ),
        yaxis=dict(
            title='Error'
        ) )
    
    scene = dict(
                xaxis = dict(
                    title=x_title),
                yaxis = dict(
                    title='Error'))

    # Create figure
    fig = go.Figure(layout=layout)
    
    colors = ['#636EFA', 
     '#EF553B',
     '#00CC96',
     '#AB63FA',
     '#FFA15A',
     '#19D3F3',
     '#FF6692',
     '#B6E880',
     '#FF97FF',
     '#FECB52']
    
    num_m = len(df.columns)

    # Add traces, one for each slider step
    for slide in slider:
        for i,c in enumerate(df.columns):
            x = df.loc[pd.IndexSlice[slide,:],c].index.get_level_values(1).values
            y = np.minimum(df.loc[pd.IndexSlice[slide,:],c].values, 25)
            fig.add_trace(
                go.Scatter(
                    visible=False,
                    line=dict(color=colors[i], width=2),
                    name=c,
                    x=x,
                    y=y
                )
            )

    # Make 10th trace visible
    for s in fig.data[:num_m]:
        s.visible = True

    # Create and add slider
    steps = []
    for i in range(slider.shape[0]):
        step = dict(
            method="update",
            label = str(slider[i]),
            args=[{"visible": [False] * len(fig.data)}]  # layout attribute
        )
        step["args"][0]["visible"][i*num_m:(i+1)*num_m] = [True]*num_m  # Toggle i'th trace to "visible"
        steps.append(step)

    sliders = [dict(
        active=10,
        currentvalue={"prefix":  f"{slide_title}: "},
        pad={"t": 50},
        steps=steps
    )]
    
    fig.update_layout(
        sliders=sliders,
        scene = scene,
        width=840,
        height=500
    )

    fig.show()

In [16]:
def plot_images(results, stitched_img_gt):
    
    imgs = [r["img"] for r in results]
    labels = [r["name"] for r in results]
    
    imgs.append(stitched_img_gt)
    labels.append("Ground truth")
    
    n = len(labels)
    
    columns = 2
    rows = ceil(n/columns)
    fig = plt.figure(figsize=(columns*20, rows*20))

    # ax enables access to manipulate each of subplots
    ax = []

    for i in range(n):
        # create subplot and append to ax
        ax.append( fig.add_subplot(rows, columns, i+1) )
        ax[-1].set_title(labels[i])  # set title
        plt.imshow(imgs[i])

    plt.show()  

In [17]:
class NoiseType(Enum):
    NO_NOISE = 1
    HOMOGRAPHY = 2
    POINTS = 3