In [None]:
import sys
sys.path.append('..')

In [None]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from nvidia.dali.plugin.pytorch import LastBatchPolicy

from data import OCRDataset, OCRCollator, data_transforms, data_transforms_2, process_tgt, Vocab, LightningWrapper, ExternalEncodeCallable
from utils import AttnLabelConverter, CTCLabelConverter
from loguru import logger
import os
from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.types as types
import nvidia.dali.fn as fn
import numpy as np
import random

In [None]:
class ExternalInputCallable(object):
    def __init__(self, steps_per_epoch, data_path, converter, images_names, labels, batch_size=32):
        self.data_path = data_path
        self.steps_per_epoch = steps_per_epoch
        self.converter = converter
        self.batch_size = batch_size

        self.images_names = images_names
        self.labels = labels

        self.data = list(zip(images_names, labels))
        random.shuffle(self.data)

    def __call__(self, sample_info):
        idx = sample_info.idx_in_epoch
        if idx >= len(self.data):
            logger.debug(f"Trigger skip with {idx=} and {len(self.data)=}")
            # Indicate end of the epoch
            raise StopIteration()
        image_name, label = self.data[idx % len(self.data)]
        image_path = os.path.join(self.data_path, image_name)

        with open(image_path, 'rb') as f:
            file_bytes = f.read()
        
        image = np.frombuffer(file_bytes, dtype=np.uint8)
        return image, label

In [None]:
encoder = AttnLabelConverter() 

In [None]:
@pipeline_def(num_threads=8, batch_size=32, device_id=0, py_start_method="spawn", exec_dynamic=True)
def get_dali_train_pipeline_webdataset(self):
    images, labels = fn.external_source(
        source=ExternalInputCallable(
            steps_per_epoch = self.steps_per_epoch,
            data_path = self.train_data_path,
            converter = self.converter,
            images_names = self.train_images_names,
            labels = self.train_labels,
            batch_size=self.batch_size
        ),
        num_outputs=2,
        batch=False,
        parallel=True,
        dtype=[types.UINT8, types.STRING],
        prefetch_queue_depth=8,
    )
    
    images = fn.decoders.image(images, device="cpu", output_type=types.RGB)
    images = fn.rotate(images, device="cpu", angle=fn.random.uniform(range=[-1, 1]), dtype=types.FLOAT)
    images = fn.resize(images, device="cpu", resize_y=100)
    images = fn.color_twist(images, brightness=fn.random.uniform(range=[0.8, 1.2]), contrast=fn.random.uniform(range=[0.8, 1.2]), saturation=fn.random.uniform(range=[0.8, 1.2]), hue=fn.random.uniform(range=[0, 0.3]))
    images = fn.warp_affine(images, matrix=fn.transforms.scale(scale=fn.random.uniform(range=[0.9, 1], shape=[2])), fill_value=0, inverse_map=False)
    images = fn.noise.gaussian(images, mean=0.0, stddev=fn.random.uniform(range=[-10, 10])) 
    images = fn.normalize(images, device="cpu", dtype=types.FLOAT)
    images = fn.pad(images, fill_value=0)
    indices = fn.pad(indices, fill_value=0)
    length = fn.pad(length, fill_value=0)
    return images.gpu(), indices.gpu(), length.gpu()