# Pre-training of GPT-2-124-million 
- pre-train on novel Anna Karenina
- adaptations
    - reduce the embedding size from 1024 to 256 to make sure the training loop will run locally
    - change it back to 1024 when loading / using tre-trained weights
- preprocess
- create modules
- training loop
- evaluation

In [9]:
import numpy as np
import os
import sys 
from typing import Tuple, Dict, List

cwd = os.getcwd()

In [None]:
import torch
print(torch.__version__)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn 
# import torch.... as F

## preprocess

In [None]:
# load tokenizer
import importlib
import tiktoken
print("tiktoken version:", importlib.metadata.version("tiktoken"))
tokenizer = tiktoken.get_encoding("gpt2")

In [None]:
# read in raw text
pdata = f"{cwd[:-18]}traditional-NLP/data/"
sys.path.append(pdata)
with open(f"{pdata}anna.txt" , 'r', encoding='utf-8') as f:
    text_data = f.read()
print(f"The type of the raw text: {type(text_data)}")
print(f"The beginning of raw text: \n {text_data[:50]}")

In [None]:
# inspect raw text and tokens
total_characters = len(text_data)
print(f"total num of characters in Anna Karenina: {total_characters}")
total_tokens = len(tokenizer.encode(text_data))
print(f"total num of tokens in Anna Karenina with BPE tokenizer: {total_tokens}")
# total num of characters in Anna Karenina: 1985223
# total num of tokens in Anna Karenina with BPE tokenizer: 508206

### set parameters

In [None]:
CONFIG_GPT2_124M = {
    "vocab_size": 50257,   # Vocabulary size
    "context_length": 256, # Shortened context length (orig: 1024)
    "emb_dim": 768,        # Embedding dimension
    "n_heads": 12,         # Number of attention heads
    "n_layers": 12,        # Number of layers
    "drop_rate": 0.1,      # Dropout rate
    "qkv_bias": False      # Query-key-value bias
}

torch.manual_seed(123)

### torch dataset dataloader

In [None]:
# create dataset and dataloader

class my_text_dataset(Dataset):

    # initialize with n varg in
    def __init__(self, raw_text:str, tokenizer, max_length:int, stride:int=1):
        # create class attributes
        self.input_tokens_x = []
        self.target_tokens_y = []

        # tokenize the enitre text 
        tokens = tokenizer.encode(raw_text, allowed_special={"<|endoftext|>"})

        # set y as stride number of tokens trailing x 
        for i in range(0, (len(tokens)-max_length), stride):
            x_tmp = tokens[i : (i+max_length)]
            y_tmp = tokens[(i+1) : (i+max_length+1)]
            self.input_tokens_x.append(torch.tensor(x_tmp))
            self.target_tokens_y.append(torch.tensor(y_tmp))

    # overwrite the __len__() method to return number of rows in the dataset
    def __len__(self) -> int:
        "Returns the number of rows / pairs of x-y sequences in the dataset"
        return len(self.input_tokens_x)
    
    # overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, data and label (X, y)."
        return self.input_tokens_x[idx], self.target_tokens_y[idx]

def my_text_dataloader(raw_text:str, batch_size:int=4, max_length:int=256,
                       stride:int=128, shuffle=True, drop_last=True, num_workers=0):
    # initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # create dataset
    dataset = my_text_dataset(raw_text, tokenizer, max_length, stride)

    # create dataloader
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)

    return dataloader

#### split into T, V, H

In [None]:
total_characters = len(text_data)
print(f"total num of characters in Anna Karenina: {total_characters}")
prop_t, prop_v, prop_h = (0.8,0.1,0.1)
split_idx_t, split_idx_v = int(prop_t * total_characters), int((prop_t+prop_v) * total_characters)
print(f"Split at character index {split_idx_t} between train and valid sets, and at {split_idx_v} betwee valid and hold sets")

d_train = text_data[:split_idx_t]
d_valid = text_data[split_idx_t:split_idx_v]
d_hold  = text_data[split_idx_v:]
print(len(d_train), len(d_valid), len(d_hold))

In [None]:
loader_t = my_text_dataloader(
    raw_text=d_train,
    batch_size=2,
    max_length=CONFIG_GPT2_124M["context_length"],
    stride=CONFIG_GPT2_124M["context_length"],
    drop_last=True,
    shuffle=True,
    num_workers=0
)

loader_v = my_text_dataloader(
    raw_text=d_valid,
    batch_size=2,
    max_length=CONFIG_GPT2_124M["context_length"],
    stride=CONFIG_GPT2_124M["context_length"],
    drop_last=False,
    shuffle=False,
    num_workers=0
)

# loader_h = my_text_dataloader(
#     raw_text=d_hold,
#     batch_size=2,
#     max_length=CONFIG_GPT2_124M["context_length"],
#     stride=CONFIG_GPT2_124M["context_length"],
#     drop_last=False,
#     shuffle=False,
#     num_workers=0
# )