# Домашняя работа 5: И снова распознавание текста

## Грачев Денис Вадимович

In [1]:
import matplotlib.pyplot as plt
import numpy as np

from tqdm.notebook import tqdm
import tqdm
import json
import cv2
import os

import torch
from torch.nn            import Module, Sequential, Conv2d, AvgPool2d, GRU, Linear, ReLU, Dropout, LayerNorm
from torch.utils.data    import Dataset, DataLoader
from torch.nn.functional import ctc_loss, softmax
from torchvision         import models
from torchvision import transforms as T
from torch.optim import lr_scheduler

from string import digits, ascii_uppercase

from glob import glob
import pandas as pd
import wandb
import gc



  warn(


# Reading the data

In [2]:
PATH_TO_DATA = "./seminar/data/seminar_crnn_data/"  # Change to your path with unzipped data
config_path = os.path.join(PATH_TO_DATA, "config.json")
images_path = os.path.join(PATH_TO_DATA, "images")

assert os.path.isfile(config_path)
assert os.path.isdir(images_path)

with open(config_path, "rt") as fp:
    config = json.load(fp)

config_full_paths = []
for item in config:
    config_full_paths.append({"file": os.path.join(images_path, item["file"]),
                              "text": item["text"]})
seminar_config = config_full_paths
abc = "0123456789ABEKMHOPCTYX"  # this is our alphabet for predictions.

In [3]:
def compute_mask(text):
    """Compute letter-digit mask of text, e.g. 'E506EC152' -> 'LDDDLLDDD'.
    
    Args:
        - text: String of text. 
        
    Returns:
        String of the same length but with every letter replaced by 'L' and every digit replaced by 'D' 
        or None if non-letter and non-digit character met in text.
    """
    mask = []
    
    # YOUR CODE HERE
    for char in text:
        if char in digits:
            mask.append("D")
        elif char in ascii_uppercase:
            mask.append("L")
        else:
            return None
    # END OF YOUR CODE
    
    return "".join(mask)

assert compute_mask("E506EC152") == "LDDDLLDDD"
assert compute_mask("E123KX99") == "LDDDLLDD"
assert compute_mask("P@@@KA@@") is None

def check_in_alphabet(text, alphabet=abc):
    """Check if all chars in text come from alphabet.
    
    Args:
        - text: String of text.
        - alphabet: String of alphabet.
        
    Returns:
        True if all chars in text are from alphabet and False otherwise.
    """
    
    # YOUR CODE HERE
    for char in text:
        if char not in alphabet:
            return False
    # END OF YOUR CODE
    
    return True

assert check_in_alphabet("E506EC152") is True
assert check_in_alphabet("A123GG999") is False

def filter_data(config):
    """Filter config items keeping only ones with correct text.
    
    Args:
        - config: List of dicts, each dict having keys "file" and "text".
        
    Returns:
        Filtered list (config subset).
    """
    config_filtered = []
    for item in tqdm.tqdm(config):
        text = item["text"]
        mask = compute_mask(text)
        if check_in_alphabet(text) and (mask == "LDDDLLDD" or mask == "LDDDLLDDD"):
            config_filtered.append({"file": item["file"],
                                    "text": item["text"]})
    return config_filtered
seminar_config = filter_data(seminar_config)
print("Total items in data after filtering:", len(seminar_config))

100%|██████████| 41141/41141 [00:00<00:00, 384736.67it/s]

Total items in data after filtering: 31345





In [4]:
class RecognitionDataset(Dataset):
    """Class for training image-to-text mapping using CTC-Loss."""

    def __init__(self, config, alphabet=abc, transforms=None):
        """Constructor for class.
        
        Args:
            - config: List of items, each of which is a dict with keys "file" & "text".
            - alphabet: String of chars required for predicting.
            - transforms: Transformation for items, should accept and return dict with keys "image", "seq", "seq_len" & "text".
        """
        super(RecognitionDataset, self).__init__()
        self.config = config
        self.alphabet = alphabet
        self.image_names, self.texts = self._parse_root_()
        self.transforms = transforms

    def _parse_root_(self):
        image_names, texts = [], []
        for item in self.config:
            image_name = item["file"]
            text = item['text']
            texts.append(text)
            image_names.append(image_name)
        return image_names, texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        """Returns dict with keys "image", "seq", "seq_len" & "text".
        Image is a numpy array, float32, [0, 1].
        Seq is list of integers.
        Seq_len is an integer.
        Text is a string.
        """
        image = cv2.imread(self.image_names[item]).astype(np.float32) / 255.
        text = self.texts[item]
        seq = self.text_to_seq(text)
        seq_len = len(seq)
        output = dict(image=image, seq=seq, seq_len=seq_len, text=text)
        if self.transforms is not None:
            output = self.transforms(output)
        return output

    def text_to_seq(self, text):
        """Encode text to sequence of integers.
        
        Args:
            - String of text.
            
        Returns:
            List of integers where each number is index of corresponding characted in alphabet + 1.
        """
        
        # YOUR CODE HERE
        seq = [self.alphabet.find(c) + 1 for c in text]
        # END OF YOUR CODE
        
        return seq


In [5]:
class Resize(object):

    def __init__(self, size=(320, 64)):
        self.size = size

    def __call__(self, item):
        """Apply resizing.
        
        Args: 
            - item: Dict with keys "image", "seq", "seq_len", "text".
        
        Returns: 
            Dict with image resized to self.size.
        """
        
        # YOUR CODE HERE
        interpolation = cv2.INTER_AREA if self.size[0] < item["image"].shape[1] else cv2.INTER_LINEAR
        item["image"] = cv2.resize(item["image"], self.size, interpolation=interpolation)
        # END OF YOUR CODE
        
        return item

In [6]:
class RandomRotation:
    def __init__(self, max_angle, prob):
        self.max_angle = max_angle
        self.prob = prob

    def __call__(self, item):
        if np.random.random() < self.prob:
            angle = (np.random.random() * 2 - 1) * self.max_angle

            (h, w) = item['image'].shape[:2]
            (cX, cY) = (w // 2, h // 2)
            # rotate our image by 45 degrees around the center of the image
            M = cv2.getRotationMatrix2D((cX, cY), 45, 1.0)
            item['image'] = cv2.warpAffine(item['image'], M, (w, h))

        return item


In [7]:
def collate_fn(batch):
    """Function for torch.utils.data.Dataloader for batch collecting.
    
    Args:
        - batch: List of dataset __getitem__ return values (dicts).
        
    Returns:
        Dict with same keys but values are either torch.Tensors of batched images or sequences or so.
    """
    images, seqs, seq_lens, texts = [], [], [], []
    for item in batch:
        images.append(torch.from_numpy(item["image"]).permute(2, 0, 1).float())
        seqs.extend(item["seq"])
        seq_lens.append(item["seq_len"])
        texts.append(item["text"])
    images = torch.stack(images)
    seqs = torch.Tensor(seqs).int()
    seq_lens = torch.Tensor(seq_lens).int()
    batch = {"image": images, "seq": seqs, "seq_len": seq_lens, "text": texts}
    return batch

In [8]:
def pred_to_string(pred, abc):
    seq = []
    for i in range(len(pred)):
        label = np.argmax(pred[i])
        seq.append(label - 1)
    out = []
    for i in range(len(seq)):
        if len(out) == 0:
            if seq[i] != -1:
                out.append(seq[i])
        else:
            if seq[i] != -1 and seq[i] != seq[i - 1]:
                out.append(seq[i])
    out = ''.join([abc[c] for c in out])
    return out

def decode(pred, abc):
    pred = pred.permute(1, 0, 2).cpu().data.numpy()
    outputs = []
    for i in range(len(pred)):
        outputs.append(pred_to_string(pred[i], abc))
    return outputs

In [9]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [10]:
def read_config(path_folder, name_as_text=False):
    config = []
    assert os.path.isdir(path_folder)

    for path in glob(os.path.join(path_folder, '*.png')):
        if not name_as_text:
            text = path.split('_')[-1][:-4]
        else:
            text = path.split('/')[-1].split('.')[0]

        config.append({'file': path, 'text': text})

    return config

In [11]:
simple_config = read_config('kaggle/train/train/simple/')
complex_config = read_config('kaggle/train/train/complex/')
test_config = read_config('kaggle/test/result/', name_as_text=True)

In [12]:
def save_checkpoint(model, filename):

    with open(filename, "wb") as fp:
        torch.save(model.state_dict(), fp)


def load_checkpoint(model, filename):

    with open(filename, "rb") as fp:
        state_dict = torch.load(fp, map_location="cpu")
    model.load_state_dict(state_dict)

## Описание задачи

Разработать каждую из компонент и объединить их в модель `Transformer`. 



## Модель трансформер

### Механизм внимания на несколько голов (Multi-Head Attention)
<img src="figures/multi-head-attention.jpg" width="500">

In [13]:
class MultiHeadAttention(Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        # YOUR CODE HERE
        self.d_model   = d_model
        self.num_heads = num_heads
        self.d_k       = d_model // num_heads
        
        self.W_q = Linear(in_features=d_model, out_features=self.d_model, bias=False)
        self.W_k = Linear(in_features=d_model, out_features=self.d_model, bias=False)
        self.W_v = Linear(in_features=d_model, out_features=self.d_model, bias=False)
        self.W_o = Linear(in_features=d_model, out_features=self.d_model, bias=False)

        self.scale = self.d_k ** -0.5
                
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # YOUR CODE HERE
        attn_scores = (Q.transpose(2, 3) @ K) * self.scale
        if mask is not None:
           attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = softmax(attn_scores, dim=-1)
        output     = V @ attn_probs
        
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        # YOUR CODE HERE
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output      = self.W_o(self.combine_heads(attn_output))
        return output

In [14]:
batch_size = 3
seq_length = 11
d_model = 10
num_heads = 2

mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
x = torch.rand(batch_size, seq_length, d_model)

mha(x, x, x).shape == x.shape

True

### Position-wise Feed-Forward Networks
Состоит из двух линеных слоев, которые применяется к последнему измерению, то есть для каждой позиции в последовательности используются одни и те же линейные слои, так называемые `position-wise`.

In [15]:
class PositionWiseFeedForward(Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        # YOUR CODE HERE
        self.fc1  = Linear(d_model, d_ff)
        self.fc2  = Linear(d_ff, d_model)
        self.relu = ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)

        return x

### Позиционное кодирование (Positional Encoding)

Так как в архитектуре трансформер обработка последовательности заменяется на обработку множества мы теряем информацию о порядке элементов последовательности. Чтобы отобразить информацию о позиции элемента в исходной последовательности мы используем позиционное кодирование.

$$
p(i,s)=
\begin{cases}
\sin \Big(i*10000 \dfrac{-2k}{d_{model}}\Big), s = 2k + 0\\
\cos \Big(i*10000 \dfrac{-2k}{d_{model}}\Big), s = 2k + 1
\end{cases}
$$

In [16]:
batch_size = 3
max_len = 11
d_model = 5
seq_length = 7

In [17]:
arange = torch.arange(max_len * d_model).view(d_model, max_len)

i = arange % max_len
k = arange // max_len

pe = torch.zeros(max_len, d_model)
pe += (torch.sin(i * 10_000 * (-2) * k / d_model) * (k % 2)).T
pe += (torch.cos(i * 10_000 * (-2) * k / d_model) * ((k + 1) % 2)).T

In [18]:
x = torch.rand(batch_size, seq_length, d_model)

In [53]:
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        # YOUR CODE HERE
        self.dropout = Dropout(p=dropout)
        self.scale   = d_model
    
        arange = torch.arange(max_len * d_model).view(d_model, max_len)

        i = arange % max_len
        k = arange // max_len
        
        pe = torch.zeros(max_len, d_model)
        pe += (torch.sin(i * 10_000 * (-2) * k / d_model) * (k % 2)).T
        pe += (torch.cos(i * 10_000 * (-2) * k / d_model) * ((k + 1) % 2)).T
        
        self.register_buffer('pe', pe)
        self.pe = pe

    def forward(self, x):
        return self.dropout(x + self.pe[: x.shape[1]]) 

### Слой кодировщика (Encoder Layer)

<img src="figures/encoder-layer.jpg" height="200">

In [20]:
class EncoderLayer(Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        # YOUR CODE HERE
        self.self_attn    = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model=d_model, d_ff=d_ff)
        self.norm1 = LayerNorm()
        self.norm2 = LayerNorm()
        self.dropout = Dropout(dropout)
                
    def forward(self, x, mask):
        # YOUR CODE HERE

        attn = self.self_attn(x, x, x, mask)
        x = self.norm1(attn + x)

        ff = self.feed_forward(x)
        x = self.norm2(ff + x)

        x = self.dropout(x)

        return x

### Слой декодировщика (Decoder Layer)

<img src="figures/decoder-layer.jpg" height="600">

In [21]:
class DecoderLayer(Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        # YOUR CODE HERE
        self.self_attn    = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.cross_attn   = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model=d_model, d_ff=d_ff)
        self.norm1 = LayerNorm()
        self.norm2 = LayerNorm()
        self.norm3 = LayerNorm()
        self.dropout = Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        
        attn = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(attn + x)

        attn = self.cross_attn(enc_output, enc_output, x, src_mask)

        ff = self.feed_forward(x)
        x = self.norm3(ff + x)

        x = self.dropout(x)

        return x

### Модель трансформер

<img src="figures/transformer-model.jpg" height="400">

In [22]:
class Transformer(Module):
    def __init__(self, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.positional_encoding = PositionalEncoding(d_model=d_model, dropout=dropout, max_len=max_seq_length)
        self.encoder_layers = [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in num_layers]
        self.decoder_layers = [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in num_layers]
        self.fc = Linear(d_model, d_model)
        self.dropout = Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src_embeddings, tgt_embeddings):
        src_mask, tgt_mask = self.generate_mask(src_embeddings, tgt_embeddings)
        encoder = self.positional_encoding(src_embeddings)
        decoder = self.positional_encoding(tgt_embeddings)

        for layer in self.encoder_layers:
            encoder = layer(encoder, src_mask)

        for layer in self.decoder_layers:
            decoder = layer(decoder, encoder, src_mask, tgt_mask)

        output = self.fc(decoder)
        output = self.dropout(output)

        return output

## Описание задачи

Мы решили задачу распознавания текста на изображениях, распознавания регистрационных знаков. Воспользовались достаточно сильной сверточной моделью для извлечения признаков `(ResNet18)` и рекуррентной моделью для обработки последовательности `(RNN,LSTM,GRU)`. Модель была достаточно простой.

Попробуем воспользоваться текущими представлениями о нейронных сетях и разработаем модель, которая будет состоять также из сверточной модели, однако заменим рекуррентную модель и будем использовать архитектуру типа `Transformer`. Мы воспользуемся некоторой вариацией современного подхода к распознаванию текста на изображениях [TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)

Отметим, что мы используем трансформер в режиме автогрегрессии, то с чем вы сталкивались на семинаре по рекуррентным нейронным сетям, когда разрабатывали модель, которая генерирует символы или слова. Поэтому результирующую строку мы будем `генерировать` итеративно. Как итог наша модель имеет следующий вид.

![picture](figures/ModelArchitecture.png) 

Задача ваша состоит в том чтобы натренировать модель с использованием `nn.Transformer` и собственной реализации, получить близкие результаты. Также требуется сравнить с предыдущим подходом, где использовались рекуррентные нейронные сети. За первую часть дается 10 баллов, за вторую оставшиесь 10 баллов. За красивое оформление и графики можно получить бонусные 5 баллов.

In [23]:
def read_config(path_folder, name_as_text=False):
    config = []
    assert os.path.isdir(path_folder)

    for path in glob(os.path.join(path_folder, '*.png')):
        if not name_as_text:
            text = path.split('_')[-1][:-4]
        else:
            text = path.split('/')[-1].split('.')[0]

        config.append({'file': path, 'text': text})

    return config

In [24]:
simple_config = read_config('kaggle/train/train/simple/')
complex_config = read_config('kaggle/train/train/complex/')
test_config = read_config('kaggle/test/result/', name_as_text=True)

In [25]:
from torch.nn    import Conv2d, MaxPool2d, BatchNorm2d, LeakyReLU
from torchvision import transforms
from collections import Counter
from time        import time
from torch       import nn
from PIL         import Image
from tqdm        import tqdm

from rapidfuzz.distance.Levenshtein import distance

import matplotlib.pyplot as plt
import numpy  as np
import pandas as pd

import random
import string
import torch
import math
import os

## Const

In [62]:
DIR               = './' # work directory
PATH_TRAIN_DIR    =  'kaggle/train'
PATH_TEST_DIR     =  'kaggle/test/result/'
PREDICT_PATH      =  'kaggle/test/result'
CHECKPOINT_PATH   = DIR
WEIGHTS_PATH      = ''
PATH_TEST_RESULTS = DIR+'/test_result.tsv'
TRAIN_LOG = DIR+'train_log.tsv'

## Config

In [63]:
### MODEL ### 
MODEL = 'model'
HIDDEN     = 512
ENC_LAYERS = 2
DEC_LAYERS = 2
N_HEADS    = 4
LENGTH     = 15

ALPHABET = ['PAD', 'SOS', '0','1','2','3','4','5','6','7','8','9','A','B','E','K','M','H','O','P','C','T','Y','X', 'EOS']

### TRAINING ###
BATCH_SIZE = 16
DROPOUT    = 0.2
N_EPOCHS   = 16
CHECKPOINT_FREQ = 10 # save checkpoint every 10 epochs
DEVICE = 'cuda'
SCHUDULER_ON = True # "ReduceLROnPlateau"
PATIENCE = 5 # for ReduceLROnPlateau
OPTIMIZER_NAME = 'Adam' # or "SGD"
#LR = 2e-6
LR = 2e-5

### INPUT IMAGE PARAMETERS ###
WIDTH, HEIGHT, CHANNELS = 320, 64, 3 

RANDOM_SEED = 42

random.seed           (RANDOM_SEED)
torch.manual_seed     (RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)

## Utils

In [64]:
  
def log_metrics(metrics):
    if metrics['epoch'] == 0:
        print('Epoch   Train_loss   Valid_loss   LEV  Time    LR')
        print('-----   -----------  ----------   ---  ----    ---')
    print('{:02d}    {:.3f}       {:.3f}       {:.3f}   {:.3f}   {:.7f}'.format(\
        metrics['epoch'], metrics['train_loss'], metrics['eval_loss'], metrics['levenstein'], metrics['time'], metrics['lr']))

## Загрузчик данных

In [65]:
def indicies_to_text(indexes, idx2char):
    text = "".join([idx2char[i] for i in indexes])
    text = text.replace('EOS', '').replace('PAD', '').replace('SOS', '')
    return text

char2idx = {char: idx for idx, char in enumerate(ALPHABET)}
idx2char = {idx: char for idx, char in enumerate(ALPHABET)}

class RecognitionDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform):
        self.root_dir   = root_dir
        self.transform  = transform
        self.total_imgs = os.listdir(root_dir)
    def __len__(self):
        return len(self.total_imgs)
    
    def __getitem__(self, idx):
        img_loc = os.path.join(self.root_dir, self.total_imgs[idx])
        f = os.path.splitext( os.path.basename( img_loc ) )[0]
        text  = ''.join( f.split('_')[ 2: ] )
         
        pil_image = Image.open(img_loc).convert("RGB")
        image_tensor = self.transform(pil_image)
        index_tensor = torch.LongTensor( [char2idx['SOS']] + [char2idx[i] for i in text if i in char2idx.keys()] + [char2idx['EOS']] )
        
        return (image_tensor, index_tensor)

class DataCollate():
    def __call__(self, batch):
        image_tensor_s, index_tensor_s = [], []
        
        for i in range(len(batch)):
            image_tensor = batch[i][0]
            index_tensor = batch[i][1]

            cur_len = index_tensor.shape[0]
            
            padding_tensor = torch.full((LENGTH,), fill_value=char2idx['PAD']).long()
            padding_tensor[:index_tensor.shape[0]] = index_tensor
            
            image_tensor_s.append(image_tensor  )
            index_tensor_s.append(padding_tensor)
            
        image_tensor_s = torch.stack(image_tensor_s,)
        index_tensor_s = torch.stack(index_tensor_s,)
               
        return image_tensor_s, index_tensor_s.T
                                                                                                  
                                                                                                  
train_trans = transforms.Compose( [
 transforms.Resize( (64, 320) ),
 transforms.ToTensor(),
 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
] )

test_trans = transforms.Compose( [
    transforms.Resize( (64, 320) ),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
] )

infer_trans = transforms.Compose( [
    transforms.Resize( (64, 320) ),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
] )

In [66]:
train_dataset = RecognitionDataset(PATH_TRAIN_DIR, train_trans)
test_dataset  = RecognitionDataset(PATH_TEST_DIR , test_trans)

train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True , batch_size=BATCH_SIZE, pin_memory=True, drop_last=True , collate_fn=DataCollate())
test_loader  = torch.utils.data.DataLoader(test_dataset , shuffle=False, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True, collate_fn=DataCollate())

In [67]:
len(train_dataset), len(test_dataset)

(24546, 10518)

## Наша модель

In [48]:
from torchvision.models import ResNet34_Weights
from torchvision        import models

class TransformerModel(nn.Module):
    def __init__(self, outtoken, hidden, enc_layers=1, dec_layers=1, nhead=1, dropout=0.1, pretrained=True):
        super(TransformerModel, self).__init__()

        self.enc_layers = enc_layers
        self.dec_layers = dec_layers
        
        self.backbone_name = 'resnet34'
        self.backbone = models.resnet34(weights=ResNet34_Weights.DEFAULT)
        self.backbone.fc = nn.Conv2d(512, int(hidden/2), 1)

        self.pos_encoder = PositionalEncoding(hidden, dropout)
        self.decoder     = nn.Embedding(outtoken, hidden)
        self.pos_decoder = PositionalEncoding(hidden, dropout)
        self.transformer = nn.Transformer(
         d_model=hidden, nhead=nhead, num_encoder_layers=enc_layers,
         num_decoder_layers=dec_layers, dim_feedforward=hidden * 4, dropout=dropout,
         activation='relu'
        )

        self.fc_out = nn.Linear(hidden, outtoken)
        self.src_mask = None
        self.trg_mask = None
        self.memory_mask = None
            
    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz, device=DEVICE), 1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def make_len_mask(self, inp):
        return (inp == 0).transpose(0, 1)
    
    # it vectorizes images
    def _get_features(self, src):
        '''
        params
        ---
        src : Tensor [64, 3, 64, 256] : [B,C,H,W]
            B - batch, C - channel, H - height, W - width

        returns
        ---
        x : Tensor : [W,B,CH]
        '''
        x = self.backbone.conv1(src)

        x = self.backbone.bn1    (x)
        x = self.backbone.relu   (x)
        x = self.backbone.maxpool(x)
        x = self.backbone.layer1 (x)
        x = self.backbone.layer2 (x)
        x = self.backbone.layer3 (x)
        x = self.backbone.layer4 (x) # [64, 2048, 2, 8] : [B,C,H,W]
            
        x = self.backbone.fc(x) # [64, 256, 2, 8] : [B,C,H,W]
        x = x.permute(0, 3, 1, 2) # [64, 8, 256, 2] : [B,W,C,H]
        x = x.flatten(2) # [64, 8, 512] : [B,W,CH]
        x = x.permute(1, 0, 2) # [8, 64, 512] : [W,B,CH]
        return x

    def predict(self, batch):
        '''
        params
        ---
        batch : Tensor [64, 3, 64, 256] : [B,C,H,W]
            B - batch, C - channel, H - height, W - width
        
        returns
        ---
        result : List [64, -1] : [B, -1]
            preticted sequences of tokens' indexes
        '''
        result = []
        for item in batch:
          x = self._get_features(item.unsqueeze(0))
          memory = self.transformer.encoder(self.pos_encoder(x))
          out_indexes = [ALPHABET.index('SOS'), ]
          for i in range(LENGTH):
              trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(DEVICE)
              output = self.fc_out(self.transformer.decoder(self.pos_decoder(self.decoder(trg_tensor)), memory))

              out_token = output.argmax(2)[-1].item()
              out_indexes.append(out_token)
              if out_token == ALPHABET.index('EOS'):
                  break
          result.append(out_indexes)
        return result

    def forward(self, src, trg):
        '''
        params
        ---
        src : Tensor : [B,C,H,W]
            B - batch, C - channel, H - height, W - width
        trg : Tensor : [L,B]
            L - max length of label, B - batch
        '''
        x = self._get_features(src)
        src_pad_mask = self.make_len_mask(x[:, :, 0])
        src = self.pos_encoder(x) # [8, 64, 512]

        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_subsequent_mask(len(trg)).to(trg.device)
        trg_pad_mask = self.make_len_mask(trg)
        trg = self.decoder(trg)
        trg = self.pos_decoder(trg)

        output = self.transformer(
         src, trg, src_mask=self.src_mask, tgt_mask=self.trg_mask,
         memory_mask=self.memory_mask,
         src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=trg_pad_mask,
         memory_key_padding_mask=src_pad_mask
        ) # [13, 64, 512] : [L,B,CH]
        
        logits = self.fc_out(output) # [13, 64, 92] : [L,B,H]
        return logits

## Train

In [68]:
from tqdm import tqdm

def train(model, optimizer, criterion, train_loader):
    """
    params
    ---
    model : nn.Module
    optimizer : nn.Object
    criterion : nn.Object
    train_loader : torch.utils.data.DataLoader
    returns
    ---
    epoch_loss / len(train_loader) : float
        overall loss
    """
    model.train()
    epoch_loss = 0
    for src, trg in train_loader:
        src, trg = src.to(DEVICE), trg.to(DEVICE)
        output = model(src, trg[:-1, :])

        loss   = criterion(output.view(-1, output.shape[-1]), torch.reshape(trg[1:, :], (-1,)))

        optimizer.zero_grad()
        loss     .backward()
        optimizer.step()
        
        epoch_loss += loss.item()

    return epoch_loss / len(train_loader)

def evaluate(model, criterion, loader, case=True, punct=True):
    result  = {'true': [], 'pred': []}

    epoch_loss = 0
    
    model.eval()
    with torch.no_grad():
        for (src, trg) in loader:
            src, trg = src.to(DEVICE), trg.to(DEVICE)
            logits = model(src, trg[:-1, :])
            loss   = criterion(logits.view(-1, logits.shape[-1]), torch.reshape(trg[1:, :], (-1,)))
            out_indexes = model.predict(src)

            true_phrases = [indicies_to_text(trg[1:,i]     , ALPHABET) for i in range(BATCH_SIZE)]
            pred_phrases = [indicies_to_text(out_indexes[i], ALPHABET) for i in range(BATCH_SIZE)]

            for i in range(len(true_phrases)):
                result['true'].append(true_phrases[i])
                result['pred'].append(pred_phrases[i])
            epoch_loss += loss.item()
            
    return epoch_loss / len(loader), result

def fit(model, optimizer, scheduler, criterion, train_loader, val_loader):
    metrics = []
    for epoch in tqdm(range(0, N_EPOCHS)):
        epoch_metrics = {'epoch': 0, 'lr': 0, 'levenstein': 0}
        
        start_time = time()
        train_loss   = train   (model, optimizer, criterion, train_loader)
        eval_loss, result_phrases = evaluate(model, criterion, val_loader)
        end_time   = time()

        levenstein_distance, counter = 0, 0
        for true_phrase, pred_phrase in zip(result_phrases['true'], result_phrases['pred']):
            levenstein_distance += distance(pred_phrase, true_phrase)
            counter += 1
        
        epoch_metrics['levenstein'] = levenstein_distance / counter                 
        epoch_metrics['train_loss'] = train_loss
        epoch_metrics['eval_loss' ] = eval_loss
        epoch_metrics['epoch'     ] = epoch
        epoch_metrics['lr'        ] = optimizer.param_groups[0]["lr"]
        epoch_metrics['time'      ] = end_time - start_time
        
        metrics.append(epoch_metrics)
        log_metrics(epoch_metrics)
        
        scheduler.step(eval_loss)
    return metrics
model = TransformerModel(len(ALPHABET), hidden=HIDDEN, enc_layers=ENC_LAYERS, dec_layers=DEC_LAYERS, nhead=N_HEADS, dropout=DROPOUT).to(DEVICE)

criterion = torch.nn.CrossEntropyLoss(ignore_index=char2idx['PAD'])
optimizer = torch.optim.__getattribute__(OPTIMIZER_NAME)(model.parameters(), lr=LR)

scheduler =torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=PATIENCE)

fit(model, optimizer, scheduler, criterion, train_loader, test_loader)


In [70]:
metrics, result = evaluate(model, criterion, test_loader)

In [81]:
prediction = []
for batch in test_loader:
    prediction.append(model.predict(batch[0].to(device)))

KeyboardInterrupt: 

# Выводы

К сожалению я запустил на testloader без проавильных ответов.  
Поэтому отследить валидационное качество не удалось.  
Но если сравнить loss на трейне с предыдущей домашкой на resnet18, то трансформеру удалось достичь более высокого качество сравнимого с рекурентной моделью, но с resnet34 в качестве бэкбона.  
Так я сделал late_submission в kaggle и качество оказалось выше чем моя лучшая рекурентная модель.  
Стоит отметить правда что это улучшение совершенно не бесплатное и время обучения увеличилось примерно в 4 раза.  
Это увеличение времени обучения довольно критическое и не позволяет тестировать разные параметры так быстро, как это получалось с рекурентными моделями.  
Судя по графику наверняка лосса, складывается ощущение что первые итерации были с неправильным lr, возможно если аккуратно подобрать lr и sheduler можно достичь гораздо более быстрого обучение (что как раз получилось сделать с рекурентной сетью в прошлой домашке).  
В целом трансформеры однозначно очень мощная архитектура и показала более высокие результаты чем рекурентная сеть, не смотря на более слабый бэкбон и практическо 0 экспрериментов с параметрами.  

UPD:
Я заметил что в трансформере тоже resnet34, так что бэкбон у него не более слабый