In [None]:
import sys
import os

# To make our imports work because python relative imports suck
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

In [None]:
import torch

import numpy as np
import lightning as L

from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from lightning.pytorch.loggers import WandbLogger

# Local Modules
from Architecture import  Tokenizer, VOCAB_SIZE
from Architecture.ModelConfig import ModelConfig
from Architecture.Decoder import DecoderDataset, DecoderBlock, DecoderModel

In [None]:
CONFIG = ModelConfig()

torch.manual_seed(CONFIG.random_seed)
# np.rand
torch.__version__

## The Data

In [None]:
dataset = DecoderDataset.load_from("./data/decoder_data.pt")

In [None]:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(CONFIG.val_split * dataset_size))

if CONFIG.shuffle_dataset:
    np.random.seed(CONFIG.random_seed)
    np.random.shuffle(indices)

train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(
    dataset,
    batch_size=CONFIG.batch_size,
    sampler=train_sampler,
    num_workers=9,
    persistent_workers=True
)

val_loader = DataLoader(
    dataset,
    batch_size=CONFIG.batch_size,
    sampler=val_sampler,
    num_workers=9,
    persistent_workers=True
)

del dataset

In [None]:
print("Number of Training Batches:", len(train_loader))
print("Number of Validation Batches:", len(val_loader))

In [None]:
sample_batch = next(iter(val_loader))
print(sample_batch.keys(), "\n")

for key, item in sample_batch.items():
    print(f"{key}:".ljust(24), item.shape)

## The Model

In [None]:
# model
transformer = DecoderModel(
    decoder_block=DecoderBlock,
    n_layers=CONFIG.n_layers,
    n_head=CONFIG.n_head,
    n_dim=CONFIG.n_dim,
    max_seq_len=CONFIG.max_seq_len,
    mlp_dropout=CONFIG.mlp_dropout,
    attn_dropout=CONFIG.attn_dropout,
    vocab_size=VOCAB_SIZE,
    learning_rate=CONFIG.learning_rate,
    min_learning_rate=CONFIG.min_learning_rate,
    weight_decay=CONFIG.weight_decay,
    beta1=CONFIG.beta1,
    beta2=CONFIG.beta2,
    bias=CONFIG.bias,
    log_interval=CONFIG.log_interval
)

# logging
if CONFIG.wandb_log:
    wandb_logger = WandbLogger(
        project=CONFIG.wandb_project_name + "-decoder",
        name=CONFIG.wandb_run_name,
        config=CONFIG
    )

    # log gradients and model topology
    wandb_logger.watch(transformer)

# Define the trainer
trainer = L.Trainer(
    default_root_dir="./checkpoints/",
    max_epochs=CONFIG.num_epochs,
    val_check_interval=CONFIG.log_interval,
    log_every_n_steps=1,
    accumulate_grad_batches=CONFIG.grad_accumulation,
    gradient_clip_val=CONFIG.grad_clip,
    profiler="simple",
    logger=wandb_logger,
    # precision="16-mixed"
)

# train model
trainer.fit(
    model=transformer,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader
)