# датасет должен быть или скачен или сделан с помощью ноутбука RTSD-R_MERGED
Объединенный датасет доступен по [ссылке](https://drive.google.com/drive/folders/1jmxG2zfi-Fs3m2KrMGmjD347aYiT8YFM?usp=sharing).

Положить в папку data содержимое так, чтобы были следующие пути:  
* \$(ROOT_DIR)/data/merged-rtsd/...
* \$(ROOT_DIR)/data/gt.csv

> *gt_Set_NaN.csv - содержит тот же датасет, но значения колонки Set обнулено*

gt - датафрейм содержащий:  
* имена файлов - поле filename
* класс знака - поле sign_class
* флаг присутствия знака при работе с датасетом - IsPresent. Предполагается, что вместо удаления записи, будет устанавливатся этот флаг, включающий/не влючающий знак в выборку
* в какой набор включен знак - поле Set $\in$ $\{train, valid, test\}$

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
from torch import nn
import seaborn as sns
import pandas as pd
import os
import pathlib
import shutil
import cv2
import PIL
import cv2
from datetime import datetime

%cd adas_system/notebooks

IN_COLAB = False
USE_COLAB_GPU = False
try:
    import google.colab
    IN_COLAB = True
    USE_COLAB_GPU = True
    from google.colab import drive
    drive.mount('/content/drive')
    if not os.path.isfile('1_ClassifierResearch.ipynb'):
        !git clone --branch 9_SignDetector https://github.com/lsd-maddrive/adas_system.git

    !gdown --id 1-K3ee1NbMmx_0T5uwMesStmKnZO_6mWi
    %cd adas_system/notebooks
    !mkdir ../data/R_MERGED
    !unzip -q -o /content/R_MERGED.zip -d ./../data/

except:
    if IN_COLAB:
        print('[!]YOU ARE IN COLAB, BUT DIDNT MOUND A DRIVE. Model wont be synced[!]')

        if not os.path.isfile('1_ClassifierResearch.ipynb'):
            !git clone --branch 9_SignDetector https://github.com/lsd-maddrive/adas_system.git
            !gdown --id 1-K3ee1NbMmx_0T5uwMesStmKnZO_6mWi
            %cd adas_system/notebooks
            !mkdir ../data/R_MERGED
            !unzip -q -o /content/R_MERGED.zip -d ./../data/

        IN_COLAB = False

    else:
        pass

###
import nt_helper
from nt_helper.helper_utils import *
###

TEXT_COLOR = 'black'

# Зафиксируем состояние случайных чисел
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
%matplotlib inline
plt.rcParams["figure.figsize"] = (17,10)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Init dirs, init main vars

In [None]:
if not IN_COLAB:
    PROJECT_ROOT = pathlib.Path(os.path.join(os.curdir, os.pardir))
else:
    PROJECT_ROOT = pathlib.Path('..')
    
DATA_DIR = PROJECT_ROOT / 'data'
NOTEBOOKS_DIR = PROJECT_ROOT / 'notebooks'

gt = pd.read_csv(DATA_DIR / 'RTDS_DATASET.csv')

SIGN_TO_NUMBER = pd.read_csv(DATA_DIR / 'sign_to_number.csv', index_col=0).T.to_dict('records')[0]
NUMBER_TO_SIGN = pd.read_csv(DATA_DIR / 'number_to_sign.csv', index_col=0).T.to_dict('records')[0]

gt['filepath'] = gt['filepath'].apply(lambda x: DATA_DIR / x)
GT_SRC_LEN = len(gt.index)
display(gt)

_, ax = plt.subplots(nrows=3, ncols=1, figsize=(21, 8))
LABELS = ['train', 'valid', 'test']

for i in range(len(LABELS)):
    g = sns.countplot(x='SIGN', 
                      data=gt[gt['SET']==LABELS[i]],  
                      ax=ax[i], 
                      order=sorted(gt['SIGN'].value_counts().index.tolist())
                     )
    ax[i].tick_params(labelrotation=90)
    ax[i].set_title(LABELS[i])
    plt.tight_layout()

Тестим обучалку: возьмем из трейна по N представителей каждого класса

In [None]:
N = 3

gt = gt[gt["SET"]=='train']
SIGN_SET = set(gt['SIGN'])

from sklearn import preprocessing

LE_LOCATION = DATA_DIR / 'le.npy'
le = preprocessing.LabelEncoder()

if os.path.isfile(LE_LOCATION):
    le.classes_ = np.load(LE_LOCATION)
else:
    le.fit_transform(gt['SIGN'])
    np.save(LE_LOCATION, le.classes_)
    
gt['ENCODED_LABELS'] = le.transform(gt['SIGN'])    

nrows, ncols = 7, 6
fig = plt.figure(figsize = (16,16))

new_mini_df = pd.DataFrame(columns=gt.columns)

for idx, sign_class in enumerate(SIGN_SET):
    
    instances = gt[gt['SIGN'] == sign_class].sample(N)
    # print(instances)
    new_mini_df = new_mini_df.append(instances)
    # new_mini_df.loc[len(new_mini_df)] = instance.iloc[0]
    path = str(instances['filepath'].sample(1).values[0])
    # print(path)
    sign = instances['SIGN'].sample(1).values[0]
    img = cv2.imread(path)
    ax = fig.add_subplot(nrows, ncols, idx+1)
    
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), aspect=1)
    ax.set_title('ENCODED: ' + str(le.transform([sign_class])[0]) + '\nDECODED: ' + str(sign_class) + '\nSIGN: ' + str(NUMBER_TO_SIGN[sign_class]))
    
plt.tight_layout()

In [None]:
new_mini_df

new_mini_df хранит только по единственному представителю знаков. Задача: обучить и провеить на этих данных

Создадим загрузчик

In [None]:
class SignDataset(torch.utils.data.Dataset):
    def __init__(self, df, set_label, img_size=64, transform=None, le=None):
        
        if isinstance(img_size, int):
            img_size = (img_size, img_size)
        
        self.img_size = img_size
        self.df = df[df['SET']==set_label]

    def __len__(self):
        return len(self.df.index)
    
    def __getitem__(self, index): 
        label = int(self.df.iloc[index]['ENCODED_LABELS'])
        path = str(self.df.iloc[index]['filepath'])

        img = cv2.imread(path)
        img = cv2.resize(img, self.img_size, interpolation=cv2.INTER_LANCZOS4)
        img_tnsr = torch.Tensor.permute(torch.Tensor(img), [2, 0, 1]).div(255)
        return img_tnsr, label 

In [None]:
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

def train_epoch(model, loader, loss_op, optim, device, it_limit=99000):

    model.train()
    model.to(device)
    
    accur = []
    loss_val = []
    
    pbar = tqdm(enumerate(loader),
                total=len(loader), 
                position=0,
                leave=False)
    
    for idx, (data, target) in pbar:
        
        if it_limit and idx > it_limit:
            break
            
        data = data.to(device)
        target = target.to(device)
        
        optim.zero_grad()
        pred = model(data)
        
        local_acc = evaluate_batch_accuracy(pred.cpu(), target.cpu())
        accur.append(local_acc)
        
        # print(pred)
        loss = loss_op(pred, target)
        loss_val.append(loss.item())
        
        # Gradient descent
        loss.backward()
        optim.step()
        
    return np.mean(loss_val)

def evaluate_batch_accuracy(y_pred, y_true):
    y_pred = y_pred.detach().numpy()
    y_true = y_true.detach().numpy()
    accuracy = 0
    for i in range(len(y_true)):
        index_max = max(range(len(y_pred[i, :])), key=y_pred[i].__getitem__)

        if (index_max == y_true[i]):
            accuracy += 1
    accuracy /= len(y_pred)
    return accuracy

def valid_epoch(model, loader, device, it_limit=9999):
    
    torch.no_grad()
    model.eval()
    model.to(device)
    
    accur = []
    pbar = tqdm(enumerate(loader),
                total=len(loader),
                position=0,
                leave=False)
        
    for idx, (imgs_batch, labels_batch) in pbar:
        imgs_batch = imgs_batch.to(device)

        if it_limit and idx > it_limit:
            break
            
        labels_batch = labels_batch.to(device)
        # print(labels_batch)
        pred = model(imgs_batch)
        # print('-\n', pred)
        local_acc = evaluate_batch_accuracy(pred.cpu(), labels_batch.cpu())
        accur.append(local_acc)
        
    return np.mean(accur)

In [None]:
config = {
    'lr': 0.1,
    'epochs': 15,
    'it_limit': None
}

DEFAULT_MODEL_LOCATION = DATA_DIR / 'resnet18_classifier'

from torchvision import models
model = models.resnet18(pretrained=True)
MODEL_CLASSES = len(set(new_mini_df['SIGN']))
model.fc = nn.Linear(512, MODEL_CLASSES)

if os.path.isfile(DEFAULT_MODEL_LOCATION):
    print('[+] Model restored from', DEFAULT_MODEL_LOCATION)
    # model.load_state_dict(torch.load(DEFAULT_MODEL_LOCATION))

loss_op = nn.CrossEntropyLoss().cuda()
optim = torch.optim.Adadelta(model.parameters(), lr=config['lr'])

model.to(device)

SHOULD_I_TRAIN = True

img_size = 64
train_dataset = SignDataset(new_mini_df, 'train', img_size)

train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=80,
        pin_memory=True,
        shuffle=True)

if SHOULD_I_TRAIN:
    pbar = tqdm(range(config['epochs']),
                total=config['epochs'],
                position=0,
                leave=True,
                desc='per epoch valid accuracy 0.0')
    
    for epoch in pbar:

        train_res = train_epoch(model, train_loader, loss_op, optim, device, config['it_limit']) # 
        # print('t:', train_res)

        valid_res = valid_epoch(model, train_loader, device, config['it_limit'])
        # print('v:', valid_res)
        
        now = datetime.now()
        torch.save(model.state_dict(), DEFAULT_MODEL_LOCATION)
           
        pbar.set_description("per epoch valid accuracy %f" % valid_res)


In [None]:
gt_ = gt[gt["SET"]=='train']
SIGN_SET = set(gt_['SIGN'])

nrows, ncols = 70, 6
fig = plt.figure(figsize = (16,200))

model.to('cuda')

for idx, (img, encoded_label) in enumerate(train_dataset):
    
    pred = model(img[None, ...].to('cuda')).cpu()
    
    argmax = np.argmax(pred.detach().numpy())
    model_pred_decoded = le.inverse_transform([argmax])[0]
    model_pred_sign = NUMBER_TO_SIGN[model_pred_decoded]
    # make img from tensor
    img = torch.Tensor.permute(img, [1, 2, 0]).numpy()
    
    # get decoded_label
    decoded_label = le.inverse_transform([encoded_label])[0]
    
    # translate decoded to sign
    sign = NUMBER_TO_SIGN[decoded_label]
     
    ax = fig.add_subplot(nrows, ncols, idx+1)
    ax.patch.set_linewidth('15')
    MATCH = None
    if argmax == encoded_label:
        MATCH = '!+++++!\n'
        ax.patch.set_edgecolor('green')    
    else:
        print('mismatch for', [(idx+1) // ncols , (idx+1) % ncols])
        ax.patch.set_edgecolor('red')
        
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), aspect=1)
    ax.set_title((MATCH if MATCH else '') + 'FACT:' + str(sign) + '\nPRED:' + str(model_pred_sign))
    
plt.tight_layout()