#  <font color='#FFE15D'><b>ðŸ’Ž Train, Evaluate, and Generate Functions (General)</b></font>

# ðŸ”´ **Environment Setup**

## ðŸŸ  Change the font size of the output cells

In [None]:
from IPython.display import HTML
shell = get_ipython()

def adjust_font_size():
  display(HTML('''<style>
    body {
      font-size: 24px;
    }
  '''))

if adjust_font_size not in shell.events.callbacks['pre_execute']:
  shell.events.register('pre_execute', adjust_font_size)

## ðŸŸ  `pip`

In [None]:
# !pip install -q datasets torchmetrics

# ðŸ”´ **Import**

In [None]:
import time
import math
import random
from tqdm import tqdm
from dataclasses import dataclass

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from datasets import load_dataset
from tokenizers import Tokenizer

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F

from torchmetrics.aggregation import MeanMetric

# ðŸ”´ **Utils**

In [None]:
def prepare_data(tokens, seq_len):
    # Trim tokens so that total length is divisible by seq_len
    n_tokens = (tokens.shape[0] // seq_len) * seq_len
    tokens = tokens[:n_tokens]
    # Reshape to 2D tensor
    return tokens.view(-1, seq_len)

In [None]:
def num_trainable_params(model):
  nums = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6
  return nums

In [None]:
# Benchmarking function
def calculate_time(model, x, num_runs=10):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_runs):
        model(x)
    torch.cuda.synchronize()
    return (time.time() - start) / num_runs

# ðŸ”´ **Init**

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

# ðŸ”´ **Dataset**

In [None]:
dataset = load_dataset("roneneldan/TinyStories")
dataset

In [None]:
tokenizer = Tokenizer.from_file("bpe-tokenizer_tinystories.json")
tokenizer

In [None]:
# Load tokens from pytorch file
train_token_ids = torch.load('tokenized-train-samples_vocab-10k.pt')
valid_token_ids = torch.load('tokenized-valid-samples_vocab-10k.pt')

print("ðŸ“Š Number of Tokens")
print(f"ðŸ”¹ Train: {len(train_token_ids):,} tokens")
print(f"ðŸ”¹ Valid: {len(valid_token_ids):,} tokens")

In [None]:
class TinyStoriesDataset(Dataset):

    def __init__(self, data, seq_len):
        self.seq_len = seq_len
        self.data = prepare_data(data, seq_len+1)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample.long()#[:-1], sample[1:]

# ðŸ”´ **Model**

## ðŸŸ  Multi Head Attention

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.head_size = self.n_embd // self.n_head

        self.qkv_proj = nn.Linear(self.n_embd, 3*self.n_embd, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_proj.residual = True

    def forward(self, x):
        B, T, C = x.shape
        # QKV linear
        q, k, v = self.qkv_proj(x).view(B, T, 3*self.n_head, self.head_size).transpose(1, 2).chunk(3, dim=-3)
        # Scaled Dot Product Attention using pytorch
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        # Reshape and final projection
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y