## Цель ноутбука: изучение метода Few Shots Learning

### Ниже код, специфичный для запуска на платформе Google Colab.

In [None]:
#!git clone --branch 11_ShotLearning_Encoder https://github.com/lsd-maddrive/adas_system.git
#!gdown --id 1-K3ee1NbMmx_0T5uwMesStmKnZO_6mWi
#%cd adas_system
#!pip install -r requirements.txt
#!pip install faiss-cpu faiss
#!pip install --upgrade tbb
#%cd SignDetectorAndClassifier/notebooks
#!unzip -q -o /content/R_MERGED.zip -d ./../data/

%cd adas_system/SignDetectorAndClassifier/notebooks

#### В RTSD не хватает 14 знаков:

| Знак | Описание | Источник |
| ------------- | ------------- | ---- |
| 1.6 | Пересечение равнозначных дорог | - |
| 1.31 | Туннель | - |
| 2.4 | Уступите дорогу | GTSRB Recognition |
| 3.21 | Конец запрещения обгона | GTSRB Recognition |
| 3.22 | Обгон грузовым автомобилям запрещен | GTSRB Recognition |
| 3.23 | Конец запрещения обгона грузовым автомобилям | GTSRB Recognition |
| 3.24-90 | Огр 90 | - |
| 3.24-100 | Огр 100 | GTSRB Recognition |
| 3.24-110 | Огр 110 | - |
| 3.24-120 | Огр 120 | GTSRB Recognition |
| 3.24-130 | Огр 130 | - |
| 3.25 | Конец огр. максимальной скорости | GTSRB Recognition |
| 3.31 | Конец всех ограничений | GTSRB Recognition |
| 6.3.2 | Зона для разворота | - |

Инициализация библиотек

In [None]:
import albumentations as A
if A.__version__ != '1.0.3':
    !pip install albumentations==1.0.3
    !pip install opencv-python-headless==4.5.2.52
    assert False, 'restart runtime pls'

import matplotlib.pyplot as plt
import numpy as np
import random
import torch
from torch import nn
import seaborn as sns
import pandas as pd
import os
import pathlib
import shutil
import cv2
import PIL
import cv2
import sys
from datetime import datetime

TEXT_COLOR = 'black'
# Зафиксируем состояние случайных чисел
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
%matplotlib inline
plt.rcParams["figure.figsize"] = (17,10)

USE_COLAB_GPU = False
IN_COLAB = False

try:
    import google.colab
    IN_COLAB = True
    USE_COLAB_GPU = True
    from google.colab import drive
except:
    if IN_COLAB:
        print('[!] YOU ARE IN COLAB, BUT DIDNT MOUND A DRIVE. Model wont be synced[!]')

        if not os.path.isfile(CURRENT_FILE_NAME):
            print("FIX ME")
        IN_COLAB = False

    else:
        print('[!] RUNNING NOT IN COLAB')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Инициализация основных путей и папки src

In [None]:
if not IN_COLAB:
    PROJECT_ROOT = pathlib.Path(os.path.join(os.curdir, os.pardir))
else:
    PROJECT_ROOT = pathlib.Path('..')
    
DATA_DIR = PROJECT_ROOT / 'data'
NOTEBOOKS_DIR = PROJECT_ROOT / 'notebooks'
SRC_PATH = str(PROJECT_ROOT / 'src')

if SRC_PATH not in sys.path:
    sys.path.append(SRC_PATH)

In [None]:
RTDS_DF = pd.read_csv(DATA_DIR / 'RTDS_DATASET.csv')
RTDS_DF['filepath'] = RTDS_DF['filepath'].apply(lambda x: str(DATA_DIR / x))

# UNFIX TRAIN
# SIMPLE_FIX = True
# JUST_FIX = False
RTDS_DF.drop_duplicates(subset=['filepath'], inplace=True)
# RTDS_DF.drop_duplicates(subset=['SET', 'SIGN'], inplace=True)

RTDS_DF

In [None]:
# SAMPLE_NUMBER = 13 # min(RTDS_DF.groupby(['SIGN', 'SET']).size())
# RTDS_DF = RTDS_DF.groupby(['SIGN', 'SET']).apply(lambda x: x.sample(frac=0.1))# sample(SAMPLE_NUMBER, random_state=RANDOM_STATE).reset_index(drop=True)
# LEARN_RTDS_DF.groupby(['SIGN']).sample(SAMPLE_NUMBER, random_state=RANDOM_STATE).reset_index(drop=True)
RTDS_DF

In [None]:
class SignDataset(torch.utils.data.Dataset):
    def __init__(self, df, set_label=None, hyp=None, transform=None, le=None):
                
        self.transform = transform
        
        if set_label == None:
            self.df = df
        else:
            self.df = df[df['SET']==set_label]
        
        self.hyp = hyp

    def __len__(self):
        return len(self.df.index)
    
    def __getitem__(self, index): 
        label = int(self.df.iloc[index]['ENCODED_LABEL'])
        path = str(self.df.iloc[index]['filepath'])
        sign = str(self.df.iloc[index]['SIGN'])
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        
        # check does it contains transparent channel 
        if img.shape[2] == 4:
        # randomize transparent
            trans_mask = img[:,:,3] == 0
            img[trans_mask] = [random.randrange(0, 256), 
                               random.randrange(0, 256), 
                               random.randrange(0, 256), 
                               255]

            img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
        # /randomize transparent
                
        # augment 
        if self.transform:
            img = self.transform(image=img)['image']
        # /augment
        
        img = img / 255
        return img, label, (path, sign)

In [None]:
from albumentations.augmentations.geometric.transforms import Perspective, ShiftScaleRotate
from albumentations.core.transforms_interface import ImageOnlyTransform
from albumentations.pytorch.transforms import ToTensorV2
from albumentations.augmentations.transforms import PadIfNeeded
from albumentations.augmentations.geometric.resize import LongestMaxSize

img_size = 40


minimal_transform = A.Compose(
        [
        LongestMaxSize(img_size),
        PadIfNeeded(
            img_size, 
            img_size, 
            border_mode=cv2.BORDER_CONSTANT, 
            value=0
        ),
        ToTensorV2(),
        ]
    )

transform = A.Compose(
        [
        A.Blur(blur_limit=2),
        A.CLAHE(p=0.5),
        A.Perspective(scale=(0.01, 0.1), p=0.5), 
        A.ShiftScaleRotate(shift_limit=0.05,
                           scale_limit=0.05,
                           interpolation=cv2.INTER_LANCZOS4, 
                           border_mode=cv2.BORDER_CONSTANT, 
                           value=(0,0,0),
                           rotate_limit=6, p=0.5),
        A.RandomGamma(
            gamma_limit=(50, 130), 
            p=1
        ),
        A.ImageCompression(quality_lower=80, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.5, 
                                   contrast_limit=0.3, 
                                   brightness_by_max=False, 
                                   p=0.5),
        A.CoarseDropout(max_height=3, 
                        max_width=3, 
                        min_holes=1, 
                        max_holes=3, 
                        p=0.5),
        LongestMaxSize(img_size),
        PadIfNeeded(
            img_size, 
            img_size, 
            border_mode=cv2.BORDER_CONSTANT, 
            value=0
        ),
        ToTensorV2(),
        ]
    )

train_dataset = SignDataset(RTDS_DF, 
                            set_label='train',  
                            transform=transform, 
                            hyp=None)

valid_dataset = SignDataset(RTDS_DF, 
                            set_label='valid',  
                            transform=minimal_transform, 
                            hyp=None)

In [None]:
def getNSamplesFromDataSet(ds, N):
    random_index = random.sample(range(0, len(ds)), N)
    ret = []
    for index in random_index:
        ret.append(ds[index])
    return ret

IMG_COUNT = 18
nrows, ncols = 70, 6
fig = plt.figure(figsize = (16,200))

PLOT_SOFT_LIMIT = 20

TEMP_DS = getNSamplesFromDataSet(train_dataset, 20)
# TEMP_DS = train_dataset.sort_values(['SIGN'], axis=1)
# TEMP_DS = train_dataset
for idx, (img, encoded_label, info) in enumerate(TEMP_DS):
    
    img = torch.Tensor.permute(img, [1, 2, 0]).numpy() 
    ax = fig.add_subplot(nrows, ncols, idx+1)
        
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), aspect=1)
    
    title = str(info[1])
    
    ax.set_title(title, fontsize=15)
    
    if idx > PLOT_SOFT_LIMIT:
        print('[!] plot soft limit reached. Breaking.')
        break
        
plt.tight_layout()
plt.show()

In [None]:
batch_size = 896 if IN_COLAB else 56
num_workers = 2 if IN_COLAB else 0

from torch.utils.data import DataLoader

def getDataLoaderFromDataset(dataset, shuffle=False, drop_last=True):
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=shuffle,
        drop_last=drop_last
    )
    
    return loader


train_loader = getDataLoaderFromDataset(
    train_dataset,
    shuffle=True
)

In [None]:
def saveCheckpoint(model, scheduler, optimizer, epoch, filename):
    torch.save({
        'epoch': epoch if epoch else None,
        'model': model.state_dict() if model else None,
        'optimizer': optimizer.state_dict() if optimizer else None,
        'scheduler': scheduler.state_dict() if scheduler else None
    }, filename)

def loadCheckpoint(model, scheduler, optimizer, filename):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    epoch = checkpoint['epoch']
    return model, optimizer, scheduler, epoch

In [None]:
from torchvision.models import resnet

def create_encoder(emb_dim):
    model = resnet.resnet18(pretrained=True)
    model.fc = nn.Linear(in_features=512, out_features=emb_dim, bias=True)

    return model

encoder = create_encoder(1024)
# encoder

In [None]:
from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

from pytorch_metric_learning.utils import common_functions as c_f
from tqdm.notebook import trange, tqdm

config = {
    'lr': 0.1,
    'epochs': 500,
    'momentum':  0.937,
    'margin': 0
}

optimizer = torch.optim.SGD(encoder.parameters(), lr=config['lr'], momentum=config['momentum'], nesterov=True)
# encoder.to('cpu')
# optimizer = torch.optim.Adam(encoder.parameters(), config['lr'])
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 
#                                              base_lr=0.00001, 
#                                              max_lr=config['lr'],
#                                              step_size_up=50,
#                                              step_size_down=20,
#                                              mode="exp_range",
#                                              gamma=0.9,
#                                              cycle_momentum=False
#                                            )
distance = distances.LpDistance()
reducer = reducers.AvgNonZeroReducer()
loss_func = losses.TripletMarginLoss(margin=config['margin'], distance=distance, reducer=reducer)

# mining_func = miners.MultiSimilarityMiner(epsilon=0.1)
mining_func = miners.TripletMarginMiner(margin=config['margin'], distance=distance, type_of_triplets="hard")

accuracy_calculator = AccuracyCalculator(k=5)
    # include=("precision_at_1",
    #          "mean_average_precision_at_r"), k=1)

try:
    # encoder, optimizer, scheduler, started_epoch = loadCheckpoint(encoder, scheduler, optimizer, 'sample')
    started_epoch
    print('[+] check point loaded')
except:
    started_epoch = 0
    print('[!] check point doesnt exist')

encoder.to(device)    


### convenient function from pytorch-metric-learning ###
@torch.no_grad()
def simpleGetAllEmbeddings(model, dataset, batch_size, dsc=''):
    
    dataloader = getDataLoaderFromDataset(
        dataset,
        shuffle=True,
        drop_last=False
    )
    
    s, e = 0, 0
    pbar = tqdm(
        enumerate(dataloader), 
        total=len(dataloader),
        position=0,
        leave=False,
        desc='Getting all embeddings...' + dsc)
    info_arr = []
    
    add_info_len = None
    
    for idx, (data, labels, info) in pbar:
        data = data.to(device)
        
        q = model(data)
        
        if labels.dim() == 1:
            labels = labels.unsqueeze(1)
        if idx == 0:
            labels_ret = torch.zeros(
                len(dataloader.dataset),
                labels.size(1),
                device=device,
                dtype=labels.dtype,
            )
            all_q = torch.zeros(
                len(dataloader.dataset),
                q.size(1),
                device=device,
                dtype=q.dtype,
            )
        
        info = np.array(info)
        if add_info_len == None:
            add_info_len = info.shape[0]
        
        info_arr.extend(info.T.reshape((-1, add_info_len)))
        e = s + q.size(0)
        all_q[s:e] = q
        labels_ret[s:e] = labels
        s = e  
    
    all_q = torch.nn.functional.normalize(all_q)
    return all_q, labels_ret, info_arr

### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator, batch_size):
    model.eval()
    train_embeddings, train_labels, _ = simpleGetAllEmbeddings(model, train_set, batch_size, ' for train')
    test_embeddings, test_labels, _ = simpleGetAllEmbeddings(model, test_set, batch_size, ' for test')
    
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    # print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, train_embeddings, test_labels, train_labels, False
    )
    print(accuracies)
    # print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))
    return accuracies["precision_at_1"]
    
    
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    loss_sum = 0
    
    pbar = tqdm(
        enumerate(train_loader), 
        total=len(train_loader),
        position=0,
        leave=False,
        desc='WAITING...')
    
    USING_CentroidTripletLoss_FLAG = False
    USING_MultiSimilarityMiner_FLAG = False
    if isinstance(loss_func, losses.CentroidTripletLoss):
        USING_CentroidTripletLoss_FLAG = True
    if isinstance(mining_func, miners.MultiSimilarityMiner):
        USING_MultiSimilarityMiner_FLAG = True
        
    for batch_idx, (data, labels, _) in pbar:
        
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)

        if USING_CentroidTripletLoss_FLAG:
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embeddings],
                requires_grad=True,
                dtype=dtype,
            ).to(
                device
            )
            print(embeddings.shape)
            print(labels.shape)
            loss = loss_func(embeddings, labels)
        else:
            indices_tuple = mining_func(embeddings, labels)
            loss = loss_func(embeddings, labels, indices_tuple)

        instant_loss = loss.item()
        loss_sum += instant_loss
        
        loss.backward()
        optimizer.step()
        
        if USING_CentroidTripletLoss_FLAG or USING_MultiSimilarityMiner_FLAG:
            pbar.set_description("TRAIN: INSTANT MEAN LOSS %f" % 
                             (round(instant_loss / len(labels), 3))
                            )            
        else:
            pbar.set_description("TRAIN: INSTANT MEAN LOSS %f, MINED TRIPLET: %d" % 
                             (round(instant_loss / len(labels), 3),
                             mining_func.num_triplets)
                            )
        # if batch_idx >= 0:
        #     break
            
    return loss_sum / (train_loader.batch_size * len(train_loader))

torch.cuda.empty_cache()

pbar = trange(
        started_epoch, 
        config['epochs'], 
        initial=started_epoch, 
        total=config['epochs'],
        leave=True,
        desc='WAITING FOR FIRST EPOCH END...')

mean_acc = -1

for epoch in pbar:
    
    # plotSmth(encoder, train_dataset, device=device, dim3=False, fcn='umap')
    train_loss = train(encoder, loss_func, mining_func, device, train_loader, optimizer, epoch)
    
    if (epoch + 1) % 10 == 0:
        mean_acc = test(train_dataset, valid_dataset, encoder, accuracy_calculator, batch_size)
    
    # print(lr_val)
    # lr_val = scheduler.get_last_lr()[0]
    # saveCheckpoint(encoder, scheduler, optimizer, epoch, 'sample')
    # plotSmth(encoder, CONST_MINIMAL_DATASET, device=device, dim3=False, fcn='umap')
    # scheduler.step()
    
    mean_train_acc = mean_valid_acc = 0
    lr_val = 1
    saveCheckpoint(encoder, None, optimizer, epoch, 'last_encoder')
    pbar.set_description("PER %d EPOCH: TRAIN LOSS: %.1e; VALID ACCUR: %.4f, LR %.1e" % (
        epoch + 1,
        train_loss, 
        mean_acc,
        lr_val)
    )

In [None]:
additional_DF = pd.DataFrame(columns=RTDS_DF.columns)

encode_offset = max(set(RTDS_DF['ENCODED_LABEL'])) + 1
files = os.listdir(DATA_DIR / 'additional_sign')

sign_list = list(set([x.split('_')[0] for x in files]))
for file in files:
    sign = file.split('_')[0]
    # print(file.split('_')[1].split('.')[0])
    encoded_label = encode_offset + int(sign_list.index(sign))
    
    # print(sign)
    row = {'filepath': str(DATA_DIR / 'additional_sign' / file), 'SIGN':sign, 'ENCODED_LABEL':encoded_label, 'SET':'valid'} 
    additional_DF = additional_DF.append(row, ignore_index=True)

display(additional_DF)    
additional_dataset = SignDataset(
    additional_DF,
    transform=minimal_transform
)

add_dataset_dict = dict(zip(additional_DF.ENCODED_LABEL, additional_DF.SIGN))

IMG_COUNT = 18
nrows, ncols = 70, 6
fig = plt.figure(figsize = (16,200))

PLOT_SOFT_LIMIT = 20

for idx, (img, encoded_label, info) in enumerate(additional_dataset):
    
    img = torch.Tensor.permute(img, [1, 2, 0]).numpy() 
    ax = fig.add_subplot(nrows, ncols, idx+1)
        
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), aspect=1)
    
    title = str(info[1])
    
    ax.set_title(title, fontsize=15)
    
    if idx > PLOT_SOFT_LIMIT:
        print('[!] plot soft limit reached. Breaking.')
        break
plt.tight_layout()

label_dict = dict(zip(RTDS_DF.ENCODED_LABEL, RTDS_DF.SIGN))
label_dict.update(add_dataset_dict)

In [None]:
@torch.no_grad()
def getInfoForFig(
    model, 
    dataset, 
    batch_size, 
    additional_dataset,
    main_dataset_marker_size=10,
    additional_dataset_marker_size=20,
    dot_limit=1000):
    
    model.eval()
    
    if len(dataset) > dot_limit:
        print("[!] Dot limit! Random choice", dot_limit, '\nSrc len', len(dataset))
        indicies = np.random.choice(len(dataset), dot_limit, replace=False)
        dataset = torch.utils.data.Subset(dataset, indicies)
        
    embeddings, labels, info = simpleGetAllEmbeddings(model, dataset, batch_size, dsc='for main dataset')
    embeddings = embeddings.cpu().numpy()
    labels = labels.cpu().numpy().flatten()[:, None]
    size = np.ones(labels.shape) * main_dataset_marker_size
    
    if additional_dataset:
        embeddings_addon, labels_addon, info_addon = simpleGetAllEmbeddings(
            model, 
            additional_dataset, 
            batch_size, 
            dsc='for addon')
        
        embeddings_addon = embeddings_addon.cpu().numpy()
        labels_addon = labels_addon.cpu().numpy().flatten()[:, None]
        
        size_addon = np.ones(labels_addon.shape) * additional_dataset_marker_size
        
        size = np.concatenate((size, size_addon))
        embeddings = np.concatenate((embeddings, embeddings_addon))
        labels = np.concatenate((labels, labels_addon))
        info.extend(info_addon)
        del embeddings_addon, labels_addon, size_addon, info_addon
        
    return embeddings, labels, info, size

embeddings, labels, info, size = getInfoForFig(
    encoder,
    train_dataset,
    batch_size,
    additional_dataset,
    dot_limit=3000)

In [None]:
import plotly.graph_objects as go
import plotly.express as px
from itertools import cycle

def getFigForModelAndDataset(
    embeddings,
    labels,
    info,
    size, 
    reducer, 
    dsc='', 
    label_dict=None,
    FORCE_USE_WO_FIT=False):
        
    palette = cycle(
        [*px.colors.qualitative.Dark24, 
         *px.colors.qualitative.Alphabet, 
         *px.colors.qualitative.Light24]
    )
    
    dim3 = True if reducer.n_components==3 else False
   
    if FORCE_USE_WO_FIT:
        X_embedded = reducer.transform(embeddings)
    else:
        X_embedded = reducer.fit_transform(embeddings)

    if label_dict:
        try:
            group = [label_dict[int(x)] for x in labels][:, None]
        except Exception as e:
            print('label dict broken', e)
            group = labels
    else:
        group = labels
        
    group = np.array(group)
    # print(len(info))
    hover_data = np.array([x[1] + ':' + x[0] for x in info])[:, None]
    
    # now embeedings, labels, info, size are concatenated. Let's build dataframe from it
    # print(X_embedded.shape)
    # print(group.shape)
    # print(size.shape)
    # print(hover_data.shape)
    
    plot_df_data = np.concatenate([X_embedded, group, size, hover_data], axis=1)
    if dim3:
        columns = columns=['x', 'y', 'z', 'group', 'size', 'hover_data']
    else:
        columns = columns=['x', 'y', 'group', 'size', 'hover_data']
        
    plot_df = pd.DataFrame(plot_df_data, columns=columns)
    plot_df['size'] = plot_df['size'].apply(pd.to_numeric)
    
    fig = go.Figure()
    
    groups = plot_df['group'].unique()
    groups.sort()
    main_dataset_marker_size = min(plot_df['size'])
    
    if dim3:
        plot_df[['x', 'y', 'z']] = plot_df[['x', 'y', 'z']].apply(pd.to_numeric)
        for group in groups: 
            df = plot_df.loc[plot_df['group'] == group]
            group_size = df['size'].iloc[0]
            symbol = 'circle' if group_size == main_dataset_marker_size else 'diamond'
            line_width = 1 if group_size == main_dataset_marker_size else 2
            marker_color=next(palette)
            
            fig.add_trace(go.Scatter3d(
                x=df['x'],
                y=df['y'],
                z=df['z'],
                mode='markers',
                marker=dict(
                    size=df['size'],
                    opacity=1,
                    symbol=symbol,
                    line=dict(
                        color='black',
                        width=line_width,
                    ),
                ),
                opacity=1,
                text=df['hover_data'],
                name=group,
                marker_color=marker_color
            ))
    else:
        plot_df[['x', 'y']] = plot_df[['x', 'y']].apply(pd.to_numeric)
        for group in groups: 
            df = plot_df.loc[plot_df['group'] == group]
            group_size = df['size'].iloc[0]
            symbol = 'circle' if group_size == main_dataset_marker_size else 'diamond'
            line_width = 1 if group_size == main_dataset_marker_size else 2
            marker_color=next(palette)
            
            fig.add_trace(go.Scatter(
                x=df['x'],
                y=df['y'],
                mode='markers',
                marker=dict(
                    size=df['size'],
                    opacity=1,
                    symbol=symbol,
                    line=dict(
                        color='black',
                        width=line_width
                    ),
                    
                ),
                text=df['hover_data'],
                name=group,
                marker_color=marker_color
            ))
            
    return fig, reducer, plot_df

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from umap import UMAP

dim3=True

reducer = PCA(n_components=3 if dim3 else 2, 
               # init='random', 
               random_state=RANDOM_STATE)

fig, reducer, plot_df = getFigForModelAndDataset(
    embeddings, 
    labels, 
    info, 
    size,
    reducer=reducer,
    label_dict=label_dict
)

fig.update_traces(
    hoverinfo="none", 
    hovertemplate=None,
)
fig.update_layout(
    width=900,
    height=800,
)

from jupyter_dash import JupyterDash
from dash import dcc, html, Input, Output, no_update
import base64

def b64_image(image_filename):
    with open(image_filename, 'rb') as f:
        image = f.read()
    return 'data:image/png;base64,' + base64.b64encode(image).decode('utf-8')

app = JupyterDash(__name__, assets_folder='sdd')
@app.callback(
    Output("graph-tooltip-5", "show"),
    Output("graph-tooltip-5", "bbox"),
    Output("graph-tooltip-5", "children"),
    Input("graph-5", "hoverData"),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update

    hover_data = hoverData["points"][0]
    bbox = hover_data["bbox"]
    num = hover_data["pointNumber"]
    sign = hover_data['text'].split(':')[0]
    rel_img_path = hover_data['text'].split(':')[1]
    b64sed_image = b64_image(rel_img_path)

    children = [
        html.Div([
            html.Img(
                src=b64sed_image,
                style={"width": "70px", 'display': 'block', 'margin': '0 auto'},
            ),
            html.P(sign, style={"fontSize": 14, 'text-align':'center'}),
            html.P(rel_img_path, style={"fontSize": 10}),
        ])
    ]
    return True, bbox, children

app.layout = html.Div(
        className="container",
        children=[
            dcc.Graph(id="graph-5", figure=fig, clear_on_unhover=True),
            dcc.Tooltip(id="graph-tooltip-5", direction='bottom'),
        ],
    )

app.run_server(mode='inline', debug=True, port=2002)

In [None]:
embeddings, labels, info = data[1:]
reducer = UMAP(n_components=n_components, init='random', metric="cosine", random_state=RANDOM_STATE)
X_embedded = reducer.fit_transform(embeddings)

In [None]:
len(info)

In [None]:
t = ['2.1', '5.5', '3.24.40', '2.5', '3.20', '4.1.1', '5.19.1', '3.24.30', '1.33',
 '3.24.50', '5.16', '3.24.20', '1.22', '3.18']

In [None]:
t.sort()
t

In [None]:
additional_DF = pd.DataFrame(columns=RTDS_DF.columns)
# display(additional_DF)
encode_offset = max(set(RTDS_DF['ENCODED_LABEL'])) + 1

files = os.listdir(DATA_DIR / 'additional_sign')

for file in files:
    sign = file.split('_')[0]
    encoded_label = encode_offset
    # print(sign)
    row = {'filepath': str(DATA_DIR / 'additional_sign' / file), 'SIGN':sign, 'ENCODED_LABEL':encoded_label, 'SET':'valid'} 
    additional_DF = additional_DF.append(row, ignore_index=True)
display(additional_DF)

additional_dataset = SignDataset(
    additional_DF,
    transform=transform
)

In [None]:
nrows, ncols = 70, 6
fig = plt.figure(figsize = (16,200))

PLOT_SOFT_LIMIT = 80

TEMP_DS = additional_dataset
for idx, (img, encoded_label, info) in enumerate(TEMP_DS):
    
    img = torch.Tensor.permute(img, [1, 2, 0]).numpy() 
    ax = fig.add_subplot(nrows, ncols, idx+1)
        
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), aspect=1)
    
    title = str(info[1])
    
    ax.set_title(title, fontsize=15)
    
    if idx > PLOT_SOFT_LIMIT:
        print('[!] plot soft limit reached. Breaking.')
        break
        
plt.tight_layout()
plt.show()
