In [1]:
!pip install uv
!uv venv gpt2-clone
!source /kaggle/working/gpt2-clone/bin/activate

Collecting uv
  Downloading uv-0.7.21-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading uv-0.7.21-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.6/18.6 MB[0m [31m75.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: uv
Successfully installed uv-0.7.21
Using CPython 3.11.13 interpreter at: [36m/usr/bin/python3[39m
Creating virtual environment at: [36mgpt2-clone[39m
Activate with: [32msource gpt2-clone/bin/activate[39m


In [2]:
!uv pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!uv pip install -q huggingface tiktoken datasets transformers tqdm

In [3]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/shakesphere-book/shakesphere_book.txt


## Imports

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm

# Dataset Preparation

In [6]:
# Downloading Dataset
def load_bookcorpus():
    full_book_corpus = load_dataset("bookcorpus", trust_remote_code=True)

    if "train" in full_book_corpus:
        full_dataset = full_book_corpus["train"]
        total_dataset_size = len(full_dataset)
        print(f"Total Dataset Size: {full_dataset}")
        print("Extracting subset of 10,000,000")
        book_corpus = full_dataset.select(range(500_000))
        print(f"Final Dataset Size after Cropping: {len(book_corpus)}")
        first_example = book_corpus[0]
        key = list(first_example.keys())[0]
        for i in range(10):
            print(book_corpus[i][key])
        return book_corpus
    else:
        print("The dataset has no training split.")
        book_corpus = full_book_corpus.select(range(500_000))
        return book_corpus

In [None]:
torch.manual_seed(47)
book_corpus = load_bookcorpus()

README.md: 0.00B [00:00, ?B/s]

bookcorpus.py: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

Generating train split:   0%|          | 0/74004228 [00:00<?, ? examples/s]

# GPT-2 Architecture Implementation

In [None]:
GPT2_CONFIG_124M = {
    "vocab_size": 50257,
    "context_length": 512,
    "embedding_dim": 768,
    "num_heads": 12,
    "num_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": False,
}

### Layer Normalization Block

In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(embedding_dim))
        self.shift = nn.Parameter(torch.zeros(embedding_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / (torch.sqrt(var + self.eps))
        return self.scale * norm_x + self.shift

### Feed-Forward Block

In [None]:
# GELU implementation
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))

In [None]:
class FeedForwardNNBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["embedding_dim"], 4 * cfg["embedding_dim"]),
            GELU(),
            nn.Linear(4 * cfg["embedding_dim"], cfg["embedding_dim"]),
        )

    def forward(self, x):
        return self.layers(x)

### Multi-Head Attention Block

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out should be divisble by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # initializing weight matrices
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(
                torch.ones(context_length, context_length),
                diagonal=1
            )
        )

    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape

        # input * weight matrices
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Roll out last dim "d_out" to num_heads and head_dim
        # (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        queries = queries.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim)

        # Transpose to (b, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute Attention scores
        attn_scores = queries @ keys.transpose(2, 3)

        # mask future tokens
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill(mask_bool, -torch.inf)

        # Compute Attention Weights
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Compute Context Vector Matrix
        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(batch_size, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec

## Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.mask_attn = MultiHeadAttention(
            d_in=cfg["embedding_dim"],
            d_out=cfg["embedding_dim"],
            context_length=cfg["context_length"],
            dropout=cfg["drop_rate"],
            num_heads=cfg["num_heads"],
            qkv_bias=cfg["qkv_bias"],
        )
        self.ffn_block = FeedForwardNNBlock(cfg)
        self.norm_1 = LayerNormalization(cfg["embedding_dim"])
        self.norm_2 = LayerNormalization(cfg["embedding_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        # Block 1
        # residual connection for attention block
        shortcut = x
        x = self.norm_1(x)
        x = self.mask_attn(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        # Block 2
        shortcut = x
        x = self.norm_2(x)
        x = self.ffn_block(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        return x

In [None]:
class GPT2Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_embeddings = nn.Embedding(cfg["vocab_size"], cfg["embedding_dim"])
        self.pos_embeddings = nn.Embedding(cfg["context_length"], cfg["embedding_dim"])
        self.drop_embeddings = nn.Dropout(cfg["drop_rate"])

        self.transformer_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["num_layers"])]
        )

        self.final_norm = LayerNormalization(cfg["embedding_dim"])
        self.out_head = nn.Linear(
            cfg["embedding_dim"], cfg["vocab_size"], bias=False
        )

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeddings = self.tok_embeddings(in_idx)
        pos_embeddings = self.pos_embeddings(torch.arange(seq_len, device=in_idx.device))

        x = tok_embeddings + pos_embeddings
        x = self.drop_embeddings(x)
        x = self.transformer_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

## Model Initialization

In [None]:
torch.manual_seed(47)
model = GPT2Model(GPT2_CONFIG_124M)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(f"total number of parameters in the model: {total_params:,}")

In [None]:
gpt_2_model_params = total_params - sum(p.numel() for p in model.out_head.parameters())
print(f"Total Architecture trainable parameters without output head weights: {gpt_2_model_params:,}")

## Logits to output tokens.

In [None]:
def generate_text_v1(model, idx, max_new_tokens, context_size):
    """Get last row from logits for each bach. fetch token with max value. 
        Append to input and repeat.
        Optional: Convert token id to text and display the generated text.
    """
    for _ in range(max_new_tokens):

        # 1. truncate input if larger than context size
        idx_cond = idx[:, -context_size:]

        # 2. Get the predictions
        with torch.no_grad():
            logits = model(idx_cond)

        # 3. Retrive only the last row from each batch
        logits = logits[:, -1, :]

        # 4. Applying softmax to logits
        probas = torch.softmax(logits, dim=-1)

        # 5. Get index of the vocab entry with the highest probability
        idx_next = torch.argmax(probas, dim=-1, keepdim=True)

        # 6. Append Retrived token id to original input
        idx = torch.cat((idx, idx_next), dim=1)

    return idx

In [None]:
print(f"Token Embedding Shape: {model.tok_embeddings.weight.shape}")
print(f"Output layer shape: {model.out_head.weight.shape}")

## Creating DataLoaders

In [None]:
# Creating Dataset for training.
import tiktoken
from torch.utils.data import Dataset, DataLoader

class BookCorpusDataset(Dataset):
    def __init__(self, book_corpus, tokenizer, context_length, stride):
        self.input_ids = []
        self.target_ids = []

        # Joining iterable dict into a string
        print("Tokenizing and Chunking data...")
        all_text = [
            content['text'] for content in tqdm(book_corpus, desc="📚 Reading examples")
        ]
        full_text = " ".join(all_text)

        # Tokenize the entire text
        token_ids = tokenizer.encode(full_text, allowed_special={"<|endoftext|>"})
        total_tokens = len(token_ids)

        # sliding window approach to chunk the text as input and output tokens of context_size
        for i in range(0, len(token_ids) - context_length, stride):
            input_chunk = token_ids[i : i+context_length]
            target_chunk = token_ids[i+1 : i+context_length+1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

def create_dataloader(book_corpus, batch_size=4, context_length=512,
                      stride=512, shuffle=True, drop_last=True, num_workers=0):

    try:
        # initializing tokenizer
        tokenizer = tiktoken.get_encoding("gpt2")

        # creating dataset
        dataset = BookCorpusDataset(book_corpus, tokenizer, context_length, stride)

        # create dataloader
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=drop_last,
            pin_memory=True,
            num_workers=num_workers
        )

        return dataloader
    except Exception as e:
        print(e)
    
    

In [None]:
# Info about dataset
train_ratio = 0.75
test_ratio = 0.25
total_dataset_size = len(book_corpus)
split_dataset = book_corpus.train_test_split(test_size=test_ratio, seed=47)
# splitting training set to train and validate
train_data = split_dataset['train']
validation_data = split_dataset['test']
print("Size of Train set: ", len(train_data))
print("Size of Validation set: ", len(validation_data))

In [None]:
train_loader = create_dataloader(
    train_data,
    batch_size=4,
    context_length=GPT2_CONFIG_124M["context_length"],
    stride=GPT2_CONFIG_124M["context_length"],
    drop_last=True,
    shuffle=True,
    num_workers=0,
)

val_loader = create_dataloader(
    validation_data,
    batch_size=4,
    context_length=GPT2_CONFIG_124M["context_length"],
    stride=GPT2_CONFIG_124M["context_length"],
    drop_last=False,
    shuffle=False,
    num_workers=0,
)

## Sanity checks before Training

In [24]:
sample_tokenizer = tiktoken.get_encoding("gpt2")
total_tokens = sample_tokenizer.encode(book_corpus['text'][:])
print("Total_tokens : ", total_tokens)

TypeError: expected string or buffer

In [None]:
# Sanity Checks
if total_tokens * (train_ratio) < GPT2_CONFIG_124M["context_length"]:
    print("Not enough tokens for training loader. Try to lower GPT2_CONFIG_124M['context_length']")

if total_tokens * (1 - train_ratio) < GPT2_CONFIG_124M["context_length"]:
    print("Not enough tokens for validation set.")

## Examining Input and target matrices.

In [25]:
print("Train loader: ")
for x, y in train_loader:
    print(x.shape, y.shape)

print("Validation loader: ")
for x, y in val_loader:
    print(x.shape, y.shape)

Train loader: 
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 512]

## Loss function for model evaluation

In [None]:
def calculate_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
    return loss

def calculate_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # reduce number of batches to match number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))

    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calculate_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
            
        else:
            break
    return total_loss / num_batches

## Model Evaluation

In [None]:
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    # disable Dropout layer
    model.eval()

    # Get predictions
    with torch.no_grad():
        train_loss = calculate_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calculate_loss_loader(val_loader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss

## Generate and Print Sample

In [None]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0)
    return tokenizer.decode(flat.tolist())

In [25]:
def generate_and_print_sample(model, tokenizer, device, start_context):
    model.eval()
    context_size = model.pos_embeddings.weight.shape[0]
    encoded = text_to_token_ids(start_context, tokenizer).to(device)
    with torch.no_grad():
        token_ids = generate_text(
            model=model, idx=encoded, max_new_tokens=50, context_size=context_size
        )

    decoded_text = token_ids_to_text(token_ids, tokenizer)
    print(decoded_text.replace("\n", " "))
    model.train()

## Training Loop

In [26]:
def train_model_v1(model, train_loader, val_loader, optimizer,
          device, num_epochs, eval_freq, eval_iter,
          start_context, tokenizer):

    # initlization of lists to track losses and token seen
    train_losses = val_losses = track_tokens_seen = []
    tokens_seen, global_step = 0, -1

    # Main loop
    # For each epoch
    for epoch in range(num_epochs):
        model.train()

        # For each batch in epoch
        for input_batch, target_batch in train_loader:
            # Reset loss gradients from previous batch
            optimizer.zero_grad()

            # calculate loss
            loss = calculate_loss_batch(input_batch, target_batch, model, device)

            # compute gradient loss
            loss.backward()

            # update weights
            optimizer.step()

            # update tokens seen at step
            tokens_seen += input_batch.numel()

            # update epoch count
            global_step += 1

            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter
                )

                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                print(f"Epoch: {epoch+1} (Step {global_step:06d}): "
                      f"Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}")

        # generate output from tokens for visualization after each epoch
        generate_and_print_sample(
            model, tokenizer, device, start_context
        )
    return train_losses, val_losses, track_tokens_seen

In [27]:
def train_model_v2(model, train_loader, val_loader, optimizer,
          scheduler, device, num_epochs, eval_freq, eval_iter,
          start_context, tokenizer):

    # initlization of lists to track losses and token seen
    train_losses = val_losses = track_tokens_seen = []
    tokens_seen, global_step = 0, -1

    # Main loop
    # For each epoch
    for epoch in range(num_epochs):
        model.train()

        total_train_loss = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

        # For each batch in epoch
        for input_batch, target_batch in progress_bar:
            # Reset loss gradients from previous batch
            optimizer.zero_grad()

            # calculate loss
            loss = calculate_loss_batch(input_batch, target_batch, model, device)

            # compute gradient loss
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # update weights
            optimizer.step()
            scheduler.step()

            total_train_loss += loss.item()
            # update tokens seen at step
            tokens_seen += input_batch.numel()

            # update step count
            # global_step += 1
            progress_bar.set_postfix({"loss": f"{loss.item():.3f}"})

        # Evaluation after each epoch
        model.eval()
        with torch.no_grad():
            val_loss = calculate_loss_loader(val_loader, model, device)

        # avg training loss for each epoch
        avg_train_loss = total_train_loss / len(train_loader)

        train_losses.append(avg_train_loss)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1:02d} | Avg Train Loss: {avg_train_loss:.3f} | Val loss: {val_loss:.3f}")

        # generate output from tokens for visualization after each epoch
        generate_and_print_sample(
            model, tokenizer, device, start_context
        )
    return train_losses, val_losses, track_tokens_seen

In [None]:
if torch.cuda.is_available():
   device = torch.device("cuda")
elif torch.backends.mps.is_available():
   device = torch.device("mps")
else:
   device = torch.device("cpu")

print(f"Using {device} device.")

In [41]:
import time
start_time = time.time()

tokenizer = tiktoken.get_encoding("gpt2")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)

num_epochs = 5

total_training_steps = num_epochs * len(train_loader)
warmup_steps = int(total_training_steps * 0.1)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_training_steps,
)

train_losses, val_losses, tokens_seen = train_model_v2(
    model, train_loader, val_loader, optimizer, scheduler, 
    device, num_epochs=num_epochs, eval_freq=5, eval_iter=5,
    start_context="he 'd seen the movie almost by mistake , considering he was a little young for the pg cartoon , but with older cousins ", tokenizer=tokenizer
)

end_time = time.time()
execution_time = (end_time - start_time) / 60
print(f"Training completed in {execution_time:.2f} minutes.")

Epoch 1/5:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 01 | Avg Train Loss: 2.919 | Val loss: 0.481
Every effort moves you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you


Epoch 2/5:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 02 | Avg Train Loss: 0.271 | Val loss: 0.135
Every effort moves you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you


Epoch 3/5:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 03 | Avg Train Loss: 0.072 | Val loss: 0.092
Every effort moves you big you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you


Epoch 4/5:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 04 | Avg Train Loss: 0.032 | Val loss: 0.084
Every effort moves you moves you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you


Epoch 5/5:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 05 | Avg Train Loss: 0.021 | Val loss: 0.084
Every effort moves you moves you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you
Training completed in 17.99 minutes.


## Temperature Scaling + Selecting Top-k Logits for Output Tokens 

In [None]:
def generate_text(model, idx, max_new_tokens, context_size, temperature=0.7, top_k=5, eos_id=None):

    # 1. Get Logits.
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]

        # 2. Select Top_k elements
        if top_k is not None:
            top_logits, idx_numbers = torch.topk(logits, top_k)
            # 3. Set all other logits except top k to -inf
            min_val = top_logits[:, -1]
            logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)

        # 3. Scale by temperature value
        if temperature > 0.0:
            logits = logits / temperature

            # 4. Apply softmax
            probs = torch.softmax(logits, dim=-1)

            # 5. Sample from Multinomial distribution
            idx_next = torch.multinomial(probs, num_samples=1)

        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        if idx_next == eos_id:
            break

        idx = torch.cat((idx, idx_next), dim=1)

    return idx

In [None]:
def train_model_v3(model, train_loader, val_loader, optimizer,
          scheduler, device, num_epochs, eval_freq, eval_iter,
          start_context, tokenizer):

    # initlization of lists to track losses and token seen
    train_losses = val_losses = track_tokens_seen = []
    tokens_seen = 0

    # Main loop
    # For each epoch
    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

        # For each batch in epoch
        for input_batch, target_batch in progress_bar:
            # Reset loss gradients from previous batch
            optimizer.zero_grad()

            # calculate loss
            loss = calculate_loss_batch(input_batch, target_batch, model, device)

            # compute gradient loss
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # update weights
            optimizer.step()
            scheduler.step()

            total_train_loss += loss.item()
            # update tokens seen at step
            tokens_seen += input_batch.numel()

            progress_bar.set_postfix({"loss": f"{loss.item():.3f}", "lr": f"{scheduler.get_last_lr()[0]:.1e}"})

        # Evaluation after each epoch
        model.eval()
        with torch.no_grad():
            val_loss = calculate_loss_loader(val_loader, model, device)

        # avg training loss for each epoch
        avg_train_loss = total_train_loss / len(train_loader)

        train_losses.append(avg_train_loss)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1:02d} | Avg Train Loss: {avg_train_loss:.3f} | Val loss: {val_loss:.3f}")

        # generate output from tokens for visualization after each epoch
        start_ids = text_to_token_ids(start_context, tokenizer).to(device)
        context_size = model.pos_embeddings.weight.shape[0]

        with torch.no_grad():
            output_ids = generate_text(
                model=model,
                idx=start_ids,
                max_new_tokens=50,
                context_size=context_size,
                top_k=50,
            )

        generated_text = token_ids_to_text(output_ids, tokenizer)
        print(f"Sample: {generated_text.replace(chr(10), ' ')}")
    return train_losses, val_losses, track_tokens_seen

In [None]:
import time
start_time = time.time()

tokenizer = tiktoken.get_encoding("gpt2")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)

num_epochs = 30

total_training_steps = num_epochs * len(train_loader)
warmup_steps = int(total_training_steps * 0.1)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_training_steps,
)

train_losses, val_losses, tokens_seen = train_model_v3(
    model, train_loader, val_loader, optimizer, scheduler, 
    device, num_epochs=num_epochs, eval_freq=5, eval_iter=5,
    start_context="Every effort moves you ", tokenizer=tokenizer
)

end_time = time.time()
execution_time = (end_time - start_time) / 60
print(f"Training completed in {execution_time:.2f} minutes.")

## Saving Trained Model Weights

In [None]:
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict()
    },
    "pre-trained_llm_and_optimizer.pth")
