## Цель ноутбука: изучение метода Few Shots Learning

#### В RTSD не хватает 14 знаков:

| Знак | Описание | Источник |
| ------------- | ------------- | ---- |
| 1.6 | Пересечение равнозначных дорог | - |
| 1.31 | Туннель | - |
| 2.4 | Уступите дорогу | GTSRB Recognition |
| 3.21 | Конец запрещения обгона | GTSRB Recognition |
| 3.22 | Обгон грузовым автомобилям запрещен | GTSRB Recognition |
| 3.23 | Конец запрещения обгона грузовым автомобилям | GTSRB Recognition |
| 3.24-90 | Огр 90 | - |
| 3.24-100 | Огр 100 | GTSRB Recognition |
| 3.24-110 | Огр 110 | - |
| 3.24-120 | Огр 120 | GTSRB Recognition |
| 3.24-130 | Огр 130 | - |
| 3.25 | Конец огр. максимальной скорости | GTSRB Recognition |
| 3.31 | Конец всех ограничений | GTSRB Recognition |
| 6.3.2 | Зона для разворота | - |

Инициализация библиотек

In [None]:
import albumentations as A
if A.__version__ != '1.0.3':
    !pip install albumentations==1.0.3
    !pip install opencv-python-headless==4.5.2.52
    assert False, 'restart runtime pls'

import matplotlib.pyplot as plt
import numpy as np
import random
import torch
from torch import nn
import seaborn as sns
import pandas as pd
import os
import pathlib
import shutil
import cv2
import PIL
import cv2
import sys
from datetime import datetime

TEXT_COLOR = 'black'
# Зафиксируем состояние случайных чисел
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
%matplotlib inline
plt.rcParams["figure.figsize"] = (17,10)

IN_COLAB = False
USE_COLAB_GPU = False

try:
    import google.colab
    IN_COLAB = True
    USE_COLAB_GPU = True
    from google.colab import drive
except:
    if IN_COLAB:
        print('[!]YOU ARE IN COLAB, BUT DIDNT MOUND A DRIVE. Model wont be synced[!]')

        if not os.path.isfile(CURRENT_FILE_NAME):
            print("FIX ME")
        IN_COLAB = False

    else:
        pass

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Инициализация основных путей и папки src

In [None]:
if not IN_COLAB:
    PROJECT_ROOT = pathlib.Path(os.path.join(os.curdir, os.pardir))
else:
    PROJECT_ROOT = pathlib.Path('..')
    
DATA_DIR = PROJECT_ROOT / 'data'
NOTEBOOKS_DIR = PROJECT_ROOT / 'notebooks'
SRC_PATH = str(PROJECT_ROOT / 'src')

if SRC_PATH not in sys.path:
    sys.path.append(SRC_PATH)

In [None]:
from functools import partial
from torchvision.models import resnet

class SplitBatchNorm(torch.nn.BatchNorm2d):
    def __init__(self, num_features, num_splits, **kw):
        super().__init__(num_features, **kw)
        self.num_splits = num_splits

    def forward(self, input):
        N, C, H, W = input.shape
        if self.training or not self.track_running_stats:
            running_mean_split = self.running_mean.repeat(self.num_splits)
            running_var_split = self.running_var.repeat(self.num_splits)
            outcome = torch.nn.functional.batch_norm(
                input.view(-1, C * self.num_splits, H, W),
                running_mean_split,
                running_var_split,
                self.weight.repeat(self.num_splits),
                self.bias.repeat(self.num_splits),
                True,
                self.momentum,
                self.eps,
            ).view(N, C, H, W)
            self.running_mean.data.copy_(
                running_mean_split.view(self.num_splits, C).mean(dim=0)
            )
            self.running_var.data.copy_(
                running_var_split.view(self.num_splits, C).mean(dim=0)
            )
            return outcome
        else:
            return torch.nn.functional.batch_norm(
                input,
                self.running_mean,
                self.running_var,
                self.weight,
                self.bias,
                False,
                self.momentum,
                self.eps,
            )
        
class ModelBase(torch.nn.Module):
    """
    Common CIFAR ResNet recipe.
    Comparing with ImageNet ResNet recipe, it:
    (i) replaces conv1 with kernel=3, str=1
    (ii) removes pool1
    """

    def __init__(self, feature_dim=128, arch="resnet18", bn_splits=8):
        super(ModelBase, self).__init__()

        # use split batchnorm
        norm_layer = (
            partial(SplitBatchNorm, num_splits=bn_splits)
            if bn_splits > 1
            else torch.nn.BatchNorm2d
        )
        # print(norm_layer)
        resnet_arch = getattr(resnet, arch)
        # print(resnet_arch)
        net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer)

        self.net = []
        for name, module in net.named_children():
            print(name)
            if name == "conv1":
                module = torch.nn.Conv2d(
                    3, 64, kernel_size=3, stride=1, padding=1, bias=False
                )
            if isinstance(module, torch.nn.MaxPool2d):
                continue
            if isinstance(module, torch.nn.Linear):
                self.net.append(torch.nn.Flatten(1))
            self.net.append(module)

        self.net = torch.nn.Sequential(*self.net)

    def forward(self, x):
        x = self.net(x)
        # note: not normalized here
        return x

def create_encoder(emb_dim):
    model = ModelBase(emb_dim)
    # model = torch.nn.DataParallel(model)
    # model.to(device)
    return model

encoder = create_encoder(128)

In [None]:
# encoder

In [None]:
RTDS_DF = pd.read_csv(DATA_DIR / 'RTDS_DATASET.csv')
RTDS_DF['filepath'] = RTDS_DF['filepath'].apply(lambda x: str(DATA_DIR / x))

SIGN_TO_NUMBER = pd.read_csv(DATA_DIR / 'sign_to_number.csv', index_col=0).T.to_dict('records')[0]
NUMBER_TO_SIGN = pd.read_csv(DATA_DIR / 'number_to_sign.csv', index_col=0).T.to_dict('records')[0]

RTDS_DF['ENCODED_LABELS'] = RTDS_DF['SIGN']
RTDS_DF['SIGN'] = RTDS_DF['SIGN'].apply(lambda x: str(NUMBER_TO_SIGN[x]).replace('_', '.').replace('n', ''))

# UNFIX TRAIN
# SIMPLE_FIX = True
# JUST_FIX = False
RTDS_DF.drop_duplicates(subset=['filepath'], inplace=True)
# RTDS_DF.drop_duplicates(subset=['SET', 'SIGN'], inplace=True)

RTDS_DF

In [None]:
min(RTDS_DF.groupby(['SIGN', 'SET']).size())

In [None]:
RTDS_DF = RTDS_DF.groupby(['SIGN', 'SET']).sample(13, random_state=RANDOM_STATE).reset_index(drop=True)
RTDS_DF

In [None]:
set(RTDS_DF['SIGN'])

In [None]:
class SignDataset(torch.utils.data.Dataset):
    def __init__(self, df, set_label=None, hyp=None, transform=None, le=None):
                
        self.transform = transform
        
        if set_label == None:
            self.df = df
        else:
            self.df = df[df['SET']==set_label]
        
        self.hyp = hyp

    def __len__(self):
        return len(self.df.index)
    
    def __getitem__(self, index): 
        label = int(self.df.iloc[index]['ENCODED_LABELS'])
        path = str(self.df.iloc[index]['filepath'])
        
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        
        # check does it contains transparent channel 
        if img.shape[2] == 4:
        # randomize transparent
            trans_mask = img[:,:,3] == 0
            img[trans_mask] = [random.randrange(0, 256), 
                               random.randrange(0, 256), 
                               random.randrange(0, 256), 
                               255]

            img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
        # /randomize transparent
                
        # augment
        if self.hyp and self.transform:
            img, _ =  random_perspective(img, 
                                      (),
                                      degrees=self.hyp['degrees'],
                                      translate=self.hyp['translate'],
                                      scale=self.hyp['scale'],
                                      shear=self.hyp['shear'],
                                      perspective=self.hyp['perspective'],
                                      border=self.hyp['border'])   
        if self.transform:
            img = self.transform(image=img)['image']
        # /augment
        img = img / 255
        return img, label #, path

In [None]:
from albumentations.augmentations.geometric.transforms import Perspective, ShiftScaleRotate
from albumentations.core.transforms_interface import ImageOnlyTransform
from albumentations.pytorch.transforms import ToTensorV2
from albumentations.augmentations.transforms import PadIfNeeded
from albumentations.augmentations.geometric.resize import LongestMaxSize

# import torchvision.transforms.functional as F

img_size = 40

MINIMAL_TRANSFORM = True

if MINIMAL_TRANSFORM:
    transform = A.Compose(
        [
        LongestMaxSize(img_size),
        PadIfNeeded(
            img_size, 
            img_size, 
            border_mode=cv2.BORDER_CONSTANT, 
            value=0
        ),
        ToTensorV2(),
        ]
    )
else:
    transform = A.Compose(
        [
        A.Blur(blur_limit=2),
        A.CLAHE(p=1),
        A.Perspective(scale=(0.01, 0.1), p=1), 
        A.ShiftScaleRotate(shift_limit=0.05,
                           scale_limit=0.05,
                           interpolation=cv2.INTER_LANCZOS4, 
                           border_mode=cv2.BORDER_CONSTANT, 
                           value=(0,0,0),
                           rotate_limit=6, p=1),
        A.RandomGamma(
            gamma_limit=(50, 130), 
            p=1
        ),
        A.ImageCompression(quality_lower=80, p=1),
        A.RandomBrightnessContrast(brightness_limit=0.5, 
                                   contrast_limit=0.3, 
                                   brightness_by_max=False, 
                                   p=1),
        A.CoarseDropout(max_height=3, 
                        max_width=3, 
                        min_holes=1, 
                        max_holes=3, 
                        p=0.8),
        LongestMaxSize(img_size),
        PadIfNeeded(
            img_size, 
            img_size, 
            border_mode=cv2.BORDER_CONSTANT, 
            value=0
        ),
        ToTensorV2(),
        ]
    )

train_dataset = SignDataset(RTDS_DF, 
                            set_label='train',  
                            transform=transform, 
                            hyp=None)

valid_dataset = SignDataset(RTDS_DF, 
                            set_label='valid',  
                            transform=transform, 
                            hyp=None)
# valid_dataset = train_dataset

In [None]:
def getNSamplesFromDataSet(ds, N):
    random_index = random.sample(range(0, len(ds)), N)
    ret = []
    for index in random_index:
        ret.append(ds[index])
    return ret

IMG_COUNT = 18
nrows, ncols = 70, 6
fig = plt.figure(figsize = (16,200))

PLOT_SOFT_LIMIT = 20

# TEMP_DS = getNSamplesFromDataSet(train_dataset, len(train_dataset))
# TEMP_DS = train_dataset.sort_values(['SIGN'], axis=1)
TEMP_DS = train_dataset
for idx, (img, encoded_label) in enumerate(TEMP_DS):
    # print(img.shape)
    img = torch.Tensor.permute(img, [1, 2, 0]).numpy() 
    ax = fig.add_subplot(nrows, ncols, idx+1)
        
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), aspect=1)
    
    title = str(NUMBER_TO_SIGN[encoded_label]) + ':' + str(encoded_label)
    # if idx % 2 == 1:
    #     title += "\n" + _
    
    ax.set_title(title, fontsize=15)
    
    if idx > PLOT_SOFT_LIMIT:
        print('[!] plot soft limit reached. Breaking.')
        break
        
plt.tight_layout()
plt.show()

In [None]:
batch_size = 40
num_workers = 2 if IN_COLAB else 0

from torch.utils.data import DataLoader

def getDataLoaderFromDataset(dataset, shuffle=False, drop_last=True):
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=shuffle,
        drop_last=drop_last
    )
    
    return loader


train_loader = getDataLoaderFromDataset(
    train_dataset,
    shuffle=True
)

In [None]:
next(encoder.parameters()).is_cuda

In [None]:
sample_img, sample_label = next(iter(train_loader))

In [None]:
# print(sample_img)
with torch.no_grad():
    encoder.eval()
    out = encoder(sample_img).detach().cpu()
    
out

In [None]:
# torch.save(encoder.state_dict(), 'last_encoder')

In [None]:
def saveCheckpoint(model, scheduler, optimizer, epoch, filename):
    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }, filename)

def loadCheckpoint(model, scheduler, optimizer, filename):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    epoch = checkpoint['epoch']
    return model, optimizer, scheduler, epoch

In [None]:
# encoder, optimizer, scheduler, started_epoch = loadCheckpoint(encoder, scheduler, optimizer, 'sample')

In [None]:
# saveCheckpoint(encoder, scheduler, optimizer, 0, 'sample')

In [None]:
from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from tqdm.notebook import trange, tqdm

config = {
    'lr': 0.1,
    'epochs': 120,
    'momentum':  0.937,
    'margin': 0.0
}

optimizer = torch.optim.SGD(encoder.parameters(), lr=config['lr'], momentum=config['momentum'], nesterov=True)
# encoder.to('cpu')
# optimizer = torch.optim.Adam(encoder.parameters(), lr=0.1)
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 
#                                              base_lr=0.00001, 
#                                              max_lr=config['lr'],
#                                              step_size_up=50,
#                                              step_size_down=20,
#                                              mode="exp_range",
#                                              gamma=0.9,
#                                              cycle_momentum=False
#                                            )
distance = distances.CosineSimilarity()
reducer = reducers.AvgNonZeroReducer()
loss_func = losses.TripletMarginLoss(margin=config['margin'], distance=distance, reducer=reducer)
mining_func = miners.TripletMarginMiner(
    margin=config['margin'], distance=distance, type_of_triplets="all"
)

accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)

try:
    # encoder, optimizer, scheduler, started_epoch = loadCheckpoint(encoder, scheduler, optimizer, 'sample')
    started_epoch
    print('[+] check point loaded')
except:
    started_epoch = 0
    print('[!] check point doesnt exist')

encoder.to(device)    
# assert False, '+'
### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester(dataloader_num_workers=0)
    return tester.get_all_embeddings(dataset, model)

### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    # print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, train_embeddings, test_labels, train_labels, False
    )
    # print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))
    return accuracies["precision_at_1"]
    
    
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    loss_sum = 0
    
    pbar = tqdm(
        enumerate(train_loader), 
        total=len(train_loader),
        position=0,
        leave=False,
        desc='WAITING...')
    
    for batch_idx, (data, labels) in pbar:
        
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        # print('e', embeddings)
        # print('l', labels)
        indices_tuple = mining_func(embeddings, labels)
        # print('it', indices_tuple)
        loss = loss_func(embeddings, labels, indices_tuple)
        # print(loss.item())
        instant_loss = loss.item()
        loss_sum += instant_loss
        
        loss.backward()
        optimizer.step()
        
        pbar.set_description("TRAIN: INSTANT MEAN LOSS %f, MINED TRIPLET: %d" % 
                             (round(instant_loss / len(labels), 3),
                             mining_func.num_triplets)
                            )
        # if batch_idx >= 0:
        #     break
            
    return loss_sum / len(train_loader.dataset)

torch.cuda.empty_cache()

pbar = trange(
        started_epoch, 
        config['epochs'], 
        initial=started_epoch, 
        total=config['epochs'],
        leave=True,
        desc='WAITING FOR FIRST EPOCH END...')

for epoch in pbar:
    
    # plotSmth(encoder, train_dataset, device=device, dim3=False, fcn='umap')
    train_loss = train(encoder, loss_func, mining_func, device, train_loader, optimizer, epoch)
    mean_acc = test(train_dataset, valid_dataset, encoder, accuracy_calculator)
    
    # print(lr_val)
    # lr_val = scheduler.get_last_lr()[0]
    # saveCheckpoint(encoder, scheduler, optimizer, epoch, 'sample')
    # plotSmth(encoder, CONST_MINIMAL_DATASET, device=device, dim3=False, fcn='umap')
    # scheduler.step()
    
    mean_train_acc = mean_valid_acc = 0
    lr_val = 1
    pbar.set_description("PER EPOCH: TRAIN LOSS: %.4f; TRAIN ACCUR %.4f; VALID ACCUR: %.4f, LR %.2e" % (train_loss, 
                                                                                                           mean_acc,
                                                                                                           mean_valid_acc,
                                                                                                           lr_val)
                            )

In [None]:
from tqdm.notebook import trange, tqdm
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from umap import UMAP
import plotly.express as px

@torch.no_grad()
def plotSmth(model, data, device='cpu', dim3=False, fcn='PCA', reducer_arg=None):
    model.eval()
    print(type(data))
    if isinstance(data, DataLoader):
        print('data is DataLoader')
        loader = data
    else:
        print('data is not DataLoader, assume its DataSet')
        loader = DataLoader(data)
        
    plt.figure(figsize=(32, 16))    
    # clean the figure
    plt.clf()
    
    n_components = 3 if dim3 else 2
    
    fcn = fcn.upper()
    
    if reducer_arg == None:
        if fcn == 'PCA':
            reducer = PCA(n_components=n_components, random_state=RANDOM_STATE) 
        elif fcn == 'TSNE':
            reducer = TSNE(n_components=n_components, init='random', random_state=RANDOM_STATE)
        elif fcn == 'UMAP':
            reducer = UMAP(n_components=n_components, init='random', random_state=RANDOM_STATE)
        else:
            assert False, "wrong fcn arg"        
    else:
        reducer = reducer_arg
        
    pbar = tqdm(enumerate(loader),
                    total=len(loader), 
                    position=0,
                    leave=False)

    MODEL_ARRAY_OUT_SIZE = (0, 128)
    model_out_arr = np.empty(MODEL_ARRAY_OUT_SIZE, dtype=np.float32)
    target_arr = np.empty((0, 1), dtype=np.int32)

    for idx, (data, target) in pbar:

        data = data.to(device)
        out = model(data).detach().cpu().numpy()
        # print(out.shape)
        model_out_arr = np.append(model_out_arr, out, axis=0)
        target = target.detach().cpu().numpy()
        # print(target)
        target_arr = np.append(target_arr, target)

        # if idx > 80:
        #     break

    # print(len(target_arr)) 
    # print(target_arr)
    if reducer_arg == None:
        X_embedded = reducer.fit_transform(model_out_arr)
    else:
        X_embedded = reducer.transform(model_out_arr)
    # print(X_embedded)
    # target_arr = np.char.mod('%d', target_arr)
    target_arr_color = [NUMBER_TO_SIGN[x] for x in target_arr]
    # print(target_arr)
    
    if not dim3:
        fig = px.scatter(X_embedded, x=0, y=1, color=target_arr_color)
    else:
        fig = px.scatter_3d(X_embedded, x=0, y=1, z=2, color=target_arr_color)    
    fig.show()
    
    return reducer

In [None]:
r = plotSmth(encoder, valid_dataset, device=device, dim3=False, fcn='umap')
plotSmth(encoder, train_dataset, device=device, dim3=False, reducer_arg=r)

In [None]:


# 

# display(CONST_MINIMAL_DF)
# TEMP_DS = getNSamplesFromDataSet(CONST_MINIMAL_DF, 1)
# print(TEMP_DS)



In [None]:
CONST_MINIMAL_DF

In [None]:
mining_func

In [None]:
from torchvision.models import resnet18
model = resnet18()
model

In [None]:
from tqdm.notebook import trange, tqdm
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from umap import UMAP
import plotly.express as px

@torch.no_grad()
def plotSmth(model, loader, device='cpu', dim3=False, fcn='PCA'):
    model.eval()
    
    plt.figure(figsize=(32, 16))    
    # clean the figure
    plt.clf()
    
    n_components = 3 if dim3 else 2
    
    fcn = fcn.upper()
    
    if fcn == 'PCA':
        reducer = PCA(n_components=n_components, random_state=RANDOM_STATE) 
    elif fcn == 'TSNE':
        reducer = TSNE(n_components=n_components, init='random', random_state=RANDOM_STATE)
    elif fcn == 'UMAP':
        reducer = UMAP(n_components=n_components, init='random', random_state=RANDOM_STATE)
    else:
        assert False, "wrong fcn arg"
    
    pbar = tqdm(enumerate(loader),
                    total=len(loader), 
                    position=0,
                    leave=False)

    MODEL_ARRAY_OUT_SIZE = (0, 128)
    model_out_arr = np.empty(MODEL_ARRAY_OUT_SIZE, dtype=np.float32)
    target_arr = np.empty((0, 1), dtype=np.int32)

    for idx, (data, target) in pbar:

        data = data.to(device)
        out = model(data).detach().cpu().numpy()
        # print(out.shape)
        model_out_arr = np.append(model_out_arr, out, axis=0)
        target = target.detach().cpu().numpy()
        # print(target)
        target_arr = np.append(target_arr, target)

        # if idx > 1:
        #     break

    # print(len(target_arr)) 
    # print(target_arr)
    X_embedded = reducer.fit_transform(model_out_arr)
    target_arr = np.char.mod('%d', target_arr)
    # print(target_arr)
    
    if not dim3:
        fig = px.scatter(X_embedded, x=0, y=1, color=target_arr)
    else:
        fig = px.scatter_3d(X_embedded, x=0, y=1, z=2, color=target_arr)    
    fig.show()
    
    
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=256)
# model.apply(init_normal)
plotSmth(model, test_loader, device=device, dim3=True, fcn='pca')

In [None]:
plt.show()

In [None]:
enumerate(train_loader)[0]

In [None]:
model(dataset1[1][0][None, ...].to(device)).shape

In [None]:
res = get_all_embeddings(dataset1, model)

In [None]:
res[0].shape

In [None]:
from torchvision.transforms.functional import pad

In [None]:
pad(dataset1[1], 1)