In [1]:
import cv2
import tqdm
import torch
import pycolmap
import numpy as np
import matplotlib.pyplot as plt
from hloc import extract_features, extractors
from hloc.utils.base_model import dynamic_load
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter


conf = extract_features.confs['netvlad']
device = 'cuda:1'


class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, image_dir, reconstruction_dir, resize_max=1024):
        print('Loading reconstruction.')
        reconstruction = pycolmap.Reconstruction(reconstruction_dir)
        
        print('Processing positives and negatives.')
        image_names = []
        image_tvecs = []
        for image in reconstruction.images.values():
            if image.name[:4] != '2015':
                image_names.append(image.name)
                image_tvecs.append(-image.rotmat().T @ image.tvec)
        image_names = np.array(image_names)
        image_tvecs = np.stack(image_tvecs)

        positives = {}
        for image_name, image_tvec in tqdm.tqdm(list(zip(image_names, image_tvecs))):
            image_distances = np.linalg.norm(image_tvecs - image_tvec, axis=1)
            image_positives_args = image_distances < 3
            image_positives = image_names[image_positives_args]
            image_distances = image_distances[image_positives_args]
            image_positives = image_positives[np.argsort(image_distances)]
            image_positives = [pos for pos in image_positives if pos[:4] != image_name[:4]]
            if len(image_positives) >= 3:
                positives[image_name] = image_positives[:15]  # max 15 positives per image

        print(f'Kept {len(positives) / len(image_names) * 100:.2f}% images with sufficient positives.')
        
        self.image_dir = image_dir
        self.resize_max = resize_max
        self.image_names = image_names
        self.image_tvecs = image_tvecs
        self.positives = positives
        self.valid_names = list(positives)
        self.hard_negatives = {image_name: [] for image_name in positives}
        self.train = True
    
    def load_image(self, image_name):
        image = cv2.imread(str(self.image_dir / image_name))
        size = np.array(image.shape[1::-1])
        new_size = tuple(map(round, size * self.resize_max / size.max()))
        image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = torch.tensor(image).movedim(2, 0) / 255
        return image
    
    def load_negatives(self, query_name):
        if self.train:
            query_tvec = self.image_tvecs[self.image_names == query_name]
            query_distances = np.linalg.norm(self.image_tvecs - query_tvec, axis=1)
            query_negatives = self.image_names[query_distances > 10]
            hard_negative_names = self.hard_negatives[query_name]
            new_negative_names = query_negatives[~np.isin(query_negatives, hard_negative_names)]
            np.random.shuffle(new_negative_names)
            query_negative_names = hard_negative_names + new_negative_names.tolist()
            return query_negative_names[:20]
        else:
            return self.hard_negatives[query_name]

    def __getitem__(self, idx):
        query_name = self.valid_names[idx]
        positive_names = self.positives[query_name]
        negative_names = self.load_negatives(query_name)
        
        query_image = self.load_image(query_name)
        positive_images = [self.load_image(pos) for pos in positive_names]
        negative_images = [self.load_image(neg) for neg in negative_names]
        
        return query_image, positive_images, negative_images, query_name, negative_names
        
    def __len__(self):
        return len(self.valid_names)

In [2]:
dataset = ImageDataset(
    image_dir=Path('/workspace/TourEiffelClean/EiffelTower/global/images/'),
    reconstruction_dir=Path('/workspace/TourEiffelClean/EiffelTower/global/sfm/'),
    resize_max=1024
)

dataset_idxs = np.arange(len(dataset))
np.random.shuffle(dataset_idxs)
train_idxs = dataset_idxs[100:]
test_idxs = dataset_idxs[:100]
train_dataset = torch.utils.data.Subset(dataset, train_idxs)
test_dataset = torch.utils.data.Subset(dataset, test_idxs)
for test_idx in test_idxs:
    tname = dataset.valid_names[test_idx]
    dataset.hard_negatives[tname] = dataset.load_negatives(tname)[:10]

train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=32, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=32, shuffle=False, pin_memory=True)

Loading reconstruction.
Processing positives and negatives.


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13168/13168 [00:06<00:00, 2181.14it/s]


Kept 97.62% images with sufficient positives.


In [3]:
Model = dynamic_load(extractors, conf['model']['name'])
model = Model(conf['model']).train().requires_grad_(False).to(device)
model.netvlad.centers.requires_grad = True
optimizer = torch.optim.Adam([model.netvlad.centers], lr=1e-4)

In [None]:
def compute_tuple_loss(m, d, tup, marg, dev):
    qim, pims, nims, qname, nnames = tup
    pred_q = m({'image': qim.to(dev, non_blocking=True)})['global_descriptor'][0]
    pred_ps = [m({'image': pim.to(dev, non_blocking=True)})['global_descriptor'][0] for pim in pims]
    pred_ns = [m({'image': nim.to(dev, non_blocking=True)})['global_descriptor'][0] for nim in nims]

    pred_min = torch.square(torch.stack(pred_ps) - pred_q).sum(dim=1).min()
    distances_ns = torch.square(torch.stack(pred_ns) - pred_q).sum(dim=1)
    tup_loss = torch.clip(pred_min + marg - distances_ns, min=0).sum()
    
    if d.train:
        d.hard_negatives[qname[0]] = np.array(nnames).flatten()[distances_ns.detach().argsort().cpu().numpy()[:10]].tolist()
    
    return tup_loss
    

margin = 0.1
batch_size = 4
eval_interval = 200

it = 0
it_loss = 0

writer = SummaryWriter('logs')

optimizer.zero_grad()

for epoch in range(30):
    for tupl in tqdm.tqdm(train_loader):
        tupl_loss = compute_tuple_loss(model, dataset, tupl, margin, device)
        
        tupl_loss.backward()
        
        it += 1
        # it_loss += tupl_loss.item()
        
        if it % batch_size == 0:
            optimizer.step()
            
            if (it // batch_size) % eval_interval == 0:
                with torch.no_grad():
                    model.eval()
                    dataset.train = False
                    test_loss = 0
                    for tupl in test_loader:
                        test_loss += compute_tuple_loss(model, dataset, tupl, margin, device).item()
                    writer.add_scalar('test loss', test_loss, it // batch_size)
                    writer.flush()
                    dataset.train = True
                    model.train()
                    
                    torch.save({
                        'centers': model.netvlad.centers.detach().cpu(),
                        'optimizer': optimizer.state_dict()
                    }, f'checkpoints/iter{it // batch_size: 06d}.pt')
            
            # it_loss = 0
            optimizer.zero_grad()

  8%|████████████▍                                                                                                                                                   | 994/12754 [14:40<2:38:24,  1.24it/s]