In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import string
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

from typing import * 


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, ConcatDataset, Subset, DataLoader

import torchvision.transforms as VT


from iqra.data import dataset
from iqra.models import crnn

import iqra.transforms as NT
from iqra.data.dataset import LMDBDataset, BalanceDatasetConcatenator
from iqra.utils import AttnLabelConverter

In [3]:
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.metrics import Accuracy

class OCRTaks(pl.LightningModule):
    def __init__(self, model, optimizer, criterion, converter, grad_clip=5.0):
        super().__init__()
        self.model = model
        self.model = self.model.to(self.device)
        
        self.optimizer = optimizer
        self.criterion = criterion
        self.converter = converter
        self.grad_clip = grad_clip
    
    def forward(self, imgs, text):
        output = self.model(imgs, texts)
        return output
    

    def backward(self, trainer, loss, optimizer, optimizer_idx):
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
   
    def shared_step(self, batch, batch_idx):
        images, texts = batch
        images = images.to(self.device)
        
        texts_encoded, texts_length = self.converter.encode(texts)
        texts_encoded = texts_encoded.to(self.device)
        texts_length = texts_encoded.to(self.device)
        
        preds = self.model(images, texts_encoded[:, :-1])
        targets = texts_encoded[:, 1:]
        
        loss = self.criterion(preds.view(-1, preds.shape[-1]), targets.contiguous().view(-1))
        
        return loss
        
        
    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
#         result = pl.TrainResult(loss)
#         result.log_dict({'trn_loss': loss})
        self.log('trn_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return result
    
    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
#         result = pl.EvalResult(checkpoint_on=loss)
#         result.log_dict({'val_loss': loss})
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return result
    
    def configure_optimizers(self):
        return self.optimizer
    

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [5]:
BATCH_SIZE = 2
NUM_WORKERS = 4
BATCH_MAX_LENGTH = 25
CHARACTER = string.printable[:-6]
IMG_SIZE = (32,100)
BETA1 = 0.9
BETA2 = 0.999
LRATE = 1.0

GRAD_CLIP = 5.0

In [6]:
trn_transform = VT.Compose([
    NT.ResizeRatioWithRightPad(size=IMG_SIZE),
    VT.ToTensor(),
    VT.Normalize(mean=(0.5), std=(0.5))  
])

val_transform = VT.Compose([
    NT.ResizeRatioWithRightPad(size=IMG_SIZE),
    VT.ToTensor(),
    VT.Normalize(mean=(0.5), std=(0.5))  
])


trn_path = '/data/clova_deeptext/training'
val_path = '/data/clova_deeptext/validation'



train_bdc = BalanceDatasetConcatenator(trn_path, dataset_class=LMDBDataset, 
                                       transform=trn_transform,
                                       subdir=('ST', 'MJ'), usage_ratio=(0.5, 0.5),
                                       im_size=IMG_SIZE, is_sensitive=True)
trainset = train_bdc.get_dataset()


valid_bdc = BalanceDatasetConcatenator(val_path, dataset_class=LMDBDataset, 
                                       transform=val_transform,
                                       im_size=IMG_SIZE, is_sensitive=True)
validset = valid_bdc.get_dataset()

In [7]:
len(trainset), len(validset)

(7221024, 6992)

In [8]:
from iqra.utils import AttnLabelConverter
converter = AttnLabelConverter(CHARACTER)
num_class = len(converter.character)

In [9]:
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
valid_loader = DataLoader(validset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [10]:
imgs, texts =  next(iter(train_loader))

In [11]:
converter = AttnLabelConverter(CHARACTER)
NUM_CLASS = len(converter.character)

In [12]:
model = crnn.OCRNet(num_class=NUM_CLASS, im_size=IMG_SIZE)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=LRATE, betas=(BETA1, BETA2))

In [13]:
checkpoint_path = 'checkpoints/'
# DEFAULTS used by the Trainer
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_top_k=1,
    verbose=True,
    monitor='val_loss',
    mode='min',
    prefix='ocr_net_'
)

In [16]:
tb_logger = pl_loggers.TensorBoardLogger('logs/ocr_net')
task = OCRTaks(model, optimizer, criterion, converter)
trainer = pl.Trainer(gpus=2, logger=tb_logger, checkpoint_callback=checkpoint_callback, distributed_backend='dp')
trainer.fit(task, train_loader, valid_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | OCRNet           | 856 M 
1 | criterion | CrossEntropyLoss | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

RuntimeError: All input tensors must be on the same device. Received cuda:0 and cpu

In [14]:
# !nvidia-smi

In [15]:
# !pip install pytorch-lightning==1.0