In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
import json
from collections import Counter, defaultdict
from pathlib import Path
import time

import cv2
import numpy as np
import pandas as pd
import torch
import tqdm
from Levenshtein import distance
from matplotlib import pyplot as plt
from scipy.special import softmax
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.nn.functional as fnn

from detection_utils import PlateImageAdjuster, PlateImageExtractor, build_mask, get_rectangular_box
from recognition import CRNN, RecognitionDataset, LanguageModel, beam_search
from recognition_utils import collate_fn_recognition, decode, normalize_text, Resize

%matplotlib inline

In [6]:
normalizer = PlateImageAdjuster()
extractor = PlateImageExtractor()

## Prepare OCR dataset

In [10]:
path_data = Path('data')
path_ocr_dataset = Path('ocr_data')
path_ocr_dataset.mkdir(parents=True, exist_ok=True)

plates_filename = path_data / 'train.json'
with open(plates_filename) as f:
    json_data = json.load(f)
    
for sample in tqdm.tqdm(json_data):
    if sample['file'] == 'train/25632.bmp':
        continue
    file_path = path_data / sample['file']
    image = cv2.imread(str(file_path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    for plate in sample['nums']:
        box = plate['box']
        text = plate['text']
        mask = build_mask(box, image)
        plate_img = extractor(image, mask, np.array(box))
        plate_img = normalizer(plate_img)
        text = normalize_text(text)
        file_path = path_ocr_dataset / ''.join([text, '.png'])
        cv2.imwrite(str(file_path), plate_img)
        
        # save also bboxes
        file_path = path_ocr_dataset / ''.join([text, '_bbox.png'])
        raw_box = get_rectangular_box(box)
        plate_bbox = image[raw_box[1]:raw_box[3], raw_box[0]:raw_box[2], :]
        plate_bbox = normalizer(plate_bbox)
        cv2.imwrite(str(file_path), plate_bbox)

100%|██████████| 25633/25633 [39:56<00:00, 10.70it/s]  


## Build the model

In [11]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
crnn = CRNN(rnn_bidirectional=True)

In [12]:
crnn.to(device)
num_epochs = 20
batch_size = 64
num_workers = 4
optimizer = torch.optim.Adam(crnn.parameters(), lr=3e-4, amsgrad=True, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/np.sqrt(10), patience=2,
                                                          verbose=True, threshold=1e-3)

In [45]:
transformations = transforms.Compose([
    Resize(),
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
                    ])

train_ocr_dataset = RecognitionDataset('ocr_data', transformations, crnn.alphabet, 'train', add_generated=True)
val_ocr_dataset = RecognitionDataset('ocr_data', transformations, crnn.alphabet, 'val')

In [46]:
train_dataloader = torch.utils.data.DataLoader(train_ocr_dataset, 
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers, pin_memory=True, 
                                               drop_last=False, collate_fn=collate_fn_recognition)
val_dataloader = torch.utils.data.DataLoader(val_ocr_dataset, 
                                             batch_size=1, shuffle=False,
                                             num_workers=num_workers, pin_memory=True, 
                                             drop_last=True, collate_fn=collate_fn_recognition)

In [47]:
experiment_name = 'Recognition_model_with_generated_test'
writer = SummaryWriter(log_dir=f'tb_logs/{experiment_name}')

In [48]:
best_loss = np.inf
prev_lr = optimizer.param_groups[0]['lr']

for i, epoch in enumerate(range(num_epochs)):
    epoch_losses = []
    levensteint_losses = []
    
    # Если поменялась lr - загружаем лучшую модель
    if optimizer.param_groups[0]['lr'] < prev_lr:
        prev_lr = optimizer.param_groups[0]['lr']
        with open(f'{experiment_name}.pth', 'rb') as fp:
            state_dict = torch.load(fp, map_location="cpu")
        crnn.load_state_dict(state_dict)
        crnn.to(device)
    
    crnn.train()
    for j, b in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):
        images = b["image"].to(device)
        seqs_gt = b["seq"]
        seq_lens_gt = b["seq_len"]

        seqs_pred = crnn(images).cpu()
        log_probs = fnn.log_softmax(seqs_pred, dim=2)
        seq_lens_pred = torch.Tensor([seqs_pred.size(0)] * seqs_pred.size(1)).int()
        
        texts_pred = decode(seqs_pred, crnn.alphabet)
        texts_gt = b["text"]
        levensteint_losses.extend([distance(pred, gt) for pred, gt in zip(texts_pred, texts_gt)])

        loss = fnn.ctc_loss(log_probs=log_probs,  # (T, N, C)
                            targets=seqs_gt,  # N, S or sum(target_lengths)
                            input_lengths=seq_lens_pred,  # N
                            target_lengths=seq_lens_gt)  # N

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.item())
    print(f'Train {i + 1}, {np.mean(epoch_losses)}')
    print(f'Train {i + 1} Levenstein, {np.mean(levensteint_losses)}')
    writer.add_scalar('Recognition/Train/loss', np.mean(epoch_losses), i)
    time.sleep(0.5)
    
    epoch_losses = []
    levensteint_losses = []
    crnn.eval()
    for j, b in enumerate(tqdm.tqdm(val_dataloader, total=len(val_dataloader))):
        images = b["image"].to(device)
        seqs_gt = b["seq"]
        seq_lens_gt = b["seq_len"]

        seqs_pred = crnn(images).cpu()
        log_probs = fnn.log_softmax(seqs_pred, dim=2)
        seq_lens_pred = torch.Tensor([seqs_pred.size(0)] * seqs_pred.size(1)).int()
        
        texts_pred = decode(seqs_pred, crnn.alphabet)
        texts_gt = b["text"]
        levensteint_losses.extend([distance(pred, gt) for pred, gt in zip(texts_pred, texts_gt)])

        loss = fnn.ctc_loss(log_probs=log_probs,  # (T, N, C)
                            targets=seqs_gt,  # N, S or sum(target_lengths)
                            input_lengths=seq_lens_pred,  # N
                            target_lengths=seq_lens_gt)  # N

        epoch_losses.append(loss.item())
        
        if best_loss > epoch_losses[-1]:
            best_loss = epoch_losses[-1]
            with open(f'{experiment_name}.pth', 'wb') as fp:
                torch.save(crnn.state_dict(), fp)
        
    lr_scheduler.step(np.mean(levensteint_losses))
    print(f'Valid {i + 1}, {np.mean(epoch_losses)}')
    print(f'Valid {i + 1} Levenstein, {np.mean(levensteint_losses)}')
    writer.add_scalar('Recognition/Valid/loss', np.mean(epoch_losses), i)
    time.sleep(0.5)

100%|██████████| 667/667 [07:45<00:00,  1.43it/s]


Train 1, 0.8760174544670548
Train 1 Levenstein, 2.211356392950881


100%|██████████| 2370/2370 [00:48<00:00, 48.70it/s]


Valid 1, 0.06269977078596248
Valid 1 Levenstein, 0.11181434599156118


100%|██████████| 667/667 [06:35<00:00,  1.69it/s]


Train 2, 0.05210205477577211
Train 2 Levenstein, 0.08593457067866517


100%|██████████| 2370/2370 [00:52<00:00, 44.86it/s]


Valid 2, 0.046893362505120745
Valid 2 Levenstein, 0.08860759493670886


100%|██████████| 667/667 [04:38<00:00,  2.40it/s]


Train 3, 0.03412311075630876
Train 3 Levenstein, 0.05865673040869891


100%|██████████| 2370/2370 [00:42<00:00, 56.21it/s] 


Valid 3, 0.047217273825221055
Valid 3 Levenstein, 0.09113924050632911


100%|██████████| 667/667 [04:29<00:00,  2.48it/s]


Train 4, 0.024278267700367537
Train 4 Levenstein, 0.042955568053993254


100%|██████████| 2370/2370 [00:25<00:00, 93.85it/s] 


Valid 4, 0.042848612364766656
Valid 4 Levenstein, 0.07679324894514768


100%|██████████| 667/667 [04:33<00:00,  2.44it/s]


Train 5, 0.016438511716002233
Train 5 Levenstein, 0.029902512185976754


100%|██████████| 2370/2370 [00:25<00:00, 93.81it/s] 


Valid 5, 0.03956691569156005
Valid 5 Levenstein, 0.0670886075949367


100%|██████████| 667/667 [04:32<00:00,  2.45it/s]


Train 6, 0.012232415945509329
Train 6 Levenstein, 0.02462973378327709


100%|██████████| 2370/2370 [00:27<00:00, 87.05it/s] 


Valid 6, 0.037110443833617605
Valid 6 Levenstein, 0.06497890295358649


100%|██████████| 667/667 [04:37<00:00,  2.40it/s]


Train 7, 0.010503782826158842
Train 7 Levenstein, 0.021606674165729284


100%|██████████| 2370/2370 [00:23<00:00, 100.59it/s]


Valid 7, 0.03732635374819537
Valid 7 Levenstein, 0.0679324894514768


100%|██████████| 667/667 [04:42<00:00,  2.36it/s]


Train 8, 0.007906998581594384
Train 8 Levenstein, 0.017529058867641546


100%|██████████| 2370/2370 [00:24<00:00, 98.36it/s] 


Valid 8, 0.039501472685619564
Valid 8 Levenstein, 0.06413502109704641


100%|██████████| 667/667 [04:43<00:00,  2.35it/s]


Train 9, 0.005343272933154192
Train 9 Levenstein, 0.011553243344581927


100%|██████████| 2370/2370 [00:22<00:00, 103.86it/s]


Valid 9, 0.03856268619063116
Valid 9 Levenstein, 0.06666666666666667


100%|██████████| 667/667 [04:41<00:00,  2.37it/s]


Train 10, 0.0036404026336088286
Train 10 Levenstein, 0.00775684289463817


100%|██████████| 2370/2370 [00:25<00:00, 92.47it/s] 


Valid 10, 0.0375503298618061
Valid 10 Levenstein, 0.05907172995780591


100%|██████████| 667/667 [04:32<00:00,  2.44it/s]


Train 11, 0.001943504961934974
Train 11 Levenstein, 0.0037495313085864268


100%|██████████| 2370/2370 [00:25<00:00, 91.25it/s] 


Valid 11, 0.035227503563875585
Valid 11 Levenstein, 0.05780590717299578


100%|██████████| 667/667 [04:33<00:00,  2.44it/s]


Train 12, 0.001803808126383394
Train 12 Levenstein, 0.0035620547431571056


100%|██████████| 2370/2370 [00:24<00:00, 97.34it/s] 


Valid 12, 0.036538671776844826
Valid 12 Levenstein, 0.056118143459915615


100%|██████████| 667/667 [04:33<00:00,  2.44it/s]


Train 13, 0.003193577966959819
Train 13 Levenstein, 0.007006936632920885


100%|██████████| 2370/2370 [00:22<00:00, 106.20it/s]


Valid 13, 0.03624013424636252
Valid 13 Levenstein, 0.05822784810126582


100%|██████████| 667/667 [04:39<00:00,  2.39it/s]


Train 14, 0.005667311568893534
Train 14 Levenstein, 0.013498312710911136


100%|██████████| 2370/2370 [00:25<00:00, 92.32it/s] 


Valid 14, 0.04919872837146406
Valid 14 Levenstein, 0.08227848101265822


100%|██████████| 667/667 [04:33<00:00,  2.44it/s]


Train 15, 0.008320959635805673
Train 15 Levenstein, 0.018724221972253468


100%|██████████| 2370/2370 [00:23<00:00, 99.17it/s] 


Epoch    15: reducing learning rate of group 0 to 9.4868e-05.
Valid 15, 0.04338452480025447
Valid 15 Levenstein, 0.07383966244725738


100%|██████████| 667/667 [04:30<00:00,  2.46it/s]


Train 16, 0.0008811721582123814
Train 16 Levenstein, 0.0011482939632545932


100%|██████████| 2370/2370 [00:27<00:00, 87.18it/s] 


Valid 16, 0.03433466004446962
Valid 16 Levenstein, 0.05063291139240506


100%|██████████| 667/667 [04:38<00:00,  2.40it/s]


Train 17, 0.00042470726229053556
Train 17 Levenstein, 0.0002109111361079865


100%|██████████| 2370/2370 [00:23<00:00, 100.62it/s]


Valid 17, 0.03477054677993466
Valid 17 Levenstein, 0.05021097046413502


100%|██████████| 667/667 [04:42<00:00,  2.36it/s]


Train 18, 0.00035168123890096364
Train 18 Levenstein, 0.00011717285339332584


100%|██████████| 2370/2370 [00:24<00:00, 97.17it/s] 


Valid 18, 0.03533918200871396
Valid 18 Levenstein, 0.049367088607594936


100%|██████████| 667/667 [04:30<00:00,  2.46it/s]


Train 19, 0.0002930440269178478
Train 19 Levenstein, 2.3434570678665168e-05


100%|██████████| 2370/2370 [00:23<00:00, 101.73it/s]


Valid 19, 0.035319822273530255
Valid 19 Levenstein, 0.049367088607594936


100%|██████████| 667/667 [04:26<00:00,  2.50it/s]


Train 20, 0.0002680685745714151
Train 20 Levenstein, 4.6869141357330336e-05


100%|██████████| 2370/2370 [00:22<00:00, 105.06it/s]


Valid 20, 0.035387376012621326
Valid 20 Levenstein, 0.049367088607594936


In [49]:
with open(f'{experiment_name}.pth', 'rb') as fp:
    state_dict = torch.load(fp, map_location="cpu")
crnn = CRNN(rnn_bidirectional=True)
crnn.load_state_dict(state_dict)
crnn.to(device)
crnn.eval()
print('Model loaded!')

Model loaded!
