In [1]:
import cv2 as cv
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import networkx as nx
import ipynb.fs.defs.Utils as Utils

In [2]:
def transform_image(ref_index, 
                    transf_index, 
                    image, 
                    Z, 
                    warp_shape = [2000,1000],
                    scaling_factor = 1):
    
        rescale = 1./scaling_factor
        T_norm = np.diag([rescale,rescale,1])
        h = Z[(transf_index)*3:(transf_index+1)*3,(ref_index)*3:(ref_index+1)*3]
        im_transformed = cv.warpPerspective(image, np.linalg.inv( np.linalg.inv(T_norm) @ h @ T_norm), warp_shape)
        
        return im_transformed

In [3]:
def stitch_images(imgs, 
                  graph, 
                  root_idx, 
                  father_idx, 
                  Z,  
                  warp_shape = [2000,1000],
                  save_output=True,
                  output_dir = "output",
                  verbose=False):
    base_image = cv.warpPerspective(imgs[root_idx], np.eye(3), warp_shape)
    nb_it = graph.neighbors(root_idx)
    scaling_factor = np.max(imgs[0].shape)
    for n in nb_it:
        if n != father_idx:
            child = stitch_images(imgs, graph, n, root_idx, Z, warp_shape, save_output, output_dir, verbose)
            tr_child = transform_image(root_idx, n, child, Z, warp_shape, scaling_factor)
            base_image =  np.maximum(base_image,tr_child)
    
    if verbose:
        figure(figsize=(40, 40), dpi=80)
        plt.imshow(base_image,),plt.show() 
    if save_output:
        cv.imwrite(os.path.join(output_dir,f"{root_idx}.jpg"), cv.cvtColor(base_image,cv.COLOR_RGB2BGR))
    return base_image  

In [5]:
def basic_stitching(dataset_name,
                    imgs,
                    Z, 
                    adj_matrix, 
                    weight_matrix,
                    idx_ref = 0,
                    verbose = True,
                    save_output = True,
                    stitching_dir = "stitched",
                    partial_results_dir = "partial",
                    basic_stitching_dir = "basic_stitching",
                    warp_shape = [10000,10000]
                   ):
    
    output_dir = os.path.join(os.path.join(stitching_dir,dataset_name),basic_stitching_dir)
    partial_output_dir = os.path.join(output_dir, partial_results_dir)
    if save_output:
        if os.path.isdir(output_dir):   
            shutil.rmtree(output_dir)
        if not os.path.isdir(output_dir):   
            os.makedirs(output_dir)
        if not os.path.isdir(partial_output_dir):
            os.makedirs(partial_output_dir)
               
    graph = Utils.build_and_print_graph(imgs, 
                                  adj_matrix, 
                                  weight_matrix, 
                                  save_output, 
                                  output_dir,
                                  "graph")
    
    spanning_tree = nx.maximum_spanning_tree(graph)
    
    Utils.print_graph(spanning_tree, save_output, output_dir, "spanning-tree")
    
    root = idx_ref
    stitched_image = stitch_images(imgs, 
                                   spanning_tree, 
                                   root, 
                                   root, 
                                   Z, 
                                   warp_shape, 
                                   save_output, 
                                   partial_results_dir, 
                                   verbose)
    
    if save_output:
        cv.imwrite(os.path.join(output_dir,"stitched.jpg"), cv.cvtColor(stitched_image,cv.COLOR_RGB2BGR))
        
    figure(figsize=(40, 40), dpi=80)
    plt.imshow(stitched_image,),plt.show() 
    
    return stitched_image