In [1]:
CODE_PATH = "../trainer"

In [2]:
import sys
sys.path.append(CODE_PATH)

In [3]:
import os
from PIL import Image

from tqdm.auto import tqdm
import torch
import torchvision as tv
from transformers import AutoTokenizer

from ignite.engine import (
    Engine,
    Events,
)
from ignite.handlers import Checkpoint
from ignite.contrib.handlers import global_step_from_engine
from ignite.contrib.handlers import ProgressBar
from ignite.contrib.handlers.neptune_logger import NeptuneLogger

  from .autonotebook import tqdm as notebook_tqdm


# Dataset

In [4]:
class SynthDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, annotation_file, height=32):
        self.images_dir = images_dir
        self.annotation_file = annotation_file
        self.image_files = self._load_data()
        self.height = height

    def _load_data(self):
        with open(self.annotation_file, "r") as f:
            lines = f.read().splitlines()

        image_files = [line.split(" ")[0] for line in lines]
        return image_files
    
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        label = image_file.split("_")[1]
        image_path = os.path.join(self.images_dir, image_file)
        
        image = Image.open(image_path).convert("L")
        w, h = image.size
        ratio = w / float(h)
        nw = round(self.height * ratio)

        image = image.resize((nw, self.height), Image.BICUBIC)

        return image, label    
    

In [6]:
IMAGES_DIR = "../data/synth/mnt/90kDICT32px/"
TRAIN_ANNOTATION_FILE = "../data/synth/mnt/annotation_train_good.txt"
VAL_ANNOTATION_FILE = "../data/synth/mnt/annotation_val_good.txt"

In [7]:
train_dataset = SynthDataset(IMAGES_DIR, TRAIN_ANNOTATION_FILE)
val_dataset = SynthDataset(IMAGES_DIR, VAL_ANNOTATION_FILE)

In [8]:
max_width = 0
max_heigth = 0
max_w_image = None
max_h_image = None
for image_file in tqdm(train_dataset.image_files):
    image_path = os.path.join(IMAGES_DIR, image_file)
    try:
        image = Image.open(image_path).convert("L")
    except:
        continue
    else:
        w, h = image.size
        if w > max_width:
            max_width = w
            max_w_image = image_file
        if h > max_heigth:
            max_heigth = h
            max_h_image = image_file

100%|██████████| 7224379/7224379 [25:43<00:00, 4679.56it/s]  


In [9]:
max_width, max_heigth, max_w_image, max_h_image

(799,
 32,
 './1619/5/67_pulchritudinous_61162.jpg',
 './2425/1/104_SYSTEMICALLY_77086.jpg')