In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import string
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

from typing import * 


import torch
from torch.utils.data import Dataset, ConcatDataset, Subset, DataLoader

import torchvision.transforms as VT


from ocrnune.data import dataset
from ocrnune.models import crnn

import ocrnune.transforms as NT
from ocrnune.data.dataset import LMDBDataset, BalanceDatasetConcatenator, AlignCollate


In [3]:
# enc = crnn.EncoderOCR()
# test_data = torch.rand(3,1,224,224)
# out = enc(test_data)
# print(out.shape)


In [4]:
batch_size = 4
batch_max_length = 25
character = string.printable[:-6]
img_size = (32,100)

trn_transform = VT.Compose([
    NT.ResizeRatioWithRightPad(size=img_size),
    VT.ToTensor(),
    VT.Normalize(mean=(0.5), std=(0.5))  
])

val_transform = VT.Compose([
    NT.ResizeRatioWithRightPad(size=img_size),
    VT.ToTensor(),
    VT.Normalize(mean=(0.5), std=(0.5))  
])


trn_path = '/run/user/1000/gvfs/sftp:host=tigapilar.fn/data/lmdb/data_lmdb_release/training'
val_path = '/run/user/1000/gvfs/sftp:host=tigapilar.fn/data/lmdb/data_lmdb_release/validation'



train_bdc = BalanceDatasetConcatenator(trn_path, dataset_class=LMDBDataset, 
                                       transform=trn_transform,
                                       subdir=('ST', 'MJ'), usage_ratio=(0.5, 0.5),
                                       im_size=img_size, is_sensitive=False)
trainset = train_bdc.get_dataset()


valid_bdc = BalanceDatasetConcatenator(val_path, dataset_class=LMDBDataset, 
                                       transform=val_transform,
                                       im_size=img_size, is_sensitive=False)
validset = valid_bdc.get_dataset()


In [5]:
len(trainset), len(validset)

(7221024, 6992)

In [6]:
from ocrnune.utils import AttnLabelConverter
converter = AttnLabelConverter(character)

train_loader = DataLoader(trainset, batch_size=4)
imgs, texts =  next(iter(train_loader))



In [7]:
texts_encode, length = converter.encode(texts)
print(texts)
print(texts_encode)

('also', 'easy', 'the', 'world')
tensor([[ 0, 12, 23, 30, 26,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0, 16, 12, 30, 36,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0, 31, 19, 16,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0, 34, 26, 29, 23, 15,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0]])


In [8]:
converter.decode(texts_encode, length)

['[GO]also[s][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO]',
 '[GO]easy[s][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO]',
 '[GO]the[s][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO]',
 '[GO]world[s][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO][GO]']

In [9]:
num_class = len(converter.character)
num_class

96

In [10]:
model = crnn.OCR(num_class=num_class, im_size=img_size)

In [12]:
out = model(imgs, texts_encode[:, :-1])

In [15]:
out[0].shape

torch.Size([26, 96])

In [40]:
import torch.nn.functional as NNF

batch_size = 4 
batch_max_length = 25
# pred_length = batch_size * batch_max_length
pred_length = torch.IntTensor([batch_max_length] * batch_size)
# preds_prob = NNF.softmax(out, dim=2)

preds_prob = NNF.softmax(out, dim=2)
_, preds_index = preds_prob.max(dim=2)
# preds_max_prob[0]

preds_str = converter.decode(preds_index[:, 1:], pred_length)
# preds_index.shape
preds_str

['LL#x@<<<KKLLLLLLLLLLLLLLL',
 'LLxL<<<<LLLLLLLLLLLLLLLLL',
 'LLL@@<LLLLLLLLLLLLLLLLLLL',
 'LLLxx@<<LLLLLLLLLLLLLLLLL']

In [1]:
valid_bdc.base_datamap

NameError: name 'valid_bdc' is not defined

In [41]:
bdc.base_datamap
# img, text = lmdb_dataset[0]
# text

{'/': {'ratio': 1,
  'dirpath': ['/run/user/1000/gvfs/sftp:host=tigapilar.fn/data/lmdb/data_lmdb_release/validation'],
  'total_length': 6992}}

In [54]:

subdir = ('MJ','ST')
batch_ratio = (0.5, 0.3)
batch_size = 64
total_data_usage_ratio = 1.0


dataloader_list = []
dataloader_iter_list = []

batch_size_list = []
total_batch_size = 0

num_worker = 1


align_collate_fn = AlignCollate(im_size=(32,100), keep_ratio_with_pad=True)


for sdir, bratio in zip(subdir, batch_ratio):
    dconcat = DatasetConcatenator(path, subdir=(sdir), 
                                  dataset_class=LMDBDataset, 
                                  im_size=(32,100), is_sensitive=False)
    dataset = dconcat.get_concat_dataset()
    dataset_length = len(dataset)

    number_dataset = int(dataset_length * float(total_data_usage_ratio))
    dataset_split = [number_dataset, dataset_length - number_dataset]
    
    indices = range(dataset_length)

    nbatch_size = max(round(batch_size * float(bratio)), 1)

    dataset, _ = [Subset(dataset, indices[offset - length:offset])
                           for offset, length in zip(_accumulate(dataset_split), dataset_split)]

    batch_size_list.append(str(batch_size))
    total_batch_size += nbatch_size
            
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=nbatch_size,
        shuffle=True,
        num_workers=num_worker,
        collate_fn=align_collate_fn, pin_memory=True
    )
    dataloader_list.append(dataloader)
    dataloader_iter_list.append(iter(dataloader))

    batch_size_sum = '+'.join(batch_size_list)
    batch_size = total_batch_size

    # _dataset, _ = [
    #     Subset(_dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(dataset_split), dataset_split)
    # ]

    print(f'nbatch_size: {nbatch_size} \n'
          f'number_dataset: {number_dataset} \n'
          f'dataset_split: {dataset_split} \n'
          f'indices: {len(indices)} \n'
          f'dataset last length: {len(dataset)}')





    
    # print(nbatch_size)

nbatch_size: 32 
number_dataset: 8919241 
dataset_split: [8919241, 0] 
indices: 8919241 
dataset last length: 8919241
nbatch_size: 10 
number_dataset: 5522808 
dataset_split: [5522808, 0] 
indices: 5522808 
dataset last length: 5522808


In [50]:
dataloader_list

[<torch.utils.data.dataloader.DataLoader at 0x7fc2ed8ac1d0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc2ed8a74e0>]

In [37]:
dataset_subsets[0]

<torch.utils.data.dataset.Subset at 0x7fc2ed926748>

In [72]:
ratio = (1,0)
len(ratio)

2