In [1]:
%load_ext autoreload
%autoreload 2

In [27]:
import os
import string
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

from typing import * 


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, ConcatDataset, Subset, DataLoader

import torchvision.transforms as VT


from ocrnune.data import dataset
from ocrnune.models import crnn

import ocrnune.transforms as NT
from ocrnune.data.dataset import LMDBDataset, BalanceDatasetConcatenator
from ocrnune.utils import AttnLabelConverter


In [28]:
BATCH_SIZE = 4
NUM_WORKERS = 4
BATCH_MAX_LENGTH = 25
CHARACTER = string.printable[:-6]
IMG_SIZE = (32,100)
BETA1 = 0.9
BETA2 = 0.999
LRATE = 1.0

GRAD_CLIP = 5.0

In [19]:
trn_transform = VT.Compose([
    NT.ResizeRatioWithRightPad(size=IMG_SIZE),
    VT.ToTensor(),
    VT.Normalize(mean=(0.5), std=(0.5))  
])

val_transform = VT.Compose([
    NT.ResizeRatioWithRightPad(size=IMG_SIZE),
    VT.ToTensor(),
    VT.Normalize(mean=(0.5), std=(0.5))  
])


trn_path = '/data/lmdb/data_lmdb_release/training'
val_path = '/data/lmdb/data_lmdb_release/validation'



train_bdc = BalanceDatasetConcatenator(trn_path, dataset_class=LMDBDataset, 
                                       transform=trn_transform,
                                       subdir=('ST', 'MJ'), usage_ratio=(0.5, 0.5),
                                       im_size=IMG_SIZE, is_sensitive=True)
trainset = train_bdc.get_dataset()


valid_bdc = BalanceDatasetConcatenator(val_path, dataset_class=LMDBDataset, 
                                       transform=val_transform,
                                       im_size=IMG_SIZE, is_sensitive=True)
validset = valid_bdc.get_dataset()

In [20]:
len(trainset), len(validset)

(7221024, 6992)

In [21]:
from ocrnune.utils import AttnLabelConverter
converter = AttnLabelConverter(CHARACTER)
num_class = len(converter.character)

In [22]:
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
valid_loader = DataLoader(validset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [23]:
imgs, texts =  next(iter(train_loader))

In [31]:
converter = AttnLabelConverter(CHARACTER)
NUM_CLASS = len(converter.character)

In [32]:
model = crnn.OCR(num_class=NUM_CLASS, im_size=IMG_SIZE)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=LRATE, betas=(BETA1, BETA2))

In [42]:
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.metrics import Accuracy

class OCRTaks(pl.LightningModule):
    def __init__(self, model, optimizer, criterion, converter, grad_clip=5.0):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.converter = converter
        self.grad_clip = grad_clip
    
    def forward(self, imgs, text):
        output = self.model(imgs, texts)
        return output
    

    def backward(self, trainer, loss, optimizer, optimizer_idx):
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
   
    def shared_step(self, batch, batch_idx):
        images, texts = batch
        texts_encoded, texts_length = self.converter.encode(texts)
        
        preds = self.model(images, texts_encoded[:, :-1])
        targets = texts_encoded[:, 1:]
        
        loss = self.criterion(preds.view(-1, preds.shape[-1]), targets.contiguous().view(-1))
        
        return loss
        
        
    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
        result = pl.TrainResult(loss)
        result.log_dict({'trn_loss': loss})
        
        return result
    
    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log_dict({'val_loss': loss})
        
        return result
    

In [43]:
checkpoint_path = '../saved_model'
# DEFAULTS used by the Trainer
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_top_k=1,
    verbose=True,
    monitor='checkpoint_on',
    mode='min',
    prefix='ocr_net_'
)

In [48]:
tb_logger = pl_loggers.TensorBoardLogger('../logs/ocr_net')
task = OCRTaks(model, optimizer, criterion, converter)
trainer = pl.Trainer(gpus=1, logger=tb_logger, checkpoint_callback=checkpoint_callback)
trainer.fit(task, train_loader, valid_loader)

In [45]:
!nvidia-smi

Failed to initialize NVML: Driver/library version mismatch
