## Welcome Back Kagglers 

credit : https://www.kaggle.com/code/ammarali32/imc-2022-kornia-loftr-from-0-533-to-0-721

# ***Import dependencies***

In [2]:
dry_run = False
!pip install ../input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
!pip install ../input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl


In [3]:
import os
import numpy as np
import cv2
import csv
from glob import glob
import torch
import matplotlib.pyplot as plt
import kornia
from kornia_moons.feature import *
import kornia as K
import kornia.feature as KF
import gc
import time
import typing
import copy
#from pydegensac import  findFundamentalMatrix
!cp ../input/super-glue-pretrained-network/models/superglue.py /opt/conda/lib/python3.7/site-packages/kornia/feature/loftr/utils/superglue.py 




# ***Model***

In [4]:
matcher = KF.LoFTR(pretrained=None)
matcher.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
matcher = matcher.to(device).eval()

## *Utils*

In [8]:
src = '../input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]


def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])

def cv2_Torch(img: np.ndarray, device:torch.device)->torch.Tensor:
    img = K.image_to_tensor(img, False).float() /255.
    img = img.to(device)
    return K.color.rgb_to_grayscale(img)


def load_torch_image(fname:str, device:torch.device) -> (typing.Dict[str,torch.Tensor],float):
    img = cv2.imread(fname)
    scale = 840 / max(img.shape[0], img.shape[1]) 
    w = int(img.shape[1] * scale)
    h = int(img.shape[0] * scale)
    img = cv2.resize(img, (w, h))
    
    
    img = cv2.cvtColor(img , cv2.COLOR_BGR2RGB)
    imgV = cv2.flip(img, 0)
    imgH = cv2.flip(img, 1)
    imgB = cv2.flip(imgV, 1)
    images = {'base':img,
             'vFlipped': imgV,
             'hFlipped': imgH,
             'bFlipped': imgB}   
    images = {k: cv2_Torch(v,device) for k, v in images.items()}
    
    return images, scale
                        
def reverseMirrorPoints(points:np.ndarray, length:typing.List[int], axis:typing.List[int]) -> np.ndarray:
    for l,a in zip(length,axis):
        points[:,a] = l - points[:,a] - 1 
    return points
                        
def getMatches(matcher:KF.LoFTR, input_dict:dict, th:float =0.4, foo = None,fooParams1:dict={},fooParams2:dict={}) -> (np.ndarray,np.ndarray):
    with torch.no_grad():
        correspondences = matcher(input_dict)
    mkpts0 = correspondences['keypoints0'].cpu().numpy()
    mkpts1 = correspondences['keypoints1'].cpu().numpy()
    select = correspondences['confidence'].cpu().numpy() > th
    if foo:
        mkpts0 = foo(mkpts0[select,:],**fooParams1)
        mkpts1 = foo(mkpts1[select,:],**fooParams2)
    else:
        mkpts0 = mkpts0[select,:]
        mkpts1 = mkpts1[select,:]
    return mkpts0, mkpts1

# ***Inference***

In [10]:
F_dict = {}

for i, row in enumerate(test_samples):
    print(i)
    sample_id, batch_id, image_1_id, image_2_id = row
    # Load the images.
    st = time.time()
    images1, scale1 =load_torch_image(f'{src}/test_images/{batch_id}/{image_1_id}.png', device)
    images2, scale2 =load_torch_image(f'{src}/test_images/{batch_id}/{image_1_id}.png', device)
    
    input_dict = {key:{"image0": images1[key],"image1": images2[key]} for key in images1.keys()}
    w1:int = images1['base'].shape[-1]
    h1:int = images1['base'].shape[-2]
    w2:int = images2['base'].shape[-1]
    h2:int = images2['base'].shape[-2]
    
    reverseFoo = {'base': {'foo': None},
                  'vFlipped': {'foo': reverseMirrorPoints,
                               'fooParams1': {'length':[h1],'axis':[1]},
                               'fooParams2': {'length':[h2],'axis':[1]}},
                  'hFlipped': {'foo': reverseMirrorPoints,
                               'fooParams1': {'length':[w1],'axis':[0]},
                               'fooParams2': {'length':[w2],'axis':[0]}},
                  'bFlipped': {'foo': reverseMirrorPoints,
                               'fooParams1': {'length':[w1,h1],'axis':[0,1]},
                               'fooParams2': {'length':[w2,h2],'axis':[0,1]}}} 
    mkpts = [getMatches(matcher, v,**reverseFoo[k]) for k,v in input_dict.items()]
    
    mkpts0 = np.empty((0,2))
    mkpts1 = np.empty((0,2))
    for points in mkpts:
        mkpts0 = np.concatenate((mkpts0, points[0]), axis=0)
        mkpts1 = np.concatenate((mkpts1, points[1]), axis=0)
    mkpts0 /= scale1
    mkpts1 /= scale2
    
    if len(mkpts0) > 7:
        F, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.1823, 0.999999, 220000)
        #F, mask = findFundamentalMatrix(mkpts0, mkpts1, 3.0)
        assert F.shape == (3, 3), 'Malformed F?'
        F_dict[sample_id] = F
    else:
        F_dict[sample_id] = np.zeros((3, 3))
        continue
    gc.collect()
    nd = time.time()    
    if (i < 3):
        print("Running time: ", nd - st, " s")

with open('submission.csv', 'w') as f:
    f.write('sample_id,fundamental_matrix\n')
    for sample_id, F in F_dict.items():
        f.write(f'{sample_id},{FlattenMatrix(F)}\n')

dry_run = False
!pip install ../input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
!pip install ../input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl<center>
    <h2 style="color: #022047"> Thanks for reading 🤗  </h2>
</center>