# Install and import packages

In [1]:
# %%capture captured
%cd /kaggle/working

# General utilities
import os, sys
from tqdm import tqdm
from time import time
from fastprogress import progress_bar
import gc
import numpy as np
import h5py
from IPython.display import clear_output
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, asdict, field
from typing import Optional, Tuple, List
import traceback
import subprocess
from pprint import pprint
import shutil
from pathlib import Path
from pprint import pprint
from functools import lru_cache
from matplotlib import pyplot as plt
from glob import glob


# CV/ML
import cv2
import torch
import torch.nn.functional as F
import kornia as K
import kornia.feature as KF
from PIL import Image
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torch import multiprocessing as mp


def internet_available(host="8.8.8.8", port=53, timeout=1):
    # https://stackoverflow.com/a/33117579
    import socket
    try:
        socket.setdefaulttimeout(timeout)
        socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect((host, port))
        return True
    except socket.error as ex:
        return False

    
def run_sh(command):
    process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE)
    out, err = process.communicate()
    print(out.decode('utf-8').strip())
    
    
def install_from_input_directory(dirname):
    src_dir = f'/kaggle/input/{dirname}/'
    install_dir = os.path.realpath(f'./{dirname}/')
    assert os.path.isdir(src_dir)
    if not os.path.isdir(install_dir):
        shutil.copytree(src_dir, install_dir)

    if dirname.endswith('silk'):
        if os.path.isfile(f'{install_dir}/silk'):
            os.remove(f'{install_dir}/silk')
        if not os.path.isdir(f'{install_dir}/silk'):
            shutil.copytree(f'{install_dir}/lib', f'{install_dir}/silk')
            
    os.system(f'pip install -e {install_dir}')
    if install_dir not in sys.path:
        sys.path.append(install_dir)
    
    
def install_hloc_local_features():
    weights = {
        'openibl': {
            'src': '/kaggle/input/hloc-local-feature-weights/vgg16_netvlad.pth',
            'dst': '/root/.cache/torch/hub/checkpoints/vgg16_netvlad.pth',
        },
        'netvlad': {
            'src': '/kaggle/input/hloc-local-feature-weights/Pitts30K_struct.mat',
            'dst': '/root/.cache/torch/hub/netvlad/VGG16-NetVLAD-Pitts30K.mat'
        },
        'openibl_src': {
            'src': '/kaggle/input/openibl-zipball/yxgeee-OpenIBL-f3ef4fb',
            'dst': '/root/.cache/torch/hub/yxgeee_OpenIBL_master'
        }
    }
    for alg in weights.keys():
        print('installing', alg)
        w = weights[alg]
        if os.path.isfile(w['src']):
            os.makedirs(os.path.dirname(w['dst']), exist_ok=True)
            shutil.copyfile(w['src'], w['dst'])
        elif os.path.isdir(w['src']):
            if os.path.isdir(w['dst']):
                shutil.rmtree(w['dst'])
            shutil.copytree(w['src'], w['dst'])
            
def install_pycolmap():
    colmap_input_dir = '/kaggle/input/colmap'
    colmap_install_dir = '/kaggle/working/colmap'

    if not os.path.isdir(colmap_install_dir):
        shutil.copytree(colmap_input_dir, colmap_install_dir)
    os.system(f'''cd {colmap_install_dir}/build && \
        cmake .. -GNinja -DCMAKE_CUDA_ARCHITECTURES=native && \
        ninja && \
        ninja install
        ''')

    pycolmap_input_dir = '/kaggle/input/pycolmap'
    pycolmap_install_dir = '/kaggle/working/pycolmap'
    if not os.path.isdir(pycolmap_install_dir):
        shutil.copytree(pycolmap_input_dir, pycolmap_install_dir)
    os.system(f'pip install -e {pycolmap_install_dir}')
    if pycolmap_install_dir not in sys.path:
        sys.path.append(pycolmap_install_dir)

            
# install custom colmap build
install_pycolmap()
import pycolmap

            
# install dependencies for SiLK
!pip install --no-deps \
    /kaggle/input/omegaconf222py3/antlr4_python3_runtime-4.9.3-py3-none-any.whl \
    /kaggle/input/omegaconf222py3/omegaconf-2.2.2-py3-none-any.whl \
    /kaggle/input/hydracore120py3/hydra_core-1.2.0-py3-none-any.whl \
    /kaggle/input/loguru006py3/loguru-0.6.0-py3-none-any.whl


install_from_input_directory('facebookresearch-silk')

from silk.backbones.superpoint.vgg import ParametricVGG
from silk.backbones.silk.silk import SiLKVGG as SiLK
from silk.backbones.silk.silk import from_feature_coords_to_image_coords
from silk.config.model import load_model_from_checkpoint
from silk.models.silk import matcher


# install hloc and weights
install_from_input_directory('hierarchical-localization-my')
install_hloc_local_features()

from hloc.visualization import plot_images, read_image
from hloc.utils import viz_3d
from hloc import extract_features as hloc_extract_features
from hloc import pairs_from_retrieval as hloc_pairs_from_retrieval

/kaggle/working
-- Found installed version of Eigen: /usr/lib/cmake/eigen3
-- Found required Ceres dependency: Eigen version 3.3.7 in /usr/include/eigen3
-- Found required Ceres dependency: glog
-- Found installed version of gflags: /usr/lib/x86_64-linux-gnu/cmake/gflags
-- Detected gflags version: 2.2.2
-- Found required Ceres dependency: gflags
-- Found Ceres version: 1.14.0 installed in: /usr with components: [EigenSparse, SparseLinearAlgebraLibrary, LAPACK, SuiteSparse, CXSparse, SchurSpecializations, OpenMP, Multithreading]
-- Found Boost: /usr/lib/x86_64-linux-gnu/cmake/Boost-1.71.0/BoostConfig.cmake (found version "1.71.0") found components: program_options filesystem graph system unit_test_framework 
-- Found Eigen
--   Includes : /usr/include/eigen3
-- Found FreeImage
--   Includes : /usr/include
--   Libraries : /usr/lib/x86_64-linux-gnu/libfreeimage.so
-- Found FLANN
--   Includes : /usr/include
--   Libraries : /usr/lib/x86_64-linux-gnu/libflann.so
-- Found LZ4
--   Include



Obtaining file:///kaggle/working/pycolmap
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Installing collected packages: pycolmap
  Attempting uninstall: pycolmap
    Found existing installation: pycolmap 0.3.0
    Uninstalling pycolmap-0.3.0:
      Successfully uninstalled pycolmap-0.3.0
  Running setup.py develop for pycolmap
Successfully installed pycolmap-0.4.0

Processing /kaggle/input/omegaconf222py3/antlr4_python3_runtime-4.9.3-py3-none-any.whl
Processing /kaggle/input/omegaconf222py3/omegaconf-2.2.2-py3-none-any.whl
Processing /kaggle/input/hydracore120py3/hydra_core-1.2.0-py3-none-any.whl
Processing /kaggle/input/loguru006py3/loguru-0.6.0-py3-none-any.whl
Installing collected packages: hydra-core, antlr4-python3-runtime, omegaconf, loguru
Successfully installed antlr4-python3-runtime-4.9.3 hydra-core-1.2.0 loguru-0.6.0 omegaconf-2.2.2
[0m

ERROR: Could not find a version that satisfies the requirement torch==1.11.0+cu113 (from silk-keypoint-library) (from versions: 1.0.0, 1.0.1, 1.0.1.post2, 1.1.0, 1.2.0, 1.3.0, 1.3.1, 1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1)
ERROR: No matching distribution found for torch==1.11.0+cu113


Obtaining file:///kaggle/working/facebookresearch-silk
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting autoflake>=1.4
  Downloading autoflake-2.1.1-py3-none-any.whl (31 kB)
Collecting black==21.10b0
  Downloading black-21.10b0-py3-none-any.whl (150 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 150.1/150.1 kB 19.7 MB/s eta 0:00:00
Collecting pdoc3>=0.10
  Downloading pdoc3-0.10.0-py3-none-any.whl (135 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 135.7/135.7 kB 29.0 MB/s eta 0:00:00
Collecting omegaconf==2.2.3
  Downloading omegaconf-2.2.3-py3-none-any.whl (79 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 79.3/79.3 kB 16.3 MB/s eta 0:00:00





Obtaining file:///kaggle/working/hierarchical-localization-my
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting gdown
  Downloading gdown-4.7.1-py3-none-any.whl (15 kB)
Installing collected packages: gdown, hloc
  Running setup.py develop for hloc
Successfully installed gdown-4.7.1 hloc-1.3

installing openibl
installing netvlad
installing openibl_src


In [2]:
@dataclass
class Option:
    test_or_train: str = 'train' # Can be train, or test
    device: object = torch.device('cuda')
    src_dirname: str = '/kaggle/input/image-matching-challenge-2023'
    input_filepath: Optional[str] = None
    output_filepath: str = 'submission.csv'
    use_all_imgs: bool = True
    debug: bool = True
    scenes_to_debug: Tuple[str] = None
    
    local_feature: str = 'LoFTR' # ['LoFTR, 'DISK', "SiLK", "DISKLoFTR"]
    matching_alg: str = 'smnn' # ['smnn', 'adalam']
    retrieval_alg: str = 'openibl' # ['exhaustive', netvlad', 'openibl', 'cosplace']
    geometric_verification_alg: str = 'magsac' # ['colmap', 'magsac']
        
    retrieval_per_img: int = 10
    exhaustive_if_less: int = 10
        
    def __init__(self):
        #self.scenes_to_debug = ('kyiv-puppet-theater', )
        #self.scenes_to_debug = ('dioscuri', )
        self.scenes_to_debug = ('bike', )

        if self.test_or_train == 'train':
            self.input_filepath = f'{self.src_dirname}/train/train_labels.csv'
        else:
            self.input_filepath = f'{self.src_dirname}/sample_submission.csv'
            self.scenes_to_debug = None
            self.debug = False
        

run_sh('cat /proc/cpuinfo | egrep -m 1 "^model name"')
run_sh('cat /proc/cpuinfo | egrep -m 1 "^cpu MHz"')
run_sh('cat /proc/cpuinfo | egrep -m 1 "^cpu cores"')
run_sh('cat /proc/meminfo | egrep "^MemTotal"')
    
print('Kornia version:', K.__version__)
print('Pycolmap version:', pycolmap.__version__)
print('Internet access:', internet_available())

opt = Option()
pprint(asdict(opt))

model name	: Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz
cpu MHz		: 1200.099
cpu cores	: 18
MemTotal:       263818860 kB
Kornia version: 0.6.11
Pycolmap version: 0.4.0
Internet access: True
{'debug': True,
 'device': device(type='cuda'),
 'exhaustive_if_less': 10,
 'geometric_verification_alg': 'magsac',
 'input_filepath': '/kaggle/input/image-matching-challenge-2023/train/train_labels.csv',
 'local_feature': 'LoFTR',
 'matching_alg': 'smnn',
 'output_filepath': 'submission.csv',
 'retrieval_alg': 'openibl',
 'retrieval_per_img': 10,
 'scenes_to_debug': ('bike',),
 'src_dirname': '/kaggle/input/image-matching-challenge-2023',
 'test_or_train': 'train',
 'use_all_imgs': True}


In [3]:
def arr_to_str(a):
    return ';'.join([str(x) for x in a.reshape(-1)])

def load_torch_image(fname, device=torch.device('cpu')):
    img = K.image_to_tensor(cv2.imread(fname), False).float() / 255.
    img = K.color.bgr_to_rgb(img.to(device))
    return img

def train_only(fn):
    def wrapped(*args, **kwargs):
        if opt.test_or_train == 'train':
            return fn(*args, **kwargs)
    return wrapped
def preprocess(img, max_long_side=640, to_gray=False, padding_required=False):
    B, C, H, W = img.shape
    long_side = max(H, W)
    scale = 1.0
    if max_long_side > 0 and long_side > max_long_side:
        img = K.geometry.resize(img, max_long_side, side='long', antialias=True)
        scale = max_long_side / long_side
        long_side = max_long_side
        *_, H, W = img.shape
    if padding_required:
        padded = torch.zeros((B, C, long_side, long_side), dtype=img.dtype, device=img.device)
        padded[..., :H, :W] = img
        img = padded
    if to_gray:
        img = K.color.rgb_to_grayscale(img)
    return img, scale

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, img_fnames, long_side=640, padding=False, to_gray=False, device=torch.device('cpu')):
        self.img_fnames = np.array(img_fnames)
        self.long_side = long_side
        self.padding = padding
        self.to_gray = to_gray
        self.device = device
        
    def __len__(self):
        return len(self.img_fnames)
    
    def __getitem__(self, idx):
        img = load_torch_image(self.img_fnames[idx], device=self.device)
        img, scale = preprocess(img, self.long_side, self.to_gray, self.padding)
        img = img.squeeze(0)
        return img, scale
    
class ImagePairDataset(torch.utils.data.Dataset):
    def __init__(self, img_ds, index_pairs):
        self.img_ds = img_ds
        self.index_pairs = np.array(index_pairs)
        
    def __len__(self):
        return len(self.index_pairs)
    
    def __getitem__(self, idx):
        idx1, idx2 = self.index_pairs[idx].tolist()
        img1, scale1 = self.img_ds[idx1]
        img2, scale2 = self.img_ds[idx2]
        return idx1, img1, scale1, idx2, img2, scale2
        
def get_image_pair_loader(img_fnames, index_pairs, long_side=640, padding=False, to_gray=False, device=torch.device('cpu')):
    image_ds = ImageDataset(img_fnames, long_side, padding, to_gray)
    pairs_ds = ImagePairDataset(image_ds, index_pairs)
    return torch.utils.data.DataLoader(pairs_ds, num_workers=os.cpu_count())

In [4]:
# We will use ViT global descriptor to get matching shortlists.
def get_global_desc(fnames, model,
                    device =  torch.device('cpu')):
    model = model.eval()
    model= model.to(device)
    config = resolve_data_config({}, model=model)
    transform = create_transform(**config)
    global_descs_convnext=[]
    for i, img_fname_full in tqdm(enumerate(fnames),total= len(fnames)):
        key = os.path.splitext(os.path.basename(img_fname_full))[0]
        img = Image.open(img_fname_full).convert('RGB')
        timg = transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            desc = model.forward_features(timg.to(device)).mean(dim=(-1,2))#
            #print (desc.shape)
            desc = desc.view(1, -1)
            desc_norm = F.normalize(desc, dim=1, p=2)
        #print (desc_norm)
        global_descs_convnext.append(desc_norm.detach().cpu())
    global_descs_all = torch.cat(global_descs_convnext, dim=0)
    return global_descs_all


def get_img_pairs_exhaustive(img_fnames):
    index_pairs = []
    for i in range(len(img_fnames)):
        for j in range(i+1, len(img_fnames)):
            index_pairs.append((i,j))
    return index_pairs

def hloc_image_pairs_shortlist(img_dir, fnames, alg='netvlad', match_per_img=20, exhaustive_if_less=20):
    if len(fnames) <= 20 or alg == 'exhaustive':
        return get_img_pairs_exhaustive(fnames)

    tmp_filepath = os.path.basename(img_dir)
    retrieval_conf = hloc_extract_features.confs[alg]
    input_path = Path(img_dir)
    tmp_path = Path(f'./{img_dir}_{alg}')
    if os.path.isdir(tmp_path):
        shutil.rmtree(tmp_path)
    global_descriptors = hloc_extract_features.main(retrieval_conf, input_path, tmp_path)
    loc_pairs = tmp_path / "loc_pairs.txt"
    hloc_pairs_from_retrieval.main(global_descriptors, loc_pairs, num_matched=match_per_img)
    img_index = {os.path.basename(f): i for i, f in enumerate(fnames)}
    pairs = []
    with open(loc_pairs, 'r') as f:
        for line in f.readlines():
            img1, img2 = line.strip().split()
            if img1 not in img_index or img2 not in img_index:
                continue
            idx1, idx2 = img_index[img1], img_index[img2]
            pairs.append(tuple(sorted([idx1, idx2])))
    return sorted(set(pairs))

def get_image_pairs_shortlist(fnames,
                              sim_th = 0.6, # should be strict
                              min_pairs = 20,
                              exhaustive_if_less = 20,
                              device=torch.device('cpu')):
    num_imgs = len(fnames)

    if num_imgs <= exhaustive_if_less:
        return get_img_pairs_exhaustive(fnames)

    model = timm.create_model('tf_efficientnet_b7',
                              checkpoint_path='/kaggle/input/tf-efficientnet/pytorch/tf-efficientnet-b7/1/tf_efficientnet_b7_ra-6c08e654.pth')
    model.eval()
    descs = get_global_desc(fnames, model, device=device)
    dm = torch.cdist(descs, descs, p=2).detach().cpu().numpy()
    # removing half
    mask = dm <= sim_th
    total = 0
    matching_list = []
    ar = np.arange(num_imgs)
    already_there_set = []
    for st_idx in range(num_imgs-1):
        mask_idx = mask[st_idx]
        to_match = ar[mask_idx]
        if len(to_match) < min_pairs:
            to_match = np.argsort(dm[st_idx])[:min_pairs]  
        for idx in to_match:
            if st_idx == idx:
                continue
            if dm[st_idx, idx] < 1000:
                matching_list.append(tuple(sorted((st_idx, idx.item()))))
                total+=1
    matching_list = sorted(list(set(matching_list)))
    return matching_list

In [5]:
# Code to manipulate a colmap database.
# Forked from https://github.com/colmap/colmap/blob/dev/scripts/python/database.py

# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#     * Redistributions of source code must retain the above copyright
#       notice, this list of conditions and the following disclaimer.
#
#     * Redistributions in binary form must reproduce the above copyright
#       notice, this list of conditions and the following disclaimer in the
#       documentation and/or other materials provided with the distribution.
#
#     * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
#       its contributors may be used to endorse or promote products derived
#       from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)

# This script is based on an original implementation by True Price.

import sys
import sqlite3
import numpy as np


IS_PYTHON3 = sys.version_info[0] >= 3

MAX_IMAGE_ID = 2**31 - 1

CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
    camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
    model INTEGER NOT NULL,
    width INTEGER NOT NULL,
    height INTEGER NOT NULL,
    params BLOB,
    prior_focal_length INTEGER NOT NULL)"""

CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
    image_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""

CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
    image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
    name TEXT NOT NULL UNIQUE,
    camera_id INTEGER NOT NULL,
    prior_qw REAL,
    prior_qx REAL,
    prior_qy REAL,
    prior_qz REAL,
    prior_tx REAL,
    prior_ty REAL,
    prior_tz REAL,
    CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
    FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
""".format(MAX_IMAGE_ID)

CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
CREATE TABLE IF NOT EXISTS two_view_geometries (
    pair_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    config INTEGER NOT NULL,
    F BLOB,
    E BLOB,
    H BLOB)
"""

CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
    image_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
"""

CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
    pair_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB)"""

CREATE_NAME_INDEX = \
    "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"

CREATE_ALL = "; ".join([
    CREATE_CAMERAS_TABLE,
    CREATE_IMAGES_TABLE,
    CREATE_KEYPOINTS_TABLE,
    CREATE_DESCRIPTORS_TABLE,
    CREATE_MATCHES_TABLE,
    CREATE_TWO_VIEW_GEOMETRIES_TABLE,
    CREATE_NAME_INDEX
])


def image_ids_to_pair_id(image_id1, image_id2):
    if image_id1 > image_id2:
        image_id1, image_id2 = image_id2, image_id1
    return image_id1 * MAX_IMAGE_ID + image_id2


def pair_id_to_image_ids(pair_id):
    image_id2 = pair_id % MAX_IMAGE_ID
    image_id1 = (pair_id - image_id2) // MAX_IMAGE_ID
    return image_id1, image_id2


def array_to_blob(array):
    if IS_PYTHON3:
        return array.tobytes()
    else:
        return np.getbuffer(array)


def blob_to_array(blob, dtype, shape=(-1,)):
    if IS_PYTHON3:
        return np.frombuffer(blob, dtype=dtype).reshape(*shape)
    else:
        return np.frombuffer(blob, dtype=dtype).reshape(*shape)


class COLMAPDatabase(sqlite3.Connection):

    @staticmethod
    def connect(database_path):
        return sqlite3.connect(database_path, factory=COLMAPDatabase)


    def __init__(self, *args, **kwargs):
        super(COLMAPDatabase, self).__init__(*args, **kwargs)

        self.create_tables = lambda: self.executescript(CREATE_ALL)
        self.create_cameras_table = \
            lambda: self.executescript(CREATE_CAMERAS_TABLE)
        self.create_descriptors_table = \
            lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
        self.create_images_table = \
            lambda: self.executescript(CREATE_IMAGES_TABLE)
        self.create_two_view_geometries_table = \
            lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE)
        self.create_keypoints_table = \
            lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
        self.create_matches_table = \
            lambda: self.executescript(CREATE_MATCHES_TABLE)
        self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)

    def add_camera(self, model, width, height, params,
                   prior_focal_length=False, camera_id=None):
        params = np.asarray(params, np.float64)
        cursor = self.execute(
            "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
            (camera_id, model, width, height, array_to_blob(params),
             prior_focal_length))
        return cursor.lastrowid

    def add_image(self, name, camera_id,
                  prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None):
        cursor = self.execute(
            "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
            (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
             prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
        return cursor.lastrowid
    
    def get_image(self, image_id):
        cursor = self.execute(f"SELECT * FROM images WHERE image_id = {image_id}")
        return cursor.fetchone()                         

    def add_keypoints(self, image_id, keypoints):
        assert(len(keypoints.shape) == 2)
        assert(keypoints.shape[1] in [2, 4, 6])

        keypoints = np.asarray(keypoints, np.float32)
        self.execute(
            "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
            (image_id,) + keypoints.shape + (array_to_blob(keypoints),))
    
    def get_keypoints(self, image_id):
        cursor = self.execute(f'SELECT * FROM keypoints WHERE image_id = {image_id}')
        _, rows, cols, blob = cursor.fetchone()
        keypoints = blob_to_array(blob, np.float32, (rows, cols))
        return keypoints

    def add_descriptors(self, image_id, descriptors):
        descriptors = np.ascontiguousarray(descriptors, np.uint8)
        self.execute(
            "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
            (image_id,) + descriptors.shape + (array_to_blob(descriptors),))

    def add_matches(self, image_id1, image_id2, matches):
        assert(len(matches.shape) == 2)
        assert(matches.shape[1] == 2)

        if image_id1 > image_id2:
            matches = matches[:,::-1]

        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        matches = np.asarray(matches, np.uint32)
        self.execute(
            "INSERT INTO matches VALUES (?, ?, ?, ?)",
            (pair_id,) + matches.shape + (array_to_blob(matches),))
        
    def get_all_matches(self):
        for row in self.execute('SELECT * FROM matches'):
            pair_id, rows, cols, matches_blob = row
            if matches_blob is None:
                continue
            idx1, idx2 = pair_id_to_image_ids(pair_id)
            matches = blob_to_array(matches_blob, np.uint32, (rows, cols))
            yield idx1, idx2, matches
    
    def get_matches(self, image_id1, image_id2):
        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        cursor = self.execute(f'SELECT * FROM matches WHERE pair_id = {pair_id}')
        _, rows, cols, blob = cursor.fetchone()
        matches = blob_to_array(blob, np.uint32, (rows, cols))
        return matches
            
    def add_two_view_geometry(self, image_id1, image_id2, matches,
                              F=np.zeros((3, 3)), E=np.zeros((3, 3)), H=np.zeros((3, 3)), config=2):
        assert(len(matches.shape) == 2)
        assert(matches.shape[1] == 2)

        if image_id1 > image_id2:
            matches = matches[:,::-1]

        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        matches = np.asarray(matches, np.uint32)
        F = np.asarray(F, dtype=np.float64)
        E = np.asarray(E, dtype=np.float64)
        H = np.asarray(H, dtype=np.float64)
        self.execute(
            "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
            (pair_id,) + matches.shape + (array_to_blob(matches), config,
             array_to_blob(F), array_to_blob(E), array_to_blob(H)))
    
    def get_two_view_geometry(self, image_id1, image_id2):
        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        cursor = self.execute(f'SELECT * FROM two_view_geometries WHERE pair_id = {pair_id}')
        try:
            (_, matches_rows, matches_cols, matches_blob, config,
              F_blob, E_blob, H_blob, qvec_blob, tvec_blob) = cursor.fetchone()
        except:
            raise KeyError
        
        matches = blob_to_array(matches_blob, np.uint32, (matches_rows, matches_cols))
        F = blob_to_array(F_blob, np.float64, (3, 3))
        return matches, F, config

def draw_matches(img1, kp1, img2, kp2, matches, mask=None):
    match_img = cv2.drawMatchesKnn(
        img1, cv2.KeyPoint.convert(kp1), 
        img2, cv2.KeyPoint.convert(kp2), 
        [(cv2.DMatch(qry, obj, 0, 0), cv2.DMatch(qry, obj, 0, 0)) 
         for qry, obj in matches],
        None,
        singlePointColor=(255, 0, 0),
        matchColor=(0, 255, 0),
        matchesMask=[[i, 0] for i in mask.reshape(-1)] if mask is not None else None,
        flags=0)
    plt.figure(figsize=[12, 12])
    plt.imshow(match_img)
    plt.show()
    plt.close()
        

def geometric_verification_magsac(database_path, min_inliers=15):
    db = COLMAPDatabase.connect(database_path)
    for idx1, idx2, matches in db.get_all_matches():
        kp1 = db.get_keypoints(idx1)[matches[:, 0]]
        kp2 = db.get_keypoints(idx2)[matches[:, 1]]
        _, img1, *_ = db.get_image(idx1)
        _, img2, *_ = db.get_image(idx2)

        F, mask = cv2.findFundamentalMat(kp1, kp2, cv2.USAC_MAGSAC, 2.0, 0.99999, 1000)
        inlier_matches = np.stack([matches[i] 
                                   for i, v in enumerate(mask.flatten()) if v == 1])
        n_inliers = mask.sum()
        print(f'vefiried ({idx1}-{idx2}) {n_inliers} / {len(mask)}')
        if n_inliers < min_inliers:
            continue
            # db.add_two_view_geometry(idx1, idx2, matches=inlier_matches, F=F, config=1)
        else:
            db.add_two_view_geometry(idx1, idx2, matches=inlier_matches, F=F, config=3)
    db.commit()
    db.close()
    del db
        

In [6]:
# Code to interface DISK with Colmap.
# Forked from https://github.com/cvlab-epfl/disk/blob/37f1f7e971cea3055bb5ccfc4cf28bfd643fa339/colmap/h5_to_db.py

#  Copyright [2020] [Michał Tyszkiewicz, Pascal Fua, Eduard Trulls]
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

import os, argparse, h5py, warnings
import numpy as np
from tqdm import tqdm
from PIL import Image, ExifTags


def get_focal(image_path, err_on_default=False):
    image         = Image.open(image_path)
    max_size      = max(image.size)

    exif = image.getexif()
    focal = None
    if exif is not None:
        focal_35mm = None
        # https://github.com/colmap/colmap/blob/d3a29e203ab69e91eda938d6e56e1c7339d62a99/src/util/bitmap.cc#L299
        for tag, value in exif.items():
            focal_35mm = None
            if ExifTags.TAGS.get(tag, None) == 'FocalLengthIn35mmFilm':
                focal_35mm = float(value)
                break

        if focal_35mm is not None:
            focal = focal_35mm / 35. * max_size
    
    if focal is None:
        if err_on_default:
            raise RuntimeError("Failed to find focal length")

        # failed to find it in exif, use prior
        FOCAL_PRIOR = 1.2
        focal = FOCAL_PRIOR * max_size

    return focal

def create_camera(db, image_path, camera_model):
    image         = Image.open(image_path)
    width, height = image.size

    focal = get_focal(image_path)

    if camera_model == 'simple-pinhole':
        model = 0 # simple pinhole
        param_arr = np.array([focal, width / 2, height / 2])
    if camera_model == 'pinhole':
        model = 1 # pinhole
        param_arr = np.array([focal, focal, width / 2, height / 2])
    elif camera_model == 'simple-radial':
        model = 2 # simple radial
        param_arr = np.array([focal, width / 2, height / 2, 0.1])
    elif camera_model == 'opencv':
        model = 4 # opencv
        param_arr = np.array([focal, focal, width / 2, height / 2, 0., 0., 0., 0.])
         
    return db.add_camera(model, width, height, param_arr)


def add_keypoints(db, h5_path, image_path, img_ext, camera_model, single_camera = True):
    keypoint_f = h5py.File(os.path.join(h5_path, 'keypoints.h5'), 'r')

    camera_id = None
    fname_to_id = {}
    for filename in tqdm(list(keypoint_f.keys())):
        keypoints = keypoint_f[filename][..., :2]

        fname_with_ext = filename# + img_ext
        path = os.path.join(image_path, fname_with_ext)
        if not os.path.isfile(path):
            raise IOError(f'Invalid image path {path}')

        if camera_id is None or not single_camera:
            camera_id = create_camera(db, path, camera_model)
        image_id = db.add_image(fname_with_ext, camera_id)
        fname_to_id[filename] = image_id
        
        db.add_keypoints(image_id, keypoints)

    return fname_to_id

def add_matches(db, h5_path, fname_to_id):
    match_file = h5py.File(os.path.join(h5_path, 'matches.h5'), 'r')
    
    added = set()
    n_keys = len(match_file.keys())
    n_total = (n_keys * (n_keys - 1)) // 2

    with tqdm(total=n_total) as pbar:
        for key_1 in match_file.keys():
            group = match_file[key_1]
            for key_2 in group.keys():
                id_1 = fname_to_id[key_1]
                id_2 = fname_to_id[key_2]

                pair_id = image_ids_to_pair_id(id_1, id_2)
                if pair_id in added:
                    warnings.warn(f'Pair {pair_id} ({id_1}, {id_2}) already added!')
                    continue
            
                matches = group[key_2][()]
                db.add_matches(id_1, id_2, matches)

                added.add(pair_id)

                pbar.update(1)
            db.commit()
            gc.collect()

In [7]:
# Making kornia local features loading w/o internet
class KeyNetAffNetHardNet(KF.LocalFeature):
    """Convenience module, which implements KeyNet detector + AffNet + HardNet descriptor.

    .. image:: _static/img/keynet_affnet.jpg
    """

    def __init__(
        self,
        num_features: int = 5000,
        upright: bool = False,
        device = torch.device('cpu'),
        scale_laf: float = 1.0,
    ):
        ori_module = KF.PassLAF() if upright else KF.LAFOrienter(angle_detector=KF.OriNet(False)).eval()
        if not upright:
            weights = torch.load('/kaggle/input/kornia-local-feature-weights/OriNet.pth')['state_dict']
            ori_module.angle_detector.load_state_dict(weights)
        detector = KF.KeyNetDetector(
            False, num_features=num_features, ori_module=ori_module, aff_module=KF.LAFAffNetShapeEstimator(False).eval()
        ).to(device)
        kn_weights = torch.load('/kaggle/input/kornia-local-feature-weights/keynet_pytorch.pth')['state_dict']
        detector.model.load_state_dict(kn_weights)
        affnet_weights = torch.load('/kaggle/input/kornia-local-feature-weights/AffNet.pth')['state_dict']
        detector.aff.load_state_dict(affnet_weights)
        
        hardnet = KF.HardNet(False).eval()
        hn_weights = torch.load('/kaggle/input/kornia-local-feature-weights/HardNetLib.pth')['state_dict']
        hardnet.load_state_dict(hn_weights)
        descriptor = KF.LAFDescriptor(hardnet, patch_size=32, grayscale_descriptor=True).to(device)
        super().__init__(detector, descriptor, scale_laf)

In [8]:
def detect_features(img_fnames,
                    num_feats = 2048,
                    upright = False,
                    device=torch.device('cpu'),
                    feature_dir = '.featureout',
                    resize_small_edge_to = 600):
    if opt.local_feature == 'DISK':
        # Load DISK from Kaggle models so it can run when the notebook is offline.
        disk = KF.DISK().to(device)
        pretrained_dict = torch.load(
            '/kaggle/input/disk/pytorch/depth-supervision/1/loftr_outdoor.ckpt', 
            map_location=opt.device)
        disk.load_state_dict(pretrained_dict['extractor'])
        disk.eval()
    elif opt.local_feature == 'KeyNetAffNetHardNet':
        feature = KeyNetAffNetHardNet(num_feats, upright, opt.device).to(opt.device).eval()
    else:
        raise NotImplementedError
    if not os.path.isdir(feature_dir):
        os.makedirs(feature_dir)
    with h5py.File(f'{feature_dir}/lafs.h5', mode='w') as f_laf, \
         h5py.File(f'{feature_dir}/keypoints.h5', mode='w') as f_kp, \
         h5py.File(f'{feature_dir}/descriptors.h5', mode='w') as f_desc:
        for img_path in progress_bar(img_fnames):
            img_fname = img_path.split('/')[-1]
            key = img_fname
            with torch.inference_mode():
                timg = load_torch_image(img_path, device=opt.device)
                H, W = timg.shape[2:]
                if resize_small_edge_to is None:
                    timg_resized = timg
                else:
                    timg_resized = K.geometry.resize(
                        timg, resize_small_edge_to, antialias=True)
                    print(f'Resized {timg.shape} to {timg_resized.shape} (resize_small_edge_to={resize_small_edge_to})')
                h, w = timg_resized.shape[2:]
                if opt.local_feature == 'DISK':
                    features = disk(timg_resized, num_feats, pad_if_not_divisible=True)[0]
                    kps1, descs = features.keypoints, features.descriptors
                    
                    lafs = KF.laf_from_center_scale_ori(
                        kps1[None], 
                        torch.ones(1, len(kps1), 1, 1, device=opt.device))
                if opt.local_feature == 'KeyNetAffNetHardNet':
                    lafs, resps, descs = feature(K.color.rgb_to_grayscale(timg_resized))
                lafs[:,:,0,:] *= float(W) / float(w)
                lafs[:,:,1,:] *= float(H) / float(h)
                desc_dim = descs.shape[-1]
                kpts = KF.get_laf_center(lafs).reshape(-1, 2).detach().cpu().numpy()
                descs = descs.reshape(-1, desc_dim).detach().cpu().numpy()
                f_laf[key] = lafs.detach().cpu().numpy()
                f_kp[key] = kpts
                f_desc[key] = descs
    return

def detect_features_disk(img_fnames,
                    num_feats = 8192,
                    device=torch.device('cpu'),
                    feature_dir='.featureout',
                    max_long_side=1600,
                    num_octaves=4):
    # Load DISK from Kaggle models so it can run when the notebook is offline.
    disk = KF.DISK().to(device)
    pretrained_dict = torch.load(
        '/kaggle/input/disk/pytorch/depth-supervision/1/loftr_outdoor.ckpt', 
        map_location=opt.device)
    disk.load_state_dict(pretrained_dict['extractor'])
    disk.eval()
    
    img_ds = ImageDataset(img_fnames, 
                          long_side=max_long_side,
                          padding=False,
                          to_gray=False,
                          device=device)
    img_loader = torch.utils.data.DataLoader(img_ds)

    if not os.path.isdir(feature_dir):
        os.makedirs(feature_dir)
        
    with h5py.File(f'{feature_dir}/lafs.h5', mode='w') as f_laf, \
         h5py.File(f'{feature_dir}/keypoints.h5', mode='w') as f_kp, \
         h5py.File(f'{feature_dir}/descriptors.h5', mode='w') as f_desc:

        for img_path, (img, scale) in zip(tqdm(img_fnames, desc='feat_ext'), img_loader):
            img_fname = img_path.split('/')[-1]
            key = img_fname
            scale = scale.to(img.device)
            h, w = img.shape[-2:]
            num_feats_target = num_feats
            lafs_multiscale = []
            kpts_multiscale = []
            desc_multiscale = []
            with torch.inference_mode():
                for octave in range(num_octaves):
                    if octave > 0:
                        h, w = h // 2, w // 2
                        img = K.geometry.resize(img, (h, w), interpolation='area')
                        num_feats_target //= 4
                        scale /= 2.0
                    features = disk(img, num_feats_target, pad_if_not_divisible=True)[0]
                    kps1, descs = features.keypoints, features.descriptors
                    lafs = KF.laf_from_center_scale_ori(
                        kps1[None], 
                        torch.ones(1, len(kps1), 1, 1, device=opt.device) / scale)
                    lafs[:,:,0,:] /= scale
                    lafs[:,:,1,:] /= scale
                    desc_dim = descs.shape[-1]
                    kpts = KF.get_laf_center(lafs).reshape(-1, 2).detach().cpu().numpy()
                    descs = descs.reshape(-1, desc_dim).detach().cpu().numpy()
                    lafs_multiscale.append(lafs.detach().cpu().numpy())
                    kpts_multiscale.append(kpts)
                    desc_multiscale.append(descs)
            f_laf[key] = np.concatenate(lafs_multiscale, axis=1)
            f_kp[key] = np.concatenate(kpts_multiscale)
            f_desc[key] = np.concatenate(desc_multiscale)
    return

def get_unique_idxs(A, dim=0):
    # https://stackoverflow.com/questions/72001505/how-to-get-unique-elements-and-their-firstly-appeared-indices-of-a-pytorch-tenso
    unique, idx, counts = torch.unique(A, dim=dim, sorted=True, return_inverse=True, return_counts=True)
    _, ind_sorted = torch.sort(idx, stable=True)
    cum_sum = counts.cumsum(0)
    cum_sum = torch.cat((torch.tensor([0],device=cum_sum.device), cum_sum[:-1]))
    first_indices = ind_sorted[cum_sum]
    return first_indices

def match_features(img_fnames,
                   index_pairs,
                   feature_dir = '.featureout',
                   device=torch.device('cpu'),
                   min_matches=15, 
                   force_mutual = True,
                   matching_alg='smnn'
                  ):
    assert matching_alg in ['smnn', 'adalam']
    with h5py.File(f'{feature_dir}/lafs.h5', mode='r') as f_laf, \
         h5py.File(f'{feature_dir}/descriptors.h5', mode='r') as f_desc, \
        h5py.File(f'{feature_dir}/matches.h5', mode='w') as f_match:

        for pair_idx in progress_bar(index_pairs):
            idx1, idx2 = pair_idx
            fname1, fname2 = img_fnames[idx1], img_fnames[idx2]
            key1, key2 = fname1.split('/')[-1], fname2.split('/')[-1]
            lafs1 = torch.from_numpy(f_laf[key1][...]).to(device)
            lafs2 = torch.from_numpy(f_laf[key2][...]).to(device)
            desc1 = torch.from_numpy(f_desc[key1][...]).to(device)
            desc2 = torch.from_numpy(f_desc[key2][...]).to(device)
            if matching_alg == 'adalam':
                img1, img2 = cv2.imread(fname1), cv2.imread(fname2)
                hw1, hw2 = img1.shape[:2], img2.shape[:2]
                adalam_config = KF.adalam.get_adalam_default_config()
                #adalam_config['orientation_difference_threshold'] = None
                #adalam_config['scale_rate_threshold'] = None
                adalam_config['force_seed_mnn']= False
                adalam_config['search_expansion'] = 16
                adalam_config['ransac_iters'] = 128
                adalam_config['device'] = device
                dists, idxs = KF.match_adalam(desc1, desc2,
                                              lafs1, lafs2, # Adalam takes into account also geometric information
                                              hw1=hw1, hw2=hw2,
                                              config=adalam_config) # Adalam also benefits from knowing image size
            else:
                dists, idxs = KF.match_smnn(desc1, desc2, 0.98)
            if len(idxs)  == 0:
                continue
            # Force mutual nearest neighbors
            if force_mutual:
                first_indices = get_unique_idxs(idxs[:,1])
                idxs = idxs[first_indices]
                dists = dists[first_indices]
            n_matches = len(idxs)
            if True:
                print (f'{key1} - {key2}: {n_matches} matches')
            group  = f_match.require_group(key1)
            if n_matches >= min_matches:
                 group.create_dataset(key2, data=idxs.detach().cpu().numpy().reshape(-1, 2))
    return

def match_loftr(img_fnames,
                   index_pairs,
                   feature_dir = '.featureout_loftr',
                   device=torch.device('cpu'),
                   min_matches=15, 
                max_long_side=800):
    matcher = KF.LoFTR(pretrained=None)
    matcher.load_state_dict(torch.load('/kaggle/input/loftr/pytorch/outdoor/1/loftr_outdoor.ckpt')['state_dict'])
    matcher = matcher.to(device).eval()
    
    print(f'{feature_dir}/matches_loftr.h5')
    
    prev_key1, prev_key2 = None, None
    # First we do pairwise matching, and then extract "keypoints" from loftr matches.
    with h5py.File(f'{feature_dir}/matches_loftr.h5', mode='w') as f_match:
        for idx1, idx2 in progress_bar(index_pairs):
            fname1, fname2 = img_fnames[idx1], img_fnames[idx2]
            key1, key2 = fname1.split('/')[-1], fname2.split('/')[-1]
            
            if key1 != prev_key1:
                img1 = load_torch_image(fname1, device=device)
                img1, scale1 = preprocess(img1, max_long_side, True, True)
                prev_key1 = key1
                
            if key2 != prev_key2:
                img2 = load_torch_image(fname2, device=device)
                img2, scale2 = preprocess(img2, max_long_side, True, True)
                prev_key2 = key2 
                
            with torch.inference_mode():
                fwd_matches = matcher({"image0": img1,"image1": img2})
                bwd_matches = matcher({"image0": img2,"image1": img1})

            mkpts1 = torch.cat([fwd_matches['keypoints0'], bwd_matches['keypoints1']])
            mkpts1 = (mkpts1 / scale1).cpu().numpy()
            
            mkpts2 = torch.cat([fwd_matches['keypoints1'], bwd_matches['keypoints0']])
            mkpts2 = (mkpts2 / scale2).cpu().numpy()
            
            n_matches = len(mkpts1)
            group = f_match.require_group(key1)
            if n_matches >= min_matches:
                 group.create_dataset(key2, data=np.concatenate([mkpts1, mkpts2], axis=1))
            gc.collect()

    # Let's find unique loftr pixels and group them together.
    kpts = defaultdict(list)
    match_indexes = defaultdict(dict)
    total_kpts = defaultdict(int)
    with h5py.File(f'{feature_dir}/matches_loftr.h5', mode='r') as f_match:
        for k1 in f_match.keys():
            group = f_match[k1]
            for k2 in group.keys():
                matches = group[k2][...]
                total_kpts[k1]
                kpts[k1].append(matches[:, :2])
                kpts[k2].append(matches[:, 2:])
                current_match = torch.arange(len(matches)).reshape(-1, 1).repeat(1, 2)
                current_match[:, 0] += total_kpts[k1]
                current_match[:, 1] += total_kpts[k2]
                total_kpts[k1] += len(matches)
                total_kpts[k2] += len(matches)
                match_indexes[k1][k2] = current_match

    orig_kpts = deepcopy(kpts)
    for k in kpts.keys():
        kpts[k] = np.round(np.concatenate(kpts[k], axis=0))
        orig_kpts[k] = np.concatenate(orig_kpts[k], axis=0)
    
    unique_kpts = {}
    unique_match_idxs = {}
    out_match = defaultdict(dict)
    
    for k in kpts.keys():
        average_duplicated_points = True
        if average_duplicated_points:
            uniq_kps, uniq_reverse_idxs, uniq_cnts = torch.unique(
                torch.from_numpy(kpts[k]), dim=0, 
                return_inverse=True, return_counts=True)
            mean_kps = torch.zeros_like(uniq_kps, dtype=torch.float32)
            for i in range(len(kpts[k])):
                mean_kps[uniq_reverse_idxs[i]] += orig_kpts[k][i]
            mean_kps /= uniq_cnts.type(mean_kps.dtype).unsqueeze(-1)
            unique_match_idxs[k] = uniq_reverse_idxs
            unique_kpts[k] = mean_kps.numpy()
        else:
            uniq_kps, uniq_reverse_idxs = torch.unique(
                torch.from_numpy(kpts[k]),dim=0, return_inverse=True)
            unique_match_idxs[k] = uniq_reverse_idxs
            unique_kpts[k] = uniq_kps.numpy()
        
    for k1, group in match_indexes.items():
        for k2, m in group.items():
            m2 = deepcopy(m)
            m2[:,0] = unique_match_idxs[k1][m2[:,0]]
            m2[:,1] = unique_match_idxs[k2][m2[:,1]]
            mkpts = np.concatenate([unique_kpts[k1][m2[:,0]],
                                    unique_kpts[k2][m2[:,1]],
                                   ],
                                   axis=1)
            unique_idxs_current = get_unique_idxs(torch.from_numpy(mkpts), dim=0)
            m2_semiclean = m2[unique_idxs_current]
            unique_idxs_current1 = get_unique_idxs(m2_semiclean[:, 0], dim=0)
            m2_semiclean = m2_semiclean[unique_idxs_current1]
            unique_idxs_current2 = get_unique_idxs(m2_semiclean[:, 1], dim=0)
            m2_semiclean2 = m2_semiclean[unique_idxs_current2]
            out_match[k1][k2] = m2_semiclean2.numpy()
            
    with h5py.File(f'{feature_dir}/keypoints.h5', mode='w') as f_kp:
        for k, kpts1 in unique_kpts.items():
            f_kp[k] = kpts1
    
    with h5py.File(f'{feature_dir}/matches.h5', mode='w') as f_match:
        for k1, gr in out_match.items():
            group  = f_match.require_group(k1)
            for k2, match in gr.items():
                group[k2] = match
    return

def import_into_colmap(img_dir,
                       feature_dir ='.featureout',
                       database_path = 'colmap.db',
                       img_ext='.jpg'):
    db = COLMAPDatabase.connect(database_path)
    db.create_tables()
    single_camera = False
    #fname_to_id = add_keypoints(db, feature_dir, img_dir, img_ext, 'opencv', single_camera)
    fname_to_id = add_keypoints(db, feature_dir, img_dir, img_ext, 'simple-radial', single_camera)
    add_matches(
        db,
        feature_dir,
        fname_to_id,
    )

    db.commit()
    db.close()
    del db
    return fname_to_id

In [9]:
def get_matcher(local_feature_type, feature_dir):
    if local_feature_type == 'LoFTR':
        return DetectAndMatchLoFTR(
            feature_dir=feature_dir,
            device=opt.device,
            max_long_side=600)
    elif local_feature_type == 'SiLK':
        return DetectAndMatchSILK(
            feature_dir=feature_dir, 
            device=opt.device, 
            max_long_side=840)
    elif local_feature_type == 'DISKLoFTR':
        return DetectAndMatchDiskLoftrHybrid( 
            num_feats=8192,
            device=opt.device,
            feature_dir=feature_dir,
            max_long_side=1200,
            min_matches=40,
            num_octaves=4,
            matching_alg=opt.matching_alg)
    else:
        return DetectAndMatchDefault(
            num_feats=5000,
            feature_dir=feature_dir,
            device=opt.device,
            upright=True,
            resize_small_edge_to=800,
            matching_alg=opt.matching_alg,
            force_mutual=True,
            min_matches=15)
    

class DetectAndMatch:
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.time_detection = 0
        self.time_matching = 0
    
    def process(self, img_fnames, index_pairs):
        raise NotImplementedError

        
class DetectAndMatchDefault(DetectAndMatch):
    def __init__(self, 
                 num_feats=2048,
                 upright=False,
                 device=torch.device('cpu'),
                 feature_dir='.featureout',
                 resize_small_edge_to=600,
                 min_matches=15, 
                 force_mutual = True,
                 matching_alg='smnn'):
        super().__init__()
        self.num_feats = num_feats
        self.upright = upright
        self.device = device
        self.feature_dir = feature_dir
        self.resize_small_edge_to = resize_small_edge_to
        self.min_matches = 15
        self.force_mutual = force_mutual
        self.matching_alg = matching_alg
    
    def process(self, img_fnames, index_pairs):
        t = time()
        detect_features(img_fnames, 
                        self.num_feats,
                        feature_dir=self.feature_dir,
                        upright=self.upright,
                        device=self.device,
                        resize_small_edge_to=self.resize_small_edge_to,
                       )
        gc.collect()
        self.time_detection = time() - t
        
        t = time()
        match_features(img_fnames, 
                       index_pairs, 
                       feature_dir=self.feature_dir, 
                       device=self.device,
                       matching_alg=self.matching_alg)
        self.time_matching = time() - t
        
        
class DetectAndMatchLoFTR(DetectAndMatch):
    def __init__(self, feature_dir, device=torch.device('cpu'), max_long_side=800):
        super().__init__()
        self.feature_dir = feature_dir
        self.device = device
        self.max_long_side = max_long_side
    
    def process(self, img_fnames, index_pairs):
        t = time()
        match_loftr(img_fnames, index_pairs,
                    feature_dir=self.feature_dir, 
                    device=self.device, 
                    max_long_side=self.max_long_side)
        self.time_matching = time() - t

        
class DetectAndMatchSILK(DetectAndMatch):     
    NMS = 0  # NMS radius, 0 = disabled
    BORDER = 0  # remove detection on border, 0 = disabled
    THRESHOLD = 1.0  # keypoint score thresholding, if # of keypoints is less than provided top-k, then will add keypoints to reach top-k value, 1.0 = disabled
    TOP_K = 10000  # minimum number of best keypoints to output, could be higher if threshold specified above has low value

    MATCHER_RATIO_THRESHOLD = 0.8
    
    DEFAULT_OUTPUTS = ("sparse_positions", "sparse_descriptors")
    SCALE_FACTOR = 1.41  # scaling of descriptor output, do not change
    CKPT_PATH = "/kaggle/input/silk-local-feature-weights/silk/coco-rgb-aug.ckpt"
    
    def __init__(self, 
                 feature_dir, 
                 multi_res = False,
                 device=torch.device('cpu'), 
                 max_long_side=800,
                 min_matches=15):
        super().__init__()
        self.feature_dir = feature_dir
        self.multi_res = multi_res
        self.device = device
        self.max_long_side = max_long_side
        self.min_matches = min_matches
        
    
    def process(self, img_fnames, index_pairs):
        t = time()
        self.detect_features(img_fnames)
        gc.collect()
        self.time_detection = time() - t
        
        t = time()
        self.match_features(img_fnames, index_pairs)
        self.time_matching = time() - t
        
    def detect_features(self, img_fnames):
        model = DetectAndMatchSILK.get_model(device=self.device)
        img_ds = ImageDataset(img_fnames, 
                              long_side=self.max_long_side,
                              padding=False,
                              to_gray=True,
                              device=self.device)
        img_loader = torch.utils.data.DataLoader(img_ds)
        if not os.path.isdir(self.feature_dir):
            os.makedirs(self.feature_dir)
        
        with h5py.File(f'{self.feature_dir}/lafs.h5', mode='w') as f_laf, \
             h5py.File(f'{self.feature_dir}/keypoints.h5', mode='w') as f_kpts, \
             h5py.File(f'{self.feature_dir}/descriptors.h5', mode='w') as f_desc:
            for img_path, (img, scale) in zip(tqdm(img_fnames, desc='feat_ext'), img_loader):
                img_fname = img_path.split('/')[-1]
                key = img_fname
                scale = scale.to(opt.device)
                with torch.inference_mode():
                    h, w = img.shape[-2:]
                    kps, descs = model(img)
                    kps = from_feature_coords_to_image_coords(model, kps)
                    kps, descs = kps[0], descs[0]
                    kps = kps / scale
                    kps = kps[..., [1, 0]]
                    
                    desc_dim = descs.shape[-1]
                    lafs = KF.laf_from_center_scale_ori(
                        kps[None], 
                        torch.ones(1, len(kps), 1, 1, device=opt.device))
                    f_laf[key] = lafs.detach().cpu().numpy()
                    f_kpts[key] = kps.detach().cpu().numpy()
                    f_desc[key] = descs.reshape(-1, desc_dim).detach().cpu().numpy()

    def match_features(self, img_fnames, idx_pairs):
        from silk.models.silk import matcher as Matcher
        matcher = Matcher(postprocessing="ratio-test", threshold=self.MATCHER_RATIO_THRESHOLD)
        with h5py.File(f'{self.feature_dir}/keypoints.h5', mode='r') as f_kpts, \
             h5py.File(f'{self.feature_dir}/descriptors.h5', mode='r') as f_desc, \
             h5py.File(f'{self.feature_dir}/matches.h5', mode='w') as f_match:
            for idx1, idx2 in tqdm(idx_pairs, desc='matching'):
                fname1, fname2 = img_fnames[idx1], img_fnames[idx2]
                key1, key2 = fname1.split('/')[-1], fname2.split('/')[-1]
                desc1 = torch.from_numpy(f_desc[key1][...]).to(self.device)
                desc2 = torch.from_numpy(f_desc[key2][...]).to(self.device)
                matches = matcher(desc1, desc2)
                n_matches = matches.shape[0]
                group  = f_match.require_group(key1)
                if n_matches >= self.min_matches:
                    group.create_dataset(
                        key2, data=matches.detach().cpu().numpy().reshape(-1, 2))

                
    @classmethod
    def get_model(cls, device=torch.device('cpu')):
        backbone = ParametricVGG(
            use_max_pooling=False,
            padding=0,
            normalization_fn=[torch.nn.BatchNorm2d(i) for i in (64, 64, 128, 128)],
        )
        # load model
        model = SiLK(
            in_channels=1,
            backbone=backbone,
            detection_threshold=cls.THRESHOLD,
            detection_top_k=cls.TOP_K,
            nms_dist=cls.NMS,
            border_dist=cls.BORDER,
            default_outputs=cls.DEFAULT_OUTPUTS,
            descriptor_scale_factor=cls.SCALE_FACTOR,
            padding=0,
        )
        model = load_model_from_checkpoint(
            model,
            checkpoint_path=cls.CKPT_PATH,
            state_dict_fn=lambda x: {k[len("_mods.model.") :]: v for k, v in x.items()},
            device=device,
            freeze=True,
            eval=True,
        )
        return model
    
class DetectAndMatchDiskLoftrHybrid(DetectAndMatch):
    def __init__(self, 
                 num_feats=2048,
                 device=torch.device('cpu'),
                 feature_dir='.featureout',
                 max_long_side=800,
                 min_matches=15,
                 num_octaves=1,
                 matching_alg="smnn"):
        super().__init__()
        self.num_feats = num_feats
        self.device = device
        self.feature_dir = feature_dir
        self.max_long_side = max_long_side
        self.min_matches = min_matches
        self.num_octaves = num_octaves
        self.matching_alg = matching_alg
    
    def process(self, img_fnames, index_pairs):
        t = time()
        detect_features_disk(img_fnames, 
                             self.num_feats,
                             feature_dir=self.feature_dir,
                             device=self.device,
                             max_long_side=self.max_long_side,
                             num_octaves=self.num_octaves)
        gc.collect()
        self.time_detection = time() - t
        
        t = time()
        match_features(img_fnames, 
                       index_pairs, 
                       feature_dir=self.feature_dir, 
                       device=self.device,
                       matching_alg=self.matching_alg)
        self.time_matching = time() - t
        
    def match_features_loftr_guided(self, img_fnames, img_pairs):
        match_features(img_fnames, 
                       index_pairs, 
                       feature_dir=self.feature_dir, 
                       device=self.device,
                       matching_alg='smnn')
        gc.collect()

        loftr = KF.LoFTR(pretrained=None)
        loftr.load_state_dict(torch.load('/kaggle/input/loftr/pytorch/outdoor/1/loftr_outdoor.ckpt')['state_dict'])
        loftr = loftr.to(device).eval()
        
        prev_key1, prev_key2 = None, None
        for idx1, idx2 in progress_bar(index_pairs):
            fname1, fname2 = img_fnames[idx1], img_fnames[idx2]
            key1, key2 = fname1.split('/')[-1], fname2.split('/')[-1]

            if key1 != prev_key1:
                img1 = load_torch_image(fname1, device=device)
                img1, scale1 = preprocess(img1, max_long_side, True, True)
                prev_key1 = key1

            if key2 != prev_key2:
                img2 = load_torch_image(fname2, device=device)
                img2, scale2 = preprocess(img2, max_long_side, True, True)
                prev_key2 = key2 

            with torch.inference_mode():
                fwd_matches = matcher({"image0": img1,"image1": img2})
                bwd_matches = matcher({"image0": img2,"image1": img1})

            mkpts1 = torch.cat([fwd_matches['keypoints0'], bwd_matches['keypoints1']])
            mkpts1 = (mkpts1 / scale1).cpu().numpy()

            mkpts2 = torch.cat([fwd_matches['keypoints1'], bwd_matches['keypoints0']])
            mkpts2 = (mkpts2 / scale2).cpu().numpy()

            n_matches = len(mkpts1)
            group  = f_match.require_group(key1)
            if n_matches >= min_matches:
                 group.create_dataset(key2, data=np.concatenate([mkpts1, mkpts2], axis=1))
            gc.collect()
            
        raise NotImplementedError

        

In [10]:
if False and opt.debug:
    img_dir, img_fnames, *_ = load_img_fnames(data_dict, "urban", "kyiv-puppet-theater")
    feature_dir = 'dbg'
    detect_and_match = DetectAndMatchSILK(feature_dir, device=opt.device, max_long_side=800)
    detect_and_match.process(img_fnames, ((0, 1), (1, 3), ))
    with h5py.File(f'{feature_dir}/keypoints.h5', mode='r') as f_kpts, \
         h5py.File(f'{feature_dir}/descriptors.h5', mode='r') as f_desc, \
         h5py.File(f'{feature_dir}/matches.h5', mode='r') as f_match:
        for key1, group in f_match.items():
            for key2, matches in group.items():
                print(key1, key2)
                kp1 = f_kpts[key1][...][..., :2]
                kp2 = f_kpts[key2][...][..., :2]
                img1 = cv2.imread(os.path.join(img_dir, key1))
                img2 = cv2.imread(os.path.join(img_dir, key2))
                draw_matches(img1, kp1, img2, kp2, matches, mask=None)
                

    print(detect_and_match.time_detection)
    print(detect_and_match.time_matching)

In [11]:
# Get data from csv.
def read_submission_file():
    data_dict = {}
    with open(opt.input_filepath, 'r') as f:
        for i, l in enumerate(f):
            # Skip header.
            if l and i > 0:
                if opt.test_or_train == 'train':
                    dataset, scene, image, _, _ = l.strip().split(',')
                else:
                    image, dataset, scene, _, _ = l.strip().split(',')

                if dataset not in data_dict:
                    data_dict[dataset] = {}
                if scene not in data_dict[dataset]:
                    data_dict[dataset][scene] = []
                data_dict[dataset][scene].append(image)
    return data_dict

      
def load_img_fnames(data_dict, dataset, scene):
    img_dir = f'{opt.src_dirname}/{opt.test_or_train}/{dataset}/{scene}/images'
    if opt.test_or_train == 'train':
        assert os.path.exists(img_dir)
    elif not os.path.exists(img_dir):
        return None, None, None, None
            
    # train directories have some extra images. Is it true also for test data?
    img_fnames_ref = [f'{opt.src_dirname}/{opt.test_or_train}/{x}' 
                  for x in data_dict[dataset][scene]]
    img_fnames_all = sorted(glob(f'{img_dir}/*'))
    if len(img_fnames_ref) < len(img_fnames_all):
        print(f'image count mismatch! {len(img_fnames_ref)} < {len(img_fnames_all)}')
    if opt.use_all_imgs:
        img_fnames = img_fnames_all
    else:
        img_fnames = img_fnames_ref
    return img_dir, img_fnames, img_fnames_ref, img_fnames_all

In [12]:
# Function to create a submission file.
def create_submission(out_results, data_dict):
    with open(opt.output_filepath, 'w') as f:
        f.write('image_path,dataset,scene,rotation_matrix,translation_vector\n')
        for dataset in data_dict:
            if dataset in out_results:
                res = out_results[dataset]
            else:
                res = {}
            for scene in data_dict[dataset]:
                if scene in res:
                    scene_res = res[scene]
                else:
                    scene_res = {"R":{}, "t":{}}
                for image in data_dict[dataset][scene]:
                    if image in scene_res:
                        print (image)
                        R = scene_res[image]['R'].reshape(-1)
                        T = scene_res[image]['t'].reshape(-1)
                    else:
                        R = np.eye(3).reshape(-1)
                        T = np.zeros((3))
                    f.write(f'{image},{dataset},{scene},{arr_to_str(R)},{arr_to_str(T)}\n')

In [13]:
@dataclass
class Debug:
    index_pairs: list = field(default_factory=list) 
    img_fnames: list = field(default_factory=list) 
    img_dirname: str = None
    recon_fnames: list = field(default_factory=list) 
    best_idx: int = -1
    database_path: str = None
    fname_to_id: dict = field(default_factory=dict)
    output_path: str = None
        
@dataclass
class Timing:
    shortlisting: float = 0
    feature_detection: float = 0
    feature_matching: float = 0
    geometric_verification: float = 0
    reconstruction: float = 0
  
def process_scene(data_dict, dataset, scene, result_queue):
    try:
        results = {}
        timing = Timing()
        dbg = Debug() if opt.debug else None

        img_dir, img_fnames, img_fnames_ref, img_fnames_all = load_img_fnames(
            data_dict, dataset, scene)
        if img_dir is None:
            return
        if dbg is not None:
            dbg.img_dirname = img_dir
                                     
        print (f"Got {len(img_fnames)} images")
        feature_dir = f'featureout/{dataset}_{scene}'
        os.makedirs(feature_dir, exist_ok=True)

        # step1. SHORTLISTING
        t = time()
        retrieval_per_img = min(opt.retrieval_per_img, max(1, len(img_fnames) - 1))
        index_pairs = hloc_image_pairs_shortlist(img_dir, img_fnames, 
                                                 alg=opt.retrieval_alg,
                                                 match_per_img=retrieval_per_img, 
                                                 exhaustive_if_less=opt.exhaustive_if_less)
        if dbg is not None:
            dbg.index_pairs = index_pairs
            dbg.img_fnames = img_fnames
        timing.shortlisting += time() - t
        print (f'{len(index_pairs)}, pairs to match, {timing.shortlisting:.4f} sec')
        gc.collect()

        # step2. FEATURE DETECTION & MATCHING
        detect_and_match = get_matcher(opt.local_feature, feature_dir)
        detect_and_match.process(img_fnames, index_pairs)
        timing.feature_detection += detect_and_match.time_detection
        timing.feature_matching += detect_and_match.time_matching
        time_detect_and_match = timing.feature_detection + timing.feature_matching 
        print(f'Features matched in  {time_detect_and_match:.4f} sec')
        gc.collect()

        # PREPARE COLMAP
        database_path = f'{feature_dir}/colmap.db'
        if dbg is not None:
            dbg.database_path = database_path
        if os.path.isfile(database_path):
            os.remove(database_path)

        fname_to_id = import_into_colmap(
            img_dir, feature_dir=feature_dir, database_path=database_path)
        output_path = f'{feature_dir}/colmap_rec_{opt.local_feature}'
        if dbg is not None:
            dbg.output_path = output_path
            dbg.fname_to_id = fname_to_id
        gc.collect()

        # step3. GEOMETRIC VERIFICATION
        t = time()
        if opt.geometric_verification_alg == 'colmap':
            pycolmap.match_exhaustive(database_path)
        elif opt.geometric_verification_alg == 'magsac':
            geometric_verification_magsac(database_path)
        timing.geometric_verification += time() - t
        print(f'Geometric Verification in  {timing.geometric_verification:.4f} sec')
        gc.collect()

        # step4. INCREMENTAL SFM
        t=time()
        mapper_options = pycolmap.IncrementalMapperOptions()
        #mapper_options.multiple_models = False
        mapper_options.min_model_size = 3
        mapper_options.ba_refine_principal_point = True
        mapper_options.extract_colors = False
        mapper_options.num_threads = os.cpu_count()

        os.makedirs(output_path, exist_ok=True)
        reconstructions = pycolmap.incremental_mapping(
            database_path=database_path, 
            image_path=img_dir, 
            output_path=output_path, 
            options=mapper_options)
        print(reconstructions)
        if not opt.debug:
            clear_output(wait=False)
        timing.reconstruction += time() - t
        print(f'Reconstruction done in  {timing.reconstruction:.4f} sec')
        gc.collect()

        # step5. GATHER RESULTS
        reconstructions = sorted(reconstructions.values(), key=lambda m: len(m.images))
        for idx, rec in enumerate(reconstructions):
            print(rec.summary())
            for k, im in rec.images.items():
                key1 = f'{dataset}/{scene}/images/{im.name}'
                results[key1] = {}
                results[key1]["R"] = im.rotmat()
                results[key1]["t"] = im.tvec
                

        print(f'Registered: {dataset} / {scene} -> {len(results)} images')
        print(f'Total: {dataset} / {scene} -> {len(img_fnames_ref)} images')
        gc.collect()
        result_queue.put([results, timing, dbg])

    except Exception as e:
        print(e)
        print(traceback.format_exc())

In [None]:
# MAIN
def process_all(data_dict):
    gc.collect()
    out_results = defaultdict(dict)
    debugs = defaultdict(dict)
    timings = defaultdict(dict)
    
    for dataset in data_dict.keys():
        if dataset not in out_results:
            out_results[dataset] = {}
        for scene in data_dict[dataset]:
            if opt.scenes_to_debug is not None and scene not in opt.scenes_to_debug:
                print(f'skipping {scene}')
                continue            
            result_queue = mp.Queue()
            p = mp.Process(target=process_scene, 
                           args=(data_dict, dataset, scene, result_queue))
            p.start()
            p.join()
            if result_queue.empty():
                print(f"failed to process {scene}")
                continue
            print(f'done processing {scene}')
            results, timing, dbg = result_queue.get()
            out_results[dataset][scene] = results
            debugs[dataset][scene] = dbg
            timings[dataset][scene] = timing
                
            create_submission(out_results, data_dict)
            gc.collect()
    pprint(timings)
    return out_results, timings, debugs

data_dict = read_submission_file()
for dataset in data_dict:
    for scene in data_dict[dataset]:
        print(f'{dataset} / {scene} -> {len(data_dict[dataset][scene])} images')
out_results, timings, debugs = process_all(data_dict)
create_submission(out_results, data_dict)
gc.collect()

urban / kyiv-puppet-theater -> 26 images
heritage / dioscuri -> 174 images
heritage / cyprus -> 30 images
heritage / wall -> 43 images
haiper / bike -> 15 images
haiper / chairs -> 16 images
haiper / fountain -> 23 images
skipping kyiv-puppet-theater
skipping dioscuri
skipping cyprus
skipping wall
Got 15 images
105, pairs to match, 0.0002 sec


NVIDIA GeForce RTX 3090 with CUDA capability sm_86 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_60 sm_70 sm_75 compute_70 compute_75.
If you want to use the NVIDIA GeForce RTX 3090 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



In [43]:
# Evaluation metric.

@dataclass
class Camera:
    rotmat: np.array
    tvec: np.array

def quaternion_from_matrix(matrix):
    M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4]
    m00 = M[0, 0]
    m01 = M[0, 1]
    m02 = M[0, 2]
    m10 = M[1, 0]
    m11 = M[1, 1]
    m12 = M[1, 2]
    m20 = M[2, 0]
    m21 = M[2, 1]
    m22 = M[2, 2]

    # Symmetric matrix K.
    K = np.array([[m00 - m11 - m22, 0.0, 0.0, 0.0],
                  [m01 + m10, m11 - m00 - m22, 0.0, 0.0],
                  [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0],
                  [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22]])
    K /= 3.0

    # Quaternion is eigenvector of K that corresponds to largest eigenvalue.
    w, V = np.linalg.eigh(K)
    q = V[[3, 0, 1, 2], np.argmax(w)]

    if q[0] < 0.0:
        np.negative(q, q)
    return q

def evaluate_R_t(R_gt, t_gt, R, t, eps=1e-15):
    t = t.flatten()
    t_gt = t_gt.flatten()

    q_gt = quaternion_from_matrix(R_gt)
    q = quaternion_from_matrix(R)
    q = q / (np.linalg.norm(q) + eps)
    q_gt = q_gt / (np.linalg.norm(q_gt) + eps)
    loss_q = np.maximum(eps, (1.0 - np.sum(q * q_gt)**2))
    err_q = np.arccos(1 - 2 * loss_q)

    GT_SCALE = np.linalg.norm(t_gt)
    t = GT_SCALE * (t / (np.linalg.norm(t) + eps))
    err_t = min(np.linalg.norm(t_gt - t), np.linalg.norm(t_gt + t))
    
    return np.degrees(err_q), err_t

def compute_dR_dT(R1, T1, R2, T2):
    '''Given absolute (R, T) pairs for two cameras, compute the relative pose difference, from the first.'''
    
    dR = np.dot(R2, R1.T)
    dT = T2 - np.dot(dR, T1)
    return dR, dT

def compute_mAA(err_q, err_t, ths_q, ths_t):
    '''Compute the mean average accuracy over a set of thresholds. Additionally returns the metric only over rotation and translation.'''

    acc, acc_q, acc_t = [], [], []
    for th_q, th_t in zip(ths_q, ths_t):
        cur_acc_q = (err_q <= th_q)
        cur_acc_t = (err_t <= th_t)
        cur_acc = cur_acc_q & cur_acc_t
        
        acc.append(cur_acc.astype(np.float32).mean())
        acc_q.append(cur_acc_q.astype(np.float32).mean())
        acc_t.append(cur_acc_t.astype(np.float32).mean())
    return np.array(acc), np.array(acc_q), np.array(acc_t)

def dict_from_csv(csv_path, has_header):
    old_format = csv_path.endswith('train_labels.csv')
    csv_dict = {}
    with open(csv_path, 'r') as f:
        for i, l in enumerate(f):
            if has_header and i == 0:
                continue
            if l:
                if old_format:
                    dataset, scene, image, R_str, T_str = l.strip().split(',')
                else:
                    image, dataset, scene, R_str, T_str = l.strip().split(',')
                R = np.fromstring(R_str.strip(), sep=';').reshape(3, 3)
                T = np.fromstring(T_str.strip(), sep=';')
                if dataset not in csv_dict:
                    csv_dict[dataset] = {}
                if scene not in csv_dict[dataset]:
                    csv_dict[dataset][scene] = {}
                csv_dict[dataset][scene][image] = Camera(rotmat=R, tvec=T)
    return csv_dict

def eval_submission(submission_csv_path, ground_truth_csv_path, rotation_thresholds_degrees_dict, translation_thresholds_meters_dict, verbose=False):
    '''Compute final metric given submission and ground truth files. Thresholds are specified per dataset.'''

    submission_dict = dict_from_csv(submission_csv_path, has_header=True)
    gt_dict = dict_from_csv(ground_truth_csv_path, has_header=True)

    # Check that all necessary keys exist in the submission file
    for dataset in gt_dict:
        assert dataset in submission_dict, f'Unknown dataset: {dataset}'
        for scene in gt_dict[dataset]:
            assert scene in submission_dict[dataset], f'Unknown scene: {dataset}->{scene}'
            for image in gt_dict[dataset][scene]:
                assert image in submission_dict[dataset][scene], f'Unknown image: {dataset}->{scene}->{image}'

    # Iterate over all the scenes
    if verbose:
        t = time()
        print('*** METRICS ***')

    metrics_per_dataset = []
    for dataset in gt_dict:
        metrics_per_scene = []
        for scene in gt_dict[dataset]:
            err_q_all = []
            err_t_all = []
            images = [camera for camera in gt_dict[dataset][scene]]
            # Process all pairs in a scene
            for i in range(len(images)):
                for j in range(i + 1, len(images)):
                    gt_i = gt_dict[dataset][scene][images[i]]
                    gt_j = gt_dict[dataset][scene][images[j]]
                    dR_gt, dT_gt = compute_dR_dT(gt_i.rotmat, gt_i.tvec, gt_j.rotmat, gt_j.tvec)

                    pred_i = submission_dict[dataset][scene][images[i]]
                    pred_j = submission_dict[dataset][scene][images[j]]
                    dR_pred, dT_pred = compute_dR_dT(pred_i.rotmat, pred_i.tvec, pred_j.rotmat, pred_j.tvec)

                    err_q, err_t = evaluate_R_t(dR_gt, dT_gt, dR_pred, dT_pred)
                    err_q_all.append(err_q)
                    err_t_all.append(err_t)

            mAA, mAA_q, mAA_t = compute_mAA(err_q=err_q_all,
                                            err_t=err_t_all,
                                            ths_q=rotation_thresholds_degrees_dict[(dataset, scene)],
                                            ths_t=translation_thresholds_meters_dict[(dataset, scene)])
            if verbose:
                print(f'{dataset} / {scene} ({len(images)} images, {len(err_q_all)} pairs) -> mAA={np.mean(mAA):.06f}, mAA_q={np.mean(mAA_q):.06f}, mAA_t={np.mean(mAA_t):.06f}')
            metrics_per_scene.append(np.mean(mAA))

        metrics_per_dataset.append(np.mean(metrics_per_scene))
        if verbose:
            print(f'{dataset} -> mAA={np.mean(metrics_per_scene):.06f}')
            print()

    if verbose:
        print(f'Final metric -> mAA={np.mean(metrics_per_dataset):.06f} (t: {time() - t} sec.)')
        print()

    return np.mean(metrics_per_dataset)

In [44]:
if opt.test_or_train == 'train':
    # Set rotation thresholds per scene.
    rotation_thresholds_degrees_dict = {
        **{('haiper', scene): np.linspace(1, 10, 10) for scene in ['bike', 'chairs', 'fountain']},
        **{('heritage', scene): np.linspace(1, 10, 10) for scene in ['cyprus', 'dioscuri']},
        **{('heritage', 'wall'): np.linspace(0.2, 10, 10)},
        **{('urban', 'kyiv-puppet-theater'): np.linspace(1, 10, 10)},
    }

    translation_thresholds_meters_dict = {
        **{('haiper', scene): np.geomspace(0.05, 0.5, 10) for scene in ['bike', 'chairs', 'fountain']},
        **{('heritage', scene): np.geomspace(0.1, 2, 10) for scene in ['cyprus', 'dioscuri']},
        **{('heritage', 'wall'): np.geomspace(0.05, 1, 10)},
        **{('urban', 'kyiv-puppet-theater'): np.geomspace(0.5, 5, 10)},
    }

    eval_submission(submission_csv_path='submission.csv',
                    ground_truth_csv_path=opt.input_filepath,
                    rotation_thresholds_degrees_dict=rotation_thresholds_degrees_dict,
                    translation_thresholds_meters_dict=translation_thresholds_meters_dict,
                    verbose=True)

*** METRICS ***
urban / kyiv-puppet-theater (26 images, 325 pairs) -> mAA=0.000000, mAA_q=0.010154, mAA_t=0.036615
urban -> mAA=0.000000

heritage / dioscuri (174 images, 15051 pairs) -> mAA=0.001787, mAA_q=0.003535, mAA_t=0.014006
heritage / cyprus (30 images, 435 pairs) -> mAA=0.000000, mAA_q=0.000460, mAA_t=0.003218
heritage / wall (43 images, 903 pairs) -> mAA=0.000111, mAA_q=0.682281, mAA_t=0.000111
heritage -> mAA=0.000633

haiper / bike (15 images, 105 pairs) -> mAA=0.000000, mAA_q=0.000000, mAA_t=0.000000
haiper / chairs (16 images, 120 pairs) -> mAA=0.000000, mAA_q=0.000000, mAA_t=0.000000
haiper / fountain (23 images, 253 pairs) -> mAA=0.000000, mAA_q=0.000000, mAA_t=0.000000
haiper -> mAA=0.000000

Final metric -> mAA=0.000211 (t: 1.8400566577911377 sec.)



In [45]:
@train_only
def visualize_all(data_dict, out_results, debugs):
    for dataset, scene_dict in data_dict.items():
        for scene in scene_dict.keys():
            if opt.scenes_to_debug is not None and scene not in opt.scenes_to_debug:
                continue
            visualize(data_dict, out_results, dataset, scene, debugs[dataset][scene])
            
def get_points(recon):
    pts = np.stack([p.xyz for p in recon.points3D.values()], axis=-1)
    return pts

def get_viewpoints(recon):
    viewpoints = []
    for img in recon.images.values():
        viewpoints.append((img.name, img.rotmat(), img.tvec))
    viewpoints.sort()
    imgs = [v[0] for v in viewpoints]
    rots = np.stack([v[1] for v in viewpoints], axis=-1)
    trans = np.stack([v[2] for v in viewpoints], axis=-1)
    return imgs, rots, trans
    

def plot_points(vis_objs, points, color='black', size=1):
    import plotly.graph_objects as go
    vis_objs.append(go.Scatter3d(
        x=points[0],
        y=points[1],
        z=points[2],
        mode='markers',
        marker=dict(
            color=color,
            size=size,
        )
    ))
    
def plot_cams(vis_objs, rots, trans, color='black'):
    import plotly.graph_objects as go
    axis_colors = ['red', 'green', 'blue']
    for axis, axis_color in enumerate(axis_colors):
        points = np.concatenate([trans, 
                                 trans + rots[:, axis, :] * 2, 
                                 np.full_like(trans, np.nan)])
        points = points.T.reshape(-1, 3).T
        vis_objs.append(go.Scatter3d(
            x=points[0],
            y=points[1],
            z=points[2],
            mode='lines',
            line=dict(
                color=axis_color,
                width=3,
            )
        ))
    plot_points(vis_objs, trans, color=color, size=2)
    
def align_recon(rec, gt_rec):
    common_fnames = []
    gt_positions = []
    for img in rec.images.values():
        for gt_img in gt_rec.images.values():
            if img.name == gt_img.name:
                common_fnames.append(img.name)
                gt_positions.append(gt_img.projection_center().reshape(3, 1))
                break
    transform = rec.align_robust(common_fnames, gt_positions, 5)
    print(align_recon, transform)
    return transform

def load_reconstructions(output_path):
    recon_fnames = sorted(glob(f'{output_path}/*'))
    reconstructions = []
    for recon_fname in recon_fnames:
        reconstructions.append(pycolmap.Reconstruction(recon_fname))
    reconstructions.sort(key=lambda m: len(m.images))
    return reconstructions
            
def visualize(data_dict, out_results, dataset, scene, dbg):
    print(dbg)

    src = f'/kaggle/input/image-matching-challenge-2023/train/{dataset}/{scene}'
    gt_rec = pycolmap.Reconstruction(f'{src}/sfm')
    fig = viz_3d.init_figure()
    viz_3d.plot_reconstruction(fig, gt_rec, points=False, color='rgba(0,100,0,0.1)', name="GT Reconstruction", cs=5)
    if dbg is not None:
        reconstructions = load_reconstructions(dbg.output_path)
        rec = reconstructions[-1]
        print(rec)
        align_recon(rec, gt_rec)
        viz_3d.plot_reconstruction(fig, rec, color='rgba(0,255,255,0.5)', name="Reconstruction", cs=5)

    fig.show()
    
    
visualize_all(data_dict, out_results, debugs)

KeyError: 'bike'

In [None]:
@train_only
def debug_all(debugs):
    for dataset in debugs.keys():
        for scene, dbg in debugs[dataset].items():
            print(f'{dataset}[{scene}]')
            dbg = debugs[dataset][scene]
            if dbg is not None:
                debug(dbg, dataset, scene)

def debug_feature_matches(db, img_dirname):
    count = 0
    ordered_index_pairs = []
    for idx1, idx2, _ in db.get_all_matches():
        try:
            verified_matches, F, match_config = db.get_two_view_geometry(idx1, idx2)
        except KeyError:
            continue
        ordered_index_pairs.append((len(verified_matches), idx1, idx2))
    ordered_index_pairs.sort(key=lambda x: x[0], reverse=True)
    print(ordered_index_pairs)

    for _, idx1, idx2 in ordered_index_pairs[::10]:
        kp1 = db.get_keypoints(idx1)
        kp2 = db.get_keypoints(idx2)
        _, img1, *_ = db.get_image(idx1)
        _, img2, *_ = db.get_image(idx2)
        img1 = cv2.imread(os.path.join(img_dirname, img1))
        img2 = cv2.imread(os.path.join(img_dirname, img2))
        print(idx1, idx2)
        
        pair_id = image_ids_to_pair_id(idx1, idx2)
        matches = db.get_matches(idx1, idx2)
        verified_matches, F, match_config = db.get_two_view_geometry(idx1, idx2)

        print('verified.shape', verified_matches.shape)
        draw_matches(img1, kp1, img2, kp2, matches, mask=None)
        draw_matches(img1, kp1, img2, kp2, verified_matches, mask=None)

def show_imgs(img_fnames, cols=3):
    for i in range(0, len(img_fnames), cols):
        plt.figure(figsize=(20, 4))
        plt.tight_layout()
        for j in range(cols):
            if i + j < len(img_fnames):
                plt.subplot(1, cols, j + 1)
                plt.imshow(Image.open(img_fnames[i + j]))
        plt.show()
        plt.close()

def debug_unaligned(dbg):
    recon = load_reconstructions(dbg.output_path)[-1]
    image_fnames = dbg.img_fnames
    
    aligned = set(img.name for img_id, img in recon.images.items())
    unaligned = [i for i in image_fnames if os.path.basename(i) not in aligned]
    show_imgs(unaligned)
    
def count_overlap(recon1, recon2):
    return recon1.find_common_reg_image_ids(recon2)
    #imgs1 = set(img.name for img_id, img in recon1.images.items())
    #imgs2 = set(img.name for img_id, img in recon2.images.items())
    #return len(imgs1 & imgs2)

def debug_recon_overlap(dbg):
    reconstructions = load_reconstructions(dbg.output_path)
    img_groups = []
    for recon in reconstructions:
        img_groups.append(set(img.name for img_id, img in recon.images.items()))
        
    for g1 in range(len(img_groups)):
        for g2 in range(len(img_groups)):
            if g1 == g2:
                continue
            print(g1, g2, img_groups[g1] & img_groups[g2])
        
def test_merge(dbg):
    reconstructions = load_reconstructions(dbg.output_path)
    while reconstructions:
        recon1 = reconstructions.pop()
        if not reconstructions:
            break
        reconstructions.sort(key=lambda r: count_overlap(recon1, r), reverse=True)
        for idx2, recon2 in enumerate(reconstructions):
            if recon1.merge(recon2, 1000.0):
                break
            else:
                print('fail!')
        else:
            break
        reconstructions.pop(idx2)
    print(len(recon1.images))
    
    

def debug(dbg, dataset, scene):
    db = COLMAPDatabase(dbg.database_path)
    #test_merge(dbg)
    debug_feature_matches(db, dbg.img_dirname)
    #debug_unaligned(dbg)
    #debug_recon_overlap(dbg)

    
debug_all(debugs)

In [None]:
!cat 'submission.csv'
