In [None]:
import cv2
import tqdm
import torch
import pycolmap
import numpy as np
import matplotlib.pyplot as plt
from hloc import extract_features, extractors
from hloc.utils.base_model import dynamic_load
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter


im_dir = Path('/workspace/TourEiffelClean/EiffelTower/global/images/')
rec_dir = Path('/workspace/TourEiffelClean/EiffelTower/global/sfm/')
rep_path = Path('representations.pt')
conf = extract_features.confs['netvlad']
device = 'cuda:1'


class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, reconstruction_dir):
        print('Loading reconstruction.')
        reconstruction = pycolmap.Reconstruction(reconstruction_dir)
        
        print('Processing positives and negatives.')
        image_names = []
        image_tvecs = []
        for image in reconstruction.images.values():
            if image.name[:4] != '2015':
                image_names.append(image.name)
                image_tvecs.append(-image.rotmat().T @ image.tvec)
        image_names = np.array(image_names)
        image_tvecs = np.stack(image_tvecs)

        positives = {}
        for image_name, image_tvec in tqdm.tqdm(list(zip(image_names, image_tvecs))):
            image_distances = np.linalg.norm(image_tvecs - image_tvec, axis=1)
            image_positives_args = image_distances < 3
            image_positives = image_names[image_positives_args]
            image_distances = image_distances[image_positives_args]
            image_positives = image_positives[np.argsort(image_distances)]
            image_positives = [pos for pos in image_positives if pos[:4] != image_name[:4]]
            if len(image_positives) >= 3:
                positives[image_name] = image_positives[:15]  # max 15 positives per image

        print(f'Kept {len(positives) / len(image_names) * 100:.2f}% images with sufficient positives.')
        
        self.image_names = image_names
        self.image_tvecs = image_tvecs
        self.positives = positives
        self.valid_names = list(positives)
        self.representations = {}
        self.negatives = {}
        self.train = True
    
    def load_negatives(self, query_name):
        if self.train:
            query_tvec = self.image_tvecs[self.image_names == query_name]
            query_distances = np.linalg.norm(self.image_tvecs - query_tvec, axis=1)
            query_negative_names = self.image_names[query_distances > 10]
            np.random.shuffle(query_negative_names)
            return query_negative_names[:20].tolist()
        else:
            return self.negatives[query_name]

    def __getitem__(self, idx):
        query_name = self.valid_names[idx]
        positive_names = self.positives[query_name]
        negative_names = self.load_negatives(query_name)
        
        query_image = self.representations[query_name]
        positive_images = [self.representations[pos] for pos in positive_names]
        negative_images = [self.representations[neg] for neg in negative_names]
        
        return query_image, positive_images, negative_images
        
    def __len__(self):
        return len(self.valid_names)

In [None]:
Model = dynamic_load(extractors, conf['model']['name'])

dataset = ImageDataset(reconstruction_dir=rec_dir)

if rep_path.exists():
    dataset.representations = torch.load(rep_path)
else:
    print('Compute image representations.')
    representation_model = Model({'name': 'netvlad', 'whiten': False}).eval().requires_grad_(False).to(device)
    representation_model.netvlad = torch.nn.Identity()
    representations_dataset = extract_features.ImageDataset(im_dir, conf['preprocessing'], dataset.image_names.tolist())
    loader = torch.utils.data.DataLoader(
        representations_dataset,
        num_workers=1, shuffle=False, pin_memory=True
    )
    for idx, data in enumerate(tqdm.tqdm(loader)):
        rname = representations_dataset.names[idx]
        dataset.representations[rname] = representation_model({'image': data['image'].to(device, non_blocking=True)})['global_descriptor'][0].cpu()

In [None]:
dataset_idxs = np.arange(len(dataset))
np.random.shuffle(dataset_idxs)
train_idxs = dataset_idxs[100:]
test_idxs = dataset_idxs[:100]
train_dataset = torch.utils.data.Subset(dataset, train_idxs)
test_dataset = torch.utils.data.Subset(dataset, test_idxs)
for test_idx in test_idxs:
    tname = dataset.valid_names[test_idx]
    dataset.negatives[tname] = dataset.load_negatives(tname)

train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=12, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=12, shuffle=False, pin_memory=True)

In [None]:
model = Model(conf['model']).train().requires_grad_(False).to(device)
model.netvlad.centers.requires_grad = True
optimizer = torch.optim.Adam([model.netvlad.centers], lr=1e-4)

In [None]:
def inference(m, rep):
    desc = m.netvlad(rep)
    desc = m.whiten(desc)
    desc = torch.nn.functional.normalize(desc, dim=1)
    return desc


def compute_tuple_loss(m, tup, marg, dev):
    qim, pims, nims = tup
    pred_q = inference(m, qim.to(dev, non_blocking=True))[0]
    pred_ps = [inference(m, pim.to(dev, non_blocking=True))[0] for pim in pims]
    pred_ns = [inference(m, nim.to(dev, non_blocking=True))[0] for nim in nims]

    pred_min = torch.square(torch.stack(pred_ps) - pred_q).sum(dim=1).min()
    distances_ns = torch.square(torch.stack(pred_ns) - pred_q).sum(dim=1)
    tup_loss = torch.clip(pred_min + marg - distances_ns, min=0).sum()
    
    return tup_loss


margin = 0.1
batch_size = 4
eval_interval = 200

it = 0
it_loss = 0

writer = SummaryWriter('logs')

optimizer.zero_grad()

for epoch in range(30):
    for tupl in tqdm.tqdm(train_loader):
        tupl_loss = compute_tuple_loss(model, tupl, margin, device)
        
        tupl_loss.backward()
        
        it += 1
        # it_loss += tupl_loss.item()
        
        if it % batch_size == 0:
            optimizer.step()
            
            if (it // batch_size) % eval_interval == 0:
                with torch.no_grad():
                    model.eval()
                    dataset.train = False
                    test_loss = 0
                    for tupl in test_loader:
                        test_loss += compute_tuple_loss(model, dataset, tupl, margin, device).item()
                    writer.add_scalar('test loss', test_loss, it // batch_size)
                    writer.flush()
                    dataset.train = True
                    model.train()
                    
                    torch.save({
                        'centers': model.netvlad.centers.detach().cpu(),
                        'optimizer': optimizer.state_dict()
                    }, f'checkpoints/iter{it // batch_size: 06d}.pt')
            
            # it_loss = 0
            optimizer.zero_grad()