In [1]:
import os, sys, random, gc
import numpy as np
import pandas as pd
import torch

sys.path.append('../')

from playdict_ocr.tokenization import Tokenizer
from datasets import PartitionedTrainDataset, TrainDataset, TestDataset

tokenizer = Tokenizer()

In [2]:
class CFG:
    max_dec_len=25
    size=(128, 32)
    epochs, batch_size = 1, 128
    encoder_lr, decoder_lr = 1e-4, 4e-4
    weight_decay, dropout = 1e-5, 0.25
    max_grad_norm=4
    embed_dim, attention_dim = 128, 256
    encoder_dim, decoder_dim = 384, 512

# MODEL

In [3]:
from models import EncoderDecoderModel
import keras4torch as k4t

model = EncoderDecoderModel(CFG, tokenizer)

In [4]:
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence

class CollateWrapper:
    # run on cpu
    def __call__(self, batch):
        src, tgt, tgt_lens = [], [], []
        for t in batch:
            src.append(t[0])
            tgt.append(torch.from_numpy(t[1]))
            tgt_lens.append(t[2])

        src = torch.stack(src)
        tgt = pad_sequence(tgt, batch_first=True, padding_value=0)
        tgt_lens = torch.tensor(tgt_lens, dtype=torch.int64)
        return src, tgt, tgt_lens, torch.tensor(0)

In [5]:
class MyLoopConfig(k4t.configs.TrainerLoopConfig):
    # run on gpu
    def process_batch(self, batch):
        src, tgt, tgt_lens, _ = batch
        if not self.training:
            return (src,), tgt

        tgt_lens, sort_idx = tgt_lens.sort(dim=0, descending=True)
        src, tgt = src[sort_idx], tgt[sort_idx]
        return (src, tgt, tgt_lens), tgt

    def prepare_for_optimizer_step(self, model):
        torch.nn.utils.clip_grad_norm_(model.model.encoder.parameters(), CFG.max_grad_norm)
        torch.nn.utils.clip_grad_norm_(model.model.decoder.parameters(), CFG.max_grad_norm)

In [6]:
from torch.optim.lr_scheduler import OneCycleLR
import torch.nn as nn
import torch.nn.functional as F

from torch_optimizer import AdaBelief

class CombinedOpt(torch.optim.Optimizer):
    def __init__(self, model):
        super().__init__(model.parameters(), {'lr': float('-inf')})
        self.encoder_opt = AdaBelief(
            model.encoder.parameters(), lr=CFG.encoder_lr, weight_decay=CFG.weight_decay)
        self.decoder_opt = torch.optim.Adam(
            model.decoder.parameters(), lr=CFG.decoder_lr)

    def step(self):
        self.encoder_opt.step()
        self.decoder_opt.step()

opt = CombinedOpt(model)

model = k4t.Model(model)

def ce_loss(y_pred, y_true):
    y_pred = y_pred.reshape(-1, tokenizer.vocab_size)
    y_true = y_true.reshape(-1)
    nonzero_indices = torch.nonzero(y_true).view(-1)
    return F.cross_entropy(y_pred[nonzero_indices], y_true[nonzero_indices])

def acc(y_pred, y_true):
    y_pred = y_pred.argmax(-1).cpu().numpy()
    y_true = y_true.cpu().numpy()

    y_ = [(tokenizer.indices_to_string(i) == tokenizer.indices_to_string(j))
            for i,j in zip(y_pred, y_true)]

    return torch.tensor(y_, dtype=float).mean()

model.compile(optimizer=opt, loss=ce_loss, metrics=[acc], loop_config=MyLoopConfig(), disable_val_loss=True)

# Train loop

In [7]:
from torch.utils.data import DataLoader
from keras4torch.callbacks import LRScheduler
import pickle
from keras4torch.utils.data import RestrictedRandomSampler

torch.backends.cudnn.benchmark = True

file_list = [f"../preprocessed/train_data_{i}.pkl" for i in range(4)]
cnt_list = [2000000] * 3 + [1224600]

val_data = pd.read_pickle("../preprocessed/val_data.pkl")

train_set = PartitionedTrainDataset(file_list, cnt_list, tokenizer, CFG.size)
val_set = TrainDataset(val_data, tokenizer, CFG.size)

model.fit(train_set,
            validation_data=val_set,
            epochs=CFG.epochs,
            batch_size=CFG.batch_size,
            validation_batch_size=CFG.batch_size*2,
            collate_fn=CollateWrapper(),
            sampler=RestrictedRandomSampler(cnt_list),
)

model.save_weights('saved_model/best.pt')

Train on 7224600 samples, validate on 802733 samples:
Epoch 1/3
56443/56443 - 6554s - loss: 0.6195 - acc: 0.6648 - val_acc: 0.8024 - lr: -inf
Epoch 2/3
56443/56443 - 6648s - loss: 0.4355 - acc: 0.8175 - val_acc: 0.8413 - lr: -inf
Epoch 3/3
56443/56443 - 6881s - loss: 0.4123 - acc: 0.8460 - val_acc: 0.8591 - lr: -inf
