In [1]:
#!pip install kornia --upgrade
#!pip install pytorch_metric_learning

In [2]:
%matplotlib inline 
%load_ext autoreload
%autoreload 2
import random
import numpy as np
from fastprogress.fastprogress import master_bar, progress_bar
from fastai2.basics import *
from fastcore import *
from fastai2.vision.all import *
from fastai2.callback.all import *
from fastprogress import fastprogress
from fastai2.callback.mixup import *
from fastscript import *
import torchvision as tv
import kornia as K
import gc
from pytorch_metric_learning import losses, miners

def imshow_torch(tensor, *kwargs):
    plt.figure()
    plt.imshow(K.tensor_to_image(tensor), *kwargs)
    return

In [3]:
train_ds_name = 'liberty'
val_ds_names = ['notredame', 'yosemite']

ds_root = '/home/old-ufo/datasets/Brown/'

In [4]:
import os
import numpy as np
from PIL import Image
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
import torchvision
from torchvision.datasets import VisionDataset

from torchvision.datasets.utils import download_url

from fastai2  import *


class PhotoTourRevisited(torchvision.datasets.VisionDataset):
    """`Learning Local Image Descriptors Data <http://phototour.cs.washington.edu/patches/default.htm>`_ Dataset.
    Args:
        root (string): Root directory where images are.
        name (string): Name of the dataset to load.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """
    urls = {
        'notredame_harris': [
            'http://matthewalunbrown.com/patchdata/notredame_harris.zip',
            'notredame_harris.zip',
            '69f8c90f78e171349abdf0307afefe4d'
        ],
        'yosemite_harris': [
            'http://matthewalunbrown.com/patchdata/yosemite_harris.zip',
            'yosemite_harris.zip',
            'a73253d1c6fbd3ba2613c45065c00d46'
        ],
        'liberty_harris': [
            'http://matthewalunbrown.com/patchdata/liberty_harris.zip',
            'liberty_harris.zip',
            'c731fcfb3abb4091110d0ae8c7ba182c'
        ],
        'notredame': [
            'http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip',
            'notredame.zip',
            '509eda8535847b8c0a90bbb210c83484'
        ],
        'yosemite': [
            'http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip',
            'yosemite.zip',
            '533b2e8eb7ede31be40abc317b2fd4f0'
        ],
        'liberty': [
            'http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip',
            'liberty.zip',
            'fdd9152f138ea5ef2091746689176414'
        ],
    }
    means = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437,
             'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437}
    stds = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019,
            'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019}
    lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092,
            'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295}
    image_ext = 'bmp'
    info_file = 'info.txt'
    matches_files = 'm50_100000_100000_0.txt'
    img_info_files = 'interest.txt'

    def __init__(
            self, root: str, name: str, train: bool = False,
        transform: Optional[Callable] = None, download: bool = False
    ) -> None:
        super(PhotoTourRevisited, self).__init__(root)
        self.name = name
        self.data_dir = os.path.join(self.root, name)
        self.data_down = os.path.join(self.root, '{}.zip'.format(name))
        self.data_file = os.path.join(self.root, '{}.pt'.format(name))

        self.train = train
        self.mean = self.means[name]
        self.std = self.stds[name]

        if download:
            self.download()

        if not self._check_datafile_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        # load the serialized data
        self.data, self.labels, self.matches, self.img_idxs = torch.load(self.data_file)

    def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (data1, data2, matches)
        """
        data = self.data[index]
        if self.transform is not None:
            data = self.transform(data)
        if self.train:
            return data
        return data, self.labels[index], self.img_idxs[index]

    def __len__(self) -> int:
        return self.lens[self.name]


    def _check_datafile_exists(self) -> bool:
        return os.path.exists(self.data_file)

    def _check_downloaded(self) -> bool:
        return os.path.exists(self.data_dir)

    def download(self) -> None:
        if self._check_datafile_exists():
            print('# Found cached data {}'.format(self.data_file))
            return

        if not self._check_downloaded():
            # download files
            url = self.urls[self.name][0]
            filename = self.urls[self.name][1]
            md5 = self.urls[self.name][2]
            fpath = os.path.join(self.root, filename)

            download_url(url, self.root, filename, md5)

            print('# Extracting data {}\n'.format(self.data_down))

            import zipfile
            with zipfile.ZipFile(fpath, 'r') as z:
                z.extractall(self.data_dir)

            os.unlink(fpath)

        # process and save as torch files
        print('# Caching data {}'.format(self.data_file))

        dataset = (
            read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
            read_info_file(self.data_dir, self.info_file),
            read_matches_files(self.data_dir, self.matches_files),
            read_interest_file(self.data_dir, self.img_info_files)
        )

        with open(self.data_file, 'wb') as f:
            torch.save(dataset, f)

    def extra_repr(self) -> str:
        return "Split: {}".format("Train" if self.train is True else "Test")


def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
    """Return a Tensor containing the patches
    """

    def PIL2array(_img: Image.Image) -> np.ndarray:
        """Convert PIL image type to numpy 2D array
        """
        return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)

    def find_files(_data_dir: str, _image_ext: str) -> List[str]:
        """Return a list with the file names of the images containing the patches
        """
        files = []
        # find those files with the specified extension
        for file_dir in os.listdir(_data_dir):
            if file_dir.endswith(_image_ext):
                files.append(os.path.join(_data_dir, file_dir))
        return sorted(files)  # sort files in ascend order to keep relations

    patches = []
    list_files = find_files(data_dir, image_ext)

    for fpath in list_files:
        img = Image.open(fpath)
        for y in range(0, 1024, 64):
            for x in range(0, 1024, 64):
                patch = img.crop((x, y, x + 64, y + 64))
                patches.append(PIL2array(patch))
    return torch.ByteTensor(np.array(patches[:n]))#.float()


def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
    """Return a Tensor containing the list of labels
       Read the file and keep only the ID of the 3D point.
    """
    labels = []
    with open(os.path.join(data_dir, info_file), 'r') as f:
        labels = [int(line.split()[0]) for line in f]
    return torch.LongTensor(labels)

def read_interest_file(data_dir: str, info_file: str) -> torch.Tensor:
    """Return a Tensor containing the list of image ids
       Read the file and keep only the ID of the image point.
    """
    labels = []
    with open(os.path.join(data_dir, info_file), 'r') as f:
        labels = [int(line.split()[0]) for line in f]
    return torch.LongTensor(labels)


def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
    """Return a Tensor containing the ground truth matches
       Read the file and keep only 3D point ID.
       Matches are represented with a 1, non matches with a 0.
    """
    matches = []
    with open(os.path.join(data_dir, matches_file), 'r') as f:
        for line in f:
            line_split = line.split()
            matches.append([int(line_split[0]), int(line_split[3]),
                            int(line_split[1] == line_split[4])])
    return torch.LongTensor(matches)

class TupleAug(ItemTransform):
    def __init__(self, tfm):
        self.tfm = tfm
    def encodes(self, o): 
        out = []
        with torch.no_grad():
            for i,oi in enumerate(o):
                if i < len(o) - 2:
                    out.append(self.tfm(oi.float().unsqueeze(1)))
                else:
                    out.append(oi)
        return out
#def average_acc_per_th(snn_ratio, is_correct, ths= np.linspace(0,1.0,20) ):
#    out = []
#    for prev_th, th in zip(ths[:-1], ths[1:]):
#        mask = snn_ratio <= th
#        #print (mask.sum())
#        AA = is_correct[mask].float().mean()
#        #print (mask.sum().item(), AA.item())
#        out.append(AA.item())
#    return out
from sklearn.metrics import average_precision_score, precision_recall_curve


In [5]:
def eval_descriptor_on_dataset( desc,
    ds_name='notredame',
    resol=100,
    device=torch.device('cuda:0'),
    ds_root = '/home/old-ufo/datasets/Brown'):
    
    desc.eval()
    desc = desc.to(device)
    dataset = PhotoTourRevisited(ds_root,
                      ds_name,
                       train=False, 
                       download=True)
    orig_size = 64
    out_size = 32
    test_aug = nn.Sequential( 
        K.Resize((out_size,out_size), interpolation='bicubic'))
    BS = 1024
    TEST_BS = 128
    N_WORKERS = 4
    
    dl_train = TfmdDL(dataset,
                 device=device,
                 after_item=[ToTensor], 
                 after_batch=[TupleAug(test_aug)], #two patches -> single tensor
                 bs=BS, num_workers=N_WORKERS,
                 shuffle = False)
    num_patches = len(dl_train.dataset)
    descriptors = []#torch.zeros(num_patches, 128)
    all_labels = []#torch.zeros(num_patches)
    all_img_labels = []#torch.zeros(num_patches)
    Miner = miners.BatchHardMiner()
    count = 0 
    max_img = -1
    min_img = 1000
    prec_per_img = []
    rec_per_img = []
    ths_per_img = []
    aps = []
    print ('Extracting descriptors and calculating AP')
    for patches, labels, img_labels in progress_bar(dl_train):
        with torch.no_grad():
            descs = desc(patches)
            descriptors.append(descs)
            all_labels.append(labels)
            all_img_labels.append(img_labels)
            all_img_labels_cat = torch.cat(all_img_labels)
            img_labels_unique = torch.sort(torch.unique(all_img_labels_cat).long())[0]
            new_max_img = img_labels_unique.max().item()
            new_min_img = img_labels_unique.min().item()
            if new_min_img != new_max_img:
                all_img_labels = torch.cat(all_img_labels)
                descriptors = torch.cat(descriptors)
                all_labels = torch.cat(all_labels)
                for ii in img_labels_unique[:-1]:
                    current_batch = all_img_labels == ii
                    cur_descs = descriptors[current_batch].cpu()
                    cur_labels = all_labels[current_batch].cpu()
                    anc, pos, neg = Miner(cur_descs, cur_labels)
                    NN = cur_labels.size(0)
                    pos_matrix = (cur_labels[None].expand(NN,NN) == cur_labels[...,None].expand(NN,NN)) != (torch.eye(NN).to(cur_labels.device)>0)
                    pos_idxs = torch.arange(NN)[None].expand(NN,NN)[pos_matrix]
                    anc_idxs = torch.nonzero(pos_matrix)[:,0]
                    pos_matrix = None
                    neg_idxs = neg[anc_idxs]
                    pos_dists = F.pairwise_distance(cur_descs[anc_idxs], cur_descs[pos_idxs])
                    neg_dists = F.pairwise_distance(cur_descs[anc_idxs], cur_descs[neg_idxs])
                    correct = pos_dists <= neg_dists
                    snn = torch.min(pos_dists,neg_dists) / torch.max(pos_dists,neg_dists)
                    snn[torch.isnan(snn)] = 1.0
                    #precision, recall, thresholds = precision_recall_curve(correct, 1-snn)
                    ap = average_precision_score(correct, 1-snn)
                    #prec_per_img.append(precision)
                    #rec_per_img.append(recall)
                    #ths_per_img.append(thresholds)
                    aps.append(ap)
                current_batch = all_img_labels == img_labels_unique[-1].item()
                descriptors = [descriptors[current_batch]]
                all_img_labels = [all_img_labels[current_batch]]
                all_labels = [all_labels[current_batch]]
                gc.collect()
    all_img_labels = torch.cat(all_img_labels)
    descriptors = torch.cat(descriptors)
    all_labels = torch.cat(all_labels)
    for ii in img_labels_unique:
        current_batch = all_img_labels == ii
        cur_descs = descriptors[current_batch].cpu()
        cur_labels = all_labels[current_batch].cpu()
        anc, pos, neg = Miner(cur_descs, cur_labels)
        NN = cur_labels.size(0)
        pos_matrix = (cur_labels[None].expand(NN,NN) == cur_labels[...,None].expand(NN,NN)) != (torch.eye(NN)>0)
        pos_idxs = torch.arange(NN)[None].expand(NN,NN)[pos_matrix]
        anc_idxs = torch.nonzero(pos_matrix)[:,0]
        pos_matrix = None
        neg_idxs = neg[anc_idxs]
        pos_dists = F.pairwise_distance(cur_descs[anc_idxs], cur_descs[pos_idxs])
        neg_dists = F.pairwise_distance(cur_descs[anc_idxs], cur_descs[neg_idxs])
        correct = pos_dists <= neg_dists
        snn = torch.min(pos_dists,neg_dists) / torch.max(pos_dists,neg_dists)
        snn[torch.isnan(snn)] = 1.0
        #precision, recall, thresholds = precision_recall_curve(correct, 1-snn)
        #prec_per_img.append(precision)
        #rec_per_img.append(recall)
        #ths_per_img.append(thresholds)
        ap = average_precision_score(correct, 1-snn)
        aps.append(ap)
    descriptors = None
    all_labels = None
    all_img_labels = None
    dataset=None
    dl_train = None
    gc.collect()
    return aps
    #return {"precision": prec_per_img, "recall": rec_per_img, "thresholds": ths_per_img}

    

    

from collections import defaultdict
results = defaultdict(dict) 

In [None]:
from collections import defaultdict
results = defaultdict(dict)
for ds_name in val_ds_names[::-1]:
    print (ds_name)
    for desc, desc_name in zip([
         K.feature.SIFTDescriptor(32, rootsift=True).to(torch.device('cuda:0')),
         K.feature.SIFTDescriptor(32, rootsift=False).to(torch.device('cuda:0')),
        K.feature.HardNet(True),
        K.feature.SOSNet(True)], ['RootSIFT', 'SIFT', 'HardNet', 'SoSNet']):
        print (desc_name)
        results[ds_name][desc_name] = eval_descriptor_on_dataset(desc,
                                                                 ds_name, resol=50)
        print (f'{desc_name} mAP = {np.array(results[ds_name][desc_name]).mean():.3f}')


yosemite
RootSIFT
# Found cached data /home/old-ufo/datasets/Brown/yosemite.pt
Extracting descriptors and calculating AP


RootSIFT mAP = 0.479
SIFT
# Found cached data /home/old-ufo/datasets/Brown/yosemite.pt
Extracting descriptors and calculating AP


SIFT mAP = 0.487
HardNet
# Found cached data /home/old-ufo/datasets/Brown/yosemite.pt
Extracting descriptors and calculating AP


HardNet mAP = 0.658
SoSNet
# Found cached data /home/old-ufo/datasets/Brown/yosemite.pt
Extracting descriptors and calculating AP
