In [3]:
CODE_PATH = "../trainer"

In [4]:
import sys
sys.path.append(CODE_PATH)

In [5]:
import os
from PIL import Image

from tqdm.auto import tqdm
import torch
import torchvision as tv
from transformers import AutoTokenizer

from ignite.engine import (
    Engine,
    Events,
)
from ignite.handlers import Checkpoint
from ignite.contrib.handlers import global_step_from_engine
from ignite.contrib.handlers import ProgressBar
from ignite.contrib.handlers.neptune_logger import NeptuneLogger

# Dataset

In [6]:
class SynthDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, annotation_file, height=32):
        self.images_dir = images_dir
        self.annotation_file = annotation_file
        self.image_files = self._load_data()
        self.height = height

    def _load_data(self):
        with open(self.annotation_file, "r") as f:
            lines = f.read().splitlines()

        image_files = [line.split(" ")[0] for line in lines]
        return image_files
    
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        label = image_file.split("_")[1]
        image_path = os.path.join(self.images_dir, image_file)
        
        image = Image.open(image_path).convert("L")
        w, h = image.size
        ratio = w / float(h)
        nw = round(self.height * ratio)

        image = image.resize((nw, self.height), Image.BICUBIC)

        return image, label    
    

In [7]:
IMAGES_DIR = "../data/synth/mnt/90kDICT32px/"
TRAIN_ANNOTATION_FILE = "../data/synth/mnt/annotation_train.txt"
VAL_ANNOTATION_FILE = "../data/synth/mnt/annotation_val.txt"

In [8]:
train_dataset = SynthDataset(IMAGES_DIR, TRAIN_ANNOTATION_FILE)
val_dataset = SynthDataset(IMAGES_DIR, VAL_ANNOTATION_FILE)

In [10]:
good_image_files = []
bad_image_files = []
for image_file in tqdm(train_dataset.image_files):
    image_path = os.path.join(IMAGES_DIR, image_file)
    try:
        image = Image.open(image_path).convert("L")
    except:
        bad_image_files.append(image_file)
    else:
        good_image_files.append(image_file)

100%|██████████| 7224612/7224612 [23:50<00:00, 5048.81it/s]


In [13]:
with open("../data/synth/mnt/annotation_train_good.txt", "w") as f:
    f.write("\n".join(good_image_files))

In [14]:
# do the same process for val dataset
good_image_files = []
bad_image_files = []
for image_file in tqdm(val_dataset.image_files):
    image_path = os.path.join(IMAGES_DIR, image_file)
    try:
        image = Image.open(image_path).convert("L")
    except:
        bad_image_files.append(image_file)
    else:
        good_image_files.append(image_file)
    
with open("../data/synth/mnt/annotation_val_good.txt", "w") as f:
    f.write("\n".join(good_image_files))

100%|██████████| 802734/802734 [02:51<00:00, 4680.67it/s]


In [15]:
len(good_image_files), len(bad_image_files)

(802731, 3)