https://www.kaggle.com/code/eduardtrulls/imc-2023-submission-example

In [1]:
!pip install -q /kaggle/input/einops-whl/einops-0.8.1-py3-none-any.whl --no-index --find-links /kaggle/input/einops-v0-8-0
!pip install /kaggle/input/imc2023-vggt-whl/* --no-deps --no-index --find-links /kaggle/input/imc2023-vggt-whl
!pip install /kaggle/input/roma-whl/roma-1.5.2.1-py3-none-any.whl --no-index --find-links /kaggle/input/roma-1.5.2.1
# !python -m pip install --no-index --find-links=/kaggle/input/pkg-check-orientation/ check_orientation==0.0.5 > /dev/null

Looking in links: /kaggle/input/imc2023-vggt-whl
Processing /kaggle/input/imc2023-vggt-whl/hydra_core-1.3.2-py3-none-any.whl
Processing /kaggle/input/imc2023-vggt-whl/lightglue-0.0-py3-none-any.whl
Processing /kaggle/input/imc2023-vggt-whl/pyceres-2.3-cp311-cp311-manylinux_2_28_x86_64.whl
Processing /kaggle/input/imc2023-vggt-whl/pycolmap-3.10.0-cp311-cp311-manylinux_2_28_x86_64.whl
Processing /kaggle/input/imc2023-vggt-whl/trimesh-4.6.10-py3-none-any.whl
Installing collected packages: hydra-core, trimesh, pycolmap, pyceres, lightglue
Successfully installed hydra-core-1.3.2 lightglue-0.0 pyceres-2.3 pycolmap-3.10.0 trimesh-4.6.10
Looking in links: /kaggle/input/roma-1.5.2.1
Processing /kaggle/input/roma-whl/roma-1.5.2.1-py3-none-any.whl
Installing collected packages: roma
Successfully installed roma-1.5.2.1


In [2]:
# General utilities
import os
from tqdm import tqdm
from time import time
from fastprogress import progress_bar
import gc
import math
import numpy as np
from IPython.display import clear_output
from collections import defaultdict
from copy import deepcopy
import matplotlib.pyplot as plt
import concurrent.futures

# CV/ML
import cv2
import torch
import torch.nn.functional as F
import kornia as K
import kornia.feature as KF
from PIL import Image
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from transformers import AutoImageProcessor, AutoModel

# 3D reconstruction
import pycolmap

print("Kornia version", K.__version__)
print("Pycolmap version", pycolmap.__version__)

2025-06-02 08:23:22.184253: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748852602.366879      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748852602.420607      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Kornia version 0.8.1
Pycolmap version 3.10.0


In [3]:
import sys
from pathlib import Path
import torch

# 這一層目錄裡面應該直接包含  mast3r/  資料夾
MAST3R_ROOT = Path("/kaggle/input/mast3r-code/mast3r")

# 如果還沒加過，就加進搜尋路徑
if str(MAST3R_ROOT) not in sys.path:
    sys.path.append(str(MAST3R_ROOT))

# 現在就可以正常 import 了

from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
from mast3r.model import AsymmetricMASt3R
from mast3r.image_pairs import make_pairs
from dust3r.utils.image import load_images
from dust3r.utils.device import to_numpy
from mast3r.model import load_model
from mast3r.fast_nn import extract_correspondences_nonsym, bruteforce_reciprocal_nns
from dust3r.inference import inference  # Import here to avoid circular imports
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid, geotrf  # noqa

  @torch.cuda.amp.autocast(enabled=False)




# Global Configs

In [4]:
# Mode can only be train or test. This will be used to find the image directory.
# Use "test" for submission 
MODE = "train"
MODE = "test"

# Option to change path for local testing
is_local = True
is_local = False

if is_local:
    NUM_CORES = 2
    SRC = "./kaggle/input/image-matching-challenge-2023"
    MODEL_DIR = "./kaggle/input/kornia-local-feature-weights/"
    DISK_PATH = "./loftr_disk.ckpt"
    HARDNET_PT = "./kaggle/input/kornia-local-feature-weights/hardnet8v2.pt"
else:
    NUM_CORES = 2
    SRC = "/kaggle/input/image-matching-challenge-2023"
    MODEL_DIR = "/kaggle/input/kornia-local-feature-weights/"
    DISK_PATH = "/kaggle/input/disk/pytorch/depth-supervision/1/loftr_outdoor.ckpt"
    HARDNET_PT = "/kaggle/input/hardnet8v2/hardnet8v2.pt"
    MAST3R_PATH = "/kaggle/input/mast3r/pytorch/default/1"

LOG_MESSAGE = "Final submission"
MATCHES_CAP = None

DEBUG = True
DEBUG = False

# DEBUG_SCENE = ["cyprus", "kyiv-puppet-theater"]
# DEBUG_SCENE = ["cyprus"]
DEBUG_SCENE = ["kyiv-puppet-theater"]
# DEBUG_SCENE = ["kyiv-puppet-theater", "cyprus", "wall", "chairs"]
# DEBUG_SCENE = ["bike"]
# DEBUG_SCENE = ["wall"]

# Longer edge limit of the input image
hardnet_res = 1600

MODEL_DICT = {
    "Keynet": {"enable": True, "resize_long_edge_to": hardnet_res, "pair_only": False},
    "GFTT": {"enable": True, "resize_long_edge_to": hardnet_res},
    "DoG": {"enable": True, "resize_long_edge_to": hardnet_res},
    "Harris": {"enable": True, "resize_long_edge_to": hardnet_res},
    "MASt3R": {"enable": True, "resize_long_edge_to": 512}  # MASt3R uses 512x512 input
}

# Find fundamental matrix parameters
FM_PARAMS = {"ransacReprojThreshold": 5, "confidence": 0.9999, "maxIters": 50000, "removeOutliers": True}

# Remove a "match" if the number of matches is lower than MATCH_FILTER_RATIO*max_num_matches
# e.g. img1 and img2 have max 10000 matches with some other images, img2 and img1 only have 99 matches. The matches btw img1 and img2 won't be selected.
MATCH_FILTER_RATIO = 0.01

# for logging
LOG_DICT = dict()
LOG_DICT["mode"] = MODE
LOG_DICT["log_message"] = LOG_MESSAGE
LOG_DICT["matches_cap"] = MATCHES_CAP
LOG_DICT["debug"] = DEBUG
LOG_DICT["debug_scene"] = DEBUG_SCENE

if MODE == "test":
    DEBUG = False
device = torch.device("cuda")
print(torch.cuda.is_available())


True


# Get datadict from submission file

In [5]:
# Get datadict from csv.
if MODE == "train":
    sample_path = f"{SRC}/train/train_labels.csv"
else:
    sample_path = f"{SRC}/sample_submission.csv"

data_dict = {}
with open(sample_path, "r") as f:
    for i, l in enumerate(f):
        # Skip header.
        if l and i > 0:
            if MODE == "train":
                dataset, scene, image, _, _ = l.strip().split(",")
            else:
                image, dataset, scene, _, _ = l.strip().split(",")
            if dataset not in data_dict:
                data_dict[dataset] = {}
            if scene not in data_dict[dataset]:
                data_dict[dataset][scene] = []
            data_dict[dataset][scene].append(image)
            
all_scenes = []
scene_len = []
for dataset in data_dict:
    for scene in data_dict[dataset]:
        print(f"{dataset} / {scene} -> {len(data_dict[dataset][scene])} images")
        if DEBUG and (scene not in DEBUG_SCENE):
            continue
        all_scenes.append((dataset, scene))
        scene_len.append(len(data_dict[dataset][scene]))

# sort all scenes by length, lowest first
all_scenes = [x for _, x in sorted(zip(scene_len, all_scenes), reverse=True)]

# Print reconst order
print("Reconstruction order: ")
for scene in all_scenes:
    print(f" --{scene[0]} / {scene[1]}")

2cfa01ab573141e4 / 2fa124afd1f74f38 -> 3 images
Reconstruction order: 
 --2cfa01ab573141e4 / 2fa124afd1f74f38


# Submission Utils

In [6]:
def arr_to_str(a):
    return ";".join([str(x) for x in a.reshape(-1)])


# Function to create a submission file.
def create_submission(out_results, data_dict, mode="test"):
    if mode == "train":
        file_name = "submission_train.csv"
    else:
        file_name = "submission.csv"

    with open(file_name, "w") as f:
        f.write("image_path,dataset,scene,rotation_matrix,translation_vector\n")
        for dataset in data_dict:
            if dataset in out_results:
                res = out_results[dataset]
            else:
                res = {}
            for scene in data_dict[dataset]:
                if scene in res:
                    scene_res = res[scene]
                else:
                    scene_res = {"R": {}, "t": {}}
                for image in data_dict[dataset][scene]:
                    if image in scene_res:
                        print(image)
                        R = scene_res[image]["R"].reshape(-1)
                        T = scene_res[image]["t"].reshape(-1)
                    else:
                        R = np.eye(3).reshape(-1)
                        T = np.zeros((3))
                    f.write(
                        f"{image},{dataset},{scene},{arr_to_str(R)},{arr_to_str(T)}\n"
                    )

# Image Loading and Resize

In [7]:
def load_torch_image(fname, device=torch.device("cpu")):
    img = K.image_to_tensor(cv2.imread(fname), False).float() / 255.0
    img = K.color.bgr_to_rgb(img.to(device))
    return img


def resize_torch_image(
    timg, resize_long_edge_to=None, align=None, disable_enlarge=True
):
    h, w = timg.shape[2:]
    raw_size = torch.tensor(timg.shape[2:])
    if resize_long_edge_to is None:
        scale = 1
    else:
        scale = float(resize_long_edge_to) / float(max(raw_size[0], raw_size[1]))

    if disable_enlarge:
        scale = min(scale, 1)

    h_resized = int(h * scale)
    w_resized = int(w * scale)

    if align is not None:
        assert align > 0
        h_resized = h_resized - h_resized % align
        w_resized = w_resized - w_resized % align
    scale_h = h_resized / h
    scale_w = w_resized / w

    timg_resized = K.geometry.resize(timg, (h_resized, w_resized), antialias = True)
    return timg_resized, scale_h, scale_w


def get_roi_image(timg, roi):
    min_h = int(roi["roi_min_h"])
    min_w = int(roi["roi_min_w"])
    max_h = int(roi["roi_max_h"])
    max_w = int(roi["roi_max_w"])
    roi_img = timg[:, :, min_h:max_h, min_w:max_w]
    roi_w_scale = (max_w - min_w) / timg.shape[3]
    roi_h_scale = (max_h - min_h) / timg.shape[2]
    return roi_img, min_h, min_w

## Rotation detection

In [8]:
# from torchvision.io import read_image as T_read_image
# from torchvision.io import ImageReadMode
# from torchvision import transforms as T
# from check_orientation.pre_trained_models import create_model

# def convert_rot_k(index):
#     if index == 0:
#         return 0
#     elif index == 1:
#         return 3
#     elif index == 2:
#         return 2
#     else:
#         return 1

# class CheckRotationDataset(Dataset):
#     def __init__(self, files, transform=None):
#         self.transform = transform
#         self.files = files

#     def __len__(self):
#         return len(self.files)

#     def __getitem__(self, idx):
#         imgPath = self.files[idx]
#         image = T_read_image(imgPath, mode=ImageReadMode.RGB)
#         if self.transform:
#             image = self.transform(image)
#         return image

# def get_CheckRotation_dataloader(images, batch_size=1):
#     transform = T.Compose([
#         T.Resize((224, 224)),
#         T.ConvertImageDtype(torch.float),
#         T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
#     ])

#     dataset = CheckRotationDataset(images, transform=transform)
#     dataloader = DataLoader(
#         dataset=dataset,
#         shuffle=False,
#         batch_size=batch_size,
#         pin_memory=True,
#         num_workers=2,
#         drop_last=False
#     )
#     return dataloader

# def exec_rotation_detection(img_files, device):
#     model = create_model("swsl_resnext50_32x4d")
#     model.eval().to(device);
    
#     dataloader = get_CheckRotation_dataloader(img_files)
    
#     rots = []
#     for idx, image in enumerate(dataloader):
#         image = image.to(torch.float32).to(device)
#         with torch.no_grad():
#             prediction = model(image).detach().cpu().numpy()
#             detected_rot = prediction[0].argmax()
#             rot_k = convert_rot_k(detected_rot)
#             rots.append(rot_k)
#             print(f"{os.path.basename(img_files[idx])} > rot_k={rot_k}")
#     return rots

# Visualization Utils

In [9]:
# Visualzation block
def draw_keypoints(img, keypoints, color=(0, 255, 0)):
    max_edge = max(img.shape[0], img.shape[1])
    good_radius = 4
    for kp in keypoints:
        x, y = kp
        cv2.circle(img, (int(x), int(y)), color=color, radius=good_radius, thickness=-1)


def draw_roi(img, roi, color=(0, 255, 255)):
    x1, y1, x2, y2 = (
        roi["roi_min_w"],
        roi["roi_min_h"],
        roi["roi_max_w"],
        roi["roi_max_h"],
    )
    cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color=color, thickness=2)


def plot_images_with_keypoints(fname1, fname2, kpts1, kpts2, matches, rois=None):
    print(fname1, fname2)
    # Draw keypoints on the images
    image1 = cv2.imread(fname1)
    image2 = cv2.imread(fname2)
    print(image1.shape, image2.shape)

    # draw_keypoints(image1, kpts1)
    # draw_keypoints(image2, kpts2)
    if rois is not None:
        draw_roi(image1, rois[0])
        draw_roi(image2, rois[1])
    print(image1.shape, image2.shape)
    print("Number of matches:", len(matches))
    print("Number of keypoints:", len(kpts1), len(kpts2))
    #print the first match
    print(matches[0])
    # Resize image1 and image2 to have the same smaller height
    display_h = 840
    h1, w1 = image1.shape[:2]
    h2, w2 = image2.shape[:2]
    # new_h = min(h1, h2)
    scale1 = display_h / h1
    scale2 = display_h / h2
    new_w1 = int(w1 * scale1)
    new_w2 = int(w2 * scale2)

    image1 = cv2.resize(image1, (new_w1, display_h))
    image2 = cv2.resize(image2, (new_w2, display_h))

    # Create a new image by horizontally concatenating the two images
    concatenated_img = cv2.hconcat([image1, image2])

    # Draw lines between the matching keypoints
    for match in matches:
        img1_idx = match[0]
        img2_idx = match[1]
        (x1, y1) = kpts1[img1_idx] * scale1
        (x2, y2) = kpts2[img2_idx] * scale2
        pt1 = (int(x1), int(y1))
        pt2 = (int(x2) + image1.shape[1], int(y2))
        cv2.line(concatenated_img, pt1, pt2, (0, 0, 255), 2)

    # Plot the concatenated image
    plt.figure(figsize=(20, 12))
    plt.imshow(cv2.cvtColor(concatenated_img, cv2.COLOR_BGR2RGB))
    plt.axis("off")
    plt.show()

# Colmap database

In [10]:
# Code to manipulate a colmap database.
# Forked from https://github.com/colmap/colmap/blob/dev/scripts/python/database.py

# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#     * Redistributions of source code must retain the above copyright
#       notice, this list of conditions and the following disclaimer.
#
#     * Redistributions in binary form must reproduce the above copyright
#       notice, this list of conditions and the following disclaimer in the
#       documentation and/or other materials provided with the distribution.
#
#     * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
#       its contributors may be used to endorse or promote products derived
#       from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)

# This script is based on an original implementation by True Price.

import sys
import sqlite3
import numpy as np


IS_PYTHON3 = sys.version_info[0] >= 3

MAX_IMAGE_ID = 2**31 - 1

CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
    camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
    model INTEGER NOT NULL,
    width INTEGER NOT NULL,
    height INTEGER NOT NULL,
    params BLOB,
    prior_focal_length INTEGER NOT NULL)"""

CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
    image_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""

CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
    image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
    name TEXT NOT NULL UNIQUE,
    camera_id INTEGER NOT NULL,
    prior_qw REAL,
    prior_qx REAL,
    prior_qy REAL,
    prior_qz REAL,
    prior_tx REAL,
    prior_ty REAL,
    prior_tz REAL,
    CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
    FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
""".format(
    MAX_IMAGE_ID
)

CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
CREATE TABLE IF NOT EXISTS two_view_geometries (
    pair_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    config INTEGER NOT NULL,
    F BLOB,
    E BLOB,
    H BLOB)
"""

CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
    image_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
"""

CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
    pair_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB)"""

CREATE_NAME_INDEX = "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"

CREATE_ALL = "; ".join(
    [
        CREATE_CAMERAS_TABLE,
        CREATE_IMAGES_TABLE,
        CREATE_KEYPOINTS_TABLE,
        CREATE_DESCRIPTORS_TABLE,
        CREATE_MATCHES_TABLE,
        CREATE_TWO_VIEW_GEOMETRIES_TABLE,
        CREATE_NAME_INDEX,
    ]
)


def image_ids_to_pair_id(image_id1, image_id2):
    if image_id1 > image_id2:
        image_id1, image_id2 = image_id2, image_id1
    return image_id1 * MAX_IMAGE_ID + image_id2


def pair_id_to_image_ids(pair_id):
    image_id2 = pair_id % MAX_IMAGE_ID
    image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
    return image_id1, image_id2


def array_to_blob(array):
    if IS_PYTHON3:
        return array.tostring()
    else:
        return np.getbuffer(array)


def blob_to_array(blob, dtype, shape=(-1,)):
    if IS_PYTHON3:
        return np.fromstring(blob, dtype=dtype).reshape(*shape)
    else:
        return np.frombuffer(blob, dtype=dtype).reshape(*shape)


class COLMAPDatabase(sqlite3.Connection):
    @staticmethod
    def connect(database_path):
        return sqlite3.connect(database_path, factory=COLMAPDatabase)

    def __init__(self, *args, **kwargs):
        super(COLMAPDatabase, self).__init__(*args, **kwargs)

        self.create_tables = lambda: self.executescript(CREATE_ALL)
        self.create_cameras_table = lambda: self.executescript(CREATE_CAMERAS_TABLE)
        self.create_descriptors_table = lambda: self.executescript(
            CREATE_DESCRIPTORS_TABLE
        )
        self.create_images_table = lambda: self.executescript(CREATE_IMAGES_TABLE)
        self.create_two_view_geometries_table = lambda: self.executescript(
            CREATE_TWO_VIEW_GEOMETRIES_TABLE
        )
        self.create_keypoints_table = lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
        self.create_matches_table = lambda: self.executescript(CREATE_MATCHES_TABLE)
        self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)

    def add_camera(
        self, model, width, height, params, prior_focal_length=False, camera_id=None
    ):
        params = np.asarray(params, np.float64)
        cursor = self.execute(
            "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
            (
                camera_id,
                model,
                width,
                height,
                array_to_blob(params),
                prior_focal_length,
            ),
        )
        return cursor.lastrowid

    def add_image(
        self, name, camera_id, prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None
    ):
        cursor = self.execute(
            "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
            (
                image_id,
                name,
                camera_id,
                prior_q[0],
                prior_q[1],
                prior_q[2],
                prior_q[3],
                prior_t[0],
                prior_t[1],
                prior_t[2],
            ),
        )
        return cursor.lastrowid

    def add_keypoints(self, image_id, keypoints):
        assert len(keypoints.shape) == 2
        assert keypoints.shape[1] in [2, 4, 6]

        keypoints = np.asarray(keypoints, np.float32)
        self.execute(
            "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
            (image_id,) + keypoints.shape + (array_to_blob(keypoints),),
        )

    def add_descriptors(self, image_id, descriptors):
        descriptors = np.ascontiguousarray(descriptors, np.uint8)
        self.execute(
            "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
            (image_id,) + descriptors.shape + (array_to_blob(descriptors),),
        )

    def add_matches(self, image_id1, image_id2, matches):
        assert len(matches.shape) == 2
        assert matches.shape[1] == 2

        if image_id1 > image_id2:
            matches = matches[:, ::-1]

        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        matches = np.asarray(matches, np.uint32)
        self.execute(
            "INSERT INTO matches VALUES (?, ?, ?, ?)",
            (pair_id,) + matches.shape + (array_to_blob(matches),),
        )

    def add_two_view_geometry(
        self,
        image_id1,
        image_id2,
        matches,
        F=np.eye(3),
        E=np.eye(3),
        H=np.eye(3),
        config=2,
    ):
        assert len(matches.shape) == 2
        assert matches.shape[1] == 2

        if image_id1 > image_id2:
            matches = matches[:, ::-1]

        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        matches = np.asarray(matches, np.uint32)
        F = np.asarray(F, dtype=np.float64)
        E = np.asarray(E, dtype=np.float64)
        H = np.asarray(H, dtype=np.float64)
        self.execute(
            "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
            (pair_id,)
            + matches.shape
            + (
                array_to_blob(matches),
                config,
                array_to_blob(F),
                array_to_blob(E),
                array_to_blob(H),
            ),
        )

# DB operation

In [11]:
# Modified from https://github.com/cvlab-epfl/disk/blob/37f1f7e971cea3055bb5ccfc4cf28bfd643fa339/colmap/h5_to_db.py

#  Copyright [2020] [Michał Tyszkiewicz, Pascal Fua, Eduard Trulls]
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

import os, argparse, h5py, warnings
import numpy as np
from tqdm import tqdm
from PIL import Image, ExifTags


def get_focal(image_path, err_on_default=False):
    image = Image.open(image_path)
    max_size = max(image.size)

    exif = image.getexif()
    
    #
    # Modified to add exif_ifd to exif dict
    #
    exif_ifd = exif.get_ifd(0x8769)
    exif.update(exif_ifd)

    focal = None
    is_from_exif = False
    if exif is not None:
        focal_35mm = None
        # https://github.com/colmap/colmap/blob/d3a29e203ab69e91eda938d6e56e1c7339d62a99/src/util/bitmap.cc#L299
        for tag, value in exif.items():
            focal_35mm = None
            if ExifTags.TAGS.get(tag, None) == "FocalLengthIn35mmFilm":
                focal_35mm = float(value)
                is_from_exif = True
                break

        if focal_35mm is not None:
            focal = focal_35mm / 35.0 * max_size

    if focal is None:
        if err_on_default:
            raise RuntimeError("Failed to find focal length")

        # failed to find it in exif, use prior
        FOCAL_PRIOR = 1.2
        focal = FOCAL_PRIOR * max_size
    
    #
    # Modified to return a bool indicating if the focal length is from exif 
    #
    return focal, is_from_exif


def create_camera(db, image_path, camera_model):
    image = Image.open(image_path)
    width, height = image.size

    focal, is_from_exif = get_focal(image_path)

    if camera_model == "simple-pinhole":
        model = 0  # simple pinhole
        param_arr = np.array([focal, width / 2, height / 2])
    if camera_model == "pinhole":
        model = 1  # pinhole
        param_arr = np.array([focal, focal, width / 2, height / 2])
    elif camera_model == "simple-radial":
        model = 2  # simple radial
        param_arr = np.array([focal, width / 2, height / 2, 0.1])
    elif camera_model == "opencv":
        model = 4  # opencv
        param_arr = np.array([focal, focal, width / 2, height / 2, 0.0, 0.0, 0.0, 0.0])

    #
    # Modified to set prior_focal_length if the focal length is from exif
    #
    return db.add_camera(
        model, width, height, param_arr, prior_focal_length=is_from_exif
    )


def add_kpts_matches(db, img_dir, kpts, matches, fms = None):
    fname_to_id = {}

    # Add keypoints
    for filename in tqdm(kpts):
        path = os.path.join(img_dir, filename)
        camera_model = "simple-radial"
        camera_id = create_camera(db, path, camera_model)
        image_id = db.add_image(filename, camera_id)
        fname_to_id[filename] = image_id
        db.add_keypoints(image_id, kpts[filename])

    n_keys = len(matches)
    n_total = (n_keys * (n_keys - 1)) // 2
    # Add matches
    added = set()
    with tqdm(total=n_total) as pbar:
        for key1 in matches:
            for key2 in matches[key1]:
                id_1 = fname_to_id[key1]
                id_2 = fname_to_id[key2]
                pair_id = image_ids_to_pair_id(id_1, id_2)
                if pair_id in added:
                    warnings.warn(f"Pair {pair_id} ({id_1}, {id_2}) already added!")
                    continue
                
                # Remove duplicate matches
                matches_array = matches[key1][key2]
                matches_tuples = set(map(tuple, matches_array))
                unique_matches = np.array(list(matches_tuples))
                
                # Sort matches by first column to ensure consistent ordering
                unique_matches = unique_matches[np.argsort(unique_matches[:, 0])]
                
                # Remove matches where the same point in image1 matches to multiple points in image2
                _, unique_indices = np.unique(unique_matches[:, 0], return_index=True)
                unique_matches = unique_matches[unique_indices]
                
                # Remove matches where the same point in image2 matches to multiple points in image1
                _, unique_indices = np.unique(unique_matches[:, 1], return_index=True)
                unique_matches = unique_matches[unique_indices]
                
                if len(unique_matches) >= 15:  # Keep only if we still have enough matches
                    db.add_matches(id_1, id_2, unique_matches)
                    added.add(pair_id)
                    if fms is not None:
                        db.add_two_view_geometry(id_1, id_2, unique_matches, fms[key1][key2], np.eye(3), np.eye(3))
                pbar.update(1)

    db.commit()


In [12]:
def get_unique_idxs(A, dim=0):
    # https://stackoverflow.com/questions/72001505/how-to-get-unique-elements-and-their-firstly-appeared-indices-of-a-pytorch-tenso
    unique, idx, counts = torch.unique(
        A, dim=dim, sorted=True, return_inverse=True, return_counts=True
    )
    _, ind_sorted = torch.sort(idx, stable=True)
    cum_sum = counts.cumsum(0)
    cum_sum = torch.cat((torch.tensor([0], device=cum_sum.device), cum_sum[:-1]))
    first_indices = ind_sorted[cum_sum]
    return first_indices

# AffNetHardNet Models

In [13]:
# Making kornia local features loading w/o internet
class AffNetHardNet(KF.LocalFeature):
    """Convenience module, which implements KeyNet detector + AffNet + HardNet descriptor.

    .. image:: _static/img/keynet_affnet.jpg
    """

    def __init__(
        self,
        num_features: int = 5000,
        upright: bool = False,
        device=torch.device("cpu"),
        scale_laf: float = 1.0,
        detector = "keynet"
    ):
        detector_options = ["keynet", "GFTT", "Hessian", "Harris", "DoG"]
        if detector not in detector_options:
            raise ValueError("Detector must be one of {}".format(detector_options))
        
        ori_module = (
            KF.PassLAF()
            if upright
            else KF.LAFOrienter(angle_detector=KF.OriNet(False)).eval()
        )
        if not upright:
            weights = torch.load(os.path.join(MODEL_DIR, "OriNet.pth"))["state_dict"]
            ori_module.angle_detector.load_state_dict(weights)

        config = {
            # Extraction Parameters
            "nms_size": 15,
            "pyramid_levels": 4,
            "up_levels": 1,
            "scale_factor_levels": math.sqrt(2),
            "s_mult": 22.0,
        }

        if detector == "keynet":
            detector = KF.KeyNetDetector(
            False,
            num_features=num_features,
            ori_module=ori_module,
            aff_module=KF.LAFAffNetShapeEstimator(False).eval(),
            ).to(device)
            kn_weights = torch.load(os.path.join(MODEL_DIR, "keynet_pytorch.pth"))[
            "state_dict"
            ]
            detector.model.load_state_dict(kn_weights)
        elif detector == "GFTT":
            detector = KF.MultiResolutionDetector(
                KF.CornerGFTT(),
                num_features=num_features,
                config=config,
                ori_module=ori_module,
                aff_module=KF.LAFAffNetShapeEstimator(False).eval(),
            ).to(device)
        elif detector == "Harris":
            detector = KF.MultiResolutionDetector(
                KF.CornerHarris(0.04),
                num_features=num_features,
                config=config,
                ori_module=ori_module,
                aff_module=KF.LAFAffNetShapeEstimator(False).eval(),
            ).to(device)
        elif detector == "DoG":
            detector = KF.MultiResolutionDetector(
                KF.BlobDoGSingle(),
                num_features=num_features,
                config=config,
                ori_module=ori_module,
                aff_module=KF.LAFAffNetShapeEstimator(False).eval(),
            ).to(device)
        affnet_weights = torch.load(os.path.join(MODEL_DIR, "AffNet.pth"))["state_dict"]
        detector.aff.load_state_dict(affnet_weights)

        # hardnet = KF.HardNet(False).eval()
        # hn_weights = torch.load(os.path.join(MODEL_DIR, "HardNetLib.pth"))["state_dict"]
        # hardnet.load_state_dict(hn_weights)
        # descriptor = KF.LAFDescriptor(
        #     hardnet, patch_size=32, grayscale_descriptor=True
        # ).to(device)
        hardnet8 = KF.HardNet8(False).eval()
        hn8_weights = torch.load(HARDNET_PT)
        hardnet8.load_state_dict(hn8_weights)
        descriptor = KF.LAFDescriptor(
            hardnet8, patch_size=32, grayscale_descriptor=True
        ).to(device)
        super().__init__(detector, descriptor, scale_laf)


In [14]:
def get_unique_matches(f_match_kpts):
    kpts = defaultdict(list)
    match_indexes = defaultdict(dict)
    total_kpts = defaultdict(int)
    for key1 in f_match_kpts:
        for key2 in f_match_kpts[key1]:
            matches = f_match_kpts[key1][key2]
            kpts[key1].append(matches[:, :2])
            kpts[key2].append(matches[:, 2:])
            current_match = torch.arange(len(matches)).reshape(-1, 1).repeat(1, 2)
            current_match[:, 0] += total_kpts[key1]
            current_match[:, 1] += total_kpts[key2]
            total_kpts[key1] += len(matches)
            total_kpts[key2] += len(matches)
            match_indexes[key1][key2] = current_match

    for key in kpts:
        kpts[key] = np.round(np.concatenate(kpts[key], axis=0))

    unique_kpts = {}
    unique_match_idxs = {}
    out_match = defaultdict(dict)

    for key in kpts.keys():
        uniq_kps, uniq_reverse_idxs = torch.unique(
            torch.from_numpy(kpts[key]), dim=0, return_inverse=True
        )
        unique_match_idxs[key] = uniq_reverse_idxs
        unique_kpts[key] = uniq_kps.numpy()

    for key1 in match_indexes:
        for key2 in match_indexes[key1]:
            m2 = deepcopy(match_indexes[key1][key2])
            m2[:, 0] = unique_match_idxs[key1][m2[:, 0]]
            m2[:, 1] = unique_match_idxs[key2][m2[:, 1]]
            mkpts = np.concatenate(
                [
                    unique_kpts[key1][m2[:, 0]],
                    unique_kpts[key2][m2[:, 1]],
                ],
                axis=1,
            )
            unique_idxs_current = get_unique_idxs(torch.from_numpy(mkpts), dim=0)
            m2_semiclean = m2[unique_idxs_current]
            unique_idxs_current1 = get_unique_idxs(m2_semiclean[:, 0], dim=0)
            m2_semiclean = m2_semiclean[unique_idxs_current1]
            unique_idxs_current2 = get_unique_idxs(m2_semiclean[:, 1], dim=0)
            m2_semiclean2 = m2_semiclean[unique_idxs_current2]
            out_match[key1][key2] = m2_semiclean2.numpy()
    return unique_kpts, out_match

# Scene feature detector

In [15]:
class AffNetHardNetDetector:
    def __init__(
        self,
        model,
        device=torch.device("cuda"),
        resize_long_edge_to=600,
        matcher="adalam",
        min_matches=15,
        rgb_input = False
    ):
        self.rgb_input = rgb_input
        print("Init AffNetHardNetDetector")
        self.model = model
        self.device = device
        self.resize_long_edge_to = resize_long_edge_to
        print("Longer edge will be resized to", self.resize_long_edge_to)

    def detect_features(self, img_fnames):
        f_lafs = dict()
        f_descs = dict()
        f_kpts = dict()
        f_raw_size = dict()
        f_matches = dict()
        # Get features
        print("Detecting AffNetHardNet features")
        for img_path in tqdm(img_fnames):
            img_fname = img_path.split("/")[-1]
            key = img_fname
            f_matches[key] = dict()
            with torch.inference_mode():
                timg = load_torch_image(img_path, device=device)
                raw_size = torch.tensor(timg.shape[2:])
                timg_resized, h_scale, w_scale = resize_torch_image(
                    timg, self.resize_long_edge_to, disable_enlarge=True
                )
                if self.rgb_input:
                    lafs, resps, descs = self.model(timg_resized)
                else:
                    lafs, resps, descs = self.model(K.color.rgb_to_grayscale(timg_resized))
                
                # Recover scale?
                lafs[:, :, 0, :] *= 1 / w_scale
                lafs[:, :, 1, :] *= 1 / h_scale
                desc_dim = descs.shape[-1]
                # Move keypoints to cpu for later colmap operations
                kpts = KF.get_laf_center(lafs).reshape(-1, 2).detach().cpu().numpy()
                descs = descs.reshape(-1, desc_dim).detach()
                f_lafs[key] = lafs.detach()
                f_kpts[key] = kpts
                f_descs[key] = descs
                f_raw_size[key] = raw_size
        gc.collect()
        torch.cuda.empty_cache()
        return f_lafs, f_kpts, f_descs, f_raw_size



# Scene LAF matcher

In [16]:
class LafMatcher:
    def __init__(self, min_matches=15, device="cuda", matcher="adalam", min_pairs=50, distances_threshold=0.3):
        self.adalam_config = KF.adalam.get_adalam_default_config()
        self.adalam_config["force_seed_mnn"] = True
        self.adalam_config["search_expansion"] = 16
        self.adalam_config["ransac_iters"] = 256
        self.adalam_config["device"] = device
        self.min_matches = min_matches
        self.matcher = matcher
        self.min_pairs = min_pairs
        self.distances_threshold = distances_threshold
        self.tolerance = 500
        
    def get_pairs(self, img_fnames):
        """Get image pairs using DINOv2 embeddings for similarity"""
        print("Getting image pairs using DINOv2...")
        # Load DINOv2 model
        processor = AutoImageProcessor.from_pretrained('/kaggle/input/dinov2/pytorch/base/1/', use_fast=True)
        model = AutoModel.from_pretrained('/kaggle/input/dinov2/pytorch/base/1/').eval().to(self.adalam_config["device"])
        embeddings = []
        
        # Get embeddings for all images
        for img_path in tqdm(img_fnames):
            image = K.io.load_image(img_path, K.io.ImageLoadType.RGB32, device=self.adalam_config["device"])[None, ...]
            with torch.inference_mode():
                inputs = processor(images=image, return_tensors="pt", do_rescale=False, do_resize=True, 
                                do_center_crop=False, size=224).to(self.adalam_config["device"])
                outputs = model(**inputs)
                embedding = F.normalize(outputs.last_hidden_state.max(dim=1)[0])
            embeddings.append(embedding)
            
        embeddings = torch.cat(embeddings, dim=0)
        distances = torch.cdist(embeddings, embeddings).cpu()
        distances_ = (distances <= self.distances_threshold).numpy()
        np.fill_diagonal(distances_, False)
        
        # Ensure minimum number of pairs per image
        z = distances_.sum(axis=1)
        idxs0 = np.where(z == 0)[0]
        for idx0 in idxs0:
            t = np.argsort(distances[idx0])[1:self.min_pairs]
            distances_[idx0, t] = True

        s = np.where(distances >= self.tolerance)
        distances_[s] = False
            
        # Convert to pairs format
        pairs = []
        for i in range(len(img_fnames)):
            for j in range(i + 1, len(img_fnames)):
                if distances_[i][j]:
                    pairs.append([i, j])
                    
        print(f"Found {len(pairs)} pairs")
        return pairs

    def match(self, img_fnames, f_lafs, f_kpts, f_descs, f_raw_size, get_roi=False):
        """Match features between image pairs"""
        index_pairs = self.get_pairs(img_fnames)
        
        f_matches = defaultdict(dict)
        f_rois = defaultdict(dict)
        print("Matching features for selected pairs")
        
        for idx1, idx2 in tqdm(index_pairs):
            fname1, fname2 = img_fnames[idx1], img_fnames[idx2]
            key1, key2 = fname1.split("/")[-1], fname2.split("/")[-1]
            lafs1 = f_lafs[key1]
            lafs2 = f_lafs[key2]
            desc1 = f_descs[key1]
            desc2 = f_descs[key2]
            
            if self.matcher == "adalam":
                hw1, hw2 = f_raw_size[key1], f_raw_size[key2]
                dists, idxs = KF.match_adalam(
                    desc1,
                    desc2,
                    lafs1,
                    lafs2,
                    hw1=hw1,
                    hw2=hw2,
                    config=self.adalam_config,
                )
            else:
                dists, idxs = KF.match_smnn(desc1, desc2, 0.98)

            if dists.mean().cpu().numpy() < 0.5:
                first_indices = get_unique_idxs(idxs[:, 1])
                idxs = idxs[first_indices]
                dists = dists[first_indices]
                n_matches = len(idxs)
                if n_matches >= self.min_matches:
                    f_matches[key1][key2] = idxs.detach().cpu().numpy().reshape(-1, 2)

                    # Compute ROI if requested
                    if get_roi:
                        mkpts1 = f_kpts[key1][idxs.cpu().numpy()[:, 0]]
                        mkpts2 = f_kpts[key2][idxs.cpu().numpy()[:, 1]]
                        roi_min_w_1, roi_max_w_1 = np.percentile(mkpts1[:, 0], [5, 95])
                        roi_min_h_1, roi_max_h_1 = np.percentile(mkpts1[:, 1], [5, 95])
                        roi_area_1 = (roi_max_w_1 - roi_min_w_1) * (roi_max_h_1 - roi_min_h_1)
                        roi1 = {
                            "roi_min_w": roi_min_w_1,
                            "roi_min_h": roi_min_h_1,
                            "roi_max_w": roi_max_w_1,
                            "roi_max_h": roi_max_h_1,
                            "area": roi_area_1,
                        }
                        roi_min_w_2, roi_max_w_2 = np.percentile(mkpts2[:, 0], [5, 95])
                        roi_min_h_2, roi_max_h_2 = np.percentile(mkpts2[:, 1], [5, 95])
                        roi_area_2 = (roi_max_w_2 - roi_min_w_2) * (roi_max_h_2 - roi_min_h_2)
                        roi2 = {
                            "roi_min_w": roi_min_w_2,
                            "roi_min_h": roi_min_h_2,
                            "roi_max_w": roi_max_w_2,
                            "roi_max_h": roi_max_h_2,
                            "area": roi_area_2,
                        }
                        f_rois[key1][key2] = [roi1, roi2]

        print(f"Successfully matched {len(f_matches)} pairs")
        torch.cuda.empty_cache()
        gc.collect()
        return index_pairs, f_kpts, f_matches, f_rois
    

# MASt3RDetector

In [17]:
class MASt3RDetector:
    def __init__(self, model, device=torch.device("cuda"), resize_long_edge_to=512, min_pairs=50, distances_threshold=0.3, min_matches=15):
        from dust3r.inference import inference
        from mast3r.cloud_opt.sparse_ga import extract_correspondences, symmetric_inference
        from mast3r.utils.misc import hash_md5
        from dust3r.utils.device import to_cpu
        import os
        
        self.inference = inference
        self.symmetric_inference = symmetric_inference
        self.extract_correspondences = extract_correspondences
        self.hash_md5 = hash_md5
        self.to_cpu = to_cpu
        self.model = model
        self.device = device
        self.resize_long_edge_to = resize_long_edge_to
        self.min_pairs = min_pairs
        self.distances_threshold = distances_threshold
        self.min_matches = min_matches
        self.tolerance = 500
        print("Init MASt3RDetector")
        print("Longer edge will be resized to", self.resize_long_edge_to)

    def remove_duplicate_matches(self, matches):
        """Remove duplicate matches by keeping only unique pairs."""
        unique_matches = {}
        for key1 in matches:
            unique_matches[key1] = {}
            for key2 in matches[key1]:
                # Convert matches to a set of tuples for uniqueness
                match_set = set(map(tuple, matches[key1][key2]))
                # Convert back to numpy array
                unique_matches[key1][key2] = np.array(list(match_set))
        return unique_matches

    def get_pairs(self, img_fnames):
        """Get image pairs using DINOv2 embeddings for similarity"""
        print("Getting image pairs using DINOv2...")
        # Load DINOv2 model
        processor = AutoImageProcessor.from_pretrained('/kaggle/input/dinov2/pytorch/base/1/')
        model = AutoModel.from_pretrained('/kaggle/input/dinov2/pytorch/base/1/').eval().to(self.device)
        embeddings = []
        
        # Get embeddings for all images
        for img_path in tqdm(img_fnames):
            image = K.io.load_image(img_path, K.io.ImageLoadType.RGB32, device=self.device)[None, ...]
            with torch.inference_mode():
                inputs = processor(images=image, return_tensors="pt", do_rescale=False, do_resize=True, 
                                do_center_crop=False, size=224).to(self.device)
                outputs = model(**inputs)
                embedding = F.normalize(outputs.last_hidden_state.max(dim=1)[0])
            embeddings.append(embedding)
            
        embeddings = torch.cat(embeddings, dim=0)
        distances = torch.cdist(embeddings, embeddings).cpu()
        distances_ = (distances <= self.distances_threshold).numpy()
        np.fill_diagonal(distances_, False)
        
        # Ensure minimum number of pairs per image
        z = distances_.sum(axis=1)
        idxs0 = np.where(z == 0)[0]
        for idx0 in idxs0:
            t = np.argsort(distances[idx0])[1:self.min_pairs]
            distances_[idx0, t] = True
        
        s = np.where(distances >= self.tolerance)
        distances_[s] = False
            
        # Convert to pairs format
        pairs = []
        for i in range(len(img_fnames)):
            for j in range(i + 1, len(img_fnames)):
                if distances_[i][j]:
                    pairs.append([i, j])
                    
        print(f"Found {len(pairs)} pairs using DINOv2")
        return pairs

    def convert_matches_to_colmap_format(self, img0, img1, matches_im0, matches_im1):
        """Convert 2D matches to COLMAP format using ravel indices"""
        matches = [matches_im0.astype(np.float64), matches_im1.astype(np.float64)]
        ravel_matches = []
        
        for j in range(2):
            H, W = [img0, img1][j]['true_shape'][0]
            qx, qy = matches[j].round().astype(np.int32).T
            # Convert 2D coordinates to ravel indices
            ravel_matches_j = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
            ravel_matches.append(ravel_matches_j)
            
        # Stack matches and ensure proper ordering
        colmap_matches = np.stack([ravel_matches[0], ravel_matches[1]], axis=-1)
        # Remove duplicates
        colmap_matches = np.unique(colmap_matches, axis=0)
        return colmap_matches

    @torch.no_grad()
    def detect_features(self, img_fnames):
        f_kpts = dict()
        f_descs = dict()
        f_matches = defaultdict(dict)
        
        # Convert image paths to MASt3R format
        imgs = []
        for i, fname in enumerate(img_fnames):
            # Load image to get shape
            img = load_torch_image(fname, device=self.device)
            # Get original size for scaling
            orig_h, orig_w = img.shape[2:]
            # Resize for MASt3R
            img_resized, h_scale, w_scale = resize_torch_image(img, self.resize_long_edge_to, align=16)
            
            # Create dense grid of keypoints
            h, w = img_resized.shape[2:]
            y, x = np.mgrid[0:h, 0:w]
            pts = np.stack([x, y], axis=-1).reshape(-1, 2)
            
            # Scale back to original image size
            pts = pts * np.array([1.0/w_scale, 1.0/h_scale])
            
            # Store keypoints
            key = fname.split("/")[-1]
            f_kpts[key] = pts
            
            imgs.append({
                'instance': fname, 
                'img': None,  # Don't store image yet
                'idx': i,
                'true_shape': np.int32([img_resized.shape[2:]]),
                'h_scale': float(h_scale),  # Store scale factors
                'w_scale': float(w_scale),
                'orig_h': orig_h,  # Store original dimensions
                'orig_w': orig_w
            })
            del img, img_resized
            
        # Get pairs using DINOv2
        pairs = self.get_pairs(img_fnames)
        print(f"Processing {len(pairs)} pairs...")

        # Process pairs
        for idx1, idx2 in tqdm(pairs):
            # Get file names
            key1 = img_fnames[idx1].split("/")[-1]
            key2 = img_fnames[idx2].split("/")[-1]
            
            # Get corresponding img objects
            img1 = next(img for img in imgs if img['idx'] == idx1)
            img2 = next(img for img in imgs if img['idx'] == idx2)

            # Load images for this pair
            img = load_torch_image(img1['instance'], device=self.device)
            img_resized, _, _ = resize_torch_image(img, self.resize_long_edge_to, align=16)
            img1['img'] = img_resized
            
            img = load_torch_image(img2['instance'], device=self.device)
            img_resized, _, _ = resize_torch_image(img, self.resize_long_edge_to, align=16)
            img2['img'] = img_resized

            # Run symmetric inference
            res = self.symmetric_inference(self.model, img1, img2, device=self.device)
            
            # Extract points and confidences
            descs = [r['desc'][0] for r in res]
            qonfs = [r['desc_conf'][0] for r in res]
            (xy1, xy2, confs) = self.extract_correspondences(descs, qonfs, device=self.device, subsample=8)
            
            # Filter by confidence
            mask = confs >= 2.5
            xy1 = xy1[mask].cpu().numpy()
            xy2 = xy2[mask].cpu().numpy()
            
            if len(xy1) >= self.min_matches:
                # Convert to indices in the dense grid
                h1, w1 = img1['true_shape'][0]
                h2, w2 = img2['true_shape'][0]
                idx1 = np.round(xy1[:, 0]).astype(int) + w1 * np.round(xy1[:, 1]).astype(int)
                idx2 = np.round(xy2[:, 0]).astype(int) + w2 * np.round(xy2[:, 1]).astype(int)
                
                # Store matches
                f_matches[key1][key2] = np.column_stack([idx1, idx2])

            # Clear GPU memory
            del img1['img']
            del img2['img']
            torch.cuda.empty_cache()

        # Remove duplicate matches before returning
        f_matches = self.remove_duplicate_matches(f_matches)

        # Final cleanup
        del self.model
        torch.cuda.empty_cache()
        
        return f_kpts, f_descs, f_matches


# Matches and pairs operation

In [18]:
def merge_kpts_matches(kpts, matches, new_kpts, new_matches, cap = None):
    # merge kpts
    prev_len = dict()
    for new_key in new_kpts:
        if new_key in kpts:
            old_len = len(kpts[new_key])
            kpts[new_key] = np.concatenate([kpts[new_key], new_kpts[new_key]], axis=0)
        else:
            old_len = 0
            kpts[new_key] = new_kpts[new_key]
        prev_len[new_key] = old_len

    for new_key1 in new_matches:
        for new_key2 in new_matches[new_key1]:
            old_len1 = prev_len[new_key1]
            old_len2 = prev_len[new_key2]
            new_match = new_matches[new_key1][new_key2] + [old_len1, old_len2]
            if cap is not None and len(new_match) > cap:
                keep = np.random.choice(len(new_match), cap, replace=False)
                new_match = new_match[keep, :]
            if new_key1 in matches and new_key2 in matches[new_key1]:

                matches[new_key1][new_key2] = np.concatenate(
                    [
                        matches[new_key1][new_key2],
                        new_match,
                    ],
                    axis=0,
                )
            else:
                if new_key1 not in matches:
                    matches[new_key1] = dict()
                matches[new_key1][new_key2] = new_match
    return kpts, matches


def keep_matches(matches, max_num=None):
    if max_num is None:
        return matches
    if len(matches) > max_num:
        # radnomly select max_num matches
        matches = np.random.choice(matches, max_num, replace=False)
    return matches


def keep_pairs(index_pairs, max_num_pairs=20):
    new_count = 0
    old_count = 0
    new_idx_count = defaultdict(int)
    new_pairs = defaultdict(list)
    for key1 in index_pairs:
        # sort pairs by number of pairs
        index_pairs[key1] = sorted(index_pairs[key1], key=lambda x: x[2], reverse=True)
        for pair in index_pairs[key1]:
            old_count += 1
            idx1 = key1
            idx2 = pair[0]

            if new_idx_count[key1] < max_num_pairs:
                new_pairs[idx1].append(pair)
                new_count += 1
                new_idx_count[idx1] += 1
                new_idx_count[idx2] += 1
            else:
                continue

    if DEBUG:
        print(f"origin pairs: {old_count}, kept pairs: {new_count}")
    return index_pairs

def select_matches(matches, keep_ratio = 0.01):
    max_matches = defaultdict(int)
    old_matches_count = 0
    for key1 in matches:
        for key2 in matches[key1]:
            max_matches[key1] = max(max_matches[key1], len(matches[key1][key2]))
            max_matches[key2] = max(max_matches[key2], len(matches[key1][key2]))
            old_matches_count +=1

    new_matches_count = 0
    new_matches = defaultdict(dict)
    for key1 in matches:
        for key2 in matches[key1]:
            n_matches = len(matches[key1][key2])
            if n_matches > max_matches[key1] * keep_ratio or n_matches > max_matches[key2] * keep_ratio:
                new_matches[key1][key2] = matches[key1][key2]
                new_matches_count+=1
    if DEBUG:
        print(f"origin matches: {old_matches_count}, kept matches: {new_matches_count}")
    return new_matches

# Get Fundamental matrices from matches and keypoints

In [19]:
def get_fms(kpts, matches):
    prev_len = dict()
    fms = defaultdict(dict)
    print("Get Fundamental Matrix")
    for key1 in tqdm(matches):
        for key2 in matches[key1]:
            match = matches[key1][key2]
            mkpts1 = kpts[key1][match[:, 0]]
            mkpts2 = kpts[key2][match[:, 1]]
            Fm, inliers = cv2.findFundamentalMat(mkpts1, mkpts2, cv2.USAC_MAGSAC, FM_PARAMS["ransacReprojThreshold"], FM_PARAMS["confidence"], FM_PARAMS["maxIters"])
            #keep inliers matches
            #print how many matches are inliers
            # print(f"key1: {key1}, key2: {key2}, inliers: {len(new_match)}/{len(match)}")
            if FM_PARAMS["removeOutliers"] == True:
                new_match = match[inliers.ravel() == 1]
                matches[key1][key2] = new_match
            fms[key1][key2] = Fm
    # print(Fm.shape)
    return kpts, matches, fms

# Model Setup

In [20]:
if MODEL_DICT["Keynet"]["enable"]:
    keynet_model = (
        AffNetHardNet(num_features=8000, upright=False, device=device, detector="keynet")
        .to(device)
        .eval()
    )
    keynet_detector = AffNetHardNetDetector(keynet_model, resize_long_edge_to=MODEL_DICT["Keynet"]["resize_long_edge_to"])
    laf_matcher = LafMatcher(device=device)
    
if MODEL_DICT["GFTT"]["enable"]:
    gftt_model = (
        AffNetHardNet(num_features=8000, upright=False, device=device, detector="GFTT")
        .to(device)
        .eval()
    )
    gftt_detector = AffNetHardNetDetector(gftt_model, resize_long_edge_to=MODEL_DICT["GFTT"]["resize_long_edge_to"])
    laf_matcher = LafMatcher(device=device)

if MODEL_DICT["DoG"]["enable"]:
    DoG_model = (
        AffNetHardNet(num_features=8000, upright=False, device=device, detector="DoG")
        .to(device)
        .eval()
    )
    DoG_detector = AffNetHardNetDetector(DoG_model, resize_long_edge_to=MODEL_DICT["DoG"]["resize_long_edge_to"])
    laf_matcher = LafMatcher(device=device)

if MODEL_DICT["Harris"]["enable"]:
    harris_model = (
        AffNetHardNet(num_features=8000, upright=False, device=device, detector="Harris")
        .to(device)
        .eval()
    )
    harris_detector = AffNetHardNetDetector(harris_model, resize_long_edge_to=MODEL_DICT["Harris"]["resize_long_edge_to"])
    laf_matcher = LafMatcher(device=device)


if MODEL_DICT["MASt3R"]["enable"]:
    def load_model(model_path, device, verbose=True):
        if verbose:
            print('... loading model from', model_path)
        ckpt = torch.load(model_path, map_location='cpu', weights_only=False )
        args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
        if 'landscape_only' not in args:
            args = args[:-1] + ', landscape_only=False)'
        else:
            args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False')
        assert "landscape_only=False" in args
        if verbose:
            print(f"instantiating : {args}")
        inf = float('inf')
        net = eval(args)
        s = net.load_state_dict(ckpt['model'], strict=False)
        if verbose:
            print(s)
        return net.to(device)
    model_path = "/kaggle/input/mast3r-vitlarge-basedecoder-512-catmlpdpt-metric/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
    mast3r_model = load_model(model_path, device='cpu').to(device)
    mast3r_detector = MASt3RDetector(
        mast3r_model,
        device=device,
        resize_long_edge_to=MODEL_DICT["MASt3R"]["resize_long_edge_to"]
    )

  aff_module=KF.LAFAffNetShapeEstimator(False).eval(),


Init AffNetHardNetDetector
Longer edge will be resized to 1600


  aff_module=KF.LAFAffNetShapeEstimator(False).eval(),


Init AffNetHardNetDetector
Longer edge will be resized to 1600


  aff_module=KF.LAFAffNetShapeEstimator(False).eval(),


Init AffNetHardNetDetector
Longer edge will be resized to 1600


  aff_module=KF.LAFAffNetShapeEstimator(False).eval(),


Init AffNetHardNetDetector
Longer edge will be resized to 1600
... loading model from /kaggle/input/mast3r-vitlarge-basedecoder-512-catmlpdpt-metric/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth
instantiating : AsymmetricMASt3R(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768, enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100',img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), patch_embed_cls='PatchEmbedDust3R', two_confs=True, desc_conf_mode=('exp', 0, inf), landscape_only=False)
<All keys matched successfully>
Init MASt3RDetector
Longer edge will be resized to 512


In [21]:
# Util to check if the pairs are identital
def compare_pairs(pairs1, pairs2):
    pair1_dict = dict()
    pair2_dict = dict()
    for idx1 in tqdm(range(len(pairs1) - 1)):
        for pair in pairs1[idx1]:
            pair1_dict[(idx1, pair[0])] = 0
    for idx2 in tqdm(range(len(pairs2) - 1)):
        for pair in pairs2[idx2]:
            pair2_dict[(idx2, pair[0])] = 0

    for key in tqdm(pair1_dict):
        if key not in pair2_dict:
            print(f"Key{key} not in pair2_dict")
    for key in tqdm(pair2_dict):
        if key not in pair1_dict:
            print(f"Key{key} not in pair1_dict")

# Function to generate scene db for reconstruction

In [22]:
def generate_scene_db(dataset, scene):
    feature_det_start = time()
    # Process a scene and write matches and keypoints to the database
    img_dir = f"{SRC}/{MODE}/{dataset}/{scene}/images"
    if not os.path.exists(img_dir):
        print("Image dir does not exist:", img_dir)
        return

    img_fnames = [f"{SRC}/{MODE}/{x}" for x in data_dict[dataset][scene]]
    print(f"Got {len(img_fnames)} images")

    matches = dict()
    kpts = dict()

    if MODEL_DICT["Keynet"]["enable"]:
        f_lafs, f_kpts, f_descs, f_raw_size = keynet_detector.detect_features(
            img_fnames
        )
        keynet_pairs, keynet_kpts, keynet_matches, keynet_rois = laf_matcher.match(
            img_fnames, f_lafs, f_kpts, f_descs, f_raw_size
        )
        if not MODEL_DICT["Keynet"]["pair_only"]:
            kpts, matches = merge_kpts_matches(kpts, matches, keynet_kpts, keynet_matches, MATCHES_CAP)

    if MODEL_DICT["GFTT"]["enable"]:
        gftt_lafs, gftt_kpts, gftt_descs, gftt_raw_size = gftt_detector.detect_features(
            img_fnames
        )
        index_pairs, gftt_kpts, gftt_matches, gftt_rois = laf_matcher.match(
            img_fnames, gftt_lafs, gftt_kpts, gftt_descs, gftt_raw_size
        )
        kpts, matches = merge_kpts_matches(kpts, matches, gftt_kpts, gftt_matches, MATCHES_CAP)

    if MODEL_DICT["DoG"]["enable"]:
        DoG_lafs, DoG_kpts, DoG_descs, DoG_raw_size = DoG_detector.detect_features(
            img_fnames
        )
        index_pairs, DoG_kpts, DoG_matches, DoG_rois = laf_matcher.match(
            img_fnames, DoG_lafs, DoG_kpts, DoG_descs, DoG_raw_size
        )
        kpts, matches = merge_kpts_matches(kpts, matches, DoG_kpts, DoG_matches, MATCHES_CAP)
        
    if MODEL_DICT["Harris"]["enable"]:
        harris_lafs, harris_kpts, harris_descs, harris_raw_size = harris_detector.detect_features(
            img_fnames
        )
        harris_pairs, harris_kpts, harris_matches, harris_rois = laf_matcher.match(
            img_fnames, harris_lafs, harris_kpts, harris_descs, harris_raw_size
        )
        kpts, matches = merge_kpts_matches(kpts, matches, harris_kpts, harris_matches, MATCHES_CAP)
        # compare_pairs(keynet_pairs, harris_pairs)

    if MODEL_DICT["MASt3R"]["enable"]:
        mast3r_kpts, mast3r_descs, mast3r_matches = mast3r_detector.detect_features(img_fnames)
        kpts, matches = merge_kpts_matches(kpts, matches, mast3r_kpts, mast3r_matches, MATCHES_CAP)
    
    # Get fundamental matrices
    kpts, matches, fms = get_fms(kpts, matches)
    
    matches = select_matches(matches, MATCH_FILTER_RATIO)

    if DEBUG:
        import random
        random.seed(0)
        for i in range(5):
            print(matches.keys())
           
            key1 = random.choice(list(matches.keys()))
            key2 = random.choice(list(matches[key1].keys()))
            print(key1, key2)
            fname1, fname2 = os.path.join(img_dir, key1), os.path.join(img_dir, key2)

            print("Plot Combined matches")
            plot_images_with_keypoints(
                fname1, fname2, kpts[key1], kpts[key2], matches[key1][key2]
            )
    # Write to database
    feature_dir = f"featureout/{dataset}_{scene}"
    if not os.path.isdir(feature_dir):
        os.makedirs(feature_dir, exist_ok=True)
    database_path = f"{feature_dir}/colmap.db"
    if os.path.isfile(database_path):
        os.remove(database_path)

    db = COLMAPDatabase.connect(database_path)
    db.create_tables()
    single_camera = False
    print("Add kpts and matches to database")
    add_kpts_matches(db, img_dir, kpts, matches, fms)
    feature_det_end = time()
    matching_time = feature_det_end - feature_det_start
    torch.cuda.empty_cache()
    gc.collect()

    return matching_time

# Function of reconstruction

In [23]:
def reconstruct_from_db(dataset, scene):
    scene_result = {}
    reconst_start = time()

    img_dir = f"{SRC}/{MODE}/{dataset}/{scene}/images"
    if not os.path.exists(img_dir):
        print("Image dir does not exist:", img_dir)
        return

    feature_dir = f"featureout/{dataset}_{scene}"
    database_path = f"{feature_dir}/colmap.db"
    db = COLMAPDatabase.connect(database_path)
    output_path = f"{feature_dir}/colmap_rec"
    t = time()
    gc.collect()

    t = time() - t
    print(f"RANSAC in  {t:.4f} sec")
    t = time()
    
    # Create pipeline options with min_num_reg_images = 3
    pipeline_options = pycolmap.IncrementalPipelineOptions()
    # Basic parameters
    pipeline_options.min_model_size = 3
    pipeline_options.max_model_overlap = 20
    pipeline_options.min_focal_length_ratio = 0.1
    pipeline_options.max_focal_length_ratio = 10
    pipeline_options.max_extra_param = 1.0
    
    # Triangulation parameters
    pipeline_options.min_tri_angle = 1.5
    pipeline_options.max_reproj_error = 4.0
    pipeline_options.min_track_length = 3
    pipeline_options.max_track_length = 100
    
    # Bundle adjustment parameters
    pipeline_options.abs_pose_max_error = 12.0
    pipeline_options.abs_pose_min_num_inliers = 30
    pipeline_options.abs_pose_min_inlier_ratio = 0.25
    pipeline_options.filter_max_reproj_error = 4.0
    pipeline_options.filter_min_tri_angle = 1.5
    
    # Local bundle adjustment parameters
    pipeline_options.local_ba_min_tri_angle = 6.0
    
    pipeline_options.min_tri_angle = 2.0
    pipeline_options.filter_min_tri_angle = 2.0
    pipeline_options.local_ba_min_tri_angle = 8.0
    pipeline_options.max_reproj_error = 3.0
    pipeline_options.filter_max_reproj_error = 3.0

    os.makedirs(output_path, exist_ok=True)
    maps = pycolmap.incremental_mapping(
        database_path=database_path,
        image_path=img_dir,
        output_path=output_path,
        options=pipeline_options
    )
    print(maps)
    t = time() - t
    print(f"Reconstruction done in  {t:.4f} sec")
    imgs_registered = 0
    best_idx = None
    print("Looking for the best reconstruction")
    if isinstance(maps, dict):
        for idx1, rec in maps.items():
            print(idx1, rec.summary())
            if len(rec.images) > imgs_registered:
                imgs_registered = len(rec.images)
                best_idx = idx1
    if best_idx is not None:
        print(maps[best_idx].summary())
        for k, im in maps[best_idx].images.items():
            key1 = f"{dataset}/{scene}/images/{im.name}"
            scene_result[key1] = {}
            scene_result[key1]["R"] = deepcopy(im.cam_from_world.rotation.matrix())
            scene_result[key1]["t"] = deepcopy(np.array(im.cam_from_world.translation))

    gc.collect()
    reconst_end = time()
    reconst_time = reconst_end - reconst_start
    return scene_result, reconst_time

# Main Loop

In [24]:
# Main loop to add kpts and matches
datasets = []
time_dict = dict()

for dataset in data_dict:
    datasets.append(dataset)

if DEBUG:
    matching_start = time()
    for dataset, scene in all_scenes:
        print(dataset, scene)
        time_dict["matching-" + scene] = generate_scene_db(dataset, scene)

    matching_end = time()
    time_dict["matching-TOTAL"] = matching_end - matching_start
else:
    # Run db generation and reconstuction with multiprocessing if not DEBUG
    out_results = defaultdict(dict)
    total_start = time()
    with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_CORES) as executors:
        futures = defaultdict(dict)

        for dataset, scene in all_scenes:
            print(dataset, scene)
            time_dict["matching-" + scene] = generate_scene_db(dataset, scene)
            futures[dataset][scene] = executors.submit(reconstruct_from_db, dataset, scene)

        for dataset, scene in all_scenes:
            result = futures[dataset][scene].result()
            if result is not None:
                out_results[dataset][scene], time_dict["reconst-" + scene] = result
    total_end = time()
    time_dict["TOTAL"] = total_end - total_start

2cfa01ab573141e4 2fa124afd1f74f38
Image dir does not exist: /kaggle/input/image-matching-challenge-2023/test/2cfa01ab573141e4/2fa124afd1f74f38/images
Image dir does not exist: /kaggle/input/image-matching-challenge-2023/test/2cfa01ab573141e4/2fa124afd1f74f38/images


# Reconstruction Loop

In [25]:
# Main loop for reconstruction
if DEBUG:
    reconst_start = time()
    out_results = defaultdict(dict)
    with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_CORES) as executors:
        futures = defaultdict(dict)
        for dataset, scene in all_scenes:
            futures[dataset][scene] = executors.submit(reconstruct_from_db, dataset, scene)
            print("Submitted", dataset, scene)
        for dataset, scene in all_scenes:
            result = futures[dataset][scene].result()
            if result is not None:
                out_results[dataset][scene], time_dict["reconst-" + scene] = result

    reconst_end = time()
    time_dict["reconst-TOTAL"] = reconst_end - reconst_start

In [26]:
create_submission(out_results, data_dict, MODE)

# Evaluation Code
from https://www.kaggle.com/code/eduardtrulls/imc2023-evaluation

In [27]:
import numpy as np
from dataclasses import dataclass
from time import time


def arr_to_str(a):
    return ";".join([str(x) for x in a.reshape(-1)])


# Evaluation metric.


@dataclass
class Camera:
    rotmat: np.array
    tvec: np.array


def quaternion_from_matrix(matrix):
    M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4]
    m00 = M[0, 0]
    m01 = M[0, 1]
    m02 = M[0, 2]
    m10 = M[1, 0]
    m11 = M[1, 1]
    m12 = M[1, 2]
    m20 = M[2, 0]
    m21 = M[2, 1]
    m22 = M[2, 2]

    # Symmetric matrix K.
    K = np.array(
        [
            [m00 - m11 - m22, 0.0, 0.0, 0.0],
            [m01 + m10, m11 - m00 - m22, 0.0, 0.0],
            [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0],
            [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
        ]
    )
    K /= 3.0

    # Quaternion is eigenvector of K that corresponds to largest eigenvalue.
    w, V = np.linalg.eigh(K)
    q = V[[3, 0, 1, 2], np.argmax(w)]

    if q[0] < 0.0:
        np.negative(q, q)
    return q


def evaluate_R_t(R_gt, t_gt, R, t, eps=1e-15):
    t = t.flatten()
    t_gt = t_gt.flatten()

    q_gt = quaternion_from_matrix(R_gt)
    q = quaternion_from_matrix(R)
    q = q / (np.linalg.norm(q) + eps)
    q_gt = q_gt / (np.linalg.norm(q_gt) + eps)
    loss_q = np.maximum(eps, (1.0 - np.sum(q * q_gt) ** 2))
    err_q = np.arccos(1 - 2 * loss_q)

    GT_SCALE = np.linalg.norm(t_gt)
    t = GT_SCALE * (t / (np.linalg.norm(t) + eps))
    err_t = min(np.linalg.norm(t_gt - t), np.linalg.norm(t_gt + t))

    return np.degrees(err_q), err_t


def compute_dR_dT(R1, T1, R2, T2):
    """Given absolute (R, T) pairs for two cameras, compute the relative pose difference, from the first."""

    dR = np.dot(R2, R1.T)
    dT = T2 - np.dot(dR, T1)
    return dR, dT


def compute_mAA(err_q, err_t, ths_q, ths_t):
    """Compute the mean average accuracy over a set of thresholds. Additionally returns the metric only over rotation and translation."""

    acc, acc_q, acc_t = [], [], []
    for th_q, th_t in zip(ths_q, ths_t):
        cur_acc_q = err_q <= th_q
        cur_acc_t = err_t <= th_t
        cur_acc = cur_acc_q & cur_acc_t

        acc.append(cur_acc.astype(np.float32).mean())
        acc_q.append(cur_acc_q.astype(np.float32).mean())
        acc_t.append(cur_acc_t.astype(np.float32).mean())
    return np.array(acc), np.array(acc_q), np.array(acc_t)


def dict_from_csv(csv_path, has_header):
    csv_dict = {}
    with open(csv_path, "r") as f:
        for i, l in enumerate(f):
            if has_header and i == 0:
                continue
            if l:
                image, dataset, scene, R_str, T_str = l.strip().split(",")
                R = np.fromstring(R_str.strip(), sep=";").reshape(3, 3)
                T = np.fromstring(T_str.strip(), sep=";")
                if dataset not in csv_dict:
                    csv_dict[dataset] = {}
                if scene not in csv_dict[dataset]:
                    csv_dict[dataset][scene] = {}
                csv_dict[dataset][scene][image] = Camera(rotmat=R, tvec=T)
    return csv_dict


def eval_submission(
    submission_csv_path,
    ground_truth_csv_path,
    rotation_thresholds_degrees_dict,
    translation_thresholds_meters_dict,
    verbose=False,
):
    """Compute final metric given submission and ground truth files. Thresholds are specified per dataset."""

    submission_dict = dict_from_csv(submission_csv_path, has_header=True)
    gt_dict = dict_from_csv(ground_truth_csv_path, has_header=True)

    # Check that all necessary keys exist in the submission file
    for dataset in gt_dict:
        assert dataset in submission_dict, f"Unknown dataset: {dataset}"
        for scene in gt_dict[dataset]:
            assert (
                scene in submission_dict[dataset]
            ), f"Unknown scene: {dataset}->{scene}"
            for image in gt_dict[dataset][scene]:
                assert (
                    image in submission_dict[dataset][scene]
                ), f"Unknown image: {dataset}->{scene}->{image}"

    # Iterate over all the scenes
    if verbose:
        t = time()
        print("*** METRICS ***")

    metrics_per_dataset = []
    for dataset in gt_dict:
        metrics_per_scene = []
        for scene in gt_dict[dataset]:
            err_q_all = []
            err_t_all = []
            images = [camera for camera in gt_dict[dataset][scene]]
            # Process all pairs in a scene
            for i in range(len(images)):
                for j in range(i + 1, len(images)):
                    gt_i = gt_dict[dataset][scene][images[i]]
                    gt_j = gt_dict[dataset][scene][images[j]]
                    dR_gt, dT_gt = compute_dR_dT(
                        gt_i.rotmat, gt_i.tvec, gt_j.rotmat, gt_j.tvec
                    )

                    pred_i = submission_dict[dataset][scene][images[i]]
                    pred_j = submission_dict[dataset][scene][images[j]]
                    dR_pred, dT_pred = compute_dR_dT(
                        pred_i.rotmat, pred_i.tvec, pred_j.rotmat, pred_j.tvec
                    )

                    err_q, err_t = evaluate_R_t(dR_gt, dT_gt, dR_pred, dT_pred)
                    err_q_all.append(err_q)
                    err_t_all.append(err_t)

            mAA, mAA_q, mAA_t = compute_mAA(
                err_q=err_q_all,
                err_t=err_t_all,
                ths_q=rotation_thresholds_degrees_dict[(dataset, scene)],
                ths_t=translation_thresholds_meters_dict[(dataset, scene)],
            )
            if verbose:
                print(
                    f"{dataset} / {scene} ({len(images)} images, {len(err_q_all)} pairs) -> mAA={np.mean(mAA):.06f}, mAA_q={np.mean(mAA_q):.06f}, mAA_t={np.mean(mAA_t):.06f}"
                )
            metrics_per_scene.append(np.mean(mAA))

        metrics_per_dataset.append(np.mean(metrics_per_scene))
        if verbose:
            print(f"{dataset} -> mAA={np.mean(metrics_per_scene):.06f}")
            print()

    if verbose:
        print(
            f"Final metric -> mAA={np.mean(metrics_per_dataset):.06f} (t: {time() - t} sec.)"
        )
        print()

    return np.mean(metrics_per_dataset)


# Set rotation thresholds per scene.

rotation_thresholds_degrees_dict = {
    **{
        ("haiper", scene): np.linspace(1, 10, 10)
        for scene in ["bike", "chairs", "fountain"]
    },
    **{("heritage", scene): np.linspace(1, 10, 10) for scene in ["cyprus", "dioscuri"]},
    **{("heritage", "wall"): np.linspace(0.2, 10, 10)},
    **{("urban", "kyiv-puppet-theater"): np.linspace(1, 10, 10)},
}

translation_thresholds_meters_dict = {
    **{
        ("haiper", scene): np.geomspace(0.05, 0.5, 10)
        for scene in ["bike", "chairs", "fountain"]
    },
    **{
        ("heritage", scene): np.geomspace(0.1, 2, 10)
        for scene in ["cyprus", "dioscuri"]
    },
    **{("heritage", "wall"): np.geomspace(0.05, 1, 10)},
    **{("urban", "kyiv-puppet-theater"): np.geomspace(0.5, 5, 10)},
}

# Run evaluation and log outputs 

In [28]:
%%capture cap --no-stderr

import datetime
current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print("=========================================")
print(current_time)
print(FM_PARAMS)
print(LOG_DICT)
print(MODEL_DICT)
print("Match filter ratio = ", MATCH_FILTER_RATIO)
for dataset in out_results:
    for scene in out_results[dataset]:
        print(
            f"Registered: {dataset} / {scene} -> {len(out_results[dataset][scene])}/{len(data_dict[dataset][scene])} images"
        )

print(time_dict)

if MODE == "train":
    with open(f"{SRC}/train/train_labels.csv", "r") as fr, open(
        "ground_truth.csv", "w"
    ) as fw:
        for i, l in enumerate(fr):
            if i == 0:
                fw.write(
                    "image_path,dataset,scene,rotation_matrix,translation_vector\n"
                )
            else:
                dataset, scene, image, R, T = l.strip().split(",")
                fw.write(f"{image},{dataset},{scene},{R},{T}\n")

    eval_submission(
        submission_csv_path="submission_train.csv",
        ground_truth_csv_path="ground_truth.csv",
        rotation_thresholds_degrees_dict=rotation_thresholds_degrees_dict,
        translation_thresholds_meters_dict=translation_thresholds_meters_dict,
        verbose=True,
    )



Write log to a text file

In [29]:
print(cap)
if is_local:
    with open("log.txt", "a") as f:
        f.write(str(cap))

2025-06-02_08-24-12
{'ransacReprojThreshold': 5, 'confidence': 0.9999, 'maxIters': 50000, 'removeOutliers': True}
{'mode': 'test', 'log_message': 'Final submission', 'matches_cap': None, 'debug': False, 'debug_scene': ['kyiv-puppet-theater']}
{'Keynet': {'enable': True, 'resize_long_edge_to': 1600, 'pair_only': False}, 'GFTT': {'enable': True, 'resize_long_edge_to': 1600}, 'DoG': {'enable': True, 'resize_long_edge_to': 1600}, 'Harris': {'enable': True, 'resize_long_edge_to': 1600}, 'MASt3R': {'enable': True, 'resize_long_edge_to': 512}}
Match filter ratio =  0.01
{'matching-2fa124afd1f74f38': None, 'TOTAL': 0.16366934776306152}



# mast3r

In [30]:
# def run_mast3r_inference(image_paths, save_dir, model, device, 
#                         scenegraph_type="swin", winsize=1, win_cyclic=False,
#                         optim_level="refine+depth", lr1=0.07, niter1=1000, lr2=0.014, niter2=400,
#                         min_conf_thr=1.5, matching_conf_thr=5, shared_intrinsics=True):
#     """Run MASt3R inference and save results in COLMAP format
    
#     Args:
#         image_paths: List of paths to input images
#         save_dir: Directory to save results
#         model: MASt3R model
#         device: Device to run inference on
#         scenegraph_type: Type of scene graph to use ("swin", "logwin", "oneref", "retrieval")
#         winsize: Window size for scene graph
#         win_cyclic: Whether to use cyclic window
#         optim_level: Optimization level ("coarse", "fine", "coarse+fine")
#         lr1: Learning rate for first optimization stage
#         niter1: Number of iterations for first optimization stage
#         lr2: Learning rate for second optimization stage
#         niter2: Number of iterations for second optimization stage
#         min_conf_thr: Minimum confidence threshold for points
#         matching_conf_thr: Matching confidence threshold
#         shared_intrinsics: Whether to use shared intrinsics
#     """
#     # Load images
#     imgs = load_images(image_paths, size=512, verbose=True)
    
#     # Create scene graph
#     scene_graph_params = [scenegraph_type]
#     if scenegraph_type in ["swin", "logwin"]:
#         scene_graph_params.append(str(winsize))
#         if not win_cyclic:
#             scene_graph_params.append('noncyclic')
#     elif scenegraph_type == "oneref":
#         scene_graph_params.append("0")  # Use first image as reference
    
#     scene_graph = '-'.join(scene_graph_params)
    
#     # Create image pairs
#     pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
    
#     # Run sparse global alignment
#     if optim_level == 'coarse':
#         niter2 = 0
    
#     scene = sparse_global_alignment(
#         image_paths, pairs, save_dir,
#         model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
#         opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
#         matching_conf_thr=matching_conf_thr
#     )
    
#     return scene




In [31]:
# def run_mast3r_reconstruction():
#     """Run MASt3R reconstruction and create submission.csv"""
#     print("Running MASt3R reconstruction...")
    
#     # Initialize output results dictionary
#     out_results = defaultdict(dict)
    
#     # Process each scene
#     for dataset, scene in all_scenes:
#         print(f"\nProcessing {dataset}/{scene}")
        
#         # Get image paths
#         img_dir = f"{SRC}/{MODE}/{dataset}/{scene}/images"
#         if not os.path.exists(img_dir):
#             print("Image dir does not exist:", img_dir)
#             continue
            
#         img_fnames = [f"{SRC}/{MODE}/{x}" for x in data_dict[dataset][scene]]
#         print(f"Got {len(img_fnames)} images")
        
#         # Create save directory
#         save_dir = f"mast3r_out/{dataset}_{scene}"
#         os.makedirs(save_dir, exist_ok=True)
        
#         try:
#             # Run MASt3R inference
#             scene_obj = run_mast3r_inference(
#                 img_fnames,
#                 save_dir,
#                 model,
#                 device='cuda',
#                 scenegraph_type='swin',
#                 winsize=2,  # 增加窗口大小
#                 optim_level='refine+depth',
#                 lr1=0.07,
#                 niter1=1000,  # 增加迭代次數
#                 lr2=0.014,   # 調整學習率
#                 niter2=500,  # 增加迭代次數
#                 shared_intrinsics=False,
#                 matching_conf_thr=5.0  # 增加匹配置信度閾值
#             )
            
#             c2w_poses = scene_obj.get_im_poses()
            
#             for i, img_path in enumerate(img_fnames):
#                 img_name = img_path.split("/")[-1]
                
#                 w2c = np.linalg.inv(c2w_poses[i])

#                 R = w2c[:3, :3]
#                 t = w2c[:3, 3]
                
#                 out_results[dataset][scene][img_name] = {
#                     "R": R,  # world2cam rotation
#                     "t": t   # world2cam translation
#                 }
                
#             print(f"Registered: {dataset}/{scene} -> {len(out_results[dataset][scene])}/{len(img_fnames)} images")
            
#         except Exception as e:
#             print(f"Error processing {dataset}/{scene}:")
#             print(e)
#             continue
            
#         gc.collect()
#         torch.cuda.empty_cache()
    
#     # Create submission file
#     create_submission(out_results, data_dict, MODE)
    
#     # Evaluate if in training mode
#     if MODE == "train":
#         with open(f"{SRC}/train/train_labels.csv", "r") as fr, open("ground_truth.csv", "w") as fw:
#             for i, l in enumerate(fr):
#                 if i == 0:
#                     fw.write(
#                         "image_path,dataset,scene,rotation_matrix,translation_vector\n"
#                     )
#                 else:
#                     dataset, scene, image, R, T = l.strip().split(",")
#                     fw.write(f"{image},{dataset},{scene},{R},{T}\n")

#         eval_submission(
#             submission_csv_path="submission_train.csv",
#             ground_truth_csv_path="ground_truth.csv",
#             rotation_thresholds_degrees_dict=rotation_thresholds_degrees_dict,
#             translation_thresholds_meters_dict=translation_thresholds_meters_dict,
#             verbose=True
#         )

In [32]:
# run_mast3r_reconstruction()