In [31]:
%cd ../code/

/opt/ml/team/gj/code


In [66]:
import pandas as pd
import csv
import numpy as np
from PIL import Image
import math
from attrdict import AttrDict

from dataset import (
    dataset_loader, SizeBatchSampler, split_gt, load_levels, load_sources, load_vocab, encode_truth
)

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision import transforms
from IPython.display import display, HTML

In [46]:
!ls /opt/ml/input/data

dummy  dummy_2	eval_dataset  out_stuff  saving_model  train_dataset


In [47]:
root = '/opt/ml/input/data/train_dataset'
image_root = f'{root}/images'
source_f = open(f'{root}/source.txt')
sources = [line.strip().split('\t') for line in source_f.readlines()]
sources = [(x, int(y)) for x, y in sources]

In [48]:
level_f = open(f'{root}/level.txt')
levels = [line.strip().split('\t') for line in level_f.readlines()]
levels = [(x, int(y)) for x, y in levels]

In [49]:
gt_f = open(f'{root}/gt.txt')
gts = [line.strip().split('\t') for line in gt_f.readlines()]

In [50]:
sources_df = pd.DataFrame(sources, columns=['path', 'source'])
gts_df = pd.DataFrame(gts, columns=['path', 'gt'])
levels_df = pd.DataFrame(levels, columns=['path', 'level'])

len(sources_df), len(gts_df), len(levels_df)

(100000, 100000, 100000)

In [51]:
merged = sources_df.merge(levels_df).merge(gts_df)

In [68]:
class CustomDataset(Dataset):
    """Load Dataset"""

    def __init__(
        self,
        groundtruth,
        tokens_file,
        levels,
        sources,
        crop=False,
        transform=None,
        rgb=3,
        max_resolution=128*128,
        is_flexible=False,
    ):
        """
        Args:
            groundtruth (string): Path to ground truth TXT/TSV file
            tokens_file (string): Path to tokens TXT file
            ext (string): Extension of the input files
            crop (bool, optional): Crop images to their bounding boxes [Default: False]
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        super(CustomDataset, self).__init__()
        self.crop = crop
        self.transform = transform
        self.rgb = rgb
        self.token_to_id, self.id_to_token = load_vocab(tokens_file)
        self.data = [
            {
                "path": p,
                "truth": {
                    "text": truth,
                    "encoded": [
                        self.token_to_id[START],
                        *encode_truth(truth, self.token_to_id),
                        self.token_to_id[END],
                    ],
                    'rotated': idx%4,
                    'flipped': idx%2,
                },
            }
            for idx, (p, truth) in enumerate(groundtruth)
        ]

        for datum in self.data:
            file_path = datum['path'].split('/')[-1]
            source = sources.get(file_path, -100) # -100 crossentory 무시 index
            level = levels.get(file_path, -99) - 1 # -100 모름
            datum['source'] = source
            datum['level'] = level

        self.is_flexible = is_flexible
        if self.is_flexible:
            self.shape_cache = np.zeros((len(self), 2), dtype=int)
            self.max_resolution = max_resolution


    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        item = self.data[i]
        image = Image.open(item["path"])
        if self.rgb == 3:
            image = image.convert("RGB")
        elif self.rgb == 1:
            image = image.convert("L")
        else:
            raise NotImplementedError

        if self.crop:
            # Image needs to be inverted because the bounding box cuts off black pixels,
            # not white ones.
            bounding_box = ImageOps.invert(image).getbbox()
            image = image.crop(bounding_box)

        if self.transform:
            image = self.transform(image)

        if self.is_flexible:
            image = transforms.Resize(self.get_shape(i))(image)
            
            rot_idx = item['truth']['rotated']
        flip_idx = item['truth']['flipped']
        
        angle = rot_idx * 90
        image = transforms.functional.rotate(image, angle)
        
        if flip_idx == 1:
            image = transforms.functional.hflip(image)

        return {
            "path": item["path"],
            "truth": item["truth"],
            "image": image,
            'source': item['source'],
            'level': item['level'],
        }

    def get_shape(self, i):
        h, w = self.shape_cache[i]
        if h == 0 and w == 0:
            item = self.data[i]
            image = Image.open(item["path"])
            rw, rh = image.size

            T = self.max_resolution
            div = rw * rh / T
            w = round(rw/math.sqrt(div))
            h = round(rh/math.sqrt(div))
            w = round(w / 32) * 32
            h = T // w
            # h = (T // w) // 32 * 32

            self.shape_cache[i][0] = h
            self.shape_cache[i][1] = w
        return h, w

In [71]:
def collate_batch(data):
    max_len = max([len(d["truth"]["encoded"]) for d in data])
    # Padding with -1, will later be replaced with the PAD token
    padded_encoded = [
        d["truth"]["encoded"] + (max_len - len(d["truth"]["encoded"])) * [-1]
        for d in data
    ]
    return {
        "path": [d["path"] for d in data],
        "image": torch.stack([d["image"] for d in data], dim=0),
        "truth": {
            "text": [d["truth"]["text"] for d in data],
            "encoded": torch.tensor(padded_encoded), 
            'rotated': torch.tensor([d['truth']['rotated'] for d in data]),
            'flipped': torch.tensor([d['truth']['flipped'] for d in data]),
        },
        'level': torch.tensor([d['level'] for d in data], dtype=torch.long),
        'source': torch.tensor([d['source'] for d in data], dtype=torch.long),
    }

In [74]:
options = AttrDict(
    input_size=AttrDict(
        height=128,
        width=128
    ),
    data=AttrDict(
        flexible_image_size=True,
        random_split=0.2,
        train=["/opt/ml/input/data/train_dataset/gt.txt"],
        test_proportions=0.2,
        dataset_proportions=[1],
        use_small_data=False,
        token_paths=["/opt/ml/input/data/train_dataset/tokens.txt"],
        source_paths=["/opt/ml/input/data/train_dataset/source.txt"],
        level_paths=["/opt/ml/input/data/train_dataset/level.txt"],
        crop= True,
        rgb=1,
    ),
    batch_size=16,
    num_workers=8,
)

# transformed = transforms.Compose(
#     [
#         transforms.ToTensor(),
#     ]
# )


# train_data_loader, validation_data_loader, train_dataset, valid_dataset = dataset_loader(options, transformed)

In [75]:
START = "<SOS>"
END = "<EOS>"
PAD = "<PAD>"
SPECIAL_TOKENS = [START, END, PAD]


train_data, valid_data = [], [] 
if options.data.random_split:
    print('Train-Test Data Loading')
    print(f'Random Split {options.data.test_proportions}')
    for i, path in enumerate(options.data.train):
        prop = 1.0
        if len(options.data.dataset_proportions) > i:
            prop = options.data.dataset_proportions[i]
        train, valid = split_gt(path, prop, options.data.test_proportions)
        train_data += train
        valid_data += valid
        print(f'From {path}')
        print(f'Prop: {prop}\tTrain +: {len(train)}\tVal +: {len(valid)}')
else:
    print('Train Data Loading')
    for i, path in enumerate(options.data.train):
        prop = 1.0
        if len(options.data.dataset_proportions) > i:
            prop = options.data.dataset_proportions[i]
        train = split_gt(path, prop)
        train_data += train
        print(f'From {path}')
        print(f'Prop: {prop}\tVal +: {len(train)}')

    print()
    print('Test Data Loading')
    for i, path in enumerate(options.data.test):
        valid = split_gt(path)
        valid_data += valid
        print(f'From {path}')
        print(f'Val +:\t{len(valid)}')

# Load data
if options.data.use_small_data:
    old_train_len = len(train_data)
    old_valid_len = len(valid_data)
    train_data = train_data[:100]
    valid_data = valid_data[:10]
    print("Using Small Data")
    print(f"Train: {old_train_len} -> {len(train_data)}")
    print(f'Valid: {old_valid_len} -> {len(valid_data)}')

levels = load_levels(options.data.level_paths)
sources = load_sources(options.data.source_paths)

train_dataset = CustomDataset(
    train_data, options.data.token_paths, sources=sources,
    levels=levels, crop=options.data.crop,
    transform=transformed, rgb=options.data.rgb,
    max_resolution=options.input_size.height * options.input_size.width,
    is_flexible=options.data.flexible_image_size,
)

valid_dataset = CustomDataset(
    valid_data, options.data.token_paths, sources=sources,
    levels=levels, crop=options.data.crop,
    transform=transformed, rgb=options.data.rgb,
    max_resolution=options.input_size.height * options.input_size.width,
    is_flexible=options.data.flexible_image_size,
)

if options.data.flexible_image_size:
    train_sampler = SizeBatchSampler(train_dataset, options.batch_size, is_random=True)
    train_data_loader = DataLoader(
        train_dataset,
        batch_sampler=train_sampler,
        num_workers=options.num_workers,
        collate_fn=collate_batch,
    )

    valid_sampler = SizeBatchSampler(valid_dataset, options.batch_size, is_random=False)
    valid_data_loader = DataLoader(
        valid_dataset,
        batch_sampler=valid_sampler,
        num_workers=options.num_workers,
        collate_fn=collate_batch,
    )
else:
    train_data_loader = DataLoader(
        train_dataset,
        batch_size=options.batch_size,
        shuffle=True,
        num_workers=options.num_workers,
        collate_fn=collate_batch,
    )

    valid_data_loader = DataLoader(
        valid_dataset,
        batch_size=options.batch_size,
        shuffle=False,
        num_workers=options.num_workers,
        collate_fn=collate_batch,
    )

Train-Test Data Loading
Random Split 0.2
From /opt/ml/input/data/train_dataset/gt.txt
Prop: 1	Train +: 80000	Val +: 20000


  0%|          | 0/80000 [00:00<?, ?it/s]

  0%|          | 0/20000 [00:00<?, ?it/s]

In [76]:
groups = merged.groupby(['source', 'level'])

In [77]:
for source in range(2):
    for level in range(1, 5 + 1):
        try:
            g = groups.get_group((source, level))
            g.loc[:, ['path', 'gt']].to_csv(
                f'gt_s:{source}_l:{level}.txt',
                sep='\t',
                header=False,
                index=False,
                quoting=csv.QUOTE_NONE,
            )
        except KeyError:
            print(f'Empty source: {source} level: {level}')

Empty source: 1 level: 5


In [101]:
groups.count()

Unnamed: 0_level_0,Unnamed: 1_level_0,path,gt
source,level,Unnamed: 2_level_1,Unnamed: 3_level_1
0,1,3748,3748
0,2,32288,32288
0,3,9236,9236
0,4,4156,4156
0,5,572,572
1,1,9753,9753
1,2,11242,11242
1,3,24548,24548
1,4,4457,4457
