In [61]:
#!pip install ../input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
#!pip install ../input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl
#!pip install ../input/loftrutils/einops-0.4.1-py3-none-any.whl

In [62]:
import itertools
import os
import sys

sys.path.append('../input/loftrutils/LoFTR-master/LoFTR-master/')

In [63]:
!pip install loguru



In [64]:
from collections import defaultdict
import cv2
import pytorch_lightning as pl
from matplotlib import pyplot as plt
import gc
from src.loftr import LoFTR
from src.loftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine
from src.losses.loftr_loss import LoFTRLoss
from src.optimizers import build_optimizer, build_scheduler
from src.utils.metrics import (
    compute_symmetrical_epipolar_errors,
    compute_pose_errors,
    aggregate_metrics
)
from src.utils.plotting import make_matching_figures
from src.utils.comm import gather, all_gather
from src.utils.misc import lower_config, flattenList
from src.utils.profiler import PassThroughProfiler


class PL_LoFTR(pl.LightningModule):
    def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
        """
        TODO:
            - use the new version of PL logging API.
        """
        super().__init__()
        # Misc
        self.config = config  # full config
        _config = lower_config(self.config)
        self.loftr_cfg = lower_config(_config['loftr'])
        self.profiler = profiler or PassThroughProfiler()
        self.n_vals_plot = 1  #max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)

        # Matcher: LoFTR
        self.matcher = LoFTR(config=_config['loftr'])
        self.loss = LoFTRLoss(_config)

        # Pretrained weights
        if pretrained_ckpt:
            state_dict = torch.load(pretrained_ckpt, map_location='cuda')['state_dict']
            self.matcher.load_state_dict(state_dict, strict=True)
            logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")

        # Testing
        self.dump_dir = dump_dir

    def configure_optimizers(self):
        # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
        optimizer = build_optimizer(self, self.config)
        scheduler = build_scheduler(self.config, optimizer)
        return [optimizer], [scheduler]

    def optimizer_step(
            self, epoch, batch_idx, optimizer, optimizer_idx,
            optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
        # learning rate warm up
        warmup_step = self.config.TRAINER.WARMUP_STEP
        if self.trainer.global_step < warmup_step:
            if self.config.TRAINER.WARMUP_TYPE == 'linear':
                base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
                lr = base_lr + \
                     (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
                     abs(self.config.TRAINER.TRUE_LR - base_lr)
                for pg in optimizer.param_groups:
                    pg['lr'] = lr
            elif self.config.TRAINER.WARMUP_TYPE == 'constant':
                pass
            else:
                raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')

        # update params
        optimizer.step(closure=optimizer_closure)
        optimizer.zero_grad()

    def _trainval_inference(self, batch):
        with self.profiler.profile("Compute coarse supervision"):
            compute_supervision_coarse(batch, self.config)

        with self.profiler.profile("LoFTR"):
            self.matcher(batch)

        with self.profiler.profile("Compute fine supervision"):
            compute_supervision_fine(batch, self.config)

        with self.profiler.profile("Compute losses"):
            self.loss(batch)

    def _compute_metrics(self, batch):
        with self.profiler.profile("Copmute metrics"):
            compute_symmetrical_epipolar_errors(batch)  # compute epi_errs for each match
            compute_pose_errors(batch, self.config)  # compute R_errs, t_errs, pose_errs for each pair

            rel_pair_names = list(zip(*batch['pair_names']))
            bs = batch['image0'].size(0)
            metrics = {
                # to filter duplicate pairs caused by DistributedSampler
                'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
                'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
                'R_errs': batch['R_errs'],
                't_errs': batch['t_errs'],
                'inliers': batch['inliers']}
            ret_dict = {'metrics': metrics}
        return ret_dict, rel_pair_names

    def training_step(self, batch, batch_idx):
        self._trainval_inference(batch)

        # logging
        if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
            # scalars
            for k, v in batch['loss_scalars'].items():
                self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step)

            # net-params
            if self.config.LOFTR.MATCH_COARSE.MATCH_TYPE == 'sinkhorn':
                self.logger.experiment.add_scalar(
                    f'skh_bin_score', self.matcher.coarse_matching.bin_score.clone().detach().cpu().data,
                    self.global_step)

            # figures
            if self.config.TRAINER.ENABLE_PLOTTING:
                compute_symmetrical_epipolar_errors(batch)  # compute epi_errs for each match
                figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
                for k, v in figures.items():
                    self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step)
        gc.collect()
        torch.cuda.empty_cache()
        return {'loss': batch['loss']}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        if self.trainer.global_rank == 0:
            self.logger.experiment.add_scalar(
                'train/avg_loss_on_epoch', avg_loss,
                global_step=self.current_epoch)
        gc.collect()
        torch.cuda.empty_cache()

    def validation_step(self, batch, batch_idx):
        self._trainval_inference(batch)

        ret_dict, _ = self._compute_metrics(batch)

        val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
        figures = {self.config.TRAINER.PLOT_MODE: []}
        if batch_idx % val_plot_interval == 0:
            figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE)
        gc.collect()
        torch.cuda.empty_cache()
        return {
            **ret_dict,
            'loss_scalars': batch['loss_scalars'],
            'figures': figures,
        }

    def validation_epoch_end(self, outputs):
        # handle multiple validation sets
        multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
        multi_val_metrics = defaultdict(list)

        for valset_idx, outputs in enumerate(multi_outputs):
            # since pl performs sanity_check at the very begining of the training
            cur_epoch = self.trainer.current_epoch
            if not self.trainer.resume_from_checkpoint:
                cur_epoch = -1

            # 1. loss_scalars: dict of list, on cpu
            _loss_scalars = [o['loss_scalars'] for o in outputs]
            loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}

            # 2. val metrics: dict of list, numpy
            _metrics = [o['metrics'] for o in outputs]
            metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
            # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 
            val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
            for thr in [5, 10, 20]:
                multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])

            # 3. figures
            _figures = [o['figures'] for o in outputs]
            figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}

            # tensorboard records only on rank 0
            if self.trainer.global_rank == 0:
                for k, v in loss_scalars.items():
                    mean_v = torch.stack(v).mean()
                    self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)

                for k, v in val_metrics_4tb.items():
                    self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)

            gc.collect()
            torch.cuda.empty_cache()
            plt.close('all')

        for thr in [5, 10, 20]:
            # log on all ranks for ModelCheckpoint callback to work properly
            self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}'])))  # ckpt monitors on this

    def test_step(self, batch, batch_idx):
        with self.profiler.profile("LoFTR"):
            self.matcher(batch)

        ret_dict, rel_pair_names = self._compute_metrics(batch)

        with self.profiler.profile("dump_results"):
            if self.dump_dir is not None:
                # dump results for further analysis
                keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'}
                pair_names = list(zip(*batch['pair_names']))
                bs = batch['image0'].shape[0]
                dumps = []
                for b_id in range(bs):
                    item = {}
                    mask = batch['m_bids'] == b_id
                    item['pair_names'] = pair_names[b_id]
                    item['identifier'] = '#'.join(rel_pair_names[b_id])
                    for key in keys_to_save:
                        item[key] = batch[key][mask].cpu().numpy()
                    for key in ['R_errs', 't_errs', 'inliers']:
                        item[key] = batch[key][b_id]
                    dumps.append(item)
                ret_dict['dumps'] = dumps

        return ret_dict

    def test_epoch_end(self, outputs):
        # metrics: dict of list, numpy
        _metrics = [o['metrics'] for o in outputs]
        metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}

        # [{key: [{...}, *#bs]}, *#batch]
        if self.dump_dir is not None:
            Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
            _dumps = flattenList([o['dumps'] for o in outputs])  # [{...}, #bs*#batch]
            dumps = flattenList(gather(_dumps))  # [{...}, #proc*#bs*#batch]
            logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')

        if self.trainer.global_rank == 0:
            print(self.profiler.summary())
            val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
            logger.info('\n' + pprint.pformat(val_metrics_4tb))
            if self.dump_dir is not None:
                np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps)

In [65]:

def read_megadepth_depth(path, pad_to=None):
    depth = cv2.imread(path, 0)
    if pad_to is not None:
        depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
    depth = torch.from_numpy(depth).float()  # (h, w)
    gc.collect()
    return depth

In [66]:
import argparse
import math
import pprint
from pathlib import Path



from src.config.default import get_cfg_defaults
from lightning_utilities.core.rank_zero import rank_zero_only
from loguru import logger

from src.utils.plotting import make_matching_figures
from src.utils.comm import gather, all_gather
from src.utils.misc import lower_config, flattenList, setup_gpus, get_rank_zero_only_logger
from src.utils.profiler import PassThroughProfiler, build_profiler
import PIL
from PIL import Image, ImageFilter
import findpeaks
from torch import optim
! pip install rasterio
import rasterio as rio
from rasterio import warp
import cv2
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import defaultdict
import cv2
import pytorch_lightning as pl
from matplotlib import pyplot as plt
import gc
from src.loftr import LoFTR
from src.loftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine
from src.losses.loftr_loss import LoFTRLoss
from src.optimizers import build_optimizer, build_scheduler
from src.utils.metrics import (
    compute_symmetrical_epipolar_errors,
    compute_pose_errors,
    aggregate_metrics
)
import math
import argparse
import pprint
from distutils.util import strtobool
from pathlib import Path
from loguru import logger as loguru_logger

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

from src.config.default import get_cfg_defaults
from src.utils.misc import get_rank_zero_only_logger, setup_gpus
from src.utils.profiler import build_profiler
import pandas as pd

# Need this to plot in HD
# This takes up a lot of memory!
matplotlib.rcParams['figure.dpi'] = 300
def get_correspondence(src_data, dst_data ,src_row, src_col):
  ''' Gets the pixel coordinates (dst_row,dst_col) in dst_data of a given pixel at (src_row,src_col) in src_data.
      Warning: no handling of pixel being outside dst_data! You have to handle that yourself.
      TODO: Should add a check to only get the correspondences with high enough cornerness on both sides
      (because we should only insert meaningful correspondences in the list!).
  '''
  # Get geo coords of pixel from src dataset
  X1, Y1 = src_data.xy(src_row, src_col)
  #print("Geographic coordinates in crs 1: ",X1,Y1)
  if src_data.crs != dst_data.crs:
    # convert coordinates in the crs of the dst dataset
    X2, Y2 = warp.transform(src_data.crs, dst_data.crs, [X1], [Y1])
    X2 = X2[0]
    Y2 = Y2[0]
  else:
    # if the crs is the same, do nothing
    X2, Y2 = X1, Y1

  # Get corresponding px coords in dst dataset
  # It still returns an index even if out of bounds
  dst_row, dst_col = dst_data.index(X2, Y2)
  #print("Corresponding pixel coordinates in image 2: ",dst_row,dst_col)
  return dst_row, dst_col

def get_datapoint(ref_data,query_data):
  ''' TODO: Call this method on a pair of rasterio datasets. It will generate a random datapoint consisting
      of a smaller SAR patch from the query dataset, and a bigger SAR patch from the reference dataset.
      The search patch contains the area of the smaller patch, with a random offset.
  '''
  datapoint = None
  return datapoint

def is_good_patch(patch):
  ''' TODO: Checks if the random query patch is sufficiently texture-rich to be used to train/test the matching model.
      If the patch is too plain (e.g. sea or desert), returns False. This function should also be used at test time:
      if a sensor image is too plain, there's no need to match it.
      At test time, something similar should also be done on the reference side.
  '''
  good = None
  return good

def get_keypoints(patch):
  ''' TODO: Use something like harris corner detection to get the list of ground truth correspondences.
      In fact, you should not train on each pixel corresp, but you should select only meaningful corresp!
      Harris peaks might be just strong speckle noise, but if you take strong ones you should be fine.
  '''
  keypoints = None
  return keypoints

def normalize_image(img):
  ''' Normalize by clipping to 99th percentile and convert to uint8.
      This clips strong speckle outliers and optimizes the brightness range
  '''
  p1 = np.nanpercentile(img,99)
  img = img.clip(0,p1)
  img = (img-np.nanmin(img))/(np.nanmax(img)-np.nanmin(img))*255
  img = img.astype(np.uint8, copy=True)
  return img
import random

def get_point_point_matrix(dataset1,dataset2,point_searching_window,search_window_w,search_window_h):
    print(point_searching_window)
    all_combination = itertools.product(range(int(point_searching_window[0][0]), int(point_searching_window[0][0])+int(search_window_w)), range(int(point_searching_window[0][1]), int(point_searching_window[0][1])+int(search_window_h)))
    results = []
    for row,col in all_combination:
            X1, Y1 = dataset1.xy(row, col)
            X2, Y2 = warp.transform(dataset1.crs, dataset2.crs, [X1], [Y1])
            dst_row, dst_col = dataset2.index(X2, Y2)
            results.append((dst_row,dst_col))
    return all_combination,results

def get_sample(tiff1_path,tiff2_path,search_window,patch_size,random_seed=1,verbose=True,random_rotation=0.03,random_zoom=0.03):
  '''
        1-Reads the first band of the TIFF files using the rio library and normalizes the image data.
        2-Selects a random search window and patch location within the image using random number generators.
        3-Calls the get_correspondence function on the selected locations to find the corresponding locations in the second image.
        4-Converts the grayscale images to RGB format using OpenCV.
        5-Creates centered patches from the RGB images by zero-padding the images and copying a portion of the original images to the patches.
        6-Draws rectangles around the search window and patch in both images.
        7-Saves the patch and search window as JPEG files.
        8-Returns the processed RGB images, the points of the search window and patch, and the original rio datasets.

  '''
  dataset1 = rio.open(tiff1_path)
  dataset2 = rio.open(tiff2_path)
  img1 = dataset1.read(1)
  img1 = normalize_image(img1)
  search_window_w,search_window_h = search_window
  patch_size_w,patch_size_h = patch_size
  lu = ( random.randint(0,img1.shape[1]-search_window_w),random.randint(0,img1.shape[0]-search_window_h))

  lu_patch=( lu[0] + random.randint(0,search_window_w-patch_size_w),lu[1] + random.randint(0,search_window_h-patch_size_h))

  points_patch = [lu_patch]
  points = [lu]
  points_ref = [get_correspondence(dataset1, dataset2, x, y) for x,y in points]
  points_patch_ref = [get_correspondence(dataset1, dataset2, x, y) for x,y in points_patch]
  #T0toT1 = [  get_correspondence(dataset1, dataset2, x, y)  for x in range(img1.shape[0])for y in range(img1.shape[1])]
  #T0toT1 = np.array(T0toT1)
  #print(T0toT1.shape,"print t01")
  img1 = dataset1.read(1)
  img1 = normalize_image(img1)

  img2 = dataset2.read(1)
  img2 = normalize_image(img2)


  rgb_img1 = cv2.cvtColor(img1, cv2.COLOR_GRAY2RGB)
  rgb_img2 = cv2.cvtColor(img2, cv2.COLOR_GRAY2RGB)


  #get centerd patches
  patch_source = np.zeros((search_window_h, search_window_w), dtype = np.uint8)
  patch_source[int(search_window_h/2 - patch_size_h/2):int(search_window_h/2 + patch_size_h/2),int(search_window_w/2 - patch_size_w/2):int(search_window_w/2 + patch_size_w/2)] =img1[points_patch[0][1]:points_patch[0][1]+patch_size_h,points_patch[0][0]:points_patch[0][0]+patch_size_w]


  patch_dest = np.zeros((search_window_h, search_window_w), dtype = np.uint8)
  patch_dest[int(search_window_h/2 - patch_size_h/2):int(search_window_h/2 + patch_size_h/2),int(search_window_w/2 - patch_size_w/2):int(search_window_w/2 + patch_size_w/2)]=img2[points_patch_ref[0][1]:points_patch_ref[0][1]+patch_size_h,points_patch_ref[0][0]:points_patch_ref[0][0]+patch_size_w]



  search_window_source = img1[points[0][1]:points[0][1]+search_window_h,points[0][0]:points[0][0]+search_window_w]
  search_window_dest = img2[points_ref[0][1] : points_ref[0][1]+search_window_h, points_ref[0][0]:points_ref[0][0]+search_window_w]

    #draw searching windows
  rgb_img1  = cv2.rectangle(rgb_img1 , points[0],( points[0][0]+search_window_w,points[0][1]+search_window_h), (0,0,255), 2)
  rgb_img1  = cv2.rectangle(rgb_img1 , points_patch[0], ( points_patch[0][0]+patch_size_w,points_patch[0][1]+patch_size_h), (255,0,0), 2)

  #draw patch windows
  rgb_img2  = cv2.rectangle(rgb_img2 , points_ref[0], ( points_ref[0][0]+search_window_w,points_ref[0][1]+search_window_h), (0,0,255), 1)
  rgb_img2  = cv2.rectangle(rgb_img2 , points_patch_ref[0], ( points_patch_ref[0][0]+patch_size_w,points_patch_ref[0][1]+patch_size_h), (255,0,0), 2)

  if verbose:
      fig, axes = plt.subplots(2, 3)

      axes[0, 0].set_title('source image')
      axes[0, 0].imshow(PIL.ImageOps.invert(  Image.fromarray(rgb_img1)))

      axes[0, 1].set_title('search window source')
      axes[0, 1].imshow(PIL.ImageOps.invert(  Image.fromarray(search_window_source)))

      axes[0, 2].set_title('patch source')
      axes[0, 2].imshow(PIL.ImageOps.invert(  Image.fromarray(patch_source)))

      axes[1, 0].set_title('dest image')
      axes[1, 0].imshow(PIL.ImageOps.invert(  Image.fromarray(rgb_img2)))

      axes[1, 1].set_title('search window dest')
      axes[1, 1].imshow(PIL.ImageOps.invert(  Image.fromarray(search_window_dest)))

      axes[1, 2].set_title('patch dest')
      axes[1, 2].imshow(PIL.ImageOps.invert(  Image.fromarray(patch_dest)))

      plt.show()
  comb,res, = get_point_point_matrix(dataset1,dataset2,points,search_window_w,search_window_h)
  print(comb)
  return img1, img2, points,points_patch_ref,dataset1,dataset2




In [67]:

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

from src.utils.dataset import read_megadepth_gray, pad_bottom_right


class MegaDepthDataset(Dataset):
    def __init__(self,
                 data,
                 npz_path,
                 mode='train',
                 min_overlap_score=0.4,
                 img_resize=None,
                 df=None,
                 img_padding=False,
                 depth_padding=False,
                 augment_fn=None,
                 **kwargs):
        """
        Manage one scene(npz_path) of MegaDepth dataset.
        
        Args:
            root_dir (str): megadepth root directory that has `phoenix`.
            npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
            mode (str): options are ['train', 'val', 'test']
            min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
            img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
                                        This is useful during training with batches and testing with memory intensive algorithms.
            df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
            img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
            depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training.
            augment_fn (callable, optional): augments images with pre-defined visual effects.
        """
        super().__init__()
        # self.root_dir = root_dir
        self.mode = mode

        # prepare scene_info and pair_info
        if mode == 'test' and min_overlap_score != 0:
            logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.")
            min_overlap_score = 0

        # parameters for image resizing, padding and depthmap padding
        if mode == 'train':
            assert img_resize is not None and img_padding and depth_padding
        self.img_resize = img_resize
        self.df = df
        self.img_padding = img_padding
        self.depth_max_size = 2000 if depth_padding else None  # the upperbound of depthmaps size in megadepth.

        # for training LoFTR
        self.augment_fn = augment_fn if mode == 'train' else None
        self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
        import os
        path_of_the_directory= './data/paired_sentinel/'
        print("Files and directories in a specified path:")
        self.path1=[]
        self.path2=[]
        for filename in os.listdir(path_of_the_directory):
            f = os.path.join(path_of_the_directory,filename)
            if os.path.isfile(f):
                print(f)
            else:
                lst = os.listdir(f)
                print(lst)
                if len(lst)>1:
                    self.path1.append(os.path.join(path_of_the_directory,filename,lst[0]))
                    self.path2.append(os.path.join(path_of_the_directory,filename,lst[1]))



        #self.path1 = data["path1"].values
        #self.path2 = data["path2"].values

        gc.collect()

    def __len__(self):
        return len(self.path1)

    def __getitem__(self, idx):
        # read grayscale image and mask. (1, h, w) and (h, w)
        img_name0 = self.path1[idx]
        img_name1 = self.path2[idx]
        size_search = (640, 480)
        verbose=False
        size_patch = (int(180*1.333333), int(180))
        img1, img2, points,points_patch_ref,dataset1,dataset2 = get_sample(img_name0,img_name1,size_search,size_patch,verbose=verbose)

        img1 = cv2.resize(img1, (640, 480))
        img2 = cv2.resize(img2, (640, 480))
        img1 = np.expand_dims(np.swapaxes(img1,0,1),0)
        img2 = np.expand_dims(np.swapaxes(img2,0,1),0)
        print(img1.shape,"madonna")

        #T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4]  # (4, 4)
        #T_1to0 = T_0to1.inverse()

        data = {
            'image0': torch.tensor(img1/255).float(),  # (1, h, w)
           # 'depth0': depth0,  # (h, w)
            'image1': torch.tensor(img2/255).float(),
            #'depth1': depth1,
            #'T_0to1': T_0to1,  # (4, 4)
            #'T_1to0': T_1to0,
            #'K0': K_0,  # (3, 3)
            #'K1': K_1,
            'scale0': 1.,  # [scale_w, scale_h]
            'scale1': 1.,
            'dataset_name': 'MegaDepth',
            'scene_id': idx,
            'pair_id': idx,
            'pair_names': (img_name0, img_name1),
        }

        # for LoFTR training

        gc.collect()
        torch.cuda.empty_cache()
        return data

In [68]:

from loguru import logger

import pytorch_lightning as pl
from torch.utils.data import (
    Dataset,
    DataLoader,
    ConcatDataset,
    DistributedSampler
)

from src.utils.augment import build_augmentor


class MultiSceneDataModule(pl.LightningDataModule):
    """ 
    For distributed training, each training process is assgined
    only a part of the training scenes to reduce memory overhead.
    """

    def __init__(self, args, config, data):
        super().__init__()

        # 1. data config
        # Train and Val should from the same data source
        self.trainval_data_source = config.DATASET.TRAINVAL_DATA_SOURCE
        self.test_data_source = config.DATASET.TEST_DATA_SOURCE
        # training and validating
        self.train_data = data
        self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT  # (optional)
        self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT
        self.train_list_path = config.DATASET.TRAIN_LIST_PATH
        self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH
        self.val_data = data
        self.val_pose_root = config.DATASET.VAL_POSE_ROOT  # (optional)
        self.val_npz_root = config.DATASET.VAL_NPZ_ROOT
        self.val_list_path = config.DATASET.VAL_LIST_PATH
        self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH
        # testing
        self.test_data = data
        self.test_pose_root = config.DATASET.TEST_POSE_ROOT  # (optional)
        self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
        self.test_list_path = config.DATASET.TEST_LIST_PATH
        self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH

        # 2. dataset config
        # general options
        self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST  # 0.4, omit data with overlap_score < min_overlap_score
        self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
        self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE)  # None, options: [None, 'dark', 'mobile']

        # MegaDepth options
        self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE  # 840
        self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD  # True
        self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD  # True
        self.mgdpt_df = config.DATASET.MGDPT_DF  # 8
        self.coarse_scale = 1 / config.LOFTR.RESOLUTION[0]  # 0.125. for training loftr.

        # 3.loader parameters
        self.train_loader_params = {
            'batch_size': args.batch_size,
            'num_workers': args.num_workers,
            'pin_memory': getattr(args, 'pin_memory', True)
        }
        self.val_loader_params = {
            'batch_size': 1,
            'shuffle': False,
            'num_workers': args.num_workers,
            'pin_memory': getattr(args, 'pin_memory', True)
        }
        self.test_loader_params = {
            'batch_size': 1,
            'shuffle': False,
            'num_workers': args.num_workers,
            'pin_memory': True
        }

        # 4. sampler
        self.data_sampler = config.TRAINER.DATA_SAMPLER
        self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
        self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
        self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
        self.repeat = config.TRAINER.SB_REPEAT

        # (optional) RandomSampler for debugging

        # misc configurations
        self.parallel_load_data = getattr(args, 'parallel_load_data', False)
        self.seed = config.TRAINER.SEED  # 66

    def setup(self, stage=None):
        """
        Setup train / val / test dataset. This method will be called by PL automatically.
        Args:
            stage (str): 'fit' in training phase, and 'test' in testing phase.
        """

        assert stage in ['fit', 'test'], "stage must be either fit or test"

        if stage == 'fit':
            self.train_dataset = self._setup_dataset(
                self.train_data,
                self.train_npz_root,
                self.train_list_path,
                self.train_intrinsic_path,
                mode='train',
                min_overlap_score=self.min_overlap_score_train,
                pose_dir=self.train_pose_root)
            # setup multiple (optional) validation subsets

            self.val_dataset = self._setup_dataset(
                self.val_data,
                self.val_npz_root,
                self.val_list_path,
                self.val_intrinsic_path,
                mode='val',
                min_overlap_score=self.min_overlap_score_test,
                pose_dir=self.val_pose_root)

        else:  # stage == 'test
            self.test_dataset = self._setup_dataset(
                self.test_data,
                self.test_npz_root,
                self.test_list_path,
                self.test_intrinsic_path,
                mode='test',
                min_overlap_score=self.min_overlap_score_test,
                pose_dir=self.test_pose_root)

    def _setup_dataset(self,
                       data,
                       split_npz_root,
                       scene_list_path,
                       intri_path,
                       mode='train',
                       min_overlap_score=0.,
                       pose_dir=None):
        """ Setup train / val / test set"""
        local_npz_names = ""
        dataset_builder = self._build_concat_dataset
        return dataset_builder(data, local_npz_names, split_npz_root, intri_path,
                               mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)

    def _build_concat_dataset(
            self,
            data,
            npz_names,
            npz_dir,
            intrinsic_path,
            mode,
            min_overlap_score=0.,
            pose_dir=None
    ):
        datasets = []
        augment_fn = self.augment_fn if mode == 'train' else None
        data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
        npz_path = ""

        datasets.append(
            MegaDepthDataset(data,
                             npz_path,
                             mode=mode,
                             min_overlap_score=min_overlap_score,
                             img_resize=self.mgdpt_img_resize,
                             df=self.mgdpt_df,
                             img_padding=self.mgdpt_img_pad,
                             depth_padding=self.mgdpt_depth_pad,
                             augment_fn=augment_fn,
                             coarse_scale=self.coarse_scale))
        return ConcatDataset(datasets)

    def train_dataloader(self):
        """ Build training dataloader for ScanNet / MegaDepth. """
        #         assert self.data_sampler in ['scene_balance']
        #         #logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).')
        #         if self.data_sampler == 'scene_balance':
        #             sampler = RandomConcatSampler(self.train_dataset,
        #                                           self.n_samples_per_subset,
        #                                           self.subset_replacement,
        #                                           self.shuffle, self.repeat, self.seed)
        #         else:
        #             sampler = None
        dataloader = DataLoader(self.train_dataset, batch_size=1,
                                shuffle=False,
                                num_workers=0, pin_memory=True, drop_last=True)
        return dataloader

    def val_dataloader(self):
        """ Build validation dataloader for ScanNet / MegaDepth. """
        #logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.')
        dataloader = DataLoader(self.val_dataset, batch_size=1,
                                shuffle=False,
                                num_workers=0, pin_memory=True, drop_last=True)
        return dataloader

    def test_dataloader(self, *args, **kwargs):
        #logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.')
        sampler = DistributedSampler(self.test_dataset, shuffle=False)
        return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)


def _build_dataset(dataset: Dataset, *args, **kwargs):
    return dataset(*args, **kwargs)

In [69]:
import math
import argparse
import pprint
from distutils.util import strtobool
from pathlib import Path
from loguru import logger as loguru_logger

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

from src.config.default import get_cfg_defaults
from src.utils.misc import get_rank_zero_only_logger, setup_gpus
from src.utils.profiler import build_profiler
import pandas as pd

loguru_logger = get_rank_zero_only_logger(loguru_logger)


def parse_args():
    # init a costum parser which will be added into pl.Trainer parser
    # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        'data_cfg_path', type=str, help='data config path')
    parser.add_argument(
        'main_cfg_path', type=str, help='main config path')
    parser.add_argument(
        '--exp_name', type=str, default='default_exp_name')
    parser.add_argument(
        '--batch_size', type=int, default=4, help='batch_size per gpu')
    parser.add_argument(
        '--num_workers', type=int, default=4)
    parser.add_argument(
        '--pin_memory', type=lambda x: bool(strtobool(x)),
        nargs='?', default=True, help='whether loading data to pinned memory or not')
    parser.add_argument(
        '--ckpt_path', type=str, default="./weights/outdoor_ds.ckpt",
        help='pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR')
    parser.add_argument(
        '--disable_ckpt', action='store_true',
        help='disable checkpoint saving (useful for debugging).')
    parser.add_argument(
        '--profiler_name', type=str, default=None,
        help='options: [inference, pytorch], or leave it unset')
    parser.add_argument(
        '--parallel_load_data', action='store_true',
        help='load datasets in with multiple processes.')

    parser = pl.Trainer.add_argparse_args(parser)
    return parser.parse_args(
        './configs/data/megadepth_trainval_640.py ./configs/loftr/outdoor/loftr_ds_dense.py --exp_name test --gpus 0 --num_nodes 0 --accelerator gpu --batch_size 1 --check_val_every_n_epoch 1 --log_every_n_steps 1 --limit_val_batches 1 --num_sanity_val_steps 10 --benchmark True --max_epochs 4'.split())


def train():
    # parse arguments
    args = parse_args()
    rank_zero_only(pprint.pprint)(vars(args))

    # init default-cfg and merge it with the main- and data-cfg
    config = get_cfg_defaults()
    config.merge_from_file(args.main_cfg_path)
    config.merge_from_file(args.data_cfg_path)
    pl.seed_everything(config.TRAINER.SEED)  # reproducibility
    # TODO: Use different seeds for each dataloader workers
    # This is needed for data augmentation

    # scale lr and warmup-step automatically
    args.gpus = _n_gpus = setup_gpus(args.gpus)
    config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
    config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
    _scaling = 1  #config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
    config.TRAINER.SCALING = _scaling
    config.TRAINER.TRUE_LR = 0.00001 * _scaling
    config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)

    # lightning module
    profiler = build_profiler(args.profiler_name)
    model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
    loguru_logger.info(f"LoFTR LightningModule initialized!")

    # lightning data
    data = None #pd.read_csv("../input/imc-gt/train.csv")
    data_module = MultiSceneDataModule(args, config, data)
    gc.collect()
    loguru_logger.info(f"LoFTR DataModule initialized!")

    # TensorBoard Logger
    logger = TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)
    ckpt_dir = Path(logger.log_dir) / 'checkpoints'

    # Callbacks
    # TODO: update ModelCheckpoint to monitor multiple metrics
    ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max',
                                    save_last=True,
                                    dirpath=str(ckpt_dir),
                                    filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
    lr_monitor = LearningRateMonitor(logging_interval='step')
    callbacks = [lr_monitor]
    if not args.disable_ckpt:
        callbacks.append(ckpt_callback)

    # Lightning Trainer
    trainer = pl.Trainer.from_argparse_args(
        args,
        #         plugins=DDPPlugin(find_unused_parameters=False,
        #                           num_nodes=args.num_nodes,
        #                           sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
        gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
        callbacks=callbacks,
        logger=logger,
        #sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
        replace_sampler_ddp=False,  # use custom sampler
        # avoid repeated samples!
      #  weights_summary='full',
        profiler=profiler)
    loguru_logger.info(f"Trainer initialized!")
    loguru_logger.info(f"Start training!")
    trainer.fit(model, datamodule=data_module)


In [70]:
train()

Global seed set to 66


{'accelerator': 'gpu',
 'accumulate_grad_batches': None,
 'amp_backend': 'native',
 'amp_level': None,
 'auto_lr_find': False,
 'auto_scale_batch_size': False,
 'auto_select_gpus': False,
 'batch_size': 1,
 'benchmark': True,
 'check_val_every_n_epoch': 1,
 'ckpt_path': './weights/outdoor_ds.ckpt',
 'data_cfg_path': './configs/data/megadepth_trainval_640.py',
 'default_root_dir': None,
 'detect_anomaly': False,
 'deterministic': None,
 'devices': None,
 'disable_ckpt': False,
 'enable_checkpointing': True,
 'enable_model_summary': True,
 'enable_progress_bar': True,
 'exp_name': 'test',
 'fast_dev_run': False,
 'gpus': 0,
 'gradient_clip_algorithm': None,
 'gradient_clip_val': None,
 'inference_mode': True,
 'ipus': None,
 'limit_predict_batches': None,
 'limit_test_batches': None,
 'limit_train_batches': None,
 'limit_val_batches': 1,
 'log_every_n_steps': 1,
 'logger': True,
 'main_cfg_path': './configs/loftr/outdoor/loftr_ds_dense.py',
 'max_epochs': 4,
 'max_steps': -1,
 'max_time'

2023-02-16 12:44:04.960 | INFO     | __main__:__init__:43 - Load './weights/outdoor_ds.ckpt' as pretrained checkpoint
2023-02-16 12:44:04.961 | INFO     | __main__:train:81 - LoFTR LightningModule initialized!
2023-02-16 12:44:05.190 | INFO     | __main__:train:87 - LoFTR DataModule initialized!
Missing logger folder: logs/tb_logs/test
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
2023-02-16 12:44:05.284 | INFO     | __main__:train:118 - Trainer initialized!
2023-02-16 12:44:05.285 | INFO     | __main__:train:119 - Start training!


Files and directories in a specified path:
['S1A_IW_GRDH_1SDV_20220815T023449_20220815T023518_044557_05515E_13BF.tif', 'S1A_IW_GRDH_1SDV_20220804T142223_20220804T142248_044404_054C89_A2BA.tif']
['S1A_IW_GRDH_1SDV_20220804T142158_20220804T142223_044404_054C89_7956.tif', 'S1A_IW_GRDH_1SDV_20220810T022716_20220810T022741_044484_054EF2_2399.tif']
['S1A_IW_GRDH_1SDV_20220812T070928_20220812T070953_044516_054FFB_C359.tif', 'S1A_IW_GRDH_1SDV_20220811T190707_20220811T190732_044509_054FC0_7579.tif']
['S1A_IW_GRDH_1SDV_20220813T043011_20220813T043036_044529_05506C_3AA2.tif', 'S1A_IW_GRDH_1SDV_20220814T161734_20220814T161759_044551_05513A_93D3.tif']
['S1A_IW_GRDH_1SDV_20220810T165117_20220810T165142_044493_054F3B_64D1.tif', 'S1A_IW_GRDH_1SDV_20220814T051044_20220814T051109_044544_0550FA_CDF6.tif']
['S1A_IW_GRDH_1SDV_20220810T054340_20220810T054405_044486_054F01_ABF3.tif', 'S1A_IW_GRDH_1SDV_20220806T172346_20220806T172411_044435_054D76_F342.tif']
['S1A_IW_GRDH_1SDV_20220814T051339_20220814T051404_

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | matcher | LoFTR     | 11.6 M
1 | loss    | LoFTRLoss | 0     
--------------------------------------
11.6 M    Trainable params
0         Non-trainable params
11.6 M    Total params
46.246    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

[(9, 638)]
<itertools.product object at 0x7fae0ef287c0>
(1, 640, 480) madonna


RuntimeError: The size of tensor a (2) must match the size of tensor b (2568) at non-singleton dimension 1

In [None]:
model

In [None]:
% cd checkpoints /

In [None]:
ls