merge 3 matcher(SuperGlue, LoFTR, ASLFeat) and some of original idea.

3 matcher referred from:

SuperGlue: https://www.kaggle.com/code/losveria/superglue-baseline
ASLFeat: https://www.kaggle.com/code/rsmits/tensorflow-aslfeat-inference
LoFTR: https://www.kaggle.com/code/mcwema/imc-2022-kornia-loftr-score-plateau-0-726

Thanks for these notebooks!

In [None]:
%%capture
#dry_run = False
!pip install ../input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
!pip install ../input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl

In [None]:
# Import Modules
import numpy as np 
import pandas as pd
import csv
import cv2
import gc
import tensorflow as tf
import sys
import yaml
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Import ASLFeat
sys.path.append('../input/aslfeat')
from models import get_model

# Disable Eager Execution
tf.compat.v1.disable_eager_execution()

In [None]:
src = '/kaggle/input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]

In [None]:
# Import Config
with open('../input/aslfeat/configs/matching_eval2.yaml', 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

# There are 2 checkpoints in the pretrained folder. This one should be the best...
aslfeat_model_path = '../input/aslfeat/pretrained/aslfeatv2/model.ckpt-60000' 
config['model_path'] = aslfeat_model_path
config['net']['config']['kpt_n'] = 8000 # Sames as original config ... just for convenience ;-)

# Summary config
print(config)

In [None]:
# Create Model
model = get_model('feat_model')(aslfeat_model_path, **config['net'])

In [None]:
# ASLFeat Functions    
def load_imgs(img_paths):
    rgb_list = []
    gray_list = []
    
    for img_path in img_paths:
        img = cv2.imread(img_path)
        scale = 840 / max(img.shape[0], img.shape[1]) 
        w = int(img.shape[1] * scale)
        h = int(img.shape[0] * scale)
        img = cv2.resize(img, (w, h))
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[..., np.newaxis]
        img = img[..., ::-1]
        rgb_list.append(img)
        gray_list.append(gray)
        
    return rgb_list, gray_list

def extract_local_features(gray_list):    
    descs = []
    kpts = []
    
    for gray_img in gray_list:
        desc, kpt = [], []
        desc, kpt, _ = model.run_test_data(gray_img)
        descs.append(desc)
        kpts.append(kpt)
        
    return descs, kpts

class MatcherWrapper(object):
    """OpenCV matcher wrapper."""

    def __init__(self):
        # Swapped BFMatcher to FlannBasedMatcher
        # FLANN parameters
        FLANN_INDEX_KDTREE = 0
        index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 6)
        search_params = dict(checks = 125)   # or pass empty dictionary
        self.matcher = cv2.FlannBasedMatcher(index_params, search_params)
    
    def get_matches(self, feat1, feat2, cv_kpts1, cv_kpts2, ratio = 0.8, cross_check = True, err_thld = 0.5):
        """Compute putative and inlier matches.
        Args:
            feat: (n_kpts, 128) Local features.
            cv_kpts: A list of keypoints represented as cv2.KeyPoint.
            ratio: The threshold to apply ratio test.
            cross_check: (True by default) Whether to apply cross check.
            err_thld: Epipolar error threshold.
        Returns:
            good_matches: Putative matches.
            mask: The mask to distinguish inliers/outliers on putative matches.
        """
        
        init_matches1 = self.matcher.knnMatch(feat1, feat2, k = 2)
        init_matches2 = self.matcher.knnMatch(feat2, feat1, k = 2)

        good_matches = []

        for i in range(len(init_matches1)):
            cond = True
            if cross_check:
                cond1 = cross_check and init_matches2[init_matches1[i][0].trainIdx][0].trainIdx == i
                cond *= cond1
            if ratio is not None and ratio < 1:
                cond2 = init_matches1[i][0].distance <= ratio * init_matches1[i][1].distance
                cond *= cond2
            if cond:
                good_matches.append(init_matches1[i][0])

        if type(cv_kpts1) is list and type(cv_kpts2) is list:
            good_kpts1 = np.array([cv_kpts1[m.queryIdx].pt for m in good_matches])
            good_kpts2 = np.array([cv_kpts2[m.trainIdx].pt for m in good_matches])
        elif type(cv_kpts1) is np.ndarray and type(cv_kpts2) is np.ndarray:
            good_kpts1 = np.array([cv_kpts1[m.queryIdx] for m in good_matches])
            good_kpts2 = np.array([cv_kpts2[m.trainIdx] for m in good_matches])
        else:
            good_kpts1 = np.empty(0)
            good_kpts2 = np.empty(0)
            
        # Calculate Fundamental Mask and inliers
        F, mask = get_fundamental_matrix(good_kpts1, good_kpts2, err_thld)
        return F, (good_kpts1, good_kpts2), mask
            
    def draw_matches(self, img1, cv_kpts1, img2, cv_kpts2, good_matches, mask, match_color = (0, 255, 0), pt_color = (0, 0, 255)):
        """Draw matches."""
        if type(cv_kpts1) is np.ndarray and type(cv_kpts2) is np.ndarray:
            cv_kpts1 = [cv2.KeyPoint(cv_kpts1[i][0], cv_kpts1[i][1], 1) for i in range(cv_kpts1.shape[0])]
            cv_kpts2 = [cv2.KeyPoint(cv_kpts2[i][0], cv_kpts2[i][1], 1) for i in range(cv_kpts2.shape[0])]
            
        display = cv2.drawMatches(img1, cv_kpts1, img2, cv_kpts2, good_matches,
                                  None,
                                  matchColor = match_color,
                                  singlePointColor = pt_color,
                                  matchesMask = mask.ravel().tolist(), flags=4)
        return display

In [None]:
def FlattenMatrix(M, num_digits = 8):
    '''Convenience function to write CSV files.'''    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])

def get_fundamental_matrix(kpts1, kpts2, err_thld):    
    if len(kpts1) > 7:
        F, inliers = cv2.findFundamentalMat(kpts1, 
                                            kpts2, 
                                            cv2.USAC_MAGSAC, 
                                            ransacReprojThreshold = err_thld, 
                                            confidence = 0.99999, 
                                            maxIters = 100000) # Lower maxIters to increase speed / lower accuracy
        return F, inliers
    else:
        return np.random.rand(3, 3), None

In [None]:
def get_aslfeat_fmatrix(batch_id, img_id1, img_id2, plot = False):
    image_fpath_1 = f'{src}/test_images/{batch_id}/{img_id1}.png'
    image_fpath_2 = f'{src}/test_images/{batch_id}/{img_id2}.png'
    
    # Load Test Image Pair
    rgb_list, gray_list = load_imgs([image_fpath_1, image_fpath_2])    

    # Extract Local Features
    descs, kpts = extract_local_features(gray_list)
        
    # feature matching and draw matches.
    matcher = MatcherWrapper()
    fundamental_matrix, match, mask = matcher.get_matches(descs[0], descs[1], kpts[0], kpts[1],
                                                          ratio = None, 
                                                          cross_check = True, # I'am only using the Cross Check...not the ratio test.
                                                          err_thld = 0.2)

    return match

In [None]:
match_dict_asl = {}
for i, row in tqdm(enumerate(test_samples)):
    sample_id, batch_id, img_id1, img_id2 = row

    # Set Plot
    plot = False
        
    # Get Fundamental matrix with ASLFeat And FLANNBasedMatcher
    match_dict_asl[sample_id] = get_aslfeat_fmatrix(batch_id, img_id1, img_id2)
        
    # Mem Cleanup
    gc.collect()

In [None]:
del get_model
sys.modules.pop('models')
sys.path.remove('../input/aslfeat')

In [None]:
import os
import csv
import random
from glob import glob
from tqdm import tqdm
from collections import namedtuple

import cv2
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch

import sys
sys.path.append("../input/super-glue-pretrained-network")
from models.matching import Matching
from models.utils import (compute_pose_error, compute_epipolar_error,
                          estimate_pose, make_matching_plot,
                          error_colormap, AverageTimer, pose_auc, read_image,
                          rotate_intrinsics, rotate_pose_inplane,
                          scale_intrinsics)

In [None]:
import kornia
from kornia_moons.feature import *
import kornia as K
import kornia.feature as KF
import gc

apply_eq = False

def CLAHE_Convert(origin_input):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    t = np.asarray(origin_input)
    t = cv2.cvtColor(t, cv2.COLOR_BGR2HSV)
    t[:,:,-1] = clahe.apply(t[:,:,-1])
    t = cv2.cvtColor(t, cv2.COLOR_HSV2BGR)
    # t = Img.fromarray(t)
    return t

def CLAHE_Convert2(origin_input):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    img = origin_input.astype(np.uint8)
    img = clahe.apply(img)
    return img

In [None]:
device = torch.device('cuda')
loftr_matcher = KF.LoFTR(pretrained=None)
loftr_matcher.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])
loftr_matcher = loftr_matcher.to(device).eval()

In [None]:
def load_torch_image(fname, device):
    img = cv2.imread(fname)

    scale = 840 / max(img.shape[0], img.shape[1]) 
    w = int(img.shape[1] * scale)
    h = int(img.shape[0] * scale)
    img = cv2.resize(img, (w, h))
    
    if apply_eq:
        img = CLAHE_Convert(img)
    
    img = K.image_to_tensor(img, False).float() /255.
    img = K.color.bgr_to_rgb(img)
    return img.to(device)

In [None]:
src = '/kaggle/input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]
test_samples_df = pd.DataFrame(test_samples, columns=["sample_id", "batch_id", "image_1_id", "image_2_id"])
test_samples_df

In [None]:
resize = [840, ]
resize_float = True

config = {
    "superpoint": {
        "nms_radius": 4,
        "keypoint_threshold": 0.005,
        "max_keypoints": 1024
    },
    "superglue": {
        "weights": "outdoor",
        "sinkhorn_iterations": 20,
        "match_threshold": 0.2,
    }
}
sg_matcher = Matching(config).eval().to(device)

In [None]:
match_dict_sg = {}
match_dict_loftr = {}

for i, row in tqdm(enumerate(test_samples)):
    sample_id, batch_id, image_1_id, image_2_id = row
    
    # SG match
    image_fpath_1 = f'{src}/test_images/{batch_id}/{image_1_id}.png'
    image_fpath_2 = f'{src}/test_images/{batch_id}/{image_2_id}.png'
    
    image_1, inp_1, scales_1 = read_image(image_fpath_1, device, resize, 0, resize_float)
    image_2, inp_2, scales_2 = read_image(image_fpath_2, device, resize, 0, resize_float)
    
    if apply_eq:
        image_1 = CLAHE_Convert2(image_1)
        image_2 = CLAHE_Convert2(image_2)
    
    sg_pred = sg_matcher({"image0": inp_1, "image1": inp_2})
    sg_pred = {k: v[0].detach().cpu().numpy() for k, v in sg_pred.items()}
    sg_kpts1, sg_kpts2 = sg_pred["keypoints0"], sg_pred["keypoints1"]
    sg_matches, conf = sg_pred["matches0"], sg_pred["matching_scores0"]

    valid = sg_matches > -1
    sg_mkpts1 = sg_kpts1[valid]
    sg_mkpts2 = sg_kpts2[sg_matches[valid]]
    mconf = conf[valid]
    
    match_dict_sg[sample_id] = (sg_mkpts1, sg_mkpts2)
    
    #LoFTR match
    image_1 = load_torch_image(f'{src}/test_images/{batch_id}/{image_1_id}.png', device)
    image_2 = load_torch_image(f'{src}/test_images/{batch_id}/{image_2_id}.png', device)
    input_dict = {"image0": K.color.rgb_to_grayscale(image_1), 
              "image1": K.color.rgb_to_grayscale(image_2)}

    with torch.no_grad():
        correspondences = loftr_matcher(input_dict)
        
    loftr_mkpts1 = correspondences['keypoints0'].cpu().numpy()
    loftr_mkpts2 = correspondences['keypoints1'].cpu().numpy()
    
    match_dict_loftr[sample_id] = (loftr_mkpts1, loftr_mkpts2)

    



In [None]:
F_dict = {}
for i, row in tqdm(enumerate(test_samples)):
    sample_id, batch_id, image_1_id, image_2_id = row

    mkpts1 = np.concatenate([match_dict_sg[sample_id][0], match_dict_loftr[sample_id][0], match_dict_asl[sample_id][0]])
    mkpts2 = np.concatenate([match_dict_sg[sample_id][1], match_dict_loftr[sample_id][1], match_dict_asl[sample_id][1]])
    
    print(len(mkpts1), len(mkpts2))

    if len(mkpts1) > 8:
#         F, inlier_mask = cv2.findFundamentalMat(mkpts1, mkpts2, cv2.USAC_MAGSAC, ransacReprojThreshold=0.25, confidence=0.99999, maxIters=10000)
        F, inlier_mask = cv2.findFundamentalMat(mkpts1, mkpts2, cv2.USAC_MAGSAC, 0.200, 0.9999, 250000)
        F_dict[sample_id] = F
    else:
        F_dict[sample_id] = np.zeros((3, 3))
        
    gc.collect()
        
    if (i < 3):
#         print("Running time: ", nd - st, " s")
        print(loftr_mkpts1.shape)
        print(mkpts1.shape)
        print(sg_mkpts1.shape)

        draw_LAF_matches(
        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts1).view(1,-1, 2),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1)),

        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts2).view(1,-1, 2),
                                    torch.ones(mkpts2.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts2.shape[0]).view(1,-1, 1)),
        torch.arange(mkpts1.shape[0]).view(-1,1).repeat(1,2),
        K.tensor_to_image(image_1),
        K.tensor_to_image(image_2),
        inlier_mask,
        draw_dict={'inlier_color': (0.2, 1, 0.2),
                   'tentative_color': None, 
                   'feature_color': (0.2, 0.5, 1), 'vertical': False})

In [None]:
def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])

with open('submission.csv', 'w') as f:
    f.write('sample_id,fundamental_matrix\n')
    for sample_id, F in F_dict.items():
        f.write(f'{sample_id},{FlattenMatrix(F)}\n')