In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torchvision
import torch.nn.functional as F
from torch import nn
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import pandas as pd

import os
import sys
import matplotlib.pyplot as plt
sys.path.append('..')

import torchvision.models as models
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

from src.beam import UniversalDataset, Experiment, Algorithm, beam_arguments, PackedFolds

In [2]:
class FeatureNet(nn.Module):

    def __init__(self):

        super().__init__()
        net = models.resnet50(pretrained=True, num_classes=1000)
        # train_nodes, eval_nodes = get_graph_node_names(net)
        return_nodes = {
            'flatten': 'features',
        }
        self.net = create_feature_extractor(net, return_nodes=return_nodes)
        

    def forward(self, x):
        return self.net(x)['features'].view(len(x), -1)

In [3]:
class MiniImageNet(UniversalDataset):

    def __init__(self, hparams):

        path = hparams.path_to_data
        seed = hparams.split_dataset_seed

        super().__init__()
        
        file = os.path.join(path, 'mini_imagenet.pt')
        if not os.path.exists(file):
                        
            dataset_train = [pd.read_pickle(os.path.join(path, f'train_data_batch_{i}')) for i in range(1, 11)]


            data_train = torch.cat([torch.ByteTensor(di['data']) for di in dataset_train]).reshape(-1, 3, 64, 64)

            data_train_f = data_train.float()

            mu = data_train_f.mean(dim=(0, 2, 3), keepdim=True)
            std = data_train_f.std(dim=(0, 2, 3), keepdim=True)

            data_test = torch.ByteTensor(dataset_test['data']).reshape(-1, 3, 64, 64)

            labels_train = torch.cat([torch.LongTensor(di['labels']) for di in dataset_train])
            labels_test = torch.LongTensor(dataset_test['labels'])

            state = {'data_train': data_train, 'data_test': data_test, 
                            'labels_train': labels_train, 
                            'labels_test': labels_test, 'mu': mu,
                            'std': std}
            
            torch.save(state, file)
        else:
            state = torch.load(file)
        
        self.normalize = True
        self.data = PackedFolds({'train': state['data_train'], 'test': state['data_test']})
        self.labels = PackedFolds({'train': state['labels_train'], 'test': state['labels_test']})
        self.mu = state['mu']
        self.std = state['std']
        self.split(validation=.2, test=self.labels['test'].index, seed=seed)

    def getitem(self, index):
        
        x = self.data[index]
        
        if self.normalize:
            mu = self.mu
            std = self.std
            
            if len(x.shape) == 3:
                mu = mu.squeeze(0)
                std = std.squeeze(0)
                
            x = (x.float() - mu) / std
            
        return {'x': x, 'y': self.labels[index]}

In [4]:
class FeaturesExtractor(Algorithm):

    def __init__(self, hparams):

        # choose your network
        net = FeatureNet()
        super().__init__(hparams, networks=net)

    def inference(self, sample=None, results=None, subset=None, predicting=True, **kwargs):

        if predicting:
            x = sample
        else:
            x, y = sample['x'], sample['y']

        net = self.networks['net']
        z = net(x)

        if not predicting:
            return {'z': z, 'y': y}, results

        return z, results

In [5]:
path_to_data = '/home/shared/data/dataset/imagenet'
root_dir = '/home/shared/data/results'

hparams = beam_arguments(
    f"--project-name=similarity --root-dir={root_dir} --algorithm=ImageNet --identifier=dev  --device=0 --override",
    path_to_data=path_to_data)

In [6]:
experiment = Experiment(hparams, print_hyperparameters=False)

[32m2022-07-21 13:50:26[0m | [1mINFO[0m | [1mDeleting old experiment[0m
[32m2022-07-21 13:50:26[0m | [1mINFO[0m | [1mExperiment directory is: /home/shared/data/results/similarity/ImageNet/dev/0002_20220721_135026[0m


In [7]:
%%time

dataset = MiniImageNet(hparams)

CPU times: user 11.4 s, sys: 41.3 s, total: 52.7 s
Wall time: 12.1 s


In [8]:
alg = experiment.algorithm_generator(FeaturesExtractor, dataset)

In [10]:
features = alg.evaluate('test')

test:   1%|          | 1/196 [00:00<?, ?it/s]

In [25]:
z = features.values['z']

In [30]:
z = z.detach().cpu().numpy()

In [22]:
d = 2048

In [23]:
import faiss                   # make faiss available
index = faiss.IndexFlatL2(d)   # build the index

In [27]:
index.add()

In [28]:
print(index.is_trained)

True


In [29]:
print(index.ntotal)

50000


In [38]:
k = 4

In [41]:
res = faiss.StandardGpuResources()

In [49]:
z.astype(np.float64)

array([[0.38854295, 0.80144668, 0.        , ..., 0.01820637, 1.0879941 ,
        1.77755833],
       [0.        , 2.62043905, 3.48512053, ..., 1.50922227, 1.61054075,
        1.27634835],
       [0.        , 0.8495881 , 1.42791975, ..., 0.29668215, 0.47264653,
        0.7178272 ],
       ...,
       [0.67733413, 0.3271088 , 0.01777057, ..., 0.97757649, 0.        ,
        0.02336848],
       [0.18501449, 1.93573213, 0.        , ..., 0.        , 0.94402933,
        1.47809362],
       [0.        , 0.        , 0.        , ..., 0.        , 0.57860261,
        0.        ]])

In [51]:
# build a flat (CPU) index
index_flat = faiss.IndexFlatL2(d)
# make it into a gpu index
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat)

In [52]:
gpu_index_flat.add(z)         # add vectors to the index
print(gpu_index_flat.ntotal)

k = 4                          # we want to see 4 nearest neighbors
D, I = gpu_index_flat.search(z[:5], k)  # actual search

50000


In [54]:
%%time

# we want to see 4 nearest neighbors
D, I = gpu_index_flat.search(z, k) # sanity check

CPU times: user 1.63 s, sys: 31.5 ms, total: 1.67 s
Wall time: 1.72 s


In [40]:
%%time

# we want to see 4 nearest neighbors
D, I = index.search(z, k) # sanity check

KeyboardInterrupt: 