In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import ConcatDataset, Subset

import numpy as np
from nltk.metrics.distance import edit_distance

import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.deterministic = False

In [2]:
from dataset import OCRDatasetModified, AlignCollate
from utils import CTCLabelConverter, Averager

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LABEL_MAX_LENGTH = 34

In [4]:
import sys
sys.path.insert(0, "..")

import easyocr
def get_training_convertor(ref_converter:easyocr.utils.CTCLabelConverter):
    if isinstance(ref_converter, CTCLabelConverter):
        return ref_converter
    character = ''.join(ref_converter.character[1:])
    converter = CTCLabelConverter(character)
    converter.separator_list = ref_converter.separator_list
    converter.ignore_idx = ref_converter.ignore_idx
    converter.dict_list = ref_converter.dict_list
    converter.dict = ref_converter.dict
    return converter

# setup model, converter
reader = easyocr.Reader(["ch_tra"])
model = reader.recognizer
ref_converter = reader.converter
character = ''.join(ref_converter.character[1:])
converter = get_training_convertor(ref_converter)
assert isinstance(converter, CTCLabelConverter)



In [5]:
freeze_FeatureFxtraction = True
freeze_SequenceModeling = False

if freeze_FeatureFxtraction:
    for param in model.module.FeatureExtraction.parameters():
        param.requires_grad = False
if freeze_SequenceModeling:
    for param in model.module.SequenceModeling.parameters():
        param.requires_grad = False

In [6]:
# define loss
criterion = torch.nn.CTCLoss(zero_infinity=True).to(DEVICE)
# loss_avg = Averager()

In [7]:
# define optimizer 
lr = 1.
rho = 0.95
eps = 1e-8
filtered_parameters = [p for p in filter(lambda p:p.requires_grad, model.parameters())]
optimizer = optim.Adadelta(filtered_parameters, lr=lr, rho=rho, eps=eps)

In [8]:
# setup dataset
character = ''.join(ref_converter.character[1:])
# print(character)

training_set_roots = ["./all_data/en_train"]
ocrs = [OCRDatasetModified(root=root, character=character, label_max_length=34) for root in training_set_roots]
ocr = ConcatDataset(ocrs)
aligncollate = AlignCollate(imgH=64, imgW=600, keep_ratio_with_pad=False, contrast_adjust=0)
train_loader = torch.utils.data.DataLoader(ocr, batch_size=32, collate_fn = aligncollate, shuffle=True)

# aligncollate1 = AlignCollate(imgH=64, imgW=600, keep_ratio_with_pad=False, contrast_adjust=0.5)
# train_loader1= torch.utils.data.DataLoader(ocr, batch_size=32, collate_fn = aligncollate1, shuffle=True)

validation_set_roots = ["./all_data/en_val"]
ocrs = [OCRDatasetModified(root=root, character=character, label_max_length=34) for root in validation_set_roots]
ocr = ConcatDataset(ocrs)
val_loader = torch.utils.data.DataLoader(ocr, batch_size=32, shuffle=True, num_workers=6, collate_fn = aligncollate, prefetch_factor=512)

Ignore data whose label is longer than 34: 
    filename                                words
64    44.jpg  (895261) Greenery {Wemyss-Islamist}
402  454.jpg  Tuktamysheva (resin) Technologies !
427  490.jpg  Fourteenth . Naiads injurious_Issue
498  571.jpg  Equalization LIGURIA carbohydrate [
781  833.jpg  Buys-Horwood misinterpreting Twitch
Ignore data whose label is longer than 34: 
    filename                                words
64    44.jpg  (895261) Greenery {Wemyss-Islamist}
402  454.jpg  Tuktamysheva (resin) Technologies !
427  490.jpg  Fourteenth . Naiads injurious_Issue
498  571.jpg  Equalization LIGURIA carbohydrate [
781  833.jpg  Buys-Horwood misinterpreting Twitch


In [9]:
def training_epoch(model:torch.nn.Module, criterion:torch.nn.CTCLoss, convertor:CTCLabelConverter, optimizer:torch.optim.Optimizer, training_set_loader:torch.utils.data.DataLoader):
    losses = []
    for image_tensors, labels in training_set_loader:
        image = image_tensors.to(DEVICE)
        text, length = convertor.encode(labels)
        batch_size = image.size(0)

        preds = model(image, text).log_softmax(2)
        preds_size = torch.IntTensor([[preds.size(1)]*batch_size])
        preds = preds.permute(1,0,2)

        torch.backends.cudnn.enabled = False
        cost = criterion(preds, text.to(DEVICE), preds_size.to(DEVICE), length.to(DEVICE))
        torch.backends.cudnn.enabled = True

        optimizer.zero_grad(set_to_none=True)
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
        optimizer.step()

        losses.append(cost.cpu().detach().numpy())
    
    return np.asarray(losses)


In [10]:
def validation(model:torch.nn.Module, 
               criterion:torch.nn.CTCLoss, 
               converter:CTCLabelConverter, 
               validation_set_loader:torch.utils.data.DataLoader,
               *,
               DEVICE= torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
    n_correct = 0
    length_of_data = 0
    losses = []
    norm_EDs = []
    norm_ED = 0
    confidence_score_list = []

    model.eval()
    with torch.no_grad():
        for image_tensors, labels in validation_set_loader:
            image = image_tensors.to(DEVICE)
            text, length = converter.encode(labels)
            batch_size = image.size(0)

            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)]*batch_size)

            # torch.backends.cudnn.enabled = False
            cost = criterion(preds.log_softmax(2).permute(1,0,2), text, preds_size, length)
            # torch.backends.cudnn.enabled = True

            # decoding phase
            _, preds_index = preds.max(2)
            preds_index = preds_index.view(-1)
            preds_index = preds_index.cpu()
            preds_size = preds_size.cpu()
            # print(f"{preds_index.data=}, {preds_size.data=}")
            # assert False
            preds_str = converter.decode_greedy(preds_index.data, preds_size.data)

            # compute accuracy & confidence score
            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)

            for gt,pred,pred_max_prob in zip(labels, preds_str, preds_max_prob):
                if pred == gt:
                    n_correct+=1
                
                if len(gt) == 0 or len(pred) ==0:
                    norm_ED += 0
                elif len(gt) > len(pred):
                    norm_ED += 1 - edit_distance(pred, gt) / len(gt)
                else:
                    norm_ED += 1 - edit_distance(pred, gt) / len(pred)

                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                confidence_score_list.append(confidence_score)

            length_of_data+=batch_size
            losses.append(cost.cpu().detach().numpy())
    
    model.train()
    accuracy = n_correct / float(length_of_data) *100
    norm_ED = norm_ED / float(length_of_data)

    return {"average CTCLoss":np.asarray(losses).mean(), 
            "acc":accuracy, 
            "norm_ED":norm_ED}

In [11]:
for epoch in range(10):
    result = training_epoch(model, criterion, converter, optimizer, training_set_loader=train_loader)
    print(epoch, result.mean(), result.std())
    val_result = validation(model, criterion, converter, val_loader)
    print(epoch, val_result)
    torch.save(model.state_dict(), f'./saved_models/OvO/iter_{epoch+1}.pth')



2.3729794 0.66403234
(1.4258851, 28.859060402684566, 0.7128050372282175)
1.47491 0.28819767
(1.1324071, 32.88590604026846, 0.7485286340649494)
1.1652516 0.29703745
(0.9320845, 43.064876957494405, 0.8029135862399721)
0.9866853 0.20075099
(0.7221694, 47.98657718120805, 0.8430779400924748)
0.8300511 0.20905429
(0.59703606, 55.70469798657718, 0.8723255568649328)
0.7169686 0.19908415
(0.5108433, 61.07382550335571, 0.8929523013505967)
0.63380766 0.18276386
(0.40636724, 64.42953020134227, 0.9090372857039414)
0.56100684 0.19664545
(0.37920424, 64.5413870246085, 0.8972572521431676)
0.45578715 0.107737504
(0.3061997, 71.81208053691275, 0.9298603969579857)
0.40657172 0.10174054
(0.23246852, 74.16107382550335, 0.9391848116875982)
