# Importing the required packages

In [15]:
# Data preprocessing packages
import pickle
import random
from collections import namedtuple
from typing import Tuple

import cv2
import lmdb
import numpy as np
import pandas as pd
from path import Path
import os
from os import walk

In [16]:
images_dir = "alexuw/pooledImages/"
filenames = next(walk(images_dir), (None, None, []))[2]

In [17]:
df = pd.DataFrame(filenames, columns=["path"])
bad_samples_reference = pd.read_csv("alexuw/bad_samples_reference.csv", header=None)

In [18]:
words_text = {}
with open("alexuw/alexuw_wordList.txt", encoding="utf-16", errors='ignore') as f:
    for num, line in enumerate(f, 1):
        words_text[int(num)] = line.rstrip("\n")

In [19]:
df["img_id"] = df["path"].apply(lambda x: int(x.rstrip(".jpg").split('-')[1]))

In [20]:
df["path"] = df["path"].apply(lambda x: images_dir+x)

In [21]:
df["word"] = df["img_id"].apply(lambda x: words_text[x])

In [22]:
chars = set(list("".join(df["word"])))

In [23]:
Sample = namedtuple('Sample', 'gt_text, file_path')
Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size')

In [24]:
samples = []

for image in df.to_dict('records'):
    samples.append(Sample(image["word"], image["path"]))

In [27]:
class DataLoader:
    def __init__(self,
                 data_dir: Path,
                 batch_size: int,
                 chars: set,
                 samples: list,
                 data_split: float = 0.95) -> None:
        """Loader for dataset."""

        assert os.path.exists(data_dir)

        self.data_augmentation = False
        self.curr_idx = 0
        self.batch_size = batch_size
        self.samples = samples
        self.chars = chars
        

        # split into training and validation set: 95% - 5%
        split_idx = int(data_split * len(self.samples))
        self.train_samples = self.samples[:split_idx]
        self.validation_samples = self.samples[split_idx:]

        # put words into lists
        self.train_words = [x.gt_text for x in self.train_samples]
        self.validation_words = [x.gt_text for x in self.validation_samples]

        # start with train set
        self.train_set()

        # list of all chars in dataset
        self.char_list = sorted(list(self.chars))

    def train_set(self) -> None:
        """Switch to randomly chosen subset of training set."""
        self.data_augmentation = True
        self.curr_idx = 0
        random.shuffle(self.train_samples)
        self.samples = self.train_samples
        self.curr_set = 'train'

    def validation_set(self) -> None:
        """Switch to validation set."""
        self.data_augmentation = False
        self.curr_idx = 0
        self.samples = self.validation_samples
        self.curr_set = 'val'

    def get_iterator_info(self) -> Tuple[int, int]:
        """Current batch index and overall number of batches."""
        if self.curr_set == 'train':
            num_batches = int(np.floor(len(self.samples) / self.batch_size))  # train set: only full-sized batches
        else:
            num_batches = int(np.ceil(len(self.samples) / self.batch_size))  # val set: allow last batch to be smaller
        curr_batch = self.curr_idx // self.batch_size + 1
        return curr_batch, num_batches

    def has_next(self) -> bool:
        """Is there a next element?"""
        if self.curr_set == 'train':
            return self.curr_idx + self.batch_size <= len(self.samples)  # train set: only full-sized batches
        else:
            return self.curr_idx < len(self.samples)  # val set: allow last batch to be smaller

    def _get_img(self, i: int) -> np.ndarray:
        img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE)
        img = cv2.bitwise_not(img)

        return img

    def get_next(self) -> Batch:
        """Get next element."""
        batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))

        imgs = [self._get_img(i) for i in batch_range]
        gt_texts = [self.samples[i].gt_text for i in batch_range]

        self.curr_idx += self.batch_size
        return Batch(imgs, gt_texts, len(imgs))

In [28]:
loader = DataLoader(data_dir=images_dir,
                    batch_size=64,
                    samples=samples,
                    chars=chars)