In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
!git clone https://github.com/n1teshy/transformer > /dev/null
!mv transformer/* . && rmdir transformer > /dev/null
!ls drive/MyDrive/checkpoints/en-hi

In [None]:
# !cp drive/MyDrive/checkpoints/en_hi/ params.pth

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from core.utils.bpe import Tokenizer
# from core.data.seq2seq import Dataset, Config as Seq2SeqConfig
from core.components import EncoderConfig, DecoderConfig
from core.models import Transformer
from core.utils.loss import LossMonitor
from core.globals import DEVICE
from core.constants import TOKEN_PAD, TOKEN_SOS, TOKEN_EOS

In [2]:
# TRAIN_SOURCE_FILE = "./datasets/en-hi/train/en.txt"
# TRAIN_TARGET_FILE = "./datasets/en-hi/train/hi.txt"
# VAL_SOURCE_FILE = "./datasets/en-hi/val/en.txt"
# VAL_TARGET_FILE = "./datasets/en-hi/val/hi.txt"
ENCODER_CONTEXT = 1024
DECODER_CONTEXT = 512
BATCH_SIZE = 8
ENCODER_BLOCKS = 2
ENCODER_HEADS = 4
DECODER_BLOCKS = 2
DECODER_HEADS = 4
MODEL_DIM = 256

assert MODEL_DIM % ENCODER_HEADS == MODEL_DIM % DECODER_HEADS == 0

In [4]:
en_tokenizer = Tokenizer()
en_tokenizer.load("tokenizers/en.model")
hi_tokenizer = Tokenizer()
hi_tokenizer.load("tokenizers/hi.model")

In [8]:
from core.data.circular_loader import CircularDataloader
import pickle

# base_data_config = dict(
#     # source=<source-file>,
#     # target=<target-file>
#     source_context=ENCODER_CONTEXT,
#     target_context=DECODER_CONTEXT,
#     encode_source=en_tokenizer.encode,
#     encode_target=hi_tokenizer.encode,
#     source_pad_id=en_tokenizer.specials[TOKEN_PAD],
#     target_pad_id=hi_tokenizer.specials[TOKEN_PAD],
#     sos_id=hi_tokenizer.specials[TOKEN_SOS],
#     eos_id=hi_tokenizer.specials[TOKEN_EOS]
# )

# train_dataset = Dataset(
#     Seq2SeqConfig(**dict(base_data_config, source=TRAIN_SOURCE_FILE, target=TRAIN_TARGET_FILE))
# )
train_dataset = pickle.load(open("drive/MyDrive/datasets/en-hi/train.pkl", "rb"))
train_dataset.device = DEVICE
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=train_dataset.collate)

val_dataset = pickle.load(open("drive/MyDrive/datasets/en-hi/val.pkl", "rb"))
val_dataset.device = DEVICE
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=val_dataset.collate)
val_loader = iter(CircularDataloader(val_loader))

In [None]:
encoder_config = EncoderConfig(
    no_blocks=ENCODER_BLOCKS,
    no_heads=ENCODER_HEADS,
    model_dim=MODEL_DIM,
    vocab_size=en_tokenizer.size,
    pad_id=en_tokenizer.specials[TOKEN_PAD],
    context=ENCODER_CONTEXT
)
decoder_config = DecoderConfig(
    no_blocks=DECODER_BLOCKS,
    no_heads=DECODER_HEADS,
    model_dim=MODEL_DIM,
    vocab_size=hi_tokenizer.size,
    pad_id=hi_tokenizer.specials[TOKEN_PAD],
    context=DECODER_CONTEXT,
    sos_id=hi_tokenizer.specials[TOKEN_SOS],
    eos_id=hi_tokenizer.specials[TOKEN_EOS]
)
model = Transformer(encoder_config, decoder_config)
# model.load_state_dict(torch.load("params.pth"))
no_params = sum(p.nelement() for p in model.parameters() if p.requires_grad)
print(f"model has {no_params / 1000 ** 2:.4f} million trainable parameters")


@torch.no_grad()
def calc_val_loss():
    model.eval()
    x, y = next(val_loader)
    logits, loss = model(x, y)
    return loss

In [86]:
optimizer = AdamW(model.parameters(), lr=0.0005)

In [87]:
loss_monitor = LossMonitor("train", "val", window=200)

In [88]:
batches_trained = 0

In [None]:
for inp, tgt in train_loader:
    model.train()
    logits, t_loss = model(inp, tgt)
    optimizer.zero_grad()
    t_loss.backward()
    optimizer.step()
    t_loss, v_loss = t_loss.item(), calc_val_loss().item()
    losses = loss_monitor.update(train=t_loss, val=v_loss)
    batches_trained += 1
    print(f"{batches_trained} -> {t_loss:.4f}, {losses['train']:.4f}; {v_loss:.4f}, {losses['val']:.4f}")

In [None]:
model.eval()
print(hi_tokenizer.decode(list(model.generate(torch.tensor([en_tokenizer.encode("""Hello, how are you?""")])))))

In [106]:
torch.save(model.state_dict(), f"drive/MyDrive/checkpoints/en_hi/")