<h1>Image Matching Challenge 2022</h1>
<h2>Team towGeeses</h2>
 <ul>
  <li>Wilmer David Garzón Cáceres</li>
  <li>Juan Sebastian Santcoloma Barrera</li>
</ul>

<p> Solution of place 45/642 <b>(top 8%)</b></p>

<h3>Explanation</h3>
<p>This solution combain use the models LoFTR [1] ans SuperGlue [2] to find key point matches in a pair of images,
each image is preprocesed with geometrical transformations in order to make an augmentation of the data. 
The points coordinates finded by geometrical transformation must be returnned to original image coordinates.</p>
<p>The fundamental matrix is finded using the robust estimator MAGSAC++ [3] avalible in openCV.</p>

<h3>References</h3>
[1] J. Sun, Z. Shen, Y. Wang, H. Bao, and X. Zhou, ‘LoFTR: Detector-Free Local Feature Matching with Transformers’, CVPR, 2021.
<br/>
[2] P.-E. Sarlin, D. DeTone, T. Malisiewicz, and A. Rabinovich, ‘SuperGlue: Learning Feature Matching with Graph Neural Networks’, CVPR, 2020.
<br/>
[3]	D. Barath, J. Noskova, M. Ivashechkin, and J. Matas, ‘MAGSAC++, a fast, reliable and accurate robust estimator’. arXiv, 2019.

# ***Import dependencies***

In [2]:
import numpy as np
import cv2
import csv
import torch
import matplotlib.pyplot as plt
from kornia_moons.feature import *
import kornia as K
import kornia.feature as KF
import gc
import time
from typing import Dict, List, Tuple, Callable, Optional


import sys
sys.path.append("../SuperGluePretrainedNetwork")
from models.matching import Matching

# ***Model***

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#LoFTR model
matcher = KF.LoFTR(pretrained=None)
matcher.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])
matcher = matcher.to(device).eval()

#SuperGLUE model
config = {
    "superpoint": {
        "nms_radius": 4,
        "keypoint_threshold": 0.005,
        "max_keypoints": 1024
    },
    "superglue": {
        "weights": "outdoor",
        "sinkhorn_iterations": 20,
        "match_threshold": 0.2,
    }
}
matching = Matching(config).eval().to(device)

## *Utils*

In [6]:
src = '../image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples.extend([row])


def FlattenMatrix(M: np.ndarray, num_digits: Optional[int] = 8):
    '''Convenience function to write CSV files.'''
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])

def cv2_Torch(img: np.ndarray, device: torch.device)->torch.Tensor:
    img = K.image_to_tensor(img, False).float() /255.
    img = img.to(device)
    return img

def preprocessing_image(fname: str, device: torch.device) -> (Dict[str,torch.Tensor], float):
    """
     Scale, normalize and apply geometrical transformations to an image given its path
    """
    img = cv2.imread(fname)
    scale = 840 / max(img.shape[0], img.shape[1]) 
    w = int(img.shape[1] * scale)
    h = int(img.shape[0] * scale)
    img = cv2.resize(img, (w, h))
    img = cv2.cvtColor(img , cv2.COLOR_BGR2RGB)
    # Fliped images
    imgV = cv2.flip(img, 0)
    imgH = cv2.flip(img, 1)
               
    images = {'base':img,
             'vFlipped': imgV,
             'hFlipped': imgH
             }   
    images = {k: cv2_Torch(v,device) for k, v in images.items()}
    
    return images, scale
                        
def reverseMirrorPoints(points: np.ndarray, length: List[int], axis: List[int]) -> np.ndarray:
    for l,a in zip(length,axis):
        points[:,a] = l - points[:,a] - 1 
    return points

def reverseRotation(points: np.ndarray, imgShape: Tuple[int,int] ,clockWise: bool) -> np.ndarray:
    h,w = imgShape
    (cX, cY) = (w // 2, h // 2)
    M = (1 if clockWise else -1)*np.array([[0,1],[-1,0]])
    points = (points-[cX, cY]) @ M + [cY, cX]
    return points
                        
def getLoFTRMatches(
        matcher: KF.LoFTR,
        input_dict: dict, th: Optional[float] = 0.4,
        foo: Optional[Callable] = None,
        fooParams1: Optional[dict] = {},
        fooParams2: Optional[dict] = {}) -> (np.ndarray,np.ndarray):

    with torch.no_grad():
        correspondences = matcher(input_dict)
                        
    mkpts0 = correspondences['keypoints0'].cpu().numpy()
    mkpts1 = correspondences['keypoints1'].cpu().numpy()
    select = correspondences['confidence'].cpu().numpy() > th

    if foo:
        mkpts0 = foo(mkpts0[select,:],**fooParams1)
        mkpts1 = foo(mkpts1[select,:],**fooParams2)
    else:
        mkpts0 = mkpts0[select,:]
        mkpts1 = mkpts1[select,:]

    return mkpts0, mkpts1
                        
def getSuperGLUEMatches(
        matcher: Matching,
        input_dict: dict,
        foo: Optional[Callable] = None,
        fooParams1: Optional[dict] = {},
        fooParams2: Optional[dict] = {}) -> (np.ndarray,np.ndarray):
    pred = matcher(input_dict)
    pred = {k: v[0].detach().cpu().numpy() for k, v in pred.items()}
    kpts1, kpts2 = pred["keypoints0"], pred["keypoints1"]
    matches, conf = pred["matches0"], pred["matching_scores0"]

    valid = matches > -1
    mkpts0 = kpts1[valid]
    mkpts1 = kpts2[matches[valid]]

    if foo:
        mkpts0 = foo(mkpts0,**fooParams1)
        mkpts1 = foo(mkpts1,**fooParams2)

    return mkpts0, mkpts1
                        
def plotMatches(
        image_1: torch.Tensor,
        image_2: torch.Tensor,
        mkpts0: np.ndarray,
        mkpts1: np.ndarray,
        inliers: np.ndarray):
    draw_LAF_matches(
        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts0).view(1,-1, 2),
                                    torch.ones(mkpts0.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts0.shape[0]).view(1,-1, 1)),

        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts1).view(1,-1, 2),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1)),
        torch.arange(mkpts0.shape[0]).view(-1,1).repeat(1,2),
        K.tensor_to_image(image_1),
        K.tensor_to_image(image_2),
        inliers,
        draw_dict={'inlier_color': (0.2, 1, 0.2),
                   'tentative_color': None, 
                   'feature_color': (0.2, 0.5, 1), 'vertical': False})

# ***Inference***

In [7]:
F_dict = {}

for i, row in enumerate(test_samples):
    sample_id, batch_id, image_1_id, image_2_id = row
    # Load the images.
    st = time.time()
    images1, scale1 = preprocessing_image(f'{src}/test_images/{batch_id}/{image_1_id}.png', device)
    images2, scale2 = preprocessing_image(f'{src}/test_images/{batch_id}/{image_2_id}.png', device)
    
    input_dict = {key:{"image0": K.color.rgb_to_grayscale(images1[key]),
                       "image1": K.color.rgb_to_grayscale(images2[key])} for key in images1.keys()}
    # Get image shapes
    w1 = images1['base'].shape[-1]
    h1 = images1['base'].shape[-2]
    w2 = images2['base'].shape[-1]
    h2 = images2['base'].shape[-2]
    # Functions to invert geometrical transformations
    reverseFoo = {'base': {'foo': None},
                  'vFlipped': {'foo': reverseMirrorPoints,
                               'fooParams1': {'length':[h1],'axis':[1]},
                               'fooParams2': {'length':[h2],'axis':[1]}},
                  'hFlipped': {'foo': reverseMirrorPoints,
                               'fooParams1': {'length':[w1],'axis':[0]},
                               'fooParams2': {'length':[w2],'axis':[0]}},
                 }
    # Get maching pairs
    mkptsGlue = [getSuperGLUEMatches(matching, v, **reverseFoo[k]) for k,v in input_dict.items()]
    mkpts = [getLoFTRMatches(matcher, v,th=0.3, **reverseFoo[k]) for k,v in input_dict.items()]
    mkpts.extend(mkptsGlue)
    
    mkpts0 = np.empty((0,2))
    mkpts1 = np.empty((0,2))
    for points in mkpts:
        mkpts0 = np.concatenate((mkpts0, points[0]), axis=0)
        mkpts1 = np.concatenate((mkpts1, points[1]), axis=0)
    mkpts0 /= scale1
    mkpts1 /= scale2
    
    
    if len(mkpts0) > 7:  
        F, inliers = cv2.findFundamentalMat(
            mkpts0, 
            mkpts1, 
            cv2.USAC_MAGSAC,
            0.2,
            0.999999,
            180_000
        )
        plotMatches(images1['base'],
                    images2['base'],
                    mkpts0 * scale1,
                    mkpts1 * scale2,
                    inliers)
        
        assert F.shape == (3, 3), 'Malformed F?'
        F_dict[sample_id] = F
    else:
        F_dict[sample_id] = np.zeros((3, 3))
        continue
    gc.collect()
    nd = time.time()    
    if (i < 3):
        print(f"Total key points pair: {len(inliers)}, Inliers pairs:{sum(inliers[0])}")
        print("Running time: ", nd - st, " s")
        print("Fundamental Matrix:", F)

with open('submission.csv', 'w') as f:
    f.write('sample_id,fundamental_matrix\n')
    for sample_id, F in F_dict.items():
        f.write(f'{sample_id},{FlattenMatrix(F)}\n')