In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
!git clone https://github.com/n1teshy/transformer > /dev/null
!mv transformer/* . && rmdir transformer > /dev/null
!ls drive/MyDrive/checkpoints/en-hi

In [None]:
# !cp drive/MyDrive/checkpoints/en-hi/ params.pth

In [2]:
import torch
from pathlib import Path
from torch.optim import AdamW
from core.utils.bpe import Tokenizer
from core.data.seq_to_seq import SeqToSeqDataset
from core.utils.configs import SeqToSeqDataConfig, EncoderConfig, DecoderConfig
from core.models import Transformer
from core.utils.loss import LossMonitor
from core.globals import DEVICE
from core.constants import TOKEN_PAD, TOKEN_SOS, TOKEN_EOS

In [3]:
TRAIN_SOURCE = Path("./drive/MyDrive/datasets/en-hi/train/")
TRAIN_TARGET = Path("./drive/MyDrive/datasets/en-hi/train/")
VAL_SOURCE = Path("./drive/MyDrive/datasets/en-hi/val/")
VAL_TARGET = Path("./drive/MyDrive/datasets/en-hi/val/")
TRAIN_CACHE = Path("drive/MyDrive/datasets/en-hi/cached-train-6.9")
VAL_CACHE = Path("drive/MyDrive/datasets/en-hi/cached-val-0.9")
ENCODER_CONTEXT = 1024
DECODER_CONTEXT = 512
BATCH_SIZE = 64
ENCODER_BLOCKS = 2
ENCODER_HEADS = 4
DECODER_BLOCKS = 2
DECODER_HEADS = 4
MODEL_DIM = 512

assert MODEL_DIM % ENCODER_HEADS == MODEL_DIM % DECODER_HEADS == 0

In [4]:
en_tokenizer = Tokenizer()
en_tokenizer.load("tokenizers/en.model")
hi_tokenizer = Tokenizer()
hi_tokenizer.load("tokenizers/hi.model")

In [13]:
base_data_config = dict(
    source=None,
    target=None,
    encoder_context=ENCODER_CONTEXT,
    decoder_context=DECODER_CONTEXT,
    encode_source=en_tokenizer.encode,
    encode_target=hi_tokenizer.encode,
    source_pad_id=en_tokenizer.specials[TOKEN_PAD],
    target_pad_id=hi_tokenizer.specials[TOKEN_PAD],
    sos_id=hi_tokenizer.specials[TOKEN_SOS],
    eos_id=hi_tokenizer.specials[TOKEN_EOS],
    batch_size=BATCH_SIZE,
    shuffle_shards=True,
    shuffle_samples=True
)

train_dataset = SeqToSeqDataset(
    SeqToSeqDataConfig(
        **dict(
            base_data_config,
            source=TRAIN_SOURCE,
            target=TRAIN_TARGET,
            cache_dir=TRAIN_CACHE
        )
    )
)

val_dataset = SeqToSeqDataset(
    SeqToSeqDataConfig(
        **dict(
            base_data_config,
            source=VAL_SOURCE,
            target=VAL_TARGET,
            cache_dir=VAL_CACHE
        )
    )
)

In [None]:
encoder_config = EncoderConfig(
    no_blocks=ENCODER_BLOCKS,
    no_heads=ENCODER_HEADS,
    model_dim=MODEL_DIM,
    vocab_size=en_tokenizer.size,
    pad_id=en_tokenizer.specials[TOKEN_PAD],
    context=ENCODER_CONTEXT
)
decoder_config = DecoderConfig(
    no_blocks=DECODER_BLOCKS,
    no_heads=DECODER_HEADS,
    model_dim=MODEL_DIM,
    vocab_size=hi_tokenizer.size,
    pad_id=hi_tokenizer.specials[TOKEN_PAD],
    context=DECODER_CONTEXT,
    sos_id=hi_tokenizer.specials[TOKEN_SOS],
    eos_id=hi_tokenizer.specials[TOKEN_EOS]
)
model = Transformer(encoder_config, decoder_config).to(DEVICE)
# model.load_state_dict(torch.load("params.pth", map_location=DEVICE))
no_params = sum(p.nelement() for p in model.parameters() if p.requires_grad)
print(f"model has {no_params / 1000 ** 2:.4f} million trainable parameters")


def save_model(t_loss, v_loss):
    name = f"{ENCODER_BLOCKS}_{DECODER_BLOCKS}__{ENCODER_HEADS}_{DECODER_HEADS}__{MODEL_DIM}__{t_loss:.2f}_{v_loss:.2f}.pth"
    torch.save(model.state_dict(), f"drive/MyDrive/checkpoints/en-hi/{name}")
    print(f"saved with t_loss: {t_loss:.2f}, v_loss: {v_loss:.2f}")


@torch.no_grad()
def calc_val_loss():
    model.eval()
    batch = val_dataset.next_batch()
    if batch is None:
        val_dataset.reset()
        batch = val_dataset.next_batch()
    x, y = batch
    logits, loss = model(x, y)
    return loss

In [18]:
optimizer = AdamW(model.parameters(), lr=0.0005)

In [19]:
loss_monitor, good_delta= LossMonitor("train", "val", window=200), None
best_t_loss, best_v_loss = None, None
assert None not in (good_delta, best_t_loss, best_v_loss)

In [20]:
epochs_trained, batches_trained = 0, 0

In [None]:
while True:
    batch = train_dataset.next_batch()
    if batch is None:
        epochs_trained, batches_trained = epochs_trained + 1, 0
        train_dataset.reset()
        batch = train_dataset.next_batch()
    x, y = batch
    model.train()
    logits, t_loss = model(x, y)
    optimizer.zero_grad()
    t_loss.backward()
    optimizer.step()
    t_loss, v_loss = t_loss.item(), calc_val_loss().item()
    losses = loss_monitor.update(train=t_loss, val=v_loss)
    batches_trained += 1
    mt_loss, mv_loss = losses["train"], losses["val"]
    print(f"{epochs_trained}:{batches_trained} -> {mt_loss:.4f}, {mv_loss:.4f}")
    if best_t_loss - mt_loss >= good_delta and best_v_loss > mv_loss:
        save_model(mt_loss, mv_loss)
        best_t_loss, best_v_loss = t_loss, v_loss

In [None]:
# gradient accumulation
acc_t_loss = 0
accumulation_steps = 3
while True:
    batch = train_dataset.next_batch()
    if batch is None:
        epochs_trained, batches_trained = epochs_trained + 1, 0
        train_dataset.reset()
        batch = train_dataset.next_batch()
    x, y = batch
    model.train()
    logits, t_loss = model(x, y)
    acc_t_loss += t_loss.item()
    t_loss = t_loss / accumulation_steps
    t_loss.backward()
    batches_trained += 1
    if batches_trained % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
        t_loss, v_loss = acc_t_loss / accumulation_steps, calc_val_loss().item()
        losses = loss_monitor.update(train=t_loss, val=v_loss)
        mt_loss, mv_loss = losses["train"], losses["val"]
        print(f"{epochs_trained}:{batches_trained // accumulation_steps} -> {mt_loss:.4f}, {mv_loss:.4f}")
        acc_t_loss = 0
        if best_t_loss - mt_loss >= good_delta and mv_loss < best_v_loss:
            save_model(mt_loss, mv_loss)

In [None]:
save_model(t_loss, v_loss)