In [None]:
import cv2
import tqdm
import torch
import pycolmap
import numpy as np
import matplotlib.pyplot as plt
from hloc import extract_features, extractors
from hloc.utils.base_model import dynamic_load
from pathlib import Path


conf = extract_features.confs['netvlad']
device = 'cuda'


class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, image_dir, reconstruction_dir, resize_max=360):
        print('Loading reconstruction.')
        reconstruction = pycolmap.Reconstruction(Path('/media/clementin/data/EiffelTower/global/sfm'))
        
        print('Processing positives and negatives.')
        image_names = []
        image_tvecs = []
        for image in reconstruction.images.values():
            if image.name[:4] != '2015':
                image_names.append(image.name)
                image_tvecs.append(-image.rotmat().T @ image.tvec)
        image_names = np.array(image_names)
        image_tvecs = np.stack(image_tvecs)

        positives = {}
        for image_name, image_tvec in tqdm.tqdm(list(zip(image_names, image_tvecs))):
            image_distances = np.linalg.norm(image_tvecs - image_tvec, axis=1)
            image_positives_args = image_distances < 3
            image_positives = image_names[image_positives_args]
            image_distances = image_distances[image_positives_args]
            image_positives = image_positives[np.argsort(image_distances)]
            image_positives = [pos for pos in image_positives if pos[:4] != image_name[:4]]
            if len(image_positives) >= 3:
                positives[image_name] = image_positives[:15]  # max 15 positives per image

        print(f'Kept {len(positives) / len(image_names) * 100:.2f}% images with sufficient positives.')
        
        self.image_dir = image_dir
        self.resize_max = resize_max
        self.image_names = image_names
        self.image_tvecs = image_tvecs
        self.positives = positives
        self.valid_names = list(positives)
        self.hard_negatives = {image_name: [] for image_name in positives}
    
    def load_image(self, image_name):
        image = cv2.imread(str(self.image_dir / image_name))
        size = np.array(image.shape[1::-1])
        new_size = tuple(map(round, size * self.resize_max / size.max()))
        image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = torch.tensor(image).movedim(2, 0) / 255
        return image
    
    def load_negatives(self, query_name):
        query_tvec = self.image_tvecs[self.image_names == query_name]
        query_distances = np.linalg.norm(self.image_tvecs - query_tvec, axis=1)
        query_negatives = self.image_names[query_distances > 10]
        hard_negative_names = self.hard_negatives[query_name]
        new_negative_names = query_negatives[~np.isin(query_negatives, hard_negative_names)]
        np.random.shuffle(new_negative_names)
        query_negative_names = hard_negative_names + new_negative_names.tolist()
        return query_negative_names[:20]

    def __getitem__(self, idx):
        query_name = self.valid_names[idx]
        positive_names = self.positives[query_name]
        negative_names = self.load_negatives(query_name)
        
        query_image = self.load_image(query_name)
        positive_images = [self.load_image(pos) for pos in positive_names]
        negative_images = [self.load_image(neg) for neg in negative_names]
        
        return query_image, positive_images, negative_images, query_name, negative_names
        
    def __len__(self):
        return len(self.valid_names)

In [None]:
dataset = ImageDataset(
    image_dir=Path('/media/clementin/data/EiffelTower/global/images/'),
    reconstruction_dir=Path('/media/clementin/data/EiffelTower/global/sfm')
)
loader = torch.utils.data.DataLoader(dataset, num_workers=6, shuffle=True, pin_memory=True)

In [None]:
Model = dynamic_load(extractors, conf['model']['name'])
model = Model(conf['model']).train().requires_grad_(False).to(device)
model.netvlad.centers.requires_grad = True
optimizer = torch.optim.Adam([model.netvlad.centers], lr=1e-2)

In [None]:
margin = 0.1
batch_size = 4

it = 0
it_loss = 0

with open('log.txt', 'w') as f:
    f.write('epoch,iter,loss\n')

optimizer.zero_grad()

for epoch in range(5):
    for idx, (qim, pims, nims, qname, nnames) in enumerate(tqdm.tqdm(loader)):
        pred_q = model({'image': qim.to(device, non_blocking=True)})['global_descriptor'][0]
        pred_ps = [model({'image': pim.to(device, non_blocking=True)})['global_descriptor'][0] for pim in pims]
        pred_ns = [model({'image': nim.to(device, non_blocking=True)})['global_descriptor'][0] for nim in nims]

        pred_min = torch.square(torch.stack(pred_ps) - pred_q).sum(dim=1).min()
        distances_ns = torch.square(torch.stack(pred_ns) - pred_q).sum(dim=1)
        tuple_loss = torch.clip(pred_min + margin - distances_ns, min=0).sum()

        dataset.hard_negatives[qname[0]] = np.array(nnames)[distances_ns.detach().argsort().cpu().numpy()[:10]].tolist()
        
        tuple_loss.backward()
        
        it += 1
        it_loss += tuple_loss.item()
        
        if it % batch_size == 0:
            optimizer.step()
            optimizer.zero_grad()
            
            with open('log.txt', 'a') as f:
                f.write(f'{epoch},{it // batch_size},{it_loss}\n')
            
            it_loss = 0