## Цель ноутбука: изучение метода Few Shots Learning

**Датасет доступен по [ссылке](https://drive.google.com/file/d/1-3_g0ZvGMJ8VBiCyRsqGM81Vt2d40sUh/view?usp=sharing).**

Проблемы со знаками решены так:

| Знак | Описание | Источник |
| ------------- | ------------- | ---- |
| 1.6 | Пересечение равнозначных дорог | Надеемся на удачу |
| 1.31 | Туннель | Надеемся на удачу |
| 2.4 | Уступите дорогу | GTSRB Recognition |
| 3.21 | Конец запрещения обгона | GTSRB Recognition |
| 3.22 | Обгон грузовым автомобилям запрещен | GTSRB Recognition |
| 3.23 | Конец запрещения обгона грузовым автомобилям | GTSRB Recognition |
| 3.24-90 | Огр 90 | Объеденили |
| 3.24-100 | Огр 100 | Объеденили |
| 3.24-110 | Огр 110 | Объеденили |
| 3.24-120 | Огр 120 | Объеденили |
| 3.24-130 | Огр 130 | Объеденили |
| 3.25 | Конец огр. максимальной скорости | GTSRB Recognition |
| 3.31 | Конец всех ограничений | GTSRB Recognition |
| 6.3.2 | Зона для разворота | Надеемся на удачу |

Инициализация библиотек

In [None]:
# autoreload 
%load_ext autoreload
%autoreload 2

# core imports
import os
import sys
import random
from datetime import datetime
from pathlib import Path

import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import seaborn as sns
import pandas as pd
import cv2


# PROJECT_ROOT
PROJECT_ROOT = Path(os.readlink(f'/proc/{os.environ["JPY_PARENT_PID"]}/cwd'))
DATA_DIR = PROJECT_ROOT / 'SignDetectorAndClassifier' / 'data'

# Зафиксируем состояние случайных чисел
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
%matplotlib inline
plt.rcParams["figure.figsize"] = (17,10)

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device

Инициализация основных путей и папки src

In [None]:
DATASET_PREFIX = DATA_DIR / 'ENCODER_DATASET'
RTDS_DF = pd.read_csv(DATASET_PREFIX / 'WIDE_DATASET_4_ENCODER.csv')

RTDS_DF['filepath'] = RTDS_DF['filepath'].apply(lambda x: str(DATASET_PREFIX / x))

RTDS_DF.drop_duplicates(subset=['filepath'], inplace=True)
RTDS_DF.reset_index(inplace=True, drop=True)
# RTDS_DF = RTDS_DF.groupby(['sign', 'set']).apply(lambda x: x.sample(frac=0.5))

# убираем доп знаки 
# RTDS_DF = RTDS_DF[RTDS_DF['filepath'].str.contains('rtsd')]

TARGET_SIGNS = [
    '1.1', '1.6', '1.8', '1.22', '1.31', '1.33', 
    '2.1', '2.2', '2.3', '2.4', '2.5', 
    '3.1', '3.18', '3.20', '3.21', '3.22', '3.23', '3.24',
    '3.25', '3.27', '3.28', '3.31', 
    '4.1.1', '4.3', 
    '5.5', '5.6', '5.16', 
    '5.19.1', '5.20', 
    '6.3.2', '6.4', 
    '7.3', '7.4'
]

# RTDS_DF = RTDS_DF[RTDS_DF['sign'].isin(TARGET_SIGNS)]

RTDS_DF

In [None]:
from maddrive_adas.utils.models import get_model_and_img_size
from maddrive_adas.utils.transforms import get_minimal_and_augment_transforms
from maddrive_adas.utils.datasets import SignDataset

encoder, img_size = get_model_and_img_size(DATA_DIR / 'encoder_config.json')
encoder = encoder.to(device)

minimal_transform, augment_transform = get_minimal_and_augment_transforms(img_size)

train_dataset = SignDataset(
    RTDS_DF, 
    set_label='train', 
    transform=minimal_transform, 
    hyp=None,
    alpha_color=144
)

valid_dataset = SignDataset(
    RTDS_DF, 
    set_label='valid',  
    transform=minimal_transform, 
    hyp=None,
    alpha_color=144
)

In [None]:
def getNSamplesFromDataSet(ds, N):
    random_index = random.sample(range(0, len(ds)), N)
    ret = []
    for index in random_index:
        ret.append(ds[index])
    return ret

IMG_COUNT = 18
nrows, ncols = 70, 6
fig = plt.figure(figsize = (16,200))

PLOT_SOFT_LIMIT = 0 # skip

TEMP_DS = getNSamplesFromDataSet(train_dataset, PLOT_SOFT_LIMIT)
# TEMP_DS = train_dataset.sort_values(['SIGN'], axis=1)
# TEMP_DS = train_dataset
for idx, (img, encoded_label, info) in enumerate(TEMP_DS):
    
    img = torch.Tensor.permute(img, [1, 2, 0]).numpy() 
    ax = fig.add_subplot(nrows, ncols, idx+1)
    ax.imshow(img, aspect=1)
    title = str(info[1]) # + '\n' + (str(info[0]))
    ax.set_title(title, fontsize=15)
    
    if idx > PLOT_SOFT_LIMIT:
        print('[!] plot soft limit reached. Breaking.')
        break
        
plt.tight_layout()
plt.show()

In [None]:
from maddrive_adas.utils.datasets import get_dataloader_from_dataset


batch_size = 512 # HANDLE PARAM
num_workers = 16 # HANDLE PARAM

train_loader = get_dataloader_from_dataset(
    train_dataset,
    shuffle=True,
    drop_last=True,
    batch_size=batch_size,
    num_workers=num_workers
)

In [None]:
@torch.no_grad()
def simpleGetAllEmbeddings(model, dataset, batch_size, num_workers, dsc=''):
    dataloader = get_dataloader_from_dataset(
        dataset,
        shuffle=False,
        drop_last=False,
        batch_size=batch_size,
        num_workers=num_workers
    )
    
    s, e = 0, 0
    pbar = tqdm(
        dataloader, 
        total=len(dataloader),
        position=0,
        leave=False,
        desc='Getting all embeddings...' + dsc)
    info_arr = []

    allocate_once_flag: bool = True
    for (data, label, info) in pbar:
        data = data.to(device)
        q = model(data)
        
        if label.dim() == 1:
            label = label.unsqueeze(1)
        if allocate_once_flag:
            labels_ret = torch.zeros(
                len(dataloader.dataset),
                label.size(1),
                device=device,
                dtype=label.dtype,
            )
            all_q = torch.zeros(
                len(dataloader.dataset),
                q.size(1),
                device=device,
                dtype=q.dtype,
            )
            allocate_once_flag = False
        e = s + q.size(0)
        all_q[s:e] = q
        labels_ret[s:e] = label
        s = e  
    
    all_q = torch.nn.functional.normalize(all_q)
    return all_q, labels_ret.squeeze(1), info_arr
    
### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###    
@torch.no_grad()
def test(train_set, test_set, model, accuracy_calculator, batch_size, num_workers):
    model.eval()
    train_embeddings, train_labels, _ = simpleGetAllEmbeddings(
        model, train_set, batch_size, num_workers, ' for train')
    test_embeddings, test_labels, _ = simpleGetAllEmbeddings(
        model, test_set, batch_size, num_workers, ' for valid')
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, train_embeddings, test_labels, train_labels, False
    )
    print(accuracies)
    return accuracies["precision_at_1"]

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    loss_sum = 0
    
    pbar = tqdm(
        train_loader, 
        total=len(train_loader),
        position=0,
        leave=False,
        desc='WAITING...')
    
    batch_size = train_loader.batch_size
    for (data, labels, _) in pbar:
        
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        embeddings = model(data)
        
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)

        instant_loss = loss.item()
        loss_sum += instant_loss
        
        loss.backward()
        optimizer.step()
        
        rounded_loss = round(instant_loss / batch_size, 5)
        pbar.set_description(
            f'TRAIN: INSTANT MEAN LOSS \t{rounded_loss}, MINED TRIPLET: \t{mining_func.num_triplets} \t'
        )
    
    return loss_sum / len(train_loader.dataset)

In [None]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(str(DATA_DIR / 'runs/encoder4'))

In [None]:
from pytorch_metric_learning import distances, losses, miners, reducers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

from torch.optim import lr_scheduler
from tqdm.notebook import trange, tqdm

config = {
    'lr': 0.1,
    'epochs': 100,
    'momentum':  0.937,
    'margin': 0.05
}

optimizer = torch.optim.SGD(
    encoder.parameters(), 
    lr=config['lr'], 
    momentum=config['momentum'], 
    nesterov=True
)

scheduler = lr_scheduler.CyclicLR(
    optimizer, 
    base_lr=0.001, 
    max_lr=0.2,
    step_size_up=3,
    step_size_down=4,
    mode="exp_range",
    gamma=0.95,
    cycle_momentum=True
)

from maddrive_adas.utils.checkpoint import save_checkpoint, load_checkpoint
try:
    # raise Exception
    encoder, optimizer, scheduler, started_epoch = load_checkpoint(
        encoder, scheduler, optimizer, str(DATA_DIR / 'last_encoder'))
    print('[+] Checkpoint loaded')

except Exception as exc_obj:
    started_epoch = 0
    print(f'[!] cannot load checkpoint: {exc_obj}')

distance = distances.LpDistance()
reducer = reducers.AvgNonZeroReducer()
loss_func = losses.TripletMarginLoss(
    margin=config['margin'], 
    distance=distance, 
    reducer=reducer
)
mining_func = miners.TripletMarginMiner(
    margin=config['margin'], 
    distance=distance, 
    type_of_triplets="hard"
)

accuracy_calculator = AccuracyCalculator(
    device=torch.device('cpu'),
    include = ("precision_at_1",), k = 1
)

pbar = trange(
        started_epoch, 
        config['epochs'], 
        initial=started_epoch, 
        total=config['epochs'],
        leave=True,
        desc='WAITING FOR FIRST EPOCH END...')

MODEL_PREFIX = 'ALL_AVAILABLE_SIGNS'
for epoch in pbar:
    train_loss = train(
        encoder, 
        loss_func, 
        mining_func, 
        device, 
        train_loader, 
        optimizer, 
        epoch
    )
    mean_acc = test(
        train_dataset, 
        valid_dataset, 
        encoder, 
        accuracy_calculator, 
        batch_size, 
        num_workers
    )
    
    iter_checkpoint_filename = str(DATA_DIR / str(MODEL_PREFIX + 'encoder_loss_' \
        + str(round(train_loss, 5)) \
        + '_acc_' + str(round(mean_acc, 5)) \
        + 'epoch_' + str(epoch) + '.encoder'))
    
    save_checkpoint(
        encoder, 
        scheduler, 
        optimizer, 
        epoch,
        iter_checkpoint_filename)
    
    save_checkpoint(
        encoder, 
        scheduler, 
        optimizer, 
        epoch,
        str(DATA_DIR / 'last_encoder'))

    lr_val = scheduler.get_last_lr()[0]
    scheduler.step()
    
    writer.add_scalar('mean valid accuracy', mean_acc, epoch)
    writer.add_scalar('traineng loss', train_loss, epoch)
    writer.add_scalar('learning rate', lr_val, epoch)
    
    pbar.set_description(
        "PER %d EPOCH: TRAIN LOSS: %.4f; VALID ACCUR: %.4f, LR %.5f" % (
            epoch,
            train_loss, 
            mean_acc,
            lr_val
            )
    )

In [None]:
assert False, 'END'