In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torchvision
import torch.nn.functional as F
from torch import nn
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import pandas as pd
import faiss                   # make faiss available
import umap
import seaborn as sns
%matplotlib inline
from byol_pytorch import BYOL

import os
import sys
import matplotlib.pyplot as plt
from sklearn import svm
sys.path.append('..')

import torchvision.models as models
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from sklearn.ensemble import BaggingClassifier, RandomForestClassifier
from sklearn.linear_model import LogisticRegression

from src.beam import UniversalDataset, Experiment, Algorithm, beam_arguments, PackedFolds

2022-07-25 11:40:47.090076: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
class FeatureNet(nn.Module):

    def __init__(self):

        super().__init__()
        net = models.resnet50(pretrained=True, num_classes=1000)
        # train_nodes, eval_nodes = get_graph_node_names(net)
        return_nodes = {
            'flatten': 'features',
        }
        self.net = create_feature_extractor(net, return_nodes=return_nodes)
        

    def forward(self, x):
        return self.net(x)['features'].view(len(x), -1)

In [3]:
class MiniImageNet(UniversalDataset):

    def __init__(self, hparams):

        path = hparams.path_to_data
        seed = hparams.split_dataset_seed

        super().__init__()
        
        file = os.path.join(path, 'mini_imagenet.pt')
        if not os.path.exists(file):
                        
            dataset_train = [pd.read_pickle(os.path.join(path, f'train_data_batch_{i}')) for i in range(1, 11)]


            data_train = torch.cat([torch.ByteTensor(di['data']) for di in dataset_train]).reshape(-1, 3, 64, 64)

            data_train_f = data_train.float()

            mu = data_train_f.mean(dim=(0, 2, 3), keepdim=True)
            std = data_train_f.std(dim=(0, 2, 3), keepdim=True)

            data_test = torch.ByteTensor(dataset_test['data']).reshape(-1, 3, 64, 64)

            labels_train = torch.cat([torch.LongTensor(di['labels']) for di in dataset_train])
            labels_test = torch.LongTensor(dataset_test['labels'])

            state = {'data_train': data_train, 'data_test': data_test, 
                            'labels_train': labels_train, 
                            'labels_test': labels_test, 'mu': mu,
                            'std': std}
            
            torch.save(state, file)
        else:
            state = torch.load(file)
        
        self.normalize = True
        self.data = PackedFolds({'train': state['data_train'], 'test': state['data_test']})
        self.labels = PackedFolds({'train': state['labels_train'], 'test': state['labels_test']})
        self.mu = state['mu']
        self.std = state['std']
        self.split(validation=.2, test=self.labels['test'].index, seed=seed)
        self.transform = torchvision.transforms.Resize((224, 224))

    def getitem(self, index):
        
        x = self.data[index]
        
        if self.normalize:
#             mu = self.mu
#             std = self.std
            
#             if len(x.shape) == 3:
#                 mu = mu.squeeze(0)
#                 std = std.squeeze(0)
                
#             x = (x.float() - mu) / std
            x = x.float() / 255

            
        x = self.transform(x)
            
        return {'x': x, 'y': self.labels[index]}

In [4]:
class BeamBYOL(Algorithm):

    def __init__(self, hparams):

        # choose your network
        # net = FeatureNet()
        resnet = models.resnet50(pretrained=True)

        layer = 'avgpool'
        networks = {'learner': BYOL(resnet,
                                   image_size = 224,
                                   hidden_layer = layer)}
        
        return_nodes = {layer: 'features'}
        
        feature_extractor = create_feature_extractor(resnet, return_nodes=return_nodes)
        self.features = lambda x: feature_extractor(x)['features'].view(len(x), -1)
        
        super().__init__(hparams, networks=networks)

    
    def preprocess_epoch(self, results=None, **kwargs):
        
        self.dataset.normalize = True
        
        return results
    
    def postprocess_epoch(self, results=None, training=None, **kwargs):
        
        print('postprocess')
        
        if not training:
            
            print('validation')
            
            z = np.concatenate(results['transforms']['z'])
            y = np.concatenate(results['transforms']['y'])
            
            classifier = LogisticRegression(n_jobs=-1)
            classifier.fit(z, y)
            
            features = self.evaluate('test', head=2000)
            
            z = features.values['z'].detach().cpu().numpy()
            y = features.values['y'].detach().cpu().numpy()
            
            y_hat = classifier.predict(z)
            results['scalar']['downstream'] = float(accuracy_score(y, y_pred=y_hat))
            
        return results
    
    def iteration(self, sample=None, results=None, counter=None, subset=None, training=True, **kwargs):

        x, y = sample['x'], sample['y']

        learner = self.networks['learner']
        opt = self.optimizers['learner']

        if training:
            loss = learner(x)
            opt.apply(loss, training=training)
            learner.update_moving_average()

            # add scalar measurements
            results['scalar']['loss'].append(float(loss))
        
        else:
            
            z = self.features(x)
            results['transforms']['z'].append(z.detach().cpu().numpy())
            results['transforms']['y'].append(y.detach().cpu().numpy())

        return results
    
    def inference(self, sample=None, results=None, subset=None, predicting=True, **kwargs):

        if predicting:
            x = sample
        else:
            x, y = sample['x'], sample['y']

        z = self.features(x)

        if not predicting:
            return {'z': z, 'y': y}, results

        return z, results

## set hparams

In [5]:
path_to_data = '/home/shared/data/dataset/imagenet'
root_dir = '/home/shared/data/results'

hparams = beam_arguments(
    f"--project-name=similarity --root-dir={root_dir} --algorithm=ImageNet --identifier=dev  --device=1 --amp",
    "--epoch-length-train=50 --epoch-length-eval=20 --no-scale-epoch-by-batch-size --batch-size=128",
    path_to_data=path_to_data)

## Build a dataset

In [6]:
%%time

dataset = MiniImageNet(hparams)

CPU times: user 27 s, sys: 44.8 s, total: 1min 11s
Wall time: 15 s


### Plot image from the data

In [10]:
dataset.normalize = False

im = np.array(dataset[10210][1]['x'].permute(1, 2, 0))

plt.imshow(im)

dataset.normalize = True

## Build Beam Experiment with BYOL trainer

In [7]:
experiment = Experiment(hparams, print_hyperparameters=False)

[32m2022-07-25 11:41:06[0m | [1mINFO[0m | [1mCreating new experiment[0m
[32m2022-07-25 11:41:06[0m | [1mINFO[0m | [1mExperiment directory is: /home/shared/data/results/similarity/ImageNet/dev/0012_20220725_114106[0m


In [8]:
# alg = experiment.algorithm_generator(BeamBYOL, dataset)

In [9]:
alg = experiment.fit(BeamBYOL, dataset)

[32m2022-07-25 11:41:06[0m | [1mINFO[0m | [1mSingle worker mode[0m
[32m2022-07-25 11:41:06[0m | [1mINFO[0m | [1mWorker: 1/1 is running...[0m


train:   2%|2         | 1/50 [00:00<?, ?it/s]

postprocess
postprocess
validation


test:   6%|6         | 1/16 [00:00<?, ?it/s]

[32m2022-07-25 11:43:57[0m | [1mINFO[0m | [1m[0m
[32m2022-07-25 11:43:57[0m | [1mINFO[0m | [1mFinished epoch 1/20000 (Total trained epochs 1).[0m
[32m2022-07-25 11:43:57[0m | [1mINFO[0m | [1mtrain:[0m
[32m2022-07-25 11:43:57[0m | [1mINFO[0m | [1mseconds:  42.03 | batches: 50 | samples:  6.4e+03 | batch_rate:  1.19 [iter/sec] | sample_rate:  152.3 [iter/sec] [0m
[32m2022-07-25 11:43:57[0m | [1mINFO[0m | [1mloss:        | avg: 1.905     | std: 0.3327    | min: 1.476     | 25%: 1.812     | 50%: 1.857     | 75%: 1.975     | max: 3.966     [0m
[32m2022-07-25 11:43:57[0m | [1mINFO[0m | [1mvalidation:[0m
[32m2022-07-25 11:43:57[0m | [1mINFO[0m | [1mseconds:  118.9 | batches: 20 | samples:  2.56e+03 | batch_rate:  5.947 [sec/iter] | sample_rate:  21.52 [iter/sec] [0m
[32m2022-07-25 11:43:57[0m | [1mINFO[0m | [1mdownstream:  | avg: 0.02393   | std: nan       | min: 0.02393   | 25%: 0.02393   | 50%: 0.02393   | 75%: 0.02393   | max: 0.02393   [0m


train:   2%|2         | 1/50 [00:00<?, ?it/s]

postprocess
postprocess
validation
[32m2022-07-25 11:46:15[0m | [1mINFO[0m | [1m[0m
[32m2022-07-25 11:46:15[0m | [1mINFO[0m | [1mFinished epoch 2/20000 (Total trained epochs 2).[0m
[32m2022-07-25 11:46:15[0m | [1mINFO[0m | [1mtrain:[0m
[32m2022-07-25 11:46:15[0m | [1mINFO[0m | [1mseconds:  35.82 | batches: 50 | samples:  6.4e+03 | batch_rate:  1.396 [iter/sec] | sample_rate:  178.7 [iter/sec] [0m
[32m2022-07-25 11:46:15[0m | [1mINFO[0m | [1mloss:        | avg: 1.11      | std: 0.3184    | min: 0.4448    | 25%: 0.8775    | 50%: 1.075     | 75%: 1.33      | max: 1.864     [0m
[32m2022-07-25 11:46:15[0m | [1mINFO[0m | [1mvalidation:[0m
[32m2022-07-25 11:46:15[0m | [1mINFO[0m | [1mseconds:  88.24 | batches: 20 | samples:  2.56e+03 | batch_rate:  4.412 [sec/iter] | sample_rate:  29.01 [iter/sec] [0m
[32m2022-07-25 11:46:15[0m | [1mINFO[0m | [1mdownstream:  | avg: 0.01465   | std: nan       | min: 0.01465   | 25%: 0.01465   | 50%: 0.01465   | 75

train:   4%|4         | 2/50 [00:00<?, ?it/s]

postprocess
postprocess
validation
[32m2022-07-25 11:48:36[0m | [1mINFO[0m | [1m[0m
[32m2022-07-25 11:48:36[0m | [1mINFO[0m | [1mFinished epoch 3/20000 (Total trained epochs 3).[0m
[32m2022-07-25 11:48:36[0m | [1mINFO[0m | [1mtrain:[0m
[32m2022-07-25 11:48:36[0m | [1mINFO[0m | [1mseconds:  35.74 | batches: 50 | samples:  6.4e+03 | batch_rate:  1.399 [iter/sec] | sample_rate:  179.1 [iter/sec] [0m
[32m2022-07-25 11:48:36[0m | [1mINFO[0m | [1mloss:        | avg: 0.8005    | std: 0.3218    | min: 0.3123    | 25%: 0.5266    | 50%: 0.7466    | 75%: 1.018     | max: 1.555     [0m
[32m2022-07-25 11:48:36[0m | [1mINFO[0m | [1mvalidation:[0m
[32m2022-07-25 11:48:36[0m | [1mINFO[0m | [1mseconds:  91.27 | batches: 20 | samples:  2.56e+03 | batch_rate:  4.563 [sec/iter] | sample_rate:  28.05 [iter/sec] [0m
[32m2022-07-25 11:48:36[0m | [1mINFO[0m | [1mdownstream:  | avg: 0.02002   | std: nan       | min: 0.02002   | 25%: 0.02002   | 50%: 0.02002   | 75

train:   4%|4         | 2/50 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


postprocess
postprocess
validation
[32m2022-07-25 11:50:54[0m | [1mINFO[0m | [1m[0m
[32m2022-07-25 11:50:54[0m | [1mINFO[0m | [1mFinished epoch 4/20000 (Total trained epochs 4).[0m
[32m2022-07-25 11:50:54[0m | [1mINFO[0m | [1mtrain:[0m
[32m2022-07-25 11:50:54[0m | [1mINFO[0m | [1mseconds:  36.64 | batches: 50 | samples:  6.4e+03 | batch_rate:  1.364 [iter/sec] | sample_rate:  174.7 [iter/sec] [0m
[32m2022-07-25 11:50:54[0m | [1mINFO[0m | [1mloss:        | avg: 0.7394    | std: 0.4062    | min: 0.3271    | 25%: 0.4669    | 50%: 0.6207    | 75%: 0.895     | max: 2.076     [0m
[32m2022-07-25 11:50:54[0m | [1mINFO[0m | [1mvalidation:[0m
[32m2022-07-25 11:50:54[0m | [1mINFO[0m | [1mseconds:  88.8 | batches: 20 | samples:  2.56e+03 | batch_rate:  4.44 [sec/iter] | sample_rate:  28.83 [iter/sec] [0m
[32m2022-07-25 11:50:54[0m | [1mINFO[0m | [1mdownstream:  | avg: 0.02295   | std: nan       | min: 0.02295   | 25%: 0.02295   | 50%: 0.02295   | 75%:

train:   4%|4         | 2/50 [00:00<?, ?it/s]

postprocess
postprocess
validation
[32m2022-07-25 11:53:08[0m | [1mINFO[0m | [1m[0m
[32m2022-07-25 11:53:08[0m | [1mINFO[0m | [1mFinished epoch 5/20000 (Total trained epochs 5).[0m
[32m2022-07-25 11:53:08[0m | [1mINFO[0m | [1mtrain:[0m
[32m2022-07-25 11:53:08[0m | [1mINFO[0m | [1mseconds:  36.05 | batches: 50 | samples:  6.4e+03 | batch_rate:  1.387 [iter/sec] | sample_rate:  177.6 [iter/sec] [0m
[32m2022-07-25 11:53:08[0m | [1mINFO[0m | [1mloss:        | avg: 0.6892    | std: 0.4125    | min: 0.1898    | 25%: 0.4194    | 50%: 0.561     | 75%: 0.9232    | max: 2.229     [0m
[32m2022-07-25 11:53:08[0m | [1mINFO[0m | [1mvalidation:[0m
[32m2022-07-25 11:53:08[0m | [1mINFO[0m | [1mseconds:  87.11 | batches: 20 | samples:  2.56e+03 | batch_rate:  4.356 [sec/iter] | sample_rate:  29.39 [iter/sec] [0m
[32m2022-07-25 11:53:08[0m | [1mINFO[0m | [1mdownstream:  | avg: 0.01807   | std: nan       | min: 0.01807   | 25%: 0.01807   | 50%: 0.01807   | 75

train:   4%|4         | 2/50 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


postprocess
postprocess
validation
[32m2022-07-25 11:55:08[0m | [1mINFO[0m | [1m[0m
[32m2022-07-25 11:55:08[0m | [1mINFO[0m | [1mFinished epoch 6/20000 (Total trained epochs 6).[0m
[32m2022-07-25 11:55:08[0m | [1mINFO[0m | [1mtrain:[0m
[32m2022-07-25 11:55:08[0m | [1mINFO[0m | [1mseconds:  36.11 | batches: 50 | samples:  6.4e+03 | batch_rate:  1.385 [iter/sec] | sample_rate:  177.2 [iter/sec] [0m
[32m2022-07-25 11:55:08[0m | [1mINFO[0m | [1mloss:        | avg: 0.5413    | std: 0.3264    | min: 0.09171   | 25%: 0.2888    | 50%: 0.4769    | 75%: 0.6814    | max: 1.517     [0m
[32m2022-07-25 11:55:08[0m | [1mINFO[0m | [1mvalidation:[0m
[32m2022-07-25 11:55:08[0m | [1mINFO[0m | [1mseconds:  72.47 | batches: 20 | samples:  2.56e+03 | batch_rate:  3.624 [sec/iter] | sample_rate:  35.32 [iter/sec] [0m
[32m2022-07-25 11:55:08[0m | [1mINFO[0m | [1mdownstream:  | avg: 0.02246   | std: nan       | min: 0.02246   | 25%: 0.02246   | 50%: 0.02246   | 75

train:   4%|4         | 2/50 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


postprocess
postprocess
validation
[32m2022-07-25 11:57:19[0m | [1mINFO[0m | [1m[0m
[32m2022-07-25 11:57:19[0m | [1mINFO[0m | [1mFinished epoch 7/20000 (Total trained epochs 7).[0m
[32m2022-07-25 11:57:19[0m | [1mINFO[0m | [1mtrain:[0m
[32m2022-07-25 11:57:19[0m | [1mINFO[0m | [1mseconds:  35.5 | batches: 50 | samples:  6.4e+03 | batch_rate:  1.408 [iter/sec] | sample_rate:  180.3 [iter/sec] [0m
[32m2022-07-25 11:57:19[0m | [1mINFO[0m | [1mloss:        | avg: 0.5357    | std: 0.3424    | min: 0.1033    | 25%: 0.3003    | 50%: 0.4672    | 75%: 0.6374    | max: 1.731     [0m
[32m2022-07-25 11:57:19[0m | [1mINFO[0m | [1mvalidation:[0m
[32m2022-07-25 11:57:19[0m | [1mINFO[0m | [1mseconds:  84.04 | batches: 20 | samples:  2.56e+03 | batch_rate:  4.202 [sec/iter] | sample_rate:  30.46 [iter/sec] [0m
[32m2022-07-25 11:57:19[0m | [1mINFO[0m | [1mdownstream:  | avg: 0.01807   | std: nan       | min: 0.01807   | 25%: 0.01807   | 50%: 0.01807   | 75%

train:   4%|4         | 2/50 [00:00<?, ?it/s]

postprocess
postprocess
validation
[32m2022-07-25 11:59:16[0m | [1mINFO[0m | [1m[0m
[32m2022-07-25 11:59:16[0m | [1mINFO[0m | [1mFinished epoch 8/20000 (Total trained epochs 8).[0m
[32m2022-07-25 11:59:16[0m | [1mINFO[0m | [1mtrain:[0m
[32m2022-07-25 11:59:16[0m | [1mINFO[0m | [1mseconds:  34.41 | batches: 50 | samples:  6.4e+03 | batch_rate:  1.453 [iter/sec] | sample_rate:  186.0 [iter/sec] [0m
[32m2022-07-25 11:59:16[0m | [1mINFO[0m | [1mloss:        | avg: 0.539     | std: 0.3224    | min: 0.1402    | 25%: 0.3297    | 50%: 0.4297    | 75%: 0.6866    | max: 1.506     [0m
[32m2022-07-25 11:59:16[0m | [1mINFO[0m | [1mvalidation:[0m
[32m2022-07-25 11:59:16[0m | [1mINFO[0m | [1mseconds:  72.0 | batches: 20 | samples:  2.56e+03 | batch_rate:  3.6 [sec/iter] | sample_rate:  35.56 [iter/sec] [0m
[32m2022-07-25 11:59:16[0m | [1mINFO[0m | [1mdownstream:  | avg: 0.01709   | std: nan       | min: 0.01709   | 25%: 0.01709   | 50%: 0.01709   | 75%: 

train:   4%|4         | 2/50 [00:00<?, ?it/s]

postprocess
postprocess
validation


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


[32m2022-07-25 12:01:27[0m | [1mINFO[0m | [1m[0m
[32m2022-07-25 12:01:27[0m | [1mINFO[0m | [1mFinished epoch 9/20000 (Total trained epochs 9).[0m
[32m2022-07-25 12:01:27[0m | [1mINFO[0m | [1mtrain:[0m
[32m2022-07-25 12:01:27[0m | [1mINFO[0m | [1mseconds:  35.7 | batches: 50 | samples:  6.4e+03 | batch_rate:  1.4 [iter/sec] | sample_rate:  179.2 [iter/sec] [0m
[32m2022-07-25 12:01:27[0m | [1mINFO[0m | [1mloss:        | avg: 0.5778    | std: 0.372     | min: 0.09351   | 25%: 0.2877    | 50%: 0.4996    | 75%: 0.7337    | max: 1.912     [0m
[32m2022-07-25 12:01:27[0m | [1mINFO[0m | [1mvalidation:[0m
[32m2022-07-25 12:01:27[0m | [1mINFO[0m | [1mseconds:  84.17 | batches: 20 | samples:  2.56e+03 | batch_rate:  4.209 [sec/iter] | sample_rate:  30.41 [iter/sec] [0m
[32m2022-07-25 12:01:27[0m | [1mINFO[0m | [1mdownstream:  | avg: 0.02051   | std: nan       | min: 0.02051   | 25%: 0.02051   | 50%: 0.02051   | 75%: 0.02051   | max: 0.02051   [0m


train:   2%|2         | 1/50 [00:00<?, ?it/s]

postprocess


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


postprocess
validation
[32m2022-07-25 12:03:31[0m | [1mINFO[0m | [1m[0m
[32m2022-07-25 12:03:31[0m | [1mINFO[0m | [1mFinished epoch 10/20000 (Total trained epochs 10).[0m
[32m2022-07-25 12:03:31[0m | [1mINFO[0m | [1mtrain:[0m
[32m2022-07-25 12:03:31[0m | [1mINFO[0m | [1mseconds:  35.19 | batches: 50 | samples:  6.4e+03 | batch_rate:  1.421 [iter/sec] | sample_rate:  181.9 [iter/sec] [0m
[32m2022-07-25 12:03:31[0m | [1mINFO[0m | [1mloss:        | avg: 0.5203    | std: 0.2832    | min: 0.1445    | 25%: 0.2973    | 50%: 0.4354    | 75%: 0.7378    | max: 1.282     [0m
[32m2022-07-25 12:03:31[0m | [1mINFO[0m | [1mvalidation:[0m
[32m2022-07-25 12:03:31[0m | [1mINFO[0m | [1mseconds:  75.29 | batches: 20 | samples:  2.56e+03 | batch_rate:  3.764 [sec/iter] | sample_rate:  34.0 [iter/sec] [0m
[32m2022-07-25 12:03:31[0m | [1mINFO[0m | [1mdownstream:  | avg: 0.02783   | std: nan       | min: 0.02783   | 25%: 0.02783   | 50%: 0.02783   | 75%: 0.02783 

train:   4%|4         | 2/50 [00:00<?, ?it/s]

postprocess
postprocess
validation
[32m2022-07-25 12:05:34[0m | [1mINFO[0m | [1m[0m
[32m2022-07-25 12:05:34[0m | [1mINFO[0m | [1mFinished epoch 11/20000 (Total trained epochs 11).[0m
[32m2022-07-25 12:05:34[0m | [1mINFO[0m | [1mtrain:[0m
[32m2022-07-25 12:05:34[0m | [1mINFO[0m | [1mseconds:  35.77 | batches: 50 | samples:  6.4e+03 | batch_rate:  1.398 [iter/sec] | sample_rate:  178.9 [iter/sec] [0m
[32m2022-07-25 12:05:34[0m | [1mINFO[0m | [1mloss:        | avg: 0.474     | std: 0.2952    | min: 0.06605   | 25%: 0.2799    | 50%: 0.4216    | 75%: 0.6041    | max: 1.404     [0m
[32m2022-07-25 12:05:34[0m | [1mINFO[0m | [1mvalidation:[0m
[32m2022-07-25 12:05:34[0m | [1mINFO[0m | [1mseconds:  75.6 | batches: 20 | samples:  2.56e+03 | batch_rate:  3.78 [sec/iter] | sample_rate:  33.86 [iter/sec] [0m
[32m2022-07-25 12:05:34[0m | [1mINFO[0m | [1mdownstream:  | avg: 0.02441   | std: nan       | min: 0.02441   | 25%: 0.02441   | 50%: 0.02441   | 75

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

train:   4%|4         | 2/50 [00:00<?, ?it/s]

[32m2022-07-25 12:05:50[0m | [31m[1mERROR[0m | [31m[1mKeyboardInterrupt: Training was interrupted, Worker terminates[0m
[32m2022-07-25 12:05:50[0m | [31m[1mERROR[0m | [31m[1mKeyboardInterrupt: Training was interrupted, reloads last checkpoint[0m
[32m2022-07-25 12:05:50[0m | [1mINFO[0m | [1mReload experiment from checkpoint: /home/shared/data/results/similarity/ImageNet/dev/0012_20220725_114106/checkpoints/checkpoint_000011[0m
[32m2022-07-25 12:05:50[0m | [1mINFO[0m | [1mLoading network state from: /home/shared/data/results/similarity/ImageNet/dev/0012_20220725_114106/checkpoints/checkpoint_000011[0m


In [17]:
features = alg.evaluate('test')

test:   0%|          | 1/782 [00:00<?, ?it/s]

In [36]:
features = alg.evaluate('test', max_iterations=100)

In [37]:
z = features.values['z'].detach().cpu().numpy()
y = features.values['y'].detach().cpu().numpy()

## Classifier

In [39]:
clf = LogisticRegression(max_iter=10)

In [40]:
%%time
clf.fit(z, y)

CPU times: user 1min 53s, sys: 6min 34s, total: 8min 27s
Wall time: 14.1 s


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


LogisticRegression(max_iter=10)

In [41]:
y_hat = clf.predict(z)

In [43]:
accuracy_score(y, y_pred=y_hat)

0.79421875

In [10]:
# clf = RandomForestClassifier(min_samples_leaf=20)
# clf.fit(z, y)

In [None]:
# clf = svm.LinearSVC()

# clf.fit(z, y)

## Faiss

In [58]:
d = z.shape[-1]

In [59]:
# index = faiss.IndexFlatL2(d)   # build the index

In [60]:
res = faiss.StandardGpuResources()

In [61]:
# build a flat (CPU) index
index_flat = faiss.IndexFlatL2(d)
# make it into a gpu index
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat)

In [62]:
gpu_index_flat.add(z)         # add vectors to the index
print(gpu_index_flat.ntotal)

50000


In [63]:
index_flat.is_trained

True

In [64]:
# k = 4                          # we want to see 4 nearest neighbors
# D, I = gpu_index_flat.search(z[:5], k)  # actual search

# %%time

# # we want to see 4 nearest neighbors
# D, I = gpu_index_flat.search(z, k) # sanity check

In [71]:
i = 1000

y = features.data['y']

D, I = gpu_index_flat.search(z[[i]], 100) # sanity check

len(np.unique(y[I[0]]))

zvi = z[I[0]]

reducer = umap.UMAP()

embedding = reducer.fit_transform(zvi)

plt.scatter(
    embedding[:, 0],
    embedding[:, 1],
    c=y[I[0]])

In [40]:
%%time

# we want to see 4 nearest neighbors
D, I = index.search(z, k) # sanity check

KeyboardInterrupt: 