In [23]:
import numpy as np
import kornia as K
import kornia.feature as KF
import torch
import torch.nn as nn
import torchvision
from PIL import Image
import skimage.transform
import PIL.Image as pil
import tqdm
import os
os.sys.path.append("/home/data/workspace/heqi/monogastroendo")
from utils import *

fpath = os.path.join("/home/data/workspace/heqi/monogastroendo/splits", "c3vd", "{}_files.txt")
data_path = "/home/data/workspace/heqi/monogastroendo/rect_c3vd_data"
img_ext = ".png"

def pil_loader(path):
    # open path as file to avoid ResourceWarning
    # (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')
        
def get_image_path(folder, frame_index_str):
    f_str = "{}_color{}".format(frame_index_str, img_ext)
    image_path = os.path.join(data_path, folder, f_str)
    return image_path

def get_color(folder, frame_index_str, do_flip):
    color = pil_loader(get_image_path(folder, frame_index_str))
    
    if do_flip:
        color = color.transpose(pil.FLIP_LEFT_RIGHT)
    return color

# utils
resize = torchvision.transforms.Resize((256, 320), interpolation=torchvision.transforms.InterpolationMode.LANCZOS)
to_tensor = torchvision.transforms.ToTensor()
torch.cuda.set_device(1)
device = torch.device("cuda")

class LoFTR(nn.Module):
    """Layer to compute the correspondences between a pair of images
    """
    def __init__(self, pretrained='indoor'):
        super(LoFTR, self).__init__()
        self.matcher = KF.LoFTR(pretrained=pretrained)

    def forward(self, src0, srcx):
        input_dict = {"image0": K.color.rgb_to_grayscale(src0), # LofTR works on grayscale images only 
                    "image1": K.color.rgb_to_grayscale(srcx)}
        with torch.no_grad():
            correspondences = self.matcher(input_dict)
        return correspondences
##########################################################

# load data
train_filenames = readlines(fpath.format("train"))
val_filenames = readlines(fpath.format("val"))

# load matcher
try:
    matcher
    print("matcher loaded")
except NameError:
    print("load matcher")
    matcher = LoFTR(pretrained="indoor")
    matcher.to(device)

# processing correspondence
matcher_result = {"no_flip": [],
                  "do_flip": []}
for i in tqdm.notebook.tnrange(len(train_filenames)):
    line = train_filenames[i].split()
    
    for do_flip in [False, True]:
        img_tensor = []
        for j in range(3):
            img_tensor.append(to_tensor(resize(get_color(line[0], line[1+j], do_flip))).to(device))
        correspondences = []
        correspondences.append(matcher.forward(img_tensor[1][None, ...], img_tensor[0][None, ...]))
        correspondences.append(matcher.forward(img_tensor[1][None, ...], img_tensor[2][None, ...]))
        for k in range(2):
            del correspondences[k]['batch_indexes']
            correspondences[k]['keypoints0'] = correspondences[k]['keypoints0'].detach().cpu().numpy()
            correspondences[k]['keypoints1'] = correspondences[k]['keypoints1'].detach().cpu().numpy()
            correspondences[k]['confidence'] = correspondences[k]['confidence'].detach().cpu().numpy()
        if do_flip:
            matcher_result["do_flip"].append(correspondences)
        else:
            matcher_result["no_flip"].append(correspondences)
    
np.save("matcher_result", matcher_result)

matcher loaded


  0%|          | 0/4632 [00:00<?, ?it/s]

  color = color.transpose(pil.FLIP_LEFT_RIGHT)


In [None]:
matcher_result['no_flip'][0][0]["keypoints0"]

In [26]:
del matcher_result_load

In [6]:
import numpy as np
try:
    matcher_result_load
except NameError:
    matcher_result_load = np.load("matcher_result.npy", allow_pickle=True).all()
matcher_result_load['do_flip'][0][0]["keypoints0"]

array([[ 16.,  16.],
       [ 24.,  16.],
       [ 48.,  16.],
       ...,
       [272., 232.],
       [280., 232.],
       [288., 232.]], dtype=float32)

In [9]:
matcher_result_load['do_flip'][0]

[{'keypoints0': array([[ 16.,  16.],
         [ 24.,  16.],
         [ 48.,  16.],
         ...,
         [272., 232.],
         [280., 232.],
         [288., 232.]], dtype=float32),
  'keypoints1': array([[ 24.194153,  16.0074  ],
         [ 32.175068,  15.949041],
         [ 88.34166 ,  48.17555 ],
         ...,
         [249.12816 , 216.30878 ],
         [256.51318 , 216.40216 ],
         [263.7251  , 223.55421 ]], dtype=float32),
  'confidence': array([0.55676436, 0.23806931, 0.5038643 , 0.4946127 , 0.6606673 ,
         0.6858394 , 0.32717958, 0.6863053 , 0.20916359, 0.6350299 ,
         0.99612015, 0.8698917 , 0.5218235 , 0.64635533, 0.20502019,
         0.5753261 , 0.52809733, 0.42726973, 0.41389427, 0.6654696 ,
         0.3709133 , 0.9040525 , 0.31108338, 0.9811368 , 0.9675403 ,
         0.30218193, 0.92043763, 0.97226334, 0.99905944, 0.95934665,
         0.66935086, 0.9463614 , 0.28133914, 0.75114036, 0.99453866,
         0.36838177, 0.4886595 , 0.99756837, 0.9327204 , 0.836789