In [None]:
import numpy as np
import pandas as pd
import cv2
import os
import gc
import json
import glob
import re
import PIL
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm
from collections import Counter
from matplotlib.path import Path
from torchvision import models
from torchvision import transforms
from PIL import Image, ImageDraw
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchmetrics import CharErrorRate as CER
from IPython.display import Image as Img

In [None]:
def load_json(file):
    with open(file, 'r') as f:
        return json.load(f)

In [None]:
PATH = '../input/nomeroff-russian-license-plates/autoriaNumberplateOcrRu-2021-09-01'
OCR_MODEL_PATH = './models/model-7-0.9156.ckpt'
ALPHABET = '0123456789ABEKMHOPCTYX'
TRAIN_SIZE = 0.9
BATCH_SIZE_OCR = 16

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

# Загружаем метки к изображениям

In [None]:
def get_annot(file, dir):
    return load_json(PATH + f'/{dir}/ann/' + file[:-3] + 'json')['description']

In [None]:
train_labels = pd.DataFrame([[f, get_annot(f, 'train')] for f in os.listdir(PATH + '/train/img')], columns=['filename', 'label']).to_dict('records')
val_labels = pd.DataFrame([[f, get_annot(f, 'val')] for f in os.listdir(PATH + '/val/img')], columns=['filename', 'label']).to_dict('records')
train_labels[:10]

# Определяем класс датасета и трансформации изображений

In [None]:
class OCRDataset(Dataset):
    def __init__(self, marks, img_folder, tokenizer, transforms=None):
        self.img_paths = []
        self.texts = []
        for item in marks:
            self.img_paths.append(os.path.join(PATH + f'/{img_folder}/img/', item['filename']))
            self.texts.append(item['label'])
            
        self.enc_texts = tokenizer.encode(self.texts)
        self.img_folder = PATH + f'/{img_folder}/img/'
        self.transforms = transforms
        
        
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        text = self.texts[idx]
        enc_text = torch.LongTensor(self.enc_texts[idx])
        image = cv2.imread(img_path)
        if self.transforms is not None:
            image = self.transforms(image)
            
        return image, text, enc_text
    
    def __len__(self):
        return len(self.texts)
    
    
class Resize(object):
    def __init__(self, size=(250, 50)):
        self.size = size

    def __call__(self, img):
        w_from, h_from = img.shape[1], img.shape[0]
        w_to, h_to = self.size
        
        # Сделаем разную интерполяцию при увеличении и уменьшении
        # Если увеличиваем картинку, меняем интерполяцию
        interpolation = cv2.INTER_AREA
        if w_to > w_from:
            interpolation = cv2.INTER_CUBIC
        
        img = cv2.resize(img, dsize=self.size, interpolation=interpolation)
        return img


class Normalize:
    def __call__(self, img):
        img = img.astype(np.float32) / 255
        return img


def collate_fn(batch):
    images, texts, enc_texts = zip(*batch)
    images = torch.stack(images, 0)
    text_lens = torch.LongTensor([len(text) for text in texts])
    enc_pad_texts = pad_sequence(enc_texts, batch_first=True, padding_value=0)
    return images, texts, enc_pad_texts, text_lens

# Определяем токенайзер

In [None]:
OOV_TOKEN = '<OOV>'
CTC_BLANK = '<BLANK>'


def get_char_map(alphabet):
    char_map = {value: idx + 2 for (idx, value) in enumerate(alphabet)}
    char_map[CTC_BLANK] = 0
    char_map[OOV_TOKEN] = 1
    return char_map


class Tokenizer:
    def __init__(self, alphabet):
        self.char_map = get_char_map(alphabet)
        self.rev_char_map = {val: key for key, val in self.char_map.items()}

    def encode(self, word_list):
        enc_words = []
        for word in word_list:
            enc_words.append(
                [self.char_map[char] if char in self.char_map
                 else self.char_map[OOV_TOKEN]
                 for char in word]
            )
        return enc_words

    def get_num_chars(self):
        return len(self.char_map)

    def decode(self, enc_word_list):
        dec_words = []
        for word in enc_word_list:
            word_chars = ''
            for idx, char_enc in enumerate(word):
                # пропустить повторяющиеся/пустые символы
                if (
                    char_enc != self.char_map[OOV_TOKEN]
                    and char_enc != self.char_map[CTC_BLANK]
                    and not (idx > 0 and char_enc == word[idx - 1])
                ):
                    word_chars += self.rev_char_map[char_enc]
            dec_words.append(word_chars)
        return dec_words

# Инициализируем токенайзер, трансформации и датасет

In [None]:
ocr_transforms = transforms.Compose([
    Resize(size=(250, 50)),
    Normalize(),
    transforms.ToTensor()
])

tokenizer = Tokenizer(ALPHABET)

train_ocr_dataset = OCRDataset(
    marks=train_labels, 
    img_folder='train', 
    tokenizer=tokenizer,
    transforms=ocr_transforms
)
val_ocr_dataset = OCRDataset(
    marks=val_labels,
    img_folder='val', 
    tokenizer=tokenizer,
    transforms=ocr_transforms
)

train_loader = DataLoader(
    train_ocr_dataset, 
    batch_size=BATCH_SIZE_OCR, 
    drop_last=True,
    num_workers=2,
    collate_fn=collate_fn,
    timeout=0,
    shuffle=True 
)
val_loader = DataLoader(
    val_ocr_dataset, 
    batch_size=BATCH_SIZE_OCR, 
    drop_last=False,
    num_workers=2,
    collate_fn=collate_fn, 
    timeout=0,
)

gc.collect()

In [None]:
plt.imshow(train_ocr_dataset[0]['img'].permute(1, 2, 0))

# Определяем класс модели

In [None]:
def get_resnet34_backbone():
    m = models.resnet34(pretrained=True)
    input_conv = nn.Conv2d(3, 64, 7, 1, 3)
    blocks = [input_conv, m.bn1, m.relu,
              m.maxpool, m.layer1, m.layer2, m.layer3]
    return nn.Sequential(*blocks)


class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers,
            dropout=dropout, batch_first=True, bidirectional=True)

    def forward(self, x):
        out, _ = self.lstm(x)
        return out


class CRNN(nn.Module):
    def __init__(
        self, number_class_symbols, time_feature_count=256, lstm_hidden=256,
        lstm_len=2,
    ):
        super().__init__()
        self.feature_extractor = get_resnet34_backbone()
        self.avg_pool = nn.AdaptiveAvgPool2d(
            (time_feature_count, time_feature_count))
        self.bilstm = BiLSTM(time_feature_count, lstm_hidden, lstm_len)
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, time_feature_count),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(time_feature_count, number_class_symbols)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        b, c, h, w = x.size()
        x = x.view(b, c * h, w)
        x = self.avg_pool(x)
        x = x.transpose(1, 2)
        x = self.bilstm(x)
        x = self.classifier(x)
        x = nn.functional.log_softmax(x, dim=2).permute(1, 0, 2)
        return x

# Подготовка к обучению

In [None]:
model = CRNN(number_class_symbols=tokenizer.get_num_chars())
model.to(device)

criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001,
                              weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer, mode='max', factor=0.5, patience=15)

In [None]:
class AverageMeter:
#     Вычисляет и хранит среднее значение
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def get_accuracy(y_true, y_pred):
    scores = []
    for true, pred in zip(y_true, y_pred):
        scores.append(true == pred)
    avg_score = np.mean(scores)
    return avg_score


def predict(images, model, tokenizer, device):
    model.eval()
    images = images.to(device)
    with torch.no_grad():
        output = model(images)
    pred = torch.argmax(output.detach().cpu(), -1).permute(1, 0).numpy()
    text_preds = tokenizer.decode(pred)
    return text_preds


def val_loop(data_loader, model, tokenizer, device):
    acc_avg = AverageMeter()
    for images, texts, _, _ in tqdm(data_loader):
        batch_size = len(texts)
        text_preds = predict(images, model, tokenizer, device)
        acc_avg.update(get_accuracy(texts, text_preds), batch_size)
    print(f'Validation, acc: {acc_avg.avg:.4f}\n')
    return acc_avg.avg


def train_loop(data_loader, model, criterion, optimizer, epoch):
    loss_avg = AverageMeter()
    model.train()
    for images, texts, enc_pad_texts, text_lens in tqdm(data_loader):
        model.zero_grad()
        images = images.to(device)
        batch_size = len(texts)
        output = model(images)
        output_lenghts = torch.full(
            size=(output.size(1),),
            fill_value=output.size(0),
            dtype=torch.long
        )
        loss = criterion(output, enc_pad_texts, output_lenghts, text_lens)
        loss_avg.update(loss.item(), batch_size)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        optimizer.step()
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
    print(f'\nEpoch {epoch}, Loss: {loss_avg.avg:.5f}, LR: {lr:.7f}')
    return loss_avg.avg

In [None]:
def train(model, optimizer, scheduler, train_loader, val_loader, epochs=10):
    best_acc = -np.inf
    os.makedirs('models', exist_ok=True)
    acc_avg = val_loop(val_loader, model, tokenizer, device)
    for epoch in range(epochs):
        print(f'Epoch {epoch} started')
        loss_avg = train_loop(train_loader, model, criterion, optimizer, epoch)
        acc_avg = val_loop(val_loader, model, tokenizer, device)
        scheduler.step(acc_avg)
        if acc_avg > best_acc:
            best_acc = acc_avg
            model_save_path = os.path.join(
                'models', f'model-{epoch}-{acc_avg:.4f}.ckpt')
            torch.save(model.state_dict(), model_save_path)
            print('Model weights saved')

# Запуск обучения

In [None]:
train(model, optimizer, scheduler, train_loader, val_loader, epochs=10)

# Определяем класс для предсказания

In [None]:
class InferenceTransform:
    def __init__(self):
        self.transforms = ocr_transforms

    def __call__(self, images):
        transformed_images = []
        for image in images:
            image = self.transforms(image)
            transformed_images.append(image)
        transformed_tensor = torch.stack(transformed_images, 0)
        return transformed_tensor


class Predictor:
    def __init__(self, model_path, device='cuda'):
        self.tokenizer = Tokenizer(ALPHABET)
        self.device = torch.device(device)
        # load model
        self.model = CRNN(number_class_symbols=self.tokenizer.get_num_chars())
        self.model.load_state_dict(torch.load(model_path))
        self.model.to(self.device)

        self.transforms = InferenceTransform()

    def __call__(self, images):
        if isinstance(images, (list, tuple)):
            one_image = False
        elif isinstance(images, np.ndarray):
            images = [images]
            one_image = True
        else:
            raise Exception(f"Input must contain np.ndarray, "
                            f"tuple or list, found {type(images)}.")

        images = self.transforms(images)
        pred = predict(images, self.model, self.tokenizer, self.device)

        if one_image:
            return pred[0]
        else:
            return pred

# Подсчет метрики Char Error Rate 
### Результат: 0.01

In [None]:
predictor = Predictor(OCR_MODEL_PATH)
pred_json = {}
y_true, y_pred = [], []

for val_img in val_labels:
    img = cv2.imread(PATH + '/val/img/' + val_img['filename'])
    pred = predictor(img)
    pred_json[val_img['filename']] = pred
    y_true.append(val_img['label'])
    y_pred.append(pred)
#     Можно заскомментировать, чтобы не выводились изображения и предсказания
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.imshow(img)
    plt.show()
    print('Prediction: ', predictor(img))
    print('True: ', val_img['label'])
    print()
    
CER()(y_pred, y_true)