# ***Install Libs***

In [None]:
%%capture
#dry_run = False
!pip install ../input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
!pip install ../input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl

# for depth estimation module
!mkdir -p /root/.cache/torch/hub/checkpoints
!cp -r ../input/midasdepthestimation/MiDaS-master  /root/.cache/torch/hub/intel-isl_MiDaS_master
!cp ../input/midasdepthestimation/dpt_large-midas-2f21e586.pt  /root/.cache/torch/hub/checkpoints/
!pip install ../input/midasdepthestimation/timm-0.5.4-py3-none-any.whl

# ***Import dependencies***

In [None]:
import os
import sys
import csv
from glob import glob
import gc
import random
from tqdm import tqdm
import matplotlib.pyplot as plt

import numpy as np
import cv2
import torch

# for LoFTR
import kornia
from kornia_moons.feature import *
import kornia as K
import kornia.feature as KF

# for depth-estimation
import timm

In [None]:
sys.path.append("../input/imcutils")
from imc_metric import EvaluateSubmission, ReadCovisibilityData, FlattenMatrix, LoadCalibration

sys.path.append("../input/super-glue-pretrained-network")
from models.matching import Matching as SuperGlue
from models.utils import (compute_pose_error, compute_epipolar_error,
                          estimate_pose, make_matching_plot,
                          error_colormap, AverageTimer, pose_auc, read_image,
                          rotate_intrinsics, rotate_pose_inplane,
                          scale_intrinsics)

In [None]:
# for pytorch3d
sys.path.append("../input/pytorch3ddependencies/pytorch3d_dependencies")
os.environ["CUB_HOME"] = "../input/pytorch3ddependencies/pytorch3d_dependencies/cub-1.10.0"
import pytorch3d

In [None]:
from pytorch3d.renderer.cameras import (
    PerspectiveCameras,
)
from pytorch3d.transforms.so3 import (
    so3_exp_map
)

# ***Model***

In [None]:
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# LoFTR
matcher = KF.LoFTR(pretrained=None)
matcher.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])
matcher = matcher.to(device)
matcher.eval()
print()

In [None]:
# Super Glue
config = {
    "superpoint": {
        "nms_radius": 4,
        "keypoint_threshold": 0.005,
        "max_keypoints": 1024
    },
    "superglue": {
        "weights": "outdoor",
        "sinkhorn_iterations": 20,
        "match_threshold": 0.2,
    }
}
superglue = SuperGlue(config).eval().to(device)

In [None]:
def match(img_path0, img_path1, matcher, device=device):
    img0 = load_torch_image(img_path0)
    img1 = load_torch_image(img_path1)
        
    input_dict = {"image0": K.color.rgb_to_grayscale(img0).to(device), 
                  "image1": K.color.rgb_to_grayscale(img1).to(device)}
    
    with torch.no_grad():
        correspondences = matcher(input_dict)
        
    mkpts0 = correspondences['keypoints0'].cpu().numpy()
    mkpts1 = correspondences['keypoints1'].cpu().numpy()
        
    return mkpts0, mkpts1

def superglue_match(img_path0, img_path1, matcher, device=device):
    resize = [-1, ]
    resize_float = True
    image_1, inp_1, scales_1 = read_image(img_path0, device, resize, 0, resize_float)
    image_2, inp_2, scales_2 = read_image(img_path1, device, resize, 0, resize_float)
    
    pred = matcher({"image0": inp_1, "image1": inp_2})
    pred = {k: v[0].detach().cpu().numpy() for k, v in pred.items()}
    kpts1, kpts2 = pred["keypoints0"], pred["keypoints1"]
    matches, conf = pred["matches0"], pred["matching_scores0"]
    
    valid = matches > -1
    mkpts1 = kpts1[valid]
    mkpts2 = kpts2[matches[valid]]
    
    return mkpts1, mkpts2

def get_F_matrix(mkpts0, mkpts1):
    # Make sure we do not trigger an exception here.
    if len(mkpts0) > 8:
        F, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.200, 0.9999, 250000)
        assert F.shape == (3, 3), 'Malformed F?'
    else:
        F = np.zeros((3, 3))
    return F

In [None]:
model_type = "DPT_Large" 
depth_estimator = torch.hub.load("intel-isl/MiDaS", model_type)
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
depth_estimator.to(device)
depth_estimator.eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform

def estimate_depth(filepath, depth_estimator, transform):
#     filepath = f'{src}/test_images/{batch_id}/{image_1_id}.png'
    img = cv2.imread(filepath)
    scale = 640 / max(img.shape[0], img.shape[1]) 
    w = int(img.shape[1] * scale)
    h = int(img.shape[0] * scale)
    resized_img = cv2.resize(img, (w, h))
    input_batch = transform(resized_img).to(device)
    
    with torch.no_grad():
        mask = depth_estimator(input_batch)

        mask = torch.nn.functional.interpolate(
            mask.unsqueeze(1),
            size=img.shape[:2],
            mode="bicubic",
            align_corners=False,
        ).squeeze()

    return mask.cpu().numpy()

## ***Utils***

In [None]:
src = '/kaggle/input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]


def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])


def load_torch_image(fname):
    img = cv2.imread(fname)
    scale = 840 / max(img.shape[0], img.shape[1]) 
    w = int(img.shape[1] * scale)
    h = int(img.shape[0] * scale)
    img = cv2.resize(img, (w, h))
    img = K.image_to_tensor(img, False).float() /255.
    img = K.color.bgr_to_rgb(img)
    return img

# ***Inference***

In [None]:
F_dict = {}

depth_points_dict = {}

import time
for i, row in enumerate(test_samples):
    sample_id, batch_id, image_1_id, image_2_id = row
    # Load the images.
    st = time.time()
    image1_filepath = f'{src}/test_images/{batch_id}/{image_1_id}.png'
    image2_filepath = f'{src}/test_images/{batch_id}/{image_2_id}.png'
    image_1 = load_torch_image(image1_filepath).to(device)
    image_2 = load_torch_image(image2_filepath).to(device)
    
    depth1 = estimate_depth(image1_filepath, depth_estimator, transform)
    depth2 = estimate_depth(image2_filepath, depth_estimator, transform)
    
    #LoFTR
    mkpts0, mkpts1 = match(image1_filepath, image2_filepath, matcher, device)
    
#     #SuperGlue
#     sg_mkpts0, sg_mkpts1 = superglue_match(f'{src}/test_images/{batch_id}/{image_1_id}.png', f'{src}/test_images/{batch_id}/{image_2_id}.png', superglue, device)
    
#     mkpts0 = np.vstack((loftr_mkpts0, sg_mkpts0))
#     mkpts1 = np.vstack((loftr_mkpts1, sg_mkpts1))

    img1_w, img1_h = image_1.shape[2], image_1.shape[3]
    img2_w, img2_h = image_2.shape[2], image_2.shape[3]

    depth_points_dict[sample_id] = {
        "depth1": depth1,
        "depth2": depth2,
        "image1_filepath": image1_filepath,
        "image2_filepath": image2_filepath,
        "principal_point1": torch.Tensor((img1_w, img1_h)),
        "principal_point2": torch.Tensor((img2_w, img2_h)),
        "points1": mkpts0,
        "points2": mkpts1,
    }
    
    if len(mkpts0) > 7:
        F, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.200, 0.9999, 250000)
        inliers = inliers > 0
        assert F.shape == (3, 3), 'Malformed F?'
        F_dict[sample_id] = F
    else:
        F_dict[sample_id] = np.zeros((3, 3))
        continue
    gc.collect()
    nd = time.time()    
    if (i < 3):
        print("Running time: ", nd - st, " s")
        draw_LAF_matches(
        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts0).view(1,-1, 2),
                                    torch.ones(mkpts0.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts0.shape[0]).view(1,-1, 1)),

        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts1).view(1,-1, 2),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1)),
        torch.arange(mkpts0.shape[0]).view(-1,1).repeat(1,2),
        K.tensor_to_image(image_1),
        K.tensor_to_image(image_2),
        inliers,
        draw_dict={'inlier_color': (0.2, 1, 0.2),
                   'tentative_color': None, 
                   'feature_color': (0.2, 0.5, 1), 'vertical': False})
    
with open('submission.csv', 'w') as f:
    f.write('sample_id,fundamental_matrix\n')
    for sample_id, F in F_dict.items():
        f.write(f'{sample_id},{FlattenMatrix(F)}\n')

In [None]:
def init_values():
    N=2
    log_R_absolute_init = torch.randn(N, 3, dtype=torch.float32, device=device)
    T_absolute_init = torch.randn(N, 3, dtype=torch.float32, device=device)
    focal_length_init = torch.ones((N, 2), dtype=torch.float32, device=device)

    log_R_absolute_init[0, :] = 0.
    T_absolute_init[0, :] = 0.

    log_R_absolute = log_R_absolute_init.clone().detach()
    log_R_absolute.requires_grad = True
    T_absolute = T_absolute_init.clone().detach()
    T_absolute.requires_grad = True
    focal_length = focal_length_init.clone().detach()
    focal_length.requires_grad = True
    
    return log_R_absolute, T_absolute, focal_length

def loss_function(xyz_unproj_world, loss_fn=torch.nn.L1Loss()):
    return loss_fn(xyz_unproj_world[0], xyz_unproj_world[1])
    
def optimization(log_R_absolute, T_absolute, focal_length, principal_points, xy_depth, n_iter = 2000):
    
    optimizer = torch.optim.SGD([log_R_absolute, T_absolute, focal_length], lr=.1, momentum=0.9)
    
    camera_mask = torch.ones(2, 1, dtype=torch.float32, device=device)
    camera_mask[0] = 0.
    
#     init_R = torch.eye(3,3).unsqueeze(0)
#     init_Rs = torch.cat((init_R, init_R), 0)
#     init_T = torch.Tensor((0,0,0)).unsqueeze(0)
#     init_Ts = torch.cat((init_T, init_T), 0)
#     cameras = PerspectiveCameras(R=init_Rs, T=init_Ts)
    
    for it in range(n_iter):
        R_absolute = so3_exp_map(log_R_absolute * camera_mask)
        print(R_absolute)
        cameras_absolute = PerspectiveCameras(
            R = R_absolute,
            T = T_absolute * camera_mask,
            focal_length = focal_length,
            principal_point = principal_points,
            device = device,
        )
        optim_one_iter(cameras_absolute, xy_depth, optimizer)
        
    return cameras_absolute
    
    
def optim_one_iter(cameras, xy_depth, optimizer):
    optimizer.zero_grad()
    
    xyz_unproj_world = cameras.unproject_points(xy_depth, world_coordinates=True)
    loss = loss_function(xyz_unproj_world)
    loss.backward()
    optimizer.step()

In [None]:
def make_xy_depth(dic, device=device):
    depth1 = dic["depth1"]
    pts1 = dic["points1"].astype(int)
    pts1_depth = torch.Tensor(depth1[pts1[:,1], pts1[:,0]])
    pts1 = torch.Tensor(pts1)
    
    depth2 = dic["depth2"]
    pts2 = dic["points2"].astype(int)
    pts2_depth = torch.Tensor(depth2[pts2[:,1], pts2[:,0]])
    pts2 = torch.Tensor(pts2)
    
    xy_depth1 = torch.cat((pts1, pts1_depth.unsqueeze(1)), dim=1).unsqueeze(0)
    xy_depth2 = torch.cat((pts2, pts2_depth.unsqueeze(1)), dim=1).unsqueeze(0)
    xy_depth = torch.cat((xy_depth1, xy_depth2), dim=0).to(device)
    xy_depth.requires_grad = False
    return xy_depth

In [None]:
for pair_id, depth_points_dic in depth_points_dict.items():
    
    xy_depth = make_xy_depth(depth_points_dic)
    
    principal_point1 = depth_points_dic["principal_point1"].unsqueeze(0)
    principal_point2 = depth_points_dic["principal_point2"].unsqueeze(0)
    principal_points = torch.cat((principal_point1, principal_point2), dim=0).to(device)
    principal_points.requires_grad = False
    
    log_R_absolute, T_absolute, focal_length = init_values()
    
    optimized_cameras = optimization(log_R_absolute, T_absolute, focal_length, principal_points, xy_depth)
    
    break

In [None]:
optimized_cameras[1].R

In [None]:
dir(optimized_cameras[0])