## Цель ноутбука: изучение метода Few Shots Learning

In [None]:
#!git clone --branch 11_ShotLearning_Encoder https://github.com/lsd-maddrive/adas_system.git
#!gdown --id 1-K3ee1NbMmx_0T5uwMesStmKnZO_6mWi
#%cd adas_system
#!pip install -r requirements.txt
#!pip install faiss-cpu faiss
#!pip install --upgrade tbb
#%cd SignDetectorAndClassifier/notebooks
#!unzip -q -o /content/R_MERGED.zip -d ./../data/

%cd adas_system/SignDetectorAndClassifier/notebooks

#### В RTSD не хватает 14 знаков:

| Знак | Описание | Источник |
| ------------- | ------------- | ---- |
| 1.6 | Пересечение равнозначных дорог | - |
| 1.31 | Туннель | - |
| 2.4 | Уступите дорогу | GTSRB Recognition |
| 3.21 | Конец запрещения обгона | GTSRB Recognition |
| 3.22 | Обгон грузовым автомобилям запрещен | GTSRB Recognition |
| 3.23 | Конец запрещения обгона грузовым автомобилям | GTSRB Recognition |
| 3.24-90 | Огр 90 | - |
| 3.24-100 | Огр 100 | GTSRB Recognition |
| 3.24-110 | Огр 110 | - |
| 3.24-120 | Огр 120 | GTSRB Recognition |
| 3.24-130 | Огр 130 | - |
| 3.25 | Конец огр. максимальной скорости | GTSRB Recognition |
| 3.31 | Конец всех ограничений | GTSRB Recognition |
| 6.3.2 | Зона для разворота | - |

Инициализация библиотек

In [None]:
import albumentations as A
if A.__version__ != '1.0.3':
    !pip install albumentations==1.0.3
    !pip install opencv-python-headless==4.5.2.52
    assert False, 'restart runtime pls'

import matplotlib.pyplot as plt
import numpy as np
import random
import torch
from torch import nn
import seaborn as sns
import pandas as pd
import os
import pathlib
import shutil
import cv2
import PIL
import cv2
import sys
from datetime import datetime

TEXT_COLOR = 'black'
# Зафиксируем состояние случайных чисел
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
%matplotlib inline
plt.rcParams["figure.figsize"] = (17,10)


USE_COLAB_GPU = False

try:
    import google.colab
    IN_COLAB = True
    USE_COLAB_GPU = True
    from google.colab import drive
except:
    if IN_COLAB:
        print('[!]YOU ARE IN COLAB, BUT DIDNT MOUND A DRIVE. Model wont be synced[!]')

        if not os.path.isfile(CURRENT_FILE_NAME):
            print("FIX ME")
        IN_COLAB = False

    else:
        pass

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Инициализация основных путей и папки src

In [None]:
if not IN_COLAB:
    PROJECT_ROOT = pathlib.Path(os.path.join(os.curdir, os.pardir))
else:
    PROJECT_ROOT = pathlib.Path('..')
    
DATA_DIR = PROJECT_ROOT / 'data'
NOTEBOOKS_DIR = PROJECT_ROOT / 'notebooks'
SRC_PATH = str(PROJECT_ROOT / 'src')

if SRC_PATH not in sys.path:
    sys.path.append(SRC_PATH)

In [None]:
from functools import partial
from torchvision.models import resnet

class SplitBatchNorm(torch.nn.BatchNorm2d):
    def __init__(self, num_features, num_splits, **kw):
        super().__init__(num_features, **kw)
        self.num_splits = num_splits

    def forward(self, input):
        N, C, H, W = input.shape
        if self.training or not self.track_running_stats:
            running_mean_split = self.running_mean.repeat(self.num_splits)
            running_var_split = self.running_var.repeat(self.num_splits)
            outcome = torch.nn.functional.batch_norm(
                input.view(-1, C * self.num_splits, H, W),
                running_mean_split,
                running_var_split,
                self.weight.repeat(self.num_splits),
                self.bias.repeat(self.num_splits),
                True,
                self.momentum,
                self.eps,
            ).view(N, C, H, W)
            self.running_mean.data.copy_(
                running_mean_split.view(self.num_splits, C).mean(dim=0)
            )
            self.running_var.data.copy_(
                running_var_split.view(self.num_splits, C).mean(dim=0)
            )
            return outcome
        else:
            return torch.nn.functional.batch_norm(
                input,
                self.running_mean,
                self.running_var,
                self.weight,
                self.bias,
                False,
                self.momentum,
                self.eps,
            )
        
class ModelBase(torch.nn.Module):
    """
    Common CIFAR ResNet recipe.
    Comparing with ImageNet ResNet recipe, it:
    (i) replaces conv1 with kernel=3, str=1
    (ii) removes pool1
    """

    def __init__(self, feature_dim=128, arch="resnet18", bn_splits=8):
        super(ModelBase, self).__init__()

        # use split batchnorm
        norm_layer = (
            partial(SplitBatchNorm, num_splits=bn_splits)
            if bn_splits > 1
            else torch.nn.BatchNorm2d
        )
        # print(norm_layer)
        resnet_arch = getattr(resnet, arch)
        # print(resnet_arch)
        net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer)

        self.net = []
        for name, module in net.named_children():
            print(name)
            if name == "conv1":
                module = torch.nn.Conv2d(
                    3, 64, kernel_size=3, stride=1, padding=1, bias=False
                )
            if isinstance(module, torch.nn.MaxPool2d):
                continue
            if isinstance(module, torch.nn.Linear):
                self.net.append(torch.nn.Flatten(1))
            self.net.append(module)

        self.net = torch.nn.Sequential(*self.net)

    def forward(self, x):
        x = self.net(x)
        # note: not normalized here
        return x

def create_encoder(emb_dim):
    model = ModelBase(emb_dim)
    # model = torch.nn.DataParallel(model)
    # model.to(device)
    return model

encoder = create_encoder(256)

In [None]:
RTDS_DF = pd.read_csv(DATA_DIR / 'RTDS_DATASET.csv')
RTDS_DF['filepath'] = RTDS_DF['filepath'].apply(lambda x: str(DATA_DIR / x))

# UNFIX TRAIN
# SIMPLE_FIX = True
# JUST_FIX = False
RTDS_DF.drop_duplicates(subset=['filepath'], inplace=True)
# RTDS_DF.drop_duplicates(subset=['SET', 'SIGN'], inplace=True)

RTDS_DF

In [None]:
RTDS_DF.groupby(['SIGN', 'SET']).size()

In [None]:
SAMPLE_NUMBER = 13
# RTDS_DF = RTDS_DF.groupby(['SIGN', 'SET']).sample(SAMPLE_NUMBER, random_state=RANDOM_STATE).reset_index(drop=True)
# LEARN_RTDS_DF.groupby(['SIGN']).sample(SAMPLE_NUMBER, random_state=RANDOM_STATE).reset_index(drop=True)
# RTDS_DF

In [None]:
# set(RTDS_DF['SIGN'])

In [None]:
class SignDataset(torch.utils.data.Dataset):
    def __init__(self, df, set_label=None, hyp=None, transform=None, le=None):
                
        self.transform = transform
        
        if set_label == None:
            self.df = df
        else:
            self.df = df[df['SET']==set_label]
        
        self.hyp = hyp

    def __len__(self):
        return len(self.df.index)
    
    def __getitem__(self, index): 
        label = int(self.df.iloc[index]['ENCODED_LABEL'])
        path = str(self.df.iloc[index]['filepath'])
        sign = str(self.df.iloc[index]['SIGN'])
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        
        # check does it contains transparent channel 
        if img.shape[2] == 4:
        # randomize transparent
            trans_mask = img[:,:,3] == 0
            img[trans_mask] = [random.randrange(0, 256), 
                               random.randrange(0, 256), 
                               random.randrange(0, 256), 
                               255]

            img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
        # /randomize transparent
                
        # augment
        if self.hyp and self.transform:
            img, _ =  random_perspective(img, 
                                      (),
                                      degrees=self.hyp['degrees'],
                                      translate=self.hyp['translate'],
                                      scale=self.hyp['scale'],
                                      shear=self.hyp['shear'],
                                      perspective=self.hyp['perspective'],
                                      border=self.hyp['border'])   
        if self.transform:
            img = self.transform(image=img)['image']
        # /augment
        
        img = img / 255
        return img, label, (path, sign)

In [None]:
from albumentations.augmentations.geometric.transforms import Perspective, ShiftScaleRotate
from albumentations.core.transforms_interface import ImageOnlyTransform
from albumentations.pytorch.transforms import ToTensorV2
from albumentations.augmentations.transforms import PadIfNeeded
from albumentations.augmentations.geometric.resize import LongestMaxSize

# import torchvision.transforms.functional as F

img_size = 40

MINIMAL_TRANSFORM = False

if MINIMAL_TRANSFORM:
    transform = A.Compose(
        [
        LongestMaxSize(img_size),
        PadIfNeeded(
            img_size, 
            img_size, 
            border_mode=cv2.BORDER_CONSTANT, 
            value=0
        ),
        ToTensorV2(),
        ]
    )
else:
    transform = A.Compose(
        [
        A.Blur(blur_limit=2),
        A.CLAHE(p=0.5),
        A.Perspective(scale=(0.01, 0.1), p=0.5), 
        A.ShiftScaleRotate(shift_limit=0.05,
                           scale_limit=0.05,
                           interpolation=cv2.INTER_LANCZOS4, 
                           border_mode=cv2.BORDER_CONSTANT, 
                           value=(0,0,0),
                           rotate_limit=6, p=0.5),
        A.RandomGamma(
            gamma_limit=(50, 130), 
            p=1
        ),
        A.ImageCompression(quality_lower=80, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.5, 
                                   contrast_limit=0.3, 
                                   brightness_by_max=False, 
                                   p=0.5),
        A.CoarseDropout(max_height=3, 
                        max_width=3, 
                        min_holes=1, 
                        max_holes=3, 
                        p=0.5),
        LongestMaxSize(img_size),
        PadIfNeeded(
            img_size, 
            img_size, 
            border_mode=cv2.BORDER_CONSTANT, 
            value=0
        ),
        ToTensorV2(),
        ]
    )

train_dataset = SignDataset(RTDS_DF, 
                            set_label='train',  
                            transform=transform, 
                            hyp=None)

valid_dataset = SignDataset(RTDS_DF, 
                            set_label='valid',  
                            transform=transform, 
                            hyp=None)

# valid_dataset = train_dataset

In [None]:
# train_dataset[0]

In [None]:
def getNSamplesFromDataSet(ds, N):
    random_index = random.sample(range(0, len(ds)), N)
    ret = []
    for index in random_index:
        ret.append(ds[index])
    return ret

IMG_COUNT = 18
nrows, ncols = 70, 6
fig = plt.figure(figsize = (16,200))

PLOT_SOFT_LIMIT = 20

TEMP_DS = getNSamplesFromDataSet(train_dataset, 20)
# TEMP_DS = train_dataset.sort_values(['SIGN'], axis=1)
# TEMP_DS = train_dataset
for idx, (img, encoded_label, info) in enumerate(TEMP_DS):
    
    img = torch.Tensor.permute(img, [1, 2, 0]).numpy() 
    ax = fig.add_subplot(nrows, ncols, idx+1)
        
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), aspect=1)
    
    title = str(info[1])
    
    ax.set_title(title, fontsize=15)
    
    if idx > PLOT_SOFT_LIMIT:
        print('[!] plot soft limit reached. Breaking.')
        break
        
plt.tight_layout()
plt.show()

In [None]:
batch_size = 896 if IN_COLAB else 56
num_workers = 2 if IN_COLAB else 0

from torch.utils.data import DataLoader

def getDataLoaderFromDataset(dataset, shuffle=False, drop_last=True):
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=shuffle,
        drop_last=drop_last
    )
    
    return loader


train_loader = getDataLoaderFromDataset(
    train_dataset,
    shuffle=True
)

In [None]:
next(encoder.parameters()).is_cuda

In [None]:
# sample_img, sample_label, _ = next(iter(train_loader))

In [None]:
# print(sample_img)
# with torch.no_grad():
#     encoder.eval()
#     out = encoder(sample_img).detach().cpu()
    
# out

In [None]:
# torch.save(encoder.state_dict(), 'last_encoder')

In [None]:
def saveCheckpoint(model, scheduler, optimizer, epoch, filename):
    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }, filename)

def loadCheckpoint(model, scheduler, optimizer, filename):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    epoch = checkpoint['epoch']
    return model, optimizer, scheduler, epoch

In [None]:
# encoder, optimizer, scheduler, started_epoch = loadCheckpoint(encoder, scheduler, optimizer, 'sample')

In [None]:
# saveCheckpoint(encoder, encoder, optimizer, 0, 'sample')
# encoder.load_state_dict(torch.load('sample')['model'])
# encoder.eval()
assert True

In [None]:
from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

from pytorch_metric_learning.utils import common_functions as c_f
from tqdm.notebook import trange, tqdm

config = {
    'lr': 0.1,
    'epochs': 50,
    'momentum':  0.937,
    'margin': 0
}

# optimizer = torch.optim.SGD(encoder.parameters(), lr=config['lr'], momentum=config['momentum'], nesterov=True)
# encoder.to('cpu')
optimizer = torch.optim.Adam(encoder.parameters(), lr=0.1)
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 
#                                              base_lr=0.00001, 
#                                              max_lr=config['lr'],
#                                              step_size_up=50,
#                                              step_size_down=20,
#                                              mode="exp_range",
#                                              gamma=0.9,
#                                              cycle_momentum=False
#                                            )
distance = distances.LpDistance()
reducer = reducers.AvgNonZeroReducer()
loss_func = losses.TripletMarginLoss(margin=config['margin'], distance=distance, reducer=reducer)

# mining_func = miners.MultiSimilarityMiner(epsilon=0.1)
mining_func = miners.TripletMarginMiner(margin=config['margin'], distance=distance, type_of_triplets="all")

accuracy_calculator = AccuracyCalculator(
    include=("precision_at_1",
             "mean_average_precision_at_r"), k=1)

try:
    # encoder, optimizer, scheduler, started_epoch = loadCheckpoint(encoder, scheduler, optimizer, 'sample')
    started_epoch
    print('[+] check point loaded')
except:
    started_epoch = 0
    print('[!] check point doesnt exist')

encoder.to(device)    

def get_all_embeddings(dataset, model):
    tester = testers.BaseTester(dataloader_num_workers=0)
    res =  tester.get_all_embeddings(dataset, model)
    return res

### convenient function from pytorch-metric-learning ###
@torch.no_grad()
def simpleGetAllEmbeddings(model, dataset, batch_size, dsc=''):
    
    dataloader = getDataLoaderFromDataset(
        dataset,
        shuffle=True,
        drop_last=False
    )
    
    s, e = 0, 0
    
    pbar = tqdm(
        enumerate(dataloader), 
        total=len(dataloader),
        position=0,
        leave=False,
        desc='Getting all embeddings...' + dsc)
    info_arr = []
    
    add_info_len = None
    
    for idx, (data, labels, info) in pbar:
        data = data.to(device)
        
        q = model(data)
        
        if labels.dim() == 1:
            labels = labels.unsqueeze(1)
        if idx == 0:
            labels_ret = torch.zeros(
                len(dataloader.dataset),
                labels.size(1),
                device=device,
                dtype=labels.dtype,
            )
            all_q = torch.zeros(
                len(dataloader.dataset),
                q.size(1),
                device=device,
                dtype=q.dtype,
            )
        
        info = np.array(info)
        if add_info_len == None:
            add_info_len = info.shape[0]
        
        info_arr.extend(info.T.reshape((-1, add_info_len)))
        # return info_arr
        # print(info)
        # input()
        # print(len(info))
        e = s + q.size(0)
        all_q[s:e] = q
        labels_ret[s:e] = labels
        s = e  
    
    all_q = torch.nn.functional.normalize(all_q)
    # print(info)
    return all_q, labels_ret, info_arr

### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator, batch_size):
    model.eval()
    train_embeddings, train_labels, _ = simpleGetAllEmbeddings(model, train_set, batch_size, ' for train')
    test_embeddings, test_labels, _ = simpleGetAllEmbeddings(model, test_set, batch_size, ' for test')
    
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    # print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, train_embeddings, test_labels, train_labels, False
    )
    print(accuracies)
    # print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))
    return accuracies["precision_at_1"]
    
    
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    loss_sum = 0
    
    pbar = tqdm(
        enumerate(train_loader), 
        total=len(train_loader),
        position=0,
        leave=False,
        desc='WAITING...')
    
    USING_CentroidTripletLoss_FLAG = False
    USING_MultiSimilarityMiner_FLAG = False
    if isinstance(loss_func, losses.CentroidTripletLoss):
        USING_CentroidTripletLoss_FLAG = True
    if isinstance(mining_func, miners.MultiSimilarityMiner):
        USING_MultiSimilarityMiner_FLAG = True
        
    for batch_idx, (data, labels, _) in pbar:
        
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)

        if USING_CentroidTripletLoss_FLAG:
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embeddings],
                requires_grad=True,
                dtype=dtype,
            ).to(
                device
            )
            print(embeddings.shape)
            print(labels.shape)
            loss = loss_func(embeddings, labels)
        else:
            indices_tuple = mining_func(embeddings, labels)
            loss = loss_func(embeddings, labels, indices_tuple)

        instant_loss = loss.item()
        loss_sum += instant_loss
        
        loss.backward()
        optimizer.step()
        
        if USING_CentroidTripletLoss_FLAG or USING_MultiSimilarityMiner_FLAG:
            pbar.set_description("TRAIN: INSTANT MEAN LOSS %f" % 
                             (round(instant_loss / len(labels), 3))
                            )            
        else:
            pbar.set_description("TRAIN: INSTANT MEAN LOSS %f, MINED TRIPLET: %d" % 
                             (round(instant_loss / len(labels), 3),
                             mining_func.num_triplets)
                            )
        # if batch_idx >= 0:
        #     break
            
    return loss_sum / (train_loader.batch_size * len(train_loader))

torch.cuda.empty_cache()

pbar = trange(
        started_epoch, 
        config['epochs'], 
        initial=started_epoch, 
        total=config['epochs'],
        leave=True,
        desc='WAITING FOR FIRST EPOCH END...')

mean_acc = -1

# assert False
for epoch in pbar:
    
    # plotSmth(encoder, train_dataset, device=device, dim3=False, fcn='umap')
    train_loss = train(encoder, loss_func, mining_func, device, train_loader, optimizer, epoch)
    
    if (epoch + 1) % 8 == 0:
        mean_acc = test(train_dataset, valid_dataset, encoder, accuracy_calculator, batch_size)
    
    # print(lr_val)
    # lr_val = scheduler.get_last_lr()[0]
    # saveCheckpoint(encoder, scheduler, optimizer, epoch, 'sample')
    # plotSmth(encoder, CONST_MINIMAL_DATASET, device=device, dim3=False, fcn='umap')
    # scheduler.step()
    
    mean_train_acc = mean_valid_acc = 0
    lr_val = 1
    pbar.set_description("PER %d EPOCH: TRAIN LOSS: %.4f; VALID ACCUR: %.4f, LR %.2e" % (
        epoch,
        train_loss, 
        mean_acc,
        lr_val)
    )

In [None]:
additional_DF = pd.DataFrame(columns=RTDS_DF.columns)
# display(additional_DF)
encode_offset = max(set(RTDS_DF['ENCODED_LABEL'])) + 1

files = os.listdir(DATA_DIR / 'additional_sign')

sign_list = list(set([x.split('_')[0] for x in files]))

for file in files:
    sign = file.split('_')[0]
    # print(file.split('_')[1].split('.')[0])
    encoded_label = encode_offset + int(sign_list.index(sign))
    
    # print(sign)
    row = {'filepath': str(DATA_DIR / 'additional_sign' / file), 'SIGN':sign, 'ENCODED_LABEL':encoded_label, 'SET':'valid'} 
    additional_DF = additional_DF.append(row, ignore_index=True)
display(additional_DF)

minimal_transform = A.Compose(
        [
        LongestMaxSize(img_size),
        PadIfNeeded(
            img_size, 
            img_size, 
            border_mode=cv2.BORDER_CONSTANT, 
            value=0
        ),
        ToTensorV2(),
        ]
    )
    
additional_dataset = SignDataset(
    additional_DF,
    transform=minimal_transform
)

add_dataset_dict = dict(zip(additional_DF.ENCODED_LABEL, additional_DF.SIGN))

IMG_COUNT = 18
nrows, ncols = 70, 6
fig = plt.figure(figsize = (16,200))

PLOT_SOFT_LIMIT = 20

for idx, (img, encoded_label, info) in enumerate(additional_dataset):
    
    img = torch.Tensor.permute(img, [1, 2, 0]).numpy() 
    ax = fig.add_subplot(nrows, ncols, idx+1)
        
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), aspect=1)
    
    title = str(info[1])
    
    ax.set_title(title, fontsize=15)
    
    if idx > PLOT_SOFT_LIMIT:
        print('[!] plot soft limit reached. Breaking.')
        break
plt.tight_layout()

add_dataset_dict

In [None]:
from tqdm.notebook import trange, tqdm
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from umap import UMAP
import plotly.graph_objects as go
from ipywidgets import Output, VBox
    

@torch.no_grad()
def plotSmth(model, 
             dataset, 
             batch_size, 
             device='cpu', 
             dim3=False, 
             fcn='PCA', 
             reducer_arg=None, 
             dsc='', 
             label_dict=None, 
             dot_limit=5000,
             additional_dataset=None,
             main_dataset_marker_size=10,
             additional_dataset_marker_size=20,
             color_dict=None
            ):
    
    model.eval()
        
    if len(dataset) > dot_limit:
        print("[!] Dot limit! Random choice", dot_limit, '\nSrc len', len(dataset))
        indicies = np.random.choice(len(dataset), dot_limit, replace=False)
        dataset = torch.utils.data.Subset(dataset, indicies)
    
    # main dataset data
    embeddings, labels, info = simpleGetAllEmbeddings(model, dataset, batch_size, dsc)
    embeddings = embeddings.cpu().numpy()
    labels = labels.cpu().numpy().flatten()[:, None]
    size = np.ones(labels.shape) * main_dataset_marker_size

    # additional dataset data
    if additional_dataset:
        embeddings_addon, labels_addon, info_addon = simpleGetAllEmbeddings(
            model, 
            additional_dataset, 
            batch_size, 
            dsc='for addon')
        
        embeddings_addon = embeddings_addon.cpu().numpy()
        labels_addon = labels_addon.cpu().numpy().flatten()[:, None]
        
        size_addon = np.ones(labels_addon.shape) * additional_dataset_marker_size
        
        size = np.concatenate((size, size_addon))
        embeddings = np.concatenate((embeddings, embeddings_addon))
        labels = np.concatenate((labels, labels_addon))
        info.extend(info_addon)
        
        del embeddings_addon, labels_addon, size_addon, info_addon
            
    # plt.figure(figsize=(32, 32))    
    # plt.clf()
    
    n_components = 3 if dim3 else 2
    
    fcn = fcn.upper()
    
    if reducer_arg == None:
        if fcn == 'PCA':
            reducer = PCA(n_components=n_components, random_state=RANDOM_STATE) 
        elif fcn == 'TSNE':
            reducer = TSNE(n_components=n_components, init='random', random_state=RANDOM_STATE)
        elif fcn == 'UMAP':
            reducer = UMAP(n_components=n_components, init='random', random_state=RANDOM_STATE)
        else:
            assert False, "wrong fcn arg"        
    else:
        reducer = reducer_arg
        

    if reducer_arg == None:
        X_embedded = reducer.fit_transform(embeddings)
    else:
        X_embedded = reducer.transform(embeddings)      
    
    # print(int(x) for x in labels)
    if label_dict:
        try:
            target_arr_color = [label_dict[int(x)] for x in labels]
        except:
            print('label dict broken')
            target_arr_color = labels
    else:
        target_arr_color = labels
        
    target_arr_color = np.array(target_arr_color)[:, None]
    # print(target_arr_color.shape)
    hover_data = np.array([x[1] + ':' + x[0] for x in info])[:, None]
    
    # now embeedings, labels, info, size are concatenated. Let's build dataframe from it
    
    plot_df_data = np.concatenate([X_embedded, target_arr_color, size, hover_data], axis=1)
    if dim3:
        columns = columns=['x', 'y', 'z', 'group', 'size', 'hover_data']
    else:
        columns = columns=['x', 'y', 'group', 'size', 'hover_data']
        
    plot_df = pd.DataFrame(plot_df_data, columns=columns)
    plot_df['size'] = plot_df['size'].apply(pd.to_numeric)
    
    fig = go.FigureWidget()
    
    groups = plot_df['group'].unique()
    
    if dim3:
        plot_df[['x', 'y', 'z']] = plot_df[['x', 'y', 'z']].apply(pd.to_numeric)
        
        for group in groups: 
            df = plot_df.loc[plot_df['group'] == group]
            group_size = df['size'].iloc[0]
            symbol = 'circle' if group_size == main_dataset_marker_size else 'diamond'
            
            fig.add_trace(go.Scatter3d(
                x=df['x'],
                y=df['y'],
                z=df['z'],
                mode='markers',
                marker=dict(
                    size=df['size'],
                    opacity=1,
                    symbol=symbol
                ),
                text=df['hover_data'],
                hovertemplate='Sign: %{text}<extra></extra>',
                name=group,
            ))
    else:
        plot_df[['x', 'y']] = plot_df[['x', 'y']].apply(pd.to_numeric)
        
        for group in groups: 
            df = plot_df.loc[plot_df['group'] == group]
            group_size = df['size'].iloc[0]
            symbol = 'circle' if group_size == main_dataset_marker_size else 'diamond'
            
            fig.add_trace(go.Scatter(
                x=df['x'],
                y=df['y'],
                mode='markers',
                marker=dict(
                    size=df['size'],
                    opacity=1,
                    symbol=symbol
                ),
                text=df['hover_data'],
                name=group,
            ))
            
    fig.update_layout(
        hoverlabel=dict(
            bgcolor="white",
            font_size=12,
            font_family="Rockwell"
        ),
        width=800,
        height=900
    )
    fig.show()
    
    return reducer, plot_df

label_dict = dict(zip(RTDS_DF.ENCODED_LABEL, RTDS_DF.SIGN))
# label_dict.update(add_dataset_dict)

dim3=False
r = plotSmth(encoder, 
             train_dataset, 
             batch_size=batch_size, 
             device=device, 
             dim3=dim3, 
             fcn='umap', 
             label_dict=label_dict, 
             dot_limit=9000,
             # additional_dataset=additional_dataset
            )
assert False
plotSmth(encoder, 
         valid_dataset, 
         batch_size=batch_size, 
         device=device, 
         dim3=dim3, 
         reducer_arg=r, 
         label_dict=label_dict,
         dot_limit=6000,
         # additional_dataset=additional_dataset
        )


In [None]:
pip install --upgrade tbb

In [None]:
RTDS_DF[RTDS_DF['SET'] == 'train']

In [None]:
RTDS_DF[RTDS_DF['SET'] == 'test']

In [None]:
RTDS_DF

In [None]:
transform

In [None]:
additional_DF = pd.DataFrame(columns=RTDS_DF.columns)
# display(additional_DF)
encode_offset = max(set(RTDS_DF['ENCODED_LABEL'])) + 1

files = os.listdir(DATA_DIR / 'additional_sign')

for file in files:
    sign = file.split('_')[0]
    encoded_label = encode_offset
    # print(sign)
    row = {'filepath': str(DATA_DIR / 'additional_sign' / file), 'SIGN':sign, 'ENCODED_LABEL':encoded_label, 'SET':'valid'} 
    additional_DF = additional_DF.append(row, ignore_index=True)
display(additional_DF)

additional_dataset = SignDataset(
    additional_DF,
    transform=transform
)

In [None]:
nrows, ncols = 70, 6
fig = plt.figure(figsize = (16,200))

PLOT_SOFT_LIMIT = 80

TEMP_DS = additional_dataset
for idx, (img, encoded_label, info) in enumerate(TEMP_DS):
    
    img = torch.Tensor.permute(img, [1, 2, 0]).numpy() 
    ax = fig.add_subplot(nrows, ncols, idx+1)
        
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), aspect=1)
    
    title = str(info[1])
    
    ax.set_title(title, fontsize=15)
    
    if idx > PLOT_SOFT_LIMIT:
        print('[!] plot soft limit reached. Breaking.')
        break
        
plt.tight_layout()
plt.show()
