In [None]:
import sys
import os

# To make our imports work because python relative imports suck
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

In [None]:
import torch
import copy

import numpy as np
import lightning as L

from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from lightning.pytorch.loggers import WandbLogger

# Local Modules
from Architecture import PositionalEncoder, Tokenizer, VOCAB_SIZE
from Architecture.ModelConfig import ModelConfig
from Architecture.Decoder import DecoderDataset, DecoderBlock

In [None]:
CONFIG = ModelConfig()

torch.manual_seed(CONFIG.random_seed)
# np.rand
torch.__version__

## The Data

In [None]:
dataset = DecoderDataset.load_from("./data/decoder_data.pt")

In [None]:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(CONFIG.val_split * dataset_size))

if CONFIG.shuffle_dataset:
    np.random.seed(CONFIG.random_seed)
    np.random.shuffle(indices)

train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(
    dataset,
    batch_size=CONFIG.batch_size,
    sampler=train_sampler,
    num_workers=1,
    persistent_workers=True
)

val_loader = DataLoader(
    dataset,
    batch_size=CONFIG.batch_size,
    sampler=val_sampler,
    num_workers=1,
    persistent_workers=True
)

del dataset

In [None]:
print("Number of Training Batches:", len(train_loader))
print("Number of Validation Batches:", len(val_loader))

In [None]:
sample_batch = next(iter(val_loader))
print(sample_batch.keys(), "\n")

for key, item in sample_batch.items():
    print(f"{key}:".ljust(24), item.shape)

## The Model

In [None]:
class DecoderModel(L.LightningModule):
    def __init__(
            self, decoder_block,
            # Hyperparameters and Config
            n_layers, n_head, n_dim, max_seq_len, mlp_dropout, attn_dropout,
            vocab_size, learning_rate, min_learning_rate,
            weight_decay, beta1, beta2, bias=False, log_interval=1
    ):
        super().__init__()
        self.save_hyperparameters()

        self.criterion = torch.nn.CrossEntropyLoss(
            ignore_index=Tokenizer.pad_token_id
        )

        dec = decoder_block(
            n_head,
            n_dim,
            max_seq_len,
            mlp_dropout,
            attn_dropout,
            bias
        )

        self.decoder_layers = torch.nn.ModuleList(
            [copy.deepcopy(dec) for _ in range(n_layers)]
        )
        self.embedding = torch.nn.Embedding(vocab_size, n_dim)
        self.pos_encoder = PositionalEncoder(n_dim, max_seq_len)
        self.final_linear = torch.nn.Linear(n_dim, vocab_size)

    def forward(self, x, tgt_key_pad_mask, memory=None, memory_key_pad_mask=None):
        x = self.pos_encoder(self.embedding(x))

        for layer in self.decoder_layers:
            x = layer(x, tgt_key_pad_mask, memory, memory_key_pad_mask)

        logits = self.final_linear(x)
        return logits

    def training_step(self, batch):
        _, loss, _ = self._compute_and_log_metrics(batch, "train")
        return loss

    def validation_step(self, batch):
        _, loss, _ = self._compute_and_log_metrics(batch, "validation")
        return loss
    
    def test_step(self, batch):
        self._compute_and_log_metrics(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            params=self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
            betas=(self.hparams.beta1, self.hparams.beta2),
        )

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer,
            patience=10,
            min_lr=self.hparams.min_learning_rate,
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": self.hparams.log_interval,
                "monitor": "validation_loss",
                "strict": True,
            }
        }
    
    def _compute_and_log_metrics(self, batch, prefix):
        logits = self(batch["inputs"], batch["attn_masks"])
        loss = self._compute_loss(logits, batch["targets"])
        acc = self._compute_accuracy(logits, batch["targets"])

        self.log_dict(
            { f"{prefix}_loss": loss, f"{prefix}_accuracy":  acc },
            on_step=True, on_epoch=True, logger=True
        )

        return logits, loss, acc

    def _compute_loss(self, logits, targets):
        return self.criterion(logits.view(-1, logits.size(-1)), targets.view(-1))

    def _compute_accuracy(self, logits, targets):
        # Get the index of the maximum logit as the predicted token
        _, predicted = torch.max(logits, dim=-1)

        # Mask out padding positions
        non_padding_mask = (targets != Tokenizer.pad_token_id)
        total_non_padding = non_padding_mask.sum().item()

        correct_predictions = (
            predicted[non_padding_mask] == targets[non_padding_mask]
        ).sum().item()
        
        accuracy = correct_predictions / total_non_padding if total_non_padding > 0 else 0.0

        return accuracy
    
    def _generate(self, src, src_pad_mask, tgt_seed, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = tgt_seed[-self.hparams.max_seq_len:]

            logits = self(src, idx_cond, src_pad_mask, None)
            logits = logits[:, -1]

            probs = torch.nn.functional.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)

            # append sampled index to the running sequence
            tgt_seed = torch.cat((tgt_seed, idx_next), dim=1)

            if idx_next[0][0] == Tokenizer.eos_token_id:
                break

        return tgt_seed

In [None]:
# model
transformer = DecoderModel(
    decoder_block=DecoderBlock,
    n_layers=CONFIG.n_layers,
    n_head=CONFIG.n_head,
    n_dim=CONFIG.n_dim,
    max_seq_len=CONFIG.max_seq_len,
    mlp_dropout=CONFIG.mlp_dropout,
    attn_dropout=CONFIG.attn_dropout,
    vocab_size=VOCAB_SIZE,
    learning_rate=CONFIG.learning_rate,
    min_learning_rate=CONFIG.min_learning_rate,
    weight_decay=CONFIG.weight_decay,
    beta1=CONFIG.beta1,
    beta2=CONFIG.beta2,
    bias=CONFIG.bias,
    log_interval=CONFIG.log_interval
)

# logging
if CONFIG.wandb_log:
    wandb_logger = WandbLogger(
        project=CONFIG.wandb_project_name + "-decoder",
        name=CONFIG.wandb_run_name,
        config=CONFIG
    )

    # log gradients and model topology
    wandb_logger.watch(transformer)

# Define the trainer
trainer = L.Trainer(
    default_root_dir="./checkpoints/",
    max_epochs=CONFIG.num_epochs,
    val_check_interval=CONFIG.log_interval,
    log_every_n_steps=1,
    accumulate_grad_batches=CONFIG.grad_accumulation,
    gradient_clip_val=CONFIG.grad_clip,
    profiler="simple",
    logger=wandb_logger,
    precision="16-mixed"
)

# train model
trainer.fit(
    model=transformer,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader
)