### Метрики энкодера на основе Resnet18.  
#### Выходной слой: *nn.Linear(in_features=512, out_features=1024, bias=True)*

### Визуализация в 3 ГК помимо того что не дает колличественных оценок точности энкодера, так и несет в себе в лучшем случае около 40% информации от выходного вектора длинной 1024. 

###  Необходимо ознакомится с метриками и оценками модели энкодера. исп.:
* kMeans
* OneClass SVM
* Gaussian Mixture

### Конечная цель: оценка целесообразности применения энкодера в рамках *данной* задачи.

In [None]:
import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
from torch import nn
import seaborn as sns
import pandas as pd
import os
import pathlib
import shutil
import cv2
import PIL
import cv2
import sys
from datetime import datetime

TEXT_COLOR = 'black'
# Зафиксируем состояние случайных чисел
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
%matplotlib inline
plt.rcParams["figure.figsize"] = (17,10)

USE_COLAB_GPU = False
IN_COLAB = False

try:
    import google.colab
    IN_COLAB = True
    USE_COLAB_GPU = True
    from google.colab import drive
except:
    if IN_COLAB:
        print('[!] YOU ARE IN COLAB, BUT DIDNT MOUND A DRIVE. Model wont be synced[!]')

        if not os.path.isfile(CURRENT_FILE_NAME):
            print("FIX ME")
        IN_COLAB = False

    else:
        print('[!] RUNNING NOT IN COLAB')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
from torchvision.models import resnet
encoder = resnet.resnet18(pretrained=True)
encoder.fc = nn.Linear(in_features=512, out_features=1024, bias=True)
r = encoder.load_state_dict(torch.load('last_encoder_1024_98')['model'])
encoder.eval()
assert r

In [None]:
@torch.no_grad()
def simpleGetAllEmbeddings(model, dataset, batch_size, dsc=''):
    
    dataloader = getDataLoaderFromDataset(
        dataset,
        shuffle=True,
        drop_last=False
    )
    
    s, e = 0, 0
    pbar = tqdm(
        enumerate(dataloader), 
        total=len(dataloader),
        position=0,
        leave=False,
        desc='Getting all embeddings...' + dsc)
    info_arr = []
    
    add_info_len = None
    
    for idx, (data, labels, info) in pbar:
        data = data.to(device)
        
        q = model(data)
        
        if labels.dim() == 1:
            labels = labels.unsqueeze(1)
        if idx == 0:
            labels_ret = torch.zeros(
                len(dataloader.dataset),
                labels.size(1),
                device=device,
                dtype=labels.dtype,
            )
            all_q = torch.zeros(
                len(dataloader.dataset),
                q.size(1),
                device=device,
                dtype=q.dtype,
            )
        
        info = np.array(info)
        if add_info_len == None:
            add_info_len = info.shape[0]
        
        info_arr.extend(info.T.reshape((-1, add_info_len)))
        e = s + q.size(0)
        all_q[s:e] = q
        labels_ret[s:e] = labels
        s = e  
    
    all_q = torch.nn.functional.normalize(all_q)
    return all_q, labels_ret, info_arr

### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator, batch_size):
    model.eval()
    train_embeddings, train_labels, _ = simpleGetAllEmbeddings(model, train_set, batch_size, ' for train')
    test_embeddings, test_labels, _ = simpleGetAllEmbeddings(model, test_set, batch_size, ' for test')
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, train_embeddings, test_labels, train_labels, False
    )
    print(accuracies)
    # print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))
    return accuracies["precision_at_1"]
