In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
from collections import Counter, defaultdict
from pathlib import Path
import time

import cv2
import numpy as np
import pandas as pd
import torch
import tqdm
from Levenshtein import distance
from matplotlib import pyplot as plt
from scipy.special import softmax
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.nn.functional as fnn

from detection_utils import PlateImageAdjuster, PlateImageExtractor, build_mask, get_rectangular_box
from recognition import CRNN, RecognitionDataset, LanguageModel, beam_search
from recognition_utils import collate_fn_recognition, decode, normalize_text, Resize

%matplotlib inline

In [3]:
normalizer = PlateImageAdjuster()
extractor = PlateImageExtractor()

## Prepare OCR dataset

In [4]:
path_data = Path('data')
path_ocr_dataset = Path('ocr_data')
path_ocr_dataset.mkdir(parents=True, exist_ok=True)

plates_filename = path_data / 'train.json'
with open(plates_filename) as f:
    json_data = json.load(f)
    
for sample in tqdm.tqdm(json_data):
    if sample['file'] == 'train/25632.bmp':
        continue
    file_path = path_data / sample['file']
    image = cv2.imread(str(file_path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    for plate in sample['nums']:
        box = plate['box']
        text = plate['text']
        mask = build_mask(box, image)
        plate_img = extractor(image, mask, np.array(box))
        plate_img = normalizer(plate_img)
        text = normalize_text(text)
        file_path = path_ocr_dataset / ''.join([text, '.png'])
        cv2.imwrite(str(file_path), plate_img)
        
        # save also bboxes
        file_path = path_ocr_dataset / ''.join([text, '_bbox.png'])
        raw_box = get_rectangular_box(box)
        plate_bbox = image[raw_box[1]:raw_box[3], raw_box[0]:raw_box[2], :]
        plate_bbox = normalizer(plate_bbox)
        cv2.imwrite(str(file_path), plate_bbox)

100%|██████████| 25633/25633 [14:53<00:00, 28.69it/s]


## Build the model

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
crnn = CRNN(rnn_bidirectional=True)

In [5]:
crnn.to(device)
num_epochs = 20
batch_size = 64
num_workers = 4
optimizer = torch.optim.Adam(crnn.parameters(), lr=3e-4, amsgrad=True, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/np.sqrt(10), patience=2,
                                                          verbose=True, threshold=1e-3)

In [6]:
transformations = transforms.Compose([
    Resize(),
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
                    ])

train_ocr_dataset = RecognitionDataset('ocr_data', transformations, crnn.alphabet, 'train', add_generated=True)
val_ocr_dataset = RecognitionDataset('ocr_data', transformations, crnn.alphabet, 'val')

In [8]:
train_dataloader = torch.utils.data.DataLoader(train_ocr_dataset, 
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers, pin_memory=True, 
                                               drop_last=False, collate_fn=collate_fn_recognition)
val_dataloader = torch.utils.data.DataLoader(val_ocr_dataset, 
                                             batch_size=1, shuffle=False,
                                             num_workers=num_workers, pin_memory=True, 
                                             drop_last=True, collate_fn=collate_fn_recognition)

In [9]:
experiment_name = 'Recognition_model_with_generated_test'
writer = SummaryWriter(log_dir=f'tb_logs/{experiment_name}')

In [10]:
best_loss = np.inf
prev_lr = optimizer.param_groups[0]['lr']

for i, epoch in enumerate(range(num_epochs)):
    epoch_losses = []
    levensteint_losses = []
    
    # Если поменялась lr - загружаем лучшую модель
    if optimizer.param_groups[0]['lr'] < prev_lr:
        prev_lr = optimizer.param_groups[0]['lr']
        with open(f'{experiment_name}.pth', 'rb') as fp:
            state_dict = torch.load(fp, map_location="cpu")
        crnn.load_state_dict(state_dict)
        crnn.to(device)
    
    crnn.train()
    for j, b in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):
        images = b["image"].to(device)
        seqs_gt = b["seq"]
        seq_lens_gt = b["seq_len"]

        seqs_pred = crnn(images).cpu()
        log_probs = fnn.log_softmax(seqs_pred, dim=2)
        seq_lens_pred = torch.Tensor([seqs_pred.size(0)] * seqs_pred.size(1)).int()
        
        texts_pred = decode(seqs_pred, crnn.alphabet)
        texts_gt = b["text"]
        levensteint_losses.extend([distance(pred, gt) for pred, gt in zip(texts_pred, texts_gt)])

        loss = fnn.ctc_loss(log_probs=log_probs,  # (T, N, C)
                            targets=seqs_gt,  # N, S or sum(target_lengths)
                            input_lengths=seq_lens_pred,  # N
                            target_lengths=seq_lens_gt)  # N

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.item())
    print(f'Train {i + 1}, {np.mean(epoch_losses)}')
    print(f'Train {i + 1} Levenstein, {np.mean(levensteint_losses)}')
    writer.add_scalar('Recognition/Train/loss', np.mean(epoch_losses), i)
    time.sleep(0.5)
    
    epoch_losses = []
    levensteint_losses = []
    crnn.eval()
    for j, b in enumerate(tqdm.tqdm(val_dataloader, total=len(val_dataloader))):
        images = b["image"].to(device)
        seqs_gt = b["seq"]
        seq_lens_gt = b["seq_len"]

        seqs_pred = crnn(images).cpu()
        log_probs = fnn.log_softmax(seqs_pred, dim=2)
        seq_lens_pred = torch.Tensor([seqs_pred.size(0)] * seqs_pred.size(1)).int()
        
        texts_pred = decode(seqs_pred, crnn.alphabet)
        texts_gt = b["text"]
        levensteint_losses.extend([distance(pred, gt) for pred, gt in zip(texts_pred, texts_gt)])

        loss = fnn.ctc_loss(log_probs=log_probs,  # (T, N, C)
                            targets=seqs_gt,  # N, S or sum(target_lengths)
                            input_lengths=seq_lens_pred,  # N
                            target_lengths=seq_lens_gt)  # N

        epoch_losses.append(loss.item())
        
        if best_loss > epoch_losses[-1]:
            best_loss = epoch_losses[-1]
            with open(f'{experiment_name}.pth', 'wb') as fp:
                torch.save(crnn.state_dict(), fp)
        
    lr_scheduler.step(np.mean(levensteint_losses))
    print(f'Valid {i + 1}, {np.mean(epoch_losses)}')
    print(f'Valid {i + 1} Levenstein, {np.mean(levensteint_losses)}')
    writer.add_scalar('Recognition/Valid/loss', np.mean(epoch_losses), i)
    time.sleep(0.5)

100%|██████████| 1599/1599 [07:30<00:00,  3.55it/s]


Train 1, 0.3902352593793738
Train 1 Levenstein, 0.8286814809021498


100%|██████████| 2370/2370 [00:19<00:00, 121.09it/s]


Valid 1, 0.06148304197738804
Valid 1 Levenstein, 0.11012658227848102


100%|██████████| 1599/1599 [07:30<00:00,  3.55it/s]


Train 2, 0.1106603493974582
Train 2 Levenstein, 0.1400639364936601


100%|██████████| 2370/2370 [00:16<00:00, 145.43it/s]


Valid 2, 0.048997855991496454
Valid 2 Levenstein, 0.08270042194092828


100%|██████████| 1599/1599 [07:30<00:00,  3.55it/s]


Train 3, 0.10215797124947744
Train 3 Levenstein, 0.12761880554116278


100%|██████████| 2370/2370 [00:15<00:00, 150.44it/s]


Valid 3, 0.04481369339382543
Valid 3 Levenstein, 0.07383966244725738


100%|██████████| 1599/1599 [07:30<00:00,  3.55it/s]


Train 4, 0.09646202805536502
Train 4 Levenstein, 0.12163575751058275


100%|██████████| 2370/2370 [00:15<00:00, 148.26it/s]


Valid 4, 0.04216035803525025
Valid 4 Levenstein, 0.07130801687763713


100%|██████████| 1599/1599 [07:31<00:00,  3.55it/s]


Train 5, 0.09259370516483591
Train 5 Levenstein, 0.11800877904760043


100%|██████████| 2370/2370 [00:15<00:00, 154.13it/s]


Valid 5, 0.03749326401540444
Valid 5 Levenstein, 0.06118143459915612


100%|██████████| 1599/1599 [07:31<00:00,  3.54it/s]


Train 6, 0.08972874121959458
Train 6 Levenstein, 0.11593621992589624


100%|██████████| 2370/2370 [00:15<00:00, 153.57it/s]


Valid 6, 0.042026878098067176
Valid 6 Levenstein, 0.0751054852320675


100%|██████████| 1599/1599 [07:31<00:00,  3.55it/s]


Train 7, 0.0875242827693146
Train 7 Levenstein, 0.11485105925368319


100%|██████████| 2370/2370 [00:15<00:00, 152.76it/s]


Valid 7, 0.04159143080556409
Valid 7 Levenstein, 0.07215189873417721


100%|██████████| 1599/1599 [07:31<00:00,  3.54it/s]


Train 8, 0.08511661478709091
Train 8 Levenstein, 0.11354104546920979


100%|██████████| 2370/2370 [00:14<00:00, 160.34it/s]


Epoch     8: reducing learning rate of group 0 to 9.4868e-05.
Valid 8, 0.03880901620359816
Valid 8 Levenstein, 0.06413502109704641


100%|██████████| 1599/1599 [07:31<00:00,  3.54it/s]


Train 9, 0.07566881292366223
Train 9 Levenstein, 0.10700075276911496


100%|██████████| 2370/2370 [00:14<00:00, 160.50it/s]


Valid 9, 0.03134773735425834
Valid 9 Levenstein, 0.0510548523206751


100%|██████████| 1599/1599 [07:31<00:00,  3.54it/s]


Train 10, 0.07210314114944423
Train 10 Levenstein, 0.1057493963182747


100%|██████████| 2370/2370 [00:16<00:00, 147.43it/s]


Valid 10, 0.031436513337845286
Valid 10 Levenstein, 0.05063291139240506


100%|██████████| 1599/1599 [07:31<00:00,  3.54it/s]


Train 11, 0.0696104252833251
Train 11 Levenstein, 0.10551476698374214


100%|██████████| 2370/2370 [00:16<00:00, 146.71it/s]


Valid 11, 0.03275709099780753
Valid 11 Levenstein, 0.05063291139240506


100%|██████████| 1599/1599 [07:31<00:00,  3.54it/s]


Train 12, 0.06751998940516517
Train 12 Levenstein, 0.10612089276461789


100%|██████████| 2370/2370 [00:15<00:00, 150.52it/s]


Valid 12, 0.032377518696509316
Valid 12 Levenstein, 0.05147679324894515


100%|██████████| 1599/1599 [07:31<00:00,  3.54it/s]


Train 13, 0.06518693025217503
Train 13 Levenstein, 0.10663903254504395


100%|██████████| 2370/2370 [00:15<00:00, 150.84it/s]


Epoch    13: reducing learning rate of group 0 to 3.0000e-05.
Valid 13, 0.033466638661017206
Valid 13 Levenstein, 0.0540084388185654


100%|██████████| 1599/1599 [07:31<00:00,  3.54it/s]


Train 14, 0.056936792326884915
Train 14 Levenstein, 0.10500640342558828


100%|██████████| 2370/2370 [00:16<00:00, 146.51it/s]


Valid 14, 0.03274462418699488
Valid 14 Levenstein, 0.049789029535864976


100%|██████████| 1599/1599 [07:31<00:00,  3.54it/s]


Train 15, 0.05406321141533344
Train 15 Levenstein, 0.10554409565055871


100%|██████████| 2370/2370 [00:15<00:00, 153.66it/s]


Valid 15, 0.033340666222697116
Valid 15 Levenstein, 0.05021097046413502


100%|██████████| 1599/1599 [07:30<00:00,  3.55it/s]


Train 16, 0.052186401377142574
Train 16 Levenstein, 0.10587648720781316


100%|██████████| 2370/2370 [00:14<00:00, 158.54it/s]


Valid 16, 0.0338942266156572
Valid 16 Levenstein, 0.049789029535864976


100%|██████████| 1599/1599 [07:31<00:00,  3.54it/s]


Train 17, 0.05027586026082259
Train 17 Levenstein, 0.1069616478800262


100%|██████████| 2370/2370 [00:15<00:00, 157.27it/s]


Epoch    17: reducing learning rate of group 0 to 9.4868e-06.
Valid 17, 0.03399190202613328
Valid 17 Levenstein, 0.05063291139240506


100%|██████████| 1599/1599 [07:30<00:00,  3.55it/s]


Train 18, 0.045147094381227625
Train 18 Levenstein, 0.1036866134188427


100%|██████████| 2370/2370 [00:15<00:00, 148.50it/s]


Valid 18, 0.03448477631350004
Valid 18 Levenstein, 0.05147679324894515


100%|██████████| 1599/1599 [07:30<00:00,  3.55it/s]


Train 19, 0.04387383985046584
Train 19 Levenstein, 0.10398967630928056


100%|██████████| 2370/2370 [00:15<00:00, 154.06it/s]


Valid 19, 0.034593770058170714
Valid 19 Levenstein, 0.052742616033755275


100%|██████████| 1599/1599 [07:31<00:00,  3.54it/s]


Train 20, 0.042898616077305
Train 20 Levenstein, 0.10439050142244034


100%|██████████| 2370/2370 [00:17<00:00, 137.56it/s]


Epoch    20: reducing learning rate of group 0 to 3.0000e-06.
Valid 20, 0.034873423276246596
Valid 20 Levenstein, 0.053164556962025315


In [9]:
with open(f'{experiment_name}.pth', 'rb') as fp:
    state_dict = torch.load(fp, map_location="cpu")
crnn = CRNN(rnn_bidirectional=True)
crnn.load_state_dict(state_dict)
crnn.to(device)
crnn.eval()
print('Model loaded!')

Model loaded!
