In [None]:
import cv2
import tqdm
import torch
import pycolmap
import numpy as np
import matplotlib.pyplot as plt
from hloc import extract_features, extractors
from hloc.utils.base_model import dynamic_load
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter


im_dir = Path('/workspace/TourEiffelClean/EiffelTower/global/images/')
rec_dir = Path('/workspace/TourEiffelClean/EiffelTower/global/sfm/')
conf = extract_features.confs['netvlad']
device = 'cuda:0'


class ImageDataset(torch.utils.data.Dataset):

    def __init__(
        self,
        image_dir,
        reconstruction_dir,
        resize_max=1024,
        positives_num=20,
        negatives_num=100,
        positives_distance=3,
        negatives_distance=10
    ):
        print('Loading reconstruction.')
        reconstruction = pycolmap.Reconstruction(reconstruction_dir)
        
        print('Processing positives.')
        image_names = []
        image_tvecs = []
        for image in reconstruction.images.values():
            if image.name[:4] != '2015':
                image_names.append(image.name)
                image_tvecs.append(-image.rotmat().T @ image.tvec)
        image_names = np.array(image_names)
        image_tvecs = np.stack(image_tvecs)

        positives = {}
        for image_name, image_tvec in tqdm.tqdm(list(zip(image_names, image_tvecs))):
            image_distances = np.linalg.norm(image_tvecs - image_tvec, axis=1)
            image_positives_args = image_distances < positives_distance
            image_positives = image_names[image_positives_args]
            image_distances = image_distances[image_positives_args]
            image_positives = image_positives[np.argsort(image_distances)]
            image_positives = [pos for pos in image_positives if pos[:4] != image_name[:4]]
            if len(image_positives) >= 3:
                positives[image_name] = image_positives[:positives_num]

        print(f'Kept {len(positives)/len(image_names)*100:.2f}% images with sufficient positives.')
        
        args_valid = np.isin(image_names, list(positives))
        image_names = image_names[args_valid]
        image_tvecs = image_tvecs[args_valid]
        
        self.image_dir = image_dir
        self.reconstruction_dir = reconstruction_dir
        self.resize_max = resize_max
        self.positives_num = positives_num
        self.negatives_num = negatives_num
        self.positives_distance = positives_distance
        self.negatives_distance = negatives_distance
        self.image_names = image_names
        self.image_tvecs = image_tvecs
        self.positives = positives
        self.negatives = {}
        self.query_only = False
    
    def load_image(self, image_name):
        image = cv2.imread(str(self.image_dir / image_name))
        size = np.array(image.shape[1::-1])
        new_size = tuple(map(round, size * self.resize_max / size.max()))
        image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = torch.tensor(image).movedim(2, 0) / 255
        return image
    
    def random_negatives(self, idx):
        query_tvec = self.image_tvecs[idx]
        query_dists = np.linalg.norm(self.image_tvecs - query_tvec, axis=1)
        query_negative_names = self.image_names[query_dists >= self.negatives_distance]
        np.random.shuffle(query_negative_names)
        return query_negative_names[:self.negatives_num].tolist()
    
    def compute_negatives(self, m, loader, dev):
        print('Computing all descriptors.')
        descs = torch.full((len(self), 4096), torch.nan, device=dev)
        with torch.no_grad():
            m.eval()
            self.query_only = True
            for qidx, qim in tqdm.tqdm(loader):
                descs[qidx.item()] = m({'image': qim.to(dev, non_blocking=True)})['global_descriptor'][0]
            self.query_only = False
            m.train()
        
        print('Computing negatives.')
        for qidx in tqdm.tqdm(range(len(self))):
            qdesc = descs[qidx]
            if ~qdesc.isnan().any():
                qname = self.image_names[qidx]
                qtvec = self.image_tvecs[qidx]
                qdists = np.linalg.norm(self.image_tvecs - qtvec, axis=1)
                qdesc_dists = torch.square(descs - qdesc).sum(dim=1)
                qdesc_dists[qdists < self.negatives_distance] = torch.nan
                qnegs = qdesc_dists.argsort()
                qnegs = qnegs[~qdesc_dists[qnegs].isnan()].cpu().numpy()
                self.negatives[qname] = self.image_names[qnegs[:self.negatives_num]].tolist()

    def __getitem__(self, idx):
        query_name = self.image_names[idx]
        query_image = self.load_image(query_name)
        
        if self.query_only:
            return idx, query_image
        
        positive_names = self.positives[query_name]
        positive_images = [self.load_image(pos) for pos in positive_names]
        
        negative_names = self.negatives[query_name]
        negative_images = [self.load_image(neg) for neg in negative_names]
        
        return query_image, positive_images, negative_images
        
    def __len__(self):
        return len(self.image_names)

In [None]:
Model = dynamic_load(extractors, conf['model']['name'])
model = Model(conf['model']).train().requires_grad_(True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

In [None]:
dataset = ImageDataset(
    image_dir=im_dir,
    reconstruction_dir=rec_dir,
    resize_max=512,
    positives_num=15,
    negatives_num=20
)

dataset_idxs = np.arange(len(dataset))
np.random.shuffle(dataset_idxs)
train_idxs = dataset_idxs[100:]
test_idxs = dataset_idxs[:100]
train_dataset = torch.utils.data.Subset(dataset, train_idxs)
test_dataset = torch.utils.data.Subset(dataset, test_idxs)
for test_idx in test_idxs:
    dataset.negatives[dataset.image_names[test_idx]] = dataset.random_negatives(test_idx)

train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=16, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=16, shuffle=False, pin_memory=True)

In [None]:
def compute_tuple_loss(m, tup, marg, dev):
    qim, pims, nims = tup
    pred_q = m({'image': qim.to(dev, non_blocking=True)})['global_descriptor'][0]
    pred_ps = [m({'image': pim.to(dev, non_blocking=True)})['global_descriptor'][0] for pim in pims]
    pred_ns = [m({'image': nim.to(dev, non_blocking=True)})['global_descriptor'][0] for nim in nims]

    pred_min = torch.square(torch.stack(pred_ps) - pred_q).sum(dim=1).min()
    distances_ns = torch.square(torch.stack(pred_ns) - pred_q).sum(dim=1)
    tup_loss = torch.clip(pred_min + marg - distances_ns, min=0).sum()
    
    return tup_loss


margin = 0.1
batch_size = 4
eval_interval = 200

it = 0
it_loss = 0

writer = SummaryWriter('logs/full512/')

optimizer.zero_grad()

for epoch in range(50):
    
    dataset.compute_negatives(model, train_loader, device)
    
    for tupl in tqdm.tqdm(train_loader):
        tupl_loss = compute_tuple_loss(model, tupl, margin, device)
        
        tupl_loss.backward()
        
        it += 1
        it_loss += tupl_loss.item()
        
        if it % batch_size == 0:
            optimizer.step()
            
            writer.add_scalar('train loss', it_loss, it // batch_size)
            writer.flush()
            
            if (it // batch_size) % eval_interval == 0:
                with torch.no_grad():
                    model.eval()
                    test_loss = 0
                    for tupl in test_loader:
                        test_loss += compute_tuple_loss(model, tupl, margin, device).item()
                    writer.add_scalar('val loss', test_loss, it // batch_size)
                    model.train()
                    
                    torch.save({
                        'netvlad': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, f'checkpoints/full512/iter{it // batch_size:06d}.pt')
            
            it_loss = 0
            optimizer.zero_grad()