In [None]:
import sys
import os

if 'google.colab' in sys.modules:
    prev_dir = os.getcwd()

    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Replace with correct location
    %cd /content/drive/MyDrive/Colab Notebooks/CodeNet-Sentinel/Decoder

    !pip install datasets transformers lightning wandb

In [None]:
# To make our imports work because python relative imports suck
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

In [None]:
import torch
import math

import numpy as np
import lightning as L

from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor

# Local Modules
from Architecture import  Tokenizer, VOCAB_SIZE
from Architecture.ModelConfig import ModelConfig
from Architecture.Decoder import DecoderBlock
from Architecture.Encoder import EncoderBlock
from Architecture.CodeGenTransformer import CodeGenDataset, CodeGenModel

In [None]:
CONFIG = ModelConfig()

# Config specific to the Decoder
CONFIG.wandb_project_name += "-CodeGen"
CONFIG.batch_size = 1
CONFIG.grad_accumulation = 8

torch.manual_seed(CONFIG.random_seed)
torch.__version__

## The Data

In [None]:
training_dataset = CodeGenDataset.load_from("./data/codegen_data_training.pt")
len(training_dataset)

In [None]:
validation_dataset = CodeGenDataset.load_from("./data/codegen_data_validation.pt")
len(validation_dataset)

In [None]:
train_loader = DataLoader(
    training_dataset,
    batch_size=CONFIG.batch_size,
    num_workers=2,
    persistent_workers=True
)

val_loader = DataLoader(
    validation_dataset,
    batch_size=CONFIG.batch_size,
    num_workers=2,
    persistent_workers=True
)

del training_dataset, validation_dataset

In [None]:
print("Number of Training Batches:", len(train_loader))
print("Number of Validation Batches:", len(val_loader))

In [None]:
sample_batch = next(iter(val_loader))
print(sample_batch.keys(), "\n")

for key, item in sample_batch.items():
    print(f"{key}:".ljust(32), item.shape)

In [None]:
# unfortunately, the connection to Google Drive dies after a couple
# of hours, crashing the entire notebook. To prevent any issues when
# running the notebok in the background, we must unmount and make
# the previous directory (`prev_dir`) our current directory.
if 'google.colab' in sys.modules:
    %cd $prev_dir
    drive.flush_and_unmount()

## The Model

In [None]:
# model
model = CodeGenModel(
    decoder_block=DecoderBlock,
    encoder_block=EncoderBlock,
    n_layers=CONFIG.n_layers,
    n_head=CONFIG.n_head,
    n_dim=CONFIG.n_dim,
    max_instruct_len=CONFIG.max_instruct_len,
    max_seq_len=CONFIG.max_seq_len,
    vocab_size=VOCAB_SIZE,
    mlp_dropout=CONFIG.mlp_dropout,
    attn_dropout=CONFIG.attn_dropout,
    learning_rate=CONFIG.learning_rate,
    min_learning_rate=CONFIG.min_learning_rate,
    weight_decay=CONFIG.weight_decay,
    beta1=CONFIG.beta1,
    beta2=CONFIG.beta2,
    max_iters=len(train_loader) * CONFIG.num_epochs,
    bias=CONFIG.bias
)

### Callbacks and Logging

In [None]:
# logging
if CONFIG.wandb_log:
    wandb_logger = WandbLogger(
        project=CONFIG.wandb_project_name,
        name=CONFIG.wandb_run_name,
        config=CONFIG
    )

    # log gradients and model topology
    # wandb_logger.watch(transformer)

class GenerateTextCallback(L.Callback):
    def __init__(self):
        super().__init__()

        self.instructions_text = [
            "[CLS]Write a javascript function to make an API get request to retrieve a song based on an api endpoint and a song id.[SEP]",
            "[CLS]rust[SEP]",
            "[CLS]Complete this python function to compute the determinant of a square matrix.",
        ]

        self.tokenized_instructions = Tokenizer.batch_encode_plus(
            batch_text_or_text_pairs=self.instructions_text,
            truncation=True,
            max_length=CONFIG.max_instruct_len,
            padding="max_length",
            return_attention_mask=True,
            return_tensors="pt",
        )

        self.response_seed_text = [
            "[BOS]",
            "[BOS]use crate::virtual_machine::{RuntimeResult, VirtualMachine};\n\nimpl VirtualMachine {\n  /// Executes the instructions in a chunk of byte code\n  pub(crate) fn run(&mut self) -> RuntimeResult {\n    loop {\n      let instruction = self.next_op_code();",
            "[BOS]def compute_determinant(matrix):\n"
        ]

        self.tokenized_response_seeds = [
            Tokenizer.encode(
                self.response_seed_text[0],
                return_tensors="pt"
            ).to("cpu"),

            Tokenizer.encode(
                self.response_seed_text[1],
                return_tensors="pt"
            ).to("cpu"),

            Tokenizer.encode(
                self.response_seed_text[2],
                return_tensors="pt"
            ).to("cpu")
        ]

    def on_validation_epoch_end(self, trainer, pl_module):
        pl_module.eval()

        generated_text = []
        for idx in range(len(self.tokenized_instructions["input_ids"])):
            with torch.no_grad():
                generated = pl_module._generate(
                    self.tokenized_instructions["input_ids"][idx],
                    ~(self.tokenized_instructions["attention_mask"][idx].bool()),
                    self.tokenized_response_seeds[idx],
                    max_new_tokens=1024
                )

                generated_text.append(Tokenizer.decode(generated[0]))

        # Log the generated text to W&B
        columns = ["Instruction", "Response Seed", "Generated Response"]
        data = list(zip(self.instructions_text, self.response_seed_text, generated_text))
        wandb_logger.log_text(key="Text Generation Samples", columns=columns, data=data)

lr_monitor = LearningRateMonitor(logging_interval='step')
text_gen_callback = GenerateTextCallback()

### The Trainer

In [None]:
# Define the trainer
trainer = L.Trainer(
    default_root_dir="./checkpoints/",
    max_epochs=CONFIG.num_epochs,
    # We specified that validation metrics are not logged at every "validation step", but rather
    # at every "validation epoch." This is different from the training metrics, which are logged
    # at every training step and every training epoch. Note that validation steps are different from
    # training step. The `log_every_n_steps` parameter accounts for this difference.
    val_check_interval=CONFIG.log_interval,
    # Because we have gradient accumulation, the training step is different from the global step.
    # The global step is used to log the metrics at the interval we specify here, and is multiplied by
    # the gradient accumulation steps. To align the validation logs with the training logs, we must
    # divide the log interval by the gradient accumulation step. We further divide by 10, such that it
    # logs the training loss 10 times in the same period it logs one validation loss.
    log_every_n_steps=math.ceil(CONFIG.log_interval / 10 / CONFIG.grad_accumulation),
    accumulate_grad_batches=CONFIG.grad_accumulation,
    gradient_clip_val=CONFIG.grad_clip,
    profiler="simple",
    logger=wandb_logger,
    precision="16-mixed",
    callbacks=[lr_monitor, text_gen_callback]
)

# train model
trainer.fit(
    model=model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)