In [1]:
import tiktoken

In [2]:
tokenizer = tiktoken.get_encoding("gpt2")
text = "Hello, how are you?"
print(tokenizer.encode(text))
print(tokenizer.decode(tokenizer.encode(text)))
device = "cuda"  # Use "cpu" if CUDA is not available

[15496, 11, 703, 389, 345, 30]
Hello, how are you?


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader

class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []
# in the constructor we will input entire text with tokenizer, max_length and stride, then text would be chunked into smaller sequences
# max_length: length of each chunk
# stride: number of tokens to move the window at each step they are derived from dataset class
        # Tokenize the entire text
        token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
        assert len(token_ids) > max_length #Number of tokenized inputs must at least be equal to max_length+1

        # Use a sliding window to chunk the text  into overlapping sequences of max_length
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

In [4]:
def create_dataloader_v1(txt, batch_size=4, max_length=256,
                         stride=128, shuffle=True, drop_last=True,
                         num_workers=0):

    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Create dataset
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

    # Create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers
    )

    return dataloader

In [5]:
vocab_size = 50257
output_dim = 256
context_length = 1024


token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

In [6]:
import torch.nn as nn
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # New
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape


        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec


In [8]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

In [9]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

In [10]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))


In [11]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)

In [12]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        return x

In [13]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])

        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(
            cfg["emb_dim"], cfg["vocab_size"], bias=False
        )

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

In [14]:
def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):

    # For-loop is the same as before: Get logits, and only focus on last time step
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]

        # New: Filter logits with top_k sampling
        if top_k is not None:
            # Keep only top_k values
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)

        # New: Apply temperature scaling
        if temperature > 0.0:
            logits = logits / temperature

            # subtract rowwise max before softmax
            logits = logits - logits.max(dim=-1, keepdim=True).values

            # Apply softmax to get probabilities
            probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)

            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)

        # Otherwise same as before: get idx of the vocab entry with the highest logits value
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)

        if idx_next == eos_id:  # Stop generating early if end-of-sequence token is encountered and eos_id is specified
            break

        # Same as before: append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)

    return idx

In [15]:
model = GPTModel(GPT_CONFIG_124M)
model.eval()

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=False)
        (W_key): Linear(in_features=768, out_features=768, bias=False)
        (W_value): Linear(in_features=768, out_features=768, bias=False)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_feature

In [16]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0) # remove batch dimension
    return tokenizer.decode(flat.tolist())

start_context = "Every effort moves you"
tokenizer = tiktoken.get_encoding("gpt2")
inference_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(inference_device)

token_ids = generate(
    model=model,
    idx=text_to_token_ids("Every effort moves you", tokenizer).to(inference_device),
    max_new_tokens=15,
    context_size=GPT_CONFIG_124M["context_length"],
    top_k=25,
    temperature=1.4
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

Output text:
 Every effort moves youaffiliated harbour Downtro423 isolationape156 occur criticizingdomain Aircraft glossyikingoler


In [17]:
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
    return loss


def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

In [18]:
if torch.cuda.is_available():
    device = torch.device("cuda")

else:
    device = torch.device("cpu")
model.to(device)

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=False)
        (W_key): Linear(in_features=768, out_features=768, bias=False)
        (W_value): Linear(in_features=768, out_features=768, bias=False)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_feature

In [19]:
import os
import urllib.request
BASE_CONFIG = {
    "vocab_size": 50257,
    "context_length": 1024,
    "emb_dim": 1024,
    "n_heads": 16,
    "n_layers": 24,
    "drop_rate": 0.1,
    "qkv_bias": True
}
file_name = "gpt2-medium-355M.pth"
url = f"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}"

if not os.path.exists(file_name):
    urllib.request.urlretrieve(url, file_name)
    print(f"Downloaded to {file_name}")

gpt = GPTModel(BASE_CONFIG)
gpt.load_state_dict(torch.load(file_name, weights_only=True))
gpt.eval()

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
gpt.to(device)


torch.manual_seed(123)

token_ids = generate(
    model=gpt,
    idx=text_to_token_ids("Every effort moves you", tokenizer).to(device),
    max_new_tokens=25,
    context_size=BASE_CONFIG["context_length"],
    top_k=50,
    temperature=1.5
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))


Downloaded to gpt2-medium-355M.pth
Output text:
 Every effort moves you as far as the natural capacity is capable," the lawyer wrote, "which permits extraordinary actions." "That includes (dressing


In [20]:


import os

import requests
import json
import numpy as np
import tensorflow as tf
from tqdm import tqdm


def download_and_load_gpt2(model_size, models_dir):
    # Validate model size
    allowed_sizes = ("124M", "355M", "774M", "1558M")
    if model_size not in allowed_sizes:
        raise ValueError(f"Model size not in {allowed_sizes}")

    # Define paths
    model_dir = os.path.join(models_dir, model_size)
    base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"

    filenames = [
        "checkpoint", "encoder.json", "hparams.json",
        "model.ckpt.data-00000-of-00001", "model.ckpt.index",
        "model.ckpt.meta", "vocab.bpe"
    ]

    # Download files
    os.makedirs(model_dir, exist_ok=True)
    for filename in filenames:
        file_url = os.path.join(base_url, model_size, filename)

        file_path = os.path.join(model_dir, filename)
        download_file(file_url, file_path, )

    # Load settings and params
    tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
    settings = json.load(open(os.path.join(model_dir, "hparams.json"), "r", encoding="utf-8"))
    params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)

    return settings, params


def download_file(url, destination, backup_url=None):
    def _attempt_download(download_url):
        response = requests.get(download_url, stream=True, timeout=60)
        response.raise_for_status()

        file_size = int(response.headers.get("Content-Length", 0))

        # Check if file exists and has same size
        if os.path.exists(destination):
            file_size_local = os.path.getsize(destination)
            if file_size and file_size == file_size_local:
                print(f"File already exists and is up-to-date: {destination}")
                return True

        block_size = 1024  # 1 KB
        desc = os.path.basename(download_url)
        with tqdm(total=file_size, unit="iB", unit_scale=True, desc=desc) as progress_bar:
            with open(destination, "wb") as file:
                for chunk in response.iter_content(chunk_size=block_size):
                    if chunk:
                        file.write(chunk)
                        progress_bar.update(len(chunk))
        return True

    try:
        if _attempt_download(url):
            return
    except requests.exceptions.RequestException:

        error_message = (
            f"Failed to download from URL ({url})"

        )
        print(error_message)





def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
    # Initialize parameters dictionary with empty blocks for each layer
    params = {"blocks": [{} for _ in range(settings["n_layer"])]}

    # Iterate over each variable in the checkpoint
    for name, _ in tf.train.list_variables(ckpt_path):
        # Load the variable and remove singleton dimensions
        variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))

        # Process the variable name to extract relevant parts
        variable_name_parts = name.split("/")[1:]  # Skip the 'model/' prefix

        # Identify the target dictionary for the variable
        target_dict = params
        if variable_name_parts[0].startswith("h"):
            layer_number = int(variable_name_parts[0][1:])
            target_dict = params["blocks"][layer_number]

        # Recursively access or create nested dictionaries
        for key in variable_name_parts[1:-1]:
            target_dict = target_dict.setdefault(key, {})

        # Assign the variable array to the last key
        last_key = variable_name_parts[-1]
        target_dict[last_key] = variable_array

    return params

In [21]:

settings, params = download_and_load_gpt2(model_size="355M", models_dir="gpt2")

checkpoint: 100%|██████████| 77.0/77.0 [00:00<00:00, 154kiB/s]
encoder.json: 100%|██████████| 1.04M/1.04M [00:00<00:00, 2.72MiB/s]
hparams.json: 100%|██████████| 91.0/91.0 [00:00<00:00, 198kiB/s]
model.ckpt.data-00000-of-00001: 100%|██████████| 1.42G/1.42G [02:42<00:00, 8.72MiB/s]
model.ckpt.index: 100%|██████████| 10.4k/10.4k [00:00<00:00, 15.1MiB/s]
model.ckpt.meta: 100%|██████████| 927k/927k [00:00<00:00, 2.81MiB/s]
vocab.bpe: 100%|██████████| 456k/456k [00:00<00:00, 1.56MiB/s]


In [22]:
model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}

# Copy the base configuration and update with specific model settings
BASE_CONFIG = {
    "vocab_size": 50257,
    "context_length": 1024,
    "emb_dim": 1024,
    "n_heads": 16,
    "n_layers": 24,
    "drop_rate": 0.1,
    "qkv_bias": True
}
model_name = "gpt2-medium (355M)"  # Example model name
NEW_CONFIG = BASE_CONFIG.copy()
NEW_CONFIG.update(model_configs[model_name])
NEW_CONFIG.update({"context_length": 1024, "qkv_bias": True})

gpt = GPTModel(NEW_CONFIG)
gpt.eval()

GPTModel(
  (tok_emb): Embedding(50257, 1024)
  (pos_emb): Embedding(1024, 1024)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=1024, out_features=1024, bias=True)
        (W_key): Linear(in_features=1024, out_features=1024, bias=True)
        (W_value): Linear(in_features=1024, out_features=1024, bias=True)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=1024, out_features=4096, bias=True)
          (1): GELU()
          (2): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(i

In [23]:
def assign(left, right):
    if left.shape != right.shape:
        raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
    return torch.nn.Parameter(torch.tensor(right))

In [24]:
import numpy as np

def load_weights_into_gpt(gpt, params):
    gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
    gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])

    for b in range(len(params["blocks"])):
        q_w, k_w, v_w = np.split(
            (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
        gpt.trf_blocks[b].att.W_query.weight = assign(
            gpt.trf_blocks[b].att.W_query.weight, q_w.T)
        gpt.trf_blocks[b].att.W_key.weight = assign(
            gpt.trf_blocks[b].att.W_key.weight, k_w.T)
        gpt.trf_blocks[b].att.W_value.weight = assign(
            gpt.trf_blocks[b].att.W_value.weight, v_w.T)

        q_b, k_b, v_b = np.split(
            (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
        gpt.trf_blocks[b].att.W_query.bias = assign(
            gpt.trf_blocks[b].att.W_query.bias, q_b)
        gpt.trf_blocks[b].att.W_key.bias = assign(
            gpt.trf_blocks[b].att.W_key.bias, k_b)
        gpt.trf_blocks[b].att.W_value.bias = assign(
            gpt.trf_blocks[b].att.W_value.bias, v_b)

        gpt.trf_blocks[b].att.out_proj.weight = assign(
            gpt.trf_blocks[b].att.out_proj.weight,
            params["blocks"][b]["attn"]["c_proj"]["w"].T)
        gpt.trf_blocks[b].att.out_proj.bias = assign(
            gpt.trf_blocks[b].att.out_proj.bias,
            params["blocks"][b]["attn"]["c_proj"]["b"])

        gpt.trf_blocks[b].ff.layers[0].weight = assign(
            gpt.trf_blocks[b].ff.layers[0].weight,
            params["blocks"][b]["mlp"]["c_fc"]["w"].T)
        gpt.trf_blocks[b].ff.layers[0].bias = assign(
            gpt.trf_blocks[b].ff.layers[0].bias,
            params["blocks"][b]["mlp"]["c_fc"]["b"])
        gpt.trf_blocks[b].ff.layers[2].weight = assign(
            gpt.trf_blocks[b].ff.layers[2].weight,
            params["blocks"][b]["mlp"]["c_proj"]["w"].T)
        gpt.trf_blocks[b].ff.layers[2].bias = assign(
            gpt.trf_blocks[b].ff.layers[2].bias,
            params["blocks"][b]["mlp"]["c_proj"]["b"])

        gpt.trf_blocks[b].norm1.scale = assign(
            gpt.trf_blocks[b].norm1.scale,
            params["blocks"][b]["ln_1"]["g"])
        gpt.trf_blocks[b].norm1.shift = assign(
            gpt.trf_blocks[b].norm1.shift,
            params["blocks"][b]["ln_1"]["b"])
        gpt.trf_blocks[b].norm2.scale = assign(
            gpt.trf_blocks[b].norm2.scale,
            params["blocks"][b]["ln_2"]["g"])
        gpt.trf_blocks[b].norm2.shift = assign(
            gpt.trf_blocks[b].norm2.shift,
            params["blocks"][b]["ln_2"]["b"])

    gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
    gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
    gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])


load_weights_into_gpt(gpt, params)
gpt.to(device);

In [25]:
torch.manual_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
token_ids = generate(
    model=gpt,
    idx=text_to_token_ids("honesty is the ", tokenizer).to(device),
    max_new_tokens=25,
    context_size=NEW_CONFIG["context_length"],
    top_k=50,
    temperature=1.5
)


print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

Output text:
 honesty is the  right word to call me). I was told that although we had written a complaint on it after receiving this message, I


In [26]:
import requests
import zipfile
import os
from pathlib import Path

url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extracted_path = "sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"


def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
    if data_file_path.exists():
        print(f"{data_file_path} already exists. Skipping download and extraction.")
        return

    # Downloading the file
    response = requests.get(url, stream=True, timeout=60)
    response.raise_for_status()
    with open(zip_path, "wb") as out_file:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                out_file.write(chunk)

    # Unzipping the file
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)

    # Add .tsv file extension
    original_file_path = Path(extracted_path) / "SMSSpamCollection"
    os.rename(original_file_path, data_file_path)
    print(f"File downloaded and saved as {data_file_path}")


try:
    download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
except (requests.exceptions.RequestException, TimeoutError) as e:
    print(f"Primary URL failed: {e}. Trying backup URL...")
    url ="https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
    download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)

File downloaded and saved as sms_spam_collection/SMSSpamCollection.tsv


In [27]:

import pandas as pd

df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
df.replace({"ham":"not spam"} ,inplace=True)
df.value_counts("Label")


Unnamed: 0_level_0,count
Label,Unnamed: 1_level_1
not spam,4825
spam,747


In [28]:
def create_balanced_dataset(df):

    # Count the instances of "spam"
    num_spam = df[df["Label"] == "spam"].shape[0]

    # Randomly sample "ham" instances to match the number of "spam" instances
    ham_subset = df[df["Label"] == "not spam"].sample(num_spam, random_state=123)

    # Combine ham "subset" with "spam"
    balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])

    return balanced_df

In [29]:


balanced_df = create_balanced_dataset(df)
print(balanced_df["Label"].value_counts())

Label
not spam    747
spam        747
Name: count, dtype: int64


In [30]:
def random_split(df, train_frac, validation_frac):
    # Shuffle the entire DataFrame
    df = df.sample(frac=1, random_state=123).reset_index(drop=True)

    # Calculate split indices
    train_end = int(len(df) * train_frac)
    validation_end = train_end + int(len(df) * validation_frac)

    # Split the DataFrame
    train_df = df[:train_end]
    validation_df = df[train_end:validation_end]
    test_df = df[validation_end:]

    return train_df, validation_df, test_df

train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)


train_df.to_csv("train.csv", index=None)
validation_df.to_csv("validation.csv", index=None)
test_df.to_csv("test.csv", index=None)

In [31]:
import torch
from torch.utils.data import Dataset


class SpamDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
        self.data = pd.read_csv(csv_file)

        # Pre-tokenize texts
        self.encoded_texts = [
            tokenizer.encode(text) for text in self.data["Text"]
        ]

        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length
            # Truncate sequences if they are longer than max_length
            self.encoded_texts = [
                encoded_text[:self.max_length]
                for encoded_text in self.encoded_texts
            ]

        # Pad sequences to the longest sequence
        self.encoded_texts = [
            encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
            for encoded_text in self.encoded_texts
        ]

    def __getitem__(self, index):
        encoded = self.encoded_texts[index]
        label = self.data.iloc[index]["Label"]
        return (
            torch.tensor(encoded, dtype=torch.long),
            torch.tensor(label, dtype=torch.long)
        )

    def __len__(self):
        return len(self.data)

    def _longest_encoded_length(self):
        max_length = 0
        for encoded_text in self.encoded_texts:
            encoded_length = len(encoded_text)
            if encoded_length > max_length:
                max_length = encoded_length
        return max_length

In [32]:
train_dataset = SpamDataset(
    csv_file="train.csv",
    max_length=None,
    tokenizer=tokenizer
)

print(train_dataset.max_length)

120


In [33]:
val_dataset = SpamDataset(
    csv_file="validation.csv",
    max_length=train_dataset.max_length,
    tokenizer=tokenizer
)
test_dataset = SpamDataset(
    csv_file="test.csv",
    max_length=train_dataset.max_length,
    tokenizer=tokenizer
)

In [34]:
from torch.utils.data import DataLoader

num_workers = 0
batch_size = 8

torch.manual_seed(123)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True,
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False,
)

In [35]:
print(f"{len(train_loader)} training batches")
print(f"{len(val_loader)} validation batches")
print(f"{len(test_loader)} test batches")

130 training batches
19 validation batches
38 test batches


In [36]:
# 1. Redefine the function to look for "ham" (the correct label in the raw data)
def create_balanced_dataset(df):
    # Count the instances of "spam"
    num_spam = df[df["Label"] == "spam"].shape[0]

    # Randomly sample "ham" instances to match the number of "spam" instances
    # (The raw data uses "ham", not "not spam")
    ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)

    # Combine ham "subset" with "spam"
    balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])

    return balanced_df

# 2. Reload the RAW data from source
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])

# 3. Create the balanced dataset using the fixed function
balanced_df = create_balanced_dataset(df)

# 4. Map string labels to integers (ham -> 0, spam -> 1)
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})

# 5. Save the corrected data files
train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)
train_df.to_csv("train.csv", index=None)
validation_df.to_csv("validation.csv", index=None)
test_df.to_csv("test.csv", index=None)

# 6. Reload Datasets and Loaders
train_dataset = SpamDataset(csv_file="train.csv", max_length=None, tokenizer=tokenizer)
val_dataset = SpamDataset(csv_file="validation.csv", max_length=train_dataset.max_length, tokenizer=tokenizer)
test_dataset = SpamDataset(csv_file="test.csv", max_length=train_dataset.max_length, tokenizer=tokenizer)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False)

print("Function fixed and data reset complete.")

Function fixed and data reset complete.


In [37]:
CHOOSE_MODEL = "gpt2-medium (355M)"
INPUT_PROMPT = "Every effort moves"

BASE_CONFIG = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "drop_rate": 0.0,        # Dropout rate
    "qkv_bias": True         # Query-key-value bias
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}

BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

assert train_dataset.max_length <= BASE_CONFIG["context_length"], (
    f"Dataset length {train_dataset.max_length} exceeds model's context "
    f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with "
    f"`max_length={BASE_CONFIG['context_length']}`"
)

In [38]:
# ============================================================================
# COMPREHENSIVE FIX FOR ALL ERRORS
# ============================================================================

import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

# Step 1: Delete old classifier and create new one with CORRECT dimensions
try:
    del classifier
except:
    pass

# CRITICAL: For GPTModel, classifier should be:
# Input = vocab_size (last token has vocab_size dimensions)
# Output = num_classes (2 for binary classification)
num_classes = 2
vocab_size = BASE_CONFIG["vocab_size"]  # 50257

classifier = torch.nn.Linear(vocab_size, num_classes).to(device)
print(f"✓ Classifier created: Linear({vocab_size}, {num_classes})")

# Step 2: Recreate optimizer
optimizer = torch.optim.AdamW(
    list(model.parameters()) + list(classifier.parameters()),
    lr=3e-5,
    weight_decay=0.01
)
print("✓ Optimizer created")

# Step 3: Define corrected training function
def train_epoch(model, classifier, loader, optimizer, device):
    model.train()
    classifier.train()

    totals = {"total": 0.0, "cls": 0.0}

    for input_ids, labels in loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # GPTModel.forward(input_ids) - NO labels parameter!
        logits = model(input_ids)  # [batch, seq_len, vocab_size]
        last_token_logits = logits[:, -1, :]  # [batch, vocab_size]

        # Classify
        cls_logits = classifier(last_token_logits)  # [batch, num_classes]
        cls_loss = F.cross_entropy(cls_logits, labels)

        cls_loss.backward()
        optimizer.step()

        totals["total"] += cls_loss.item()
        totals["cls"] += cls_loss.item()

    return {k: v / len(loader) for k, v in totals.items()}

print("✓ train_epoch defined")

# Step 4: Define corrected evaluation function
@torch.no_grad()
def evaluate(model, classifier, loader, device):
    model.eval()
    classifier.eval()

    all_preds = []
    all_labels = []
    cls_losses = []

    for input_ids, labels in loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        logits = model(input_ids)
        last_token_logits = logits[:, -1, :]
        cls_logits = classifier(last_token_logits)

        cls_loss = F.cross_entropy(cls_logits, labels)
        preds = torch.argmax(cls_logits, dim=-1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        cls_losses.append(cls_loss.item())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    acc = (all_preds == all_labels).mean()

    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average="binary", zero_division=0
    )
    cm = confusion_matrix(all_labels, all_preds)

    return {
        "cls_loss": np.mean(cls_losses),
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": cm
    }

print("✓ evaluate defined")
print("\n" + "="*70)
print("ALL FIXES APPLIED - Ready to train!")
print("="*70)


✓ Classifier created: Linear(50257, 2)
✓ Optimizer created
✓ train_epoch defined
✓ evaluate defined

ALL FIXES APPLIED - Ready to train!


In [39]:
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")

model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
model.eval();

File already exists and is up-to-date: gpt2/355M/checkpoint
File already exists and is up-to-date: gpt2/355M/encoder.json
File already exists and is up-to-date: gpt2/355M/hparams.json
File already exists and is up-to-date: gpt2/355M/model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2/355M/model.ckpt.index
File already exists and is up-to-date: gpt2/355M/model.ckpt.meta
File already exists and is up-to-date: gpt2/355M/vocab.bpe


In [40]:

text_1 = "Every effort moves you"

token_ids = generate(
    model=model,
    idx=text_to_token_ids(text_1, tokenizer),
    max_new_tokens=15,
    context_size=BASE_CONFIG["context_length"]
)

print(token_ids_to_text(token_ids, tokenizer))

Every effort moves you forward, but you must be careful. You must not let your guard down


In [41]:
for param in model.parameters():
    param.requires_grad = False

In [42]:
num_classes = 2
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)

In [43]:
for param in model.trf_blocks[-1].parameters():
    param.requires_grad = True

for param in model.final_norm.parameters():
    param.requires_grad = True

In [44]:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
    model.eval()
    correct_predictions, num_examples = 0, 0

    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            input_batch, target_batch = input_batch.to(device), target_batch.to(device)

            with torch.no_grad():
                logits = model(input_batch)[:, -1, :]  # Logits of last output token
            predicted_labels = torch.argmax(logits, dim=-1)

            num_examples += predicted_labels.shape[0]
            correct_predictions += (predicted_labels == target_batch).sum().item()
        else:
            break
    return correct_predictions / num_examples

In [45]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Device:", device)

model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes

torch.manual_seed(123)

train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)
test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)

print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

Device: cuda
Training accuracy: 53.75%
Validation accuracy: 55.00%
Test accuracy: 51.25%


In [46]:
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)[:, -1, :]  # Logits of last output token
    loss = torch.nn.functional.cross_entropy(logits, target_batch)
    return loss

In [47]:

def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:

        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

In [48]:
with torch.no_grad(): # Disable gradient tracking for efficiency
    train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)
    val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)
    test_loss = calc_loss_loader(test_loader, model, device, num_batches=5)

print(f"Training loss: {train_loss:.3f}")
print(f"Validation loss: {val_loss:.3f}")
print(f"Test loss: {test_loss:.3f}")

Training loss: 2.757
Validation loss: 2.604
Test loss: 2.883


In [49]:
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
                            eval_freq, eval_iter):
    # Initialize lists to track losses and examples seen
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    examples_seen, global_step = 0, -1

    # Main training loop
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode

        for input_batch, target_batch in train_loader:
            optimizer.zero_grad() # Reset loss gradients from previous batch iteration
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward() # Calculate loss gradients
            optimizer.step() # Update model weights using loss gradients
            examples_seen += input_batch.shape[0] # New: track examples instead of tokens
            global_step += 1

            # Optional evaluation step
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")

        # Calculate accuracy after each epoch
        train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
        val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
        print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
        print(f"Validation accuracy: {val_accuracy*100:.2f}%")
        train_accs.append(train_accuracy)
        val_accs.append(val_accuracy)

    return train_losses, val_losses, train_accs, val_accs, examples_seen

In [50]:
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss

In [51]:
import time

start_time = time.time()

torch.manual_seed(123)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)

num_epochs = 5
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
    model, train_loader, val_loader, optimizer, device,
    num_epochs=num_epochs, eval_freq=50, eval_iter=5,
)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")

Ep 1 (Step 000000): Train loss 2.694, Val loss 2.455
Ep 1 (Step 000050): Train loss 0.519, Val loss 0.557
Ep 1 (Step 000100): Train loss 0.366, Val loss 0.488
Training accuracy: 80.00% | Validation accuracy: 77.50%
Ep 2 (Step 000150): Train loss 0.545, Val loss 0.429
Ep 2 (Step 000200): Train loss 0.432, Val loss 0.415
Ep 2 (Step 000250): Train loss 0.418, Val loss 0.416
Training accuracy: 80.00% | Validation accuracy: 82.50%
Ep 3 (Step 000300): Train loss 0.290, Val loss 0.389
Ep 3 (Step 000350): Train loss 0.309, Val loss 0.290
Training accuracy: 95.00% | Validation accuracy: 87.50%
Ep 4 (Step 000400): Train loss 0.154, Val loss 0.272
Ep 4 (Step 000450): Train loss 0.123, Val loss 0.269
Ep 4 (Step 000500): Train loss 0.173, Val loss 0.264
Training accuracy: 95.00% | Validation accuracy: 92.50%
Ep 5 (Step 000550): Train loss 0.104, Val loss 0.173
Ep 5 (Step 000600): Train loss 0.184, Val loss 0.153
Training accuracy: 100.00% | Validation accuracy: 95.00%
Training completed in 2.89 min

In [52]:
train_accuracy = calc_accuracy_loader(train_loader, model, device)
val_accuracy = calc_accuracy_loader(val_loader, model, device)
test_accuracy = calc_accuracy_loader(test_loader, model, device)

print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

Training accuracy: 96.63%
Validation accuracy: 95.97%
Test accuracy: 96.00%


In [53]:
def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):
    model.eval()

    # Prepare inputs to the model
    input_ids = tokenizer.encode(text)
    supported_context_length = model.pos_emb.weight.shape[0]


    # Truncate sequences if they too long
    input_ids = input_ids[:min(max_length, supported_context_length)]
    assert max_length is not None, (
        "max_length must be specified. If you want to use the full model context, "
        "pass max_length=model.pos_emb.weight.shape[0]."
    )
    assert max_length <= supported_context_length, (
        f"max_length ({max_length}) exceeds model's supported context length ({supported_context_length})."
    )

    # Pad sequences to the longest sequence
    input_ids += [pad_token_id] * (max_length - len(input_ids))
    input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension

    # Model inference
    with torch.no_grad():
        logits = model(input_tensor)[:, -1, :]  # Logits of the last output token
    predicted_label = torch.argmax(logits, dim=-1).item()

    # Return the classified result
    return "spam" if predicted_label == 1 else "not spam"

In [54]:
text_1 = ("Hi lets catch up sometime tomorrow"
)
print(classify_review(
    text_1, model, tokenizer, device, max_length=train_dataset.max_length
))

not spam


In [55]:
text_2 = (
"URGENT: Your bank account has been locked due to suspicious activity. "
    "Click the link immediately to verify your identity: http://fake-bank-link.com"
)

print(classify_review(
    text_2, model, tokenizer, device, max_length=train_dataset.max_length
))

spam


In [56]:
balanced_df[balanced_df["Label"] == 1]["Text"].iloc[0]

"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's"

In [57]:
token_ids = generate(
    model=gpt,
    idx=text_to_token_ids("i am writing ", tokenizer).to(device),
    max_new_tokens=45,
    context_size=BASE_CONFIG["context_length"],
    top_k=50,
    temperature=1.5
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

Output text:
 i am writing  and the writing section is very long (and many entries) but very easy to follow. As you may know I did not write a whole new article about The Hobbit's setting. After visiting the Hobbitland, I


In [58]:

# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from copy import deepcopy
# from typing import Tuple, List
# import time



# class FeatureExtractor:
#     """Extract text representations from the final_norm layer using hooks."""

#     def __init__(self, model, layer_name="final_norm"):
#         self.features = None
#         self.hook = None
#         self.model = model
#         self.layer_name = layer_name
#         self._register_hook()

#     def _register_hook(self):
#         """Register a forward hook on final_norm."""
#         layer = getattr(self.model, self.layer_name)

#         def hook_fn(module, input, output):
#             self.features = output.detach()

#         self.hook = layer.register_forward_hook(hook_fn)

#     def get_features(self, input_batch):
#         """Extract features for the last token (used for classification)."""
#         _ = self.model(input_batch)
#         return self.features[:, -1, :]  # (batch_size, emb_dim)

#     def remove_hook(self):
#         """Remove the forward hook."""
#         if self.hook is not None:
#             self.hook.remove()



# class SelfDistillationLearner(nn.Module):
#     """
#     Student-Teacher architecture with EMA updates.

#     Loss = CrossEntropy(student, label) + λ * MSE(student_features, teacher_features)
#     """

#     def __init__(self, student_model, ema_decay=0.999, lambda_distill=0.5, device="cuda"):
#         super().__init__()

#         self.student = student_model
#         self.ema_decay = ema_decay
#         self.lambda_distill = lambda_distill
#         self.device = device

#         # Create frozen teacher copy
#         self.teacher = deepcopy(self.student)
#         self.teacher.eval()
#         for param in self.teacher.parameters():
#             param.requires_grad = False
#         self.teacher = self.teacher.to(device)

#         # Feature extractors
#         self.student_extractor = FeatureExtractor(self.student, "final_norm")
#         self.teacher_extractor = FeatureExtractor(self.teacher, "final_norm")

#         # Loss functions
#         self.ce_loss = nn.CrossEntropyLoss()
#         self.mse_loss = nn.MSELoss()

#     @torch.no_grad()
#     def update_teacher(self):
#         """Update teacher using EMA: teacher """
#         for teacher_param, student_param in zip(
#             self.teacher.parameters(), self.student.parameters()
#         ):
#             teacher_param.data.mul_(self.ema_decay).add_(
#                 student_param.data, alpha=1 - self.ema_decay
#             )

#     def forward(self, input_batch, target_batch):
#         """Compute combined loss: CE + distillation."""
#         # Student forward
#         student_logits = self.student(input_batch)[:, -1, :]
#         student_features = self.student_extractor.get_features(input_batch)

#         # Teacher forward (frozen)
#         with torch.no_grad():
#             teacher_features = self.teacher_extractor.get_features(input_batch)

#         # Combined loss
#         ce_loss = self.ce_loss(student_logits, target_batch)
#         distill_loss = self.mse_loss(student_features, teacher_features)
#         total_loss = ce_loss + self.lambda_distill * distill_loss

#         loss_dict = {
#             'total': total_loss.item(),
#             'ce': ce_loss.item(),
#             'distill': distill_loss.item()
#         }

#         return total_loss, loss_dict

#     def cleanup(self):
#         """Remove hooks."""
#         self.student_extractor.remove_hook()
#         self.teacher_extractor.remove_hook()

# def calc_accuracy_distill(data_loader, learner, device, num_batches=None):
#     """Calculate accuracy using the student model."""
#     learner.student.eval()
#     correct_predictions, num_examples = 0, 0

#     if num_batches is None:
#         num_batches = len(data_loader)
#     else:
#         num_batches = min(num_batches, len(data_loader))

#     for i, (input_batch, target_batch) in enumerate(data_loader):
#         if i >= num_batches:
#             break

#         input_batch = input_batch.to(device)
#         target_batch = target_batch.to(device)

#         with torch.no_grad():
#             logits = learner.student(input_batch)[:, -1, :]
#             predicted_labels = torch.argmax(logits, dim=-1)
#             num_examples += predicted_labels.shape[0]
#             correct_predictions += (predicted_labels == target_batch).sum().item()

#     return correct_predictions / num_examples


# def train_distillation(learner, train_loader, val_loader, optimizer, device,
#                        num_epochs=5, eval_freq=50, eval_iter=5, verbose=True):
#     """Train using self-distillation."""

#     train_losses, val_losses = [], []
#     train_accs, val_accs = [], []
#     global_step = 0

#     if verbose:

#         print(f" EMA Decay: {learner.ema_decay}")
#         print(f" Lambda (Distillation Weight): {learner.lambda_distill}")
#         print(f"Training for {num_epochs} epochs")

#     start_time = time.time()

#     for epoch in range(num_epochs):
#         learner.student.train()
#         epoch_loss = 0.0
#         epoch_ce_loss = 0.0
#         epoch_distill_loss = 0.0
#         num_batches_in_epoch = 0

#         for input_batch, target_batch in train_loader:
#             input_batch = input_batch.to(device)
#             target_batch = target_batch.to(device)

#             optimizer.zero_grad()
#             total_loss, loss_dict = learner(input_batch, target_batch)
#             total_loss.backward()
#             optimizer.step()

#             # Update teacher with EMA
#             learner.update_teacher()

#             epoch_loss += loss_dict['total']
#             epoch_ce_loss += loss_dict['ce']
#             epoch_distill_loss += loss_dict['distill']
#             num_batches_in_epoch += 1

#             # Periodic evaluation
#             if global_step % eval_freq == 0:
#                 learner.student.eval()

#                 val_loss = 0.0
#                 val_batches = 0
#                 with torch.no_grad():
#                     for i, (val_input, val_target) in enumerate(val_loader):
#                         if i >= eval_iter:
#                             break
#                         val_input = val_input.to(device)
#                         val_target = val_target.to(device)
#                         loss, _ = learner(val_input, val_target)
#                         val_loss += loss.item()
#                         val_batches += 1

#                 val_loss /= val_batches
#                 train_loss = epoch_loss / max(num_batches_in_epoch, 1)

#                 train_losses.append(train_loss)
#                 val_losses.append(val_loss)

#                 if verbose:
#                     print(f"Ep {epoch+1} (Step {global_step:06d}): "
#                           f"Train {train_loss:.3f} "
#                           f"[CE: {epoch_ce_loss/max(num_batches_in_epoch, 1):.3f}, "
#                           f"Distill: {epoch_distill_loss/max(num_batches_in_epoch, 1):.3f}] | "
#                           f"Val {val_loss:.3f}")

#                 learner.student.train()

#             global_step += 1

#         # End-of-epoch evaluation
#         train_accuracy = calc_accuracy_distill(train_loader, learner, device, num_batches=eval_iter)
#         val_accuracy = calc_accuracy_distill(val_loader, learner, device, num_batches=eval_iter)

#         train_accs.append(train_accuracy)
#         val_accs.append(val_accuracy)

#         if verbose:
#             print(f" Epoch {epoch+1}: Train acc {train_accuracy*100:.2f}% | Val acc {val_accuracy*100:.2f}%")


#     end_time = time.time()

#     if verbose:

#         print(f" Training completed in {(end_time - start_time) / 60:.2f} minutes")


#     return train_losses, val_losses, train_accs, val_accs




In [59]:


# learner = SelfDistillationLearner(
#     student_model=model,
#     ema_decay=0.999,
#     lambda_distill=0.5,       # Equal weight to CE and distillation
#     device=device
# )

# print(f" Student model: {sum(p.numel() for p in learner.student.parameters())} parameters")
# print(f" Teacher model: {sum(p.numel() for p in learner.teacher.parameters())} parameters")
# print(f" Feature extractors registered on 'final_norm' layer")




# optimizer = torch.optim.AdamW(
#     learner.student.parameters(),
#     lr=5e-5,              # Same learning rate as original training
#     weight_decay=0.1
# )

# print(f" Optimizer created for {sum(p.numel() for p in learner.student.parameters() if p.requires_grad)} trainable parameters")





# print("Starting Self-Distillation Training")



# torch.manual_seed(123)

# train_losses, val_losses, train_accs, val_accs = train_distillation(
#     learner=learner,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     optimizer=optimizer,
#     device=device,
#     num_epochs=5,
#     eval_freq=50,
#     eval_iter=5,
#     verbose=True
# )


# # Evaluate on all datasets
# train_accuracy = calc_accuracy_distill(train_loader, learner, device)
# val_accuracy = calc_accuracy_distill(val_loader, learner, device)
# test_accuracy = calc_accuracy_distill(test_loader, learner, device)

# print(f"Final Results:")
# print(f"   Training accuracy:   {train_accuracy*100:.2f}%")
# print(f"   Validation accuracy: {val_accuracy*100:.2f}%")
# print(f"   Test accuracy:       {test_accuracy*100:.2f}%")


# def classify_review_distill(text, learner, tokenizer, device, max_length, pad_token_id=50256):
#     """Modified classify_review to work with the learner."""
#     learner.student.eval()

#     input_ids = tokenizer.encode(text)
#     input_ids = input_ids[:min(max_length, learner.student.pos_emb.weight.shape[0])]
#     input_ids += [pad_token_id] * (max_length - len(input_ids))
#     input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)

#     with torch.no_grad():
#         logits = learner.student(input_tensor)[:, -1, :]
#     predicted_label = torch.argmax(logits, dim=-1).item()

#     return "spam" if predicted_label == 1 else "not spam"


# text_1 = "Hi lets catch up sometime tomorrow"
# text_2 = ("URGENT: Your bank account has been locked due to suspicious activity. "
#           "Click the link immediately to verify your identity: http://fake-bank-link.com")

# print(f"\nTest 1 (should be 'not spam'): '{text_1}'")
# result_1 = classify_review_distill(text_1, learner, tokenizer, device, max_length=120)
# print(f"   Prediction: {result_1}")

# print(f"\nTest 2 (should be 'spam'): '{text_2[:50]}...'")
# result_2 = classify_review_distill(text_2, learner, tokenizer, device, max_length=120)
# print(f"   Prediction: {result_2}")


# print("Training Progress Summary")

# print(f" Loss progression:")
# print(f"   Initial train loss: {train_losses[0]:.3f}")
# print(f"   Final train loss:   {train_losses[-1]:.3f}")
# print(f"   Initial val loss:   {val_losses[0]:.3f}")
# print(f"   Final val loss:     {val_losses[-1]:.3f}")

# print(f" Accuracy progression:")
# print(f"   Initial train acc:  {train_accs[0]*100:.2f}%")
# print(f"   Final train acc:    {train_accs[-1]*100:.2f}%")
# print(f"   Initial val acc:    {val_accs[0]*100:.2f}%")
# print(f"   Final val acc:      {val_accs[-1]*100:.2f}%")



In [60]:
# text_2 = (
# "you won a prize in our competition !!! click the link to claim it now: http://fake-prize-link.com"
# )

# # print(classify_review(
#     text_2, model, tokenizer, device, max_length=train_dataset.max_length
# ))

In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F



class ModifiedTransformerBlock(nn.Module):
    """
    Standard Block, but returns hidden states cleanly if needed.
    No internal changes to logic, just ensures compatibility.
    """
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut
        return x

class HierarchicalGPTModel(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        self.trf_blocks = nn.ModuleList(
            [ModifiedTransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )

        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx, return_hidden_states=False):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)

        hidden_states = []

        for block in self.trf_blocks:
            x = block(x)
            if return_hidden_states:
                hidden_states.append(x)

        x = self.final_norm(x)


        if return_hidden_states:
            hidden_states.append(x)

        logits = self.out_head(x)

        if return_hidden_states:
            return logits, hidden_states
        return logits

In [62]:
class AttentionPooling(nn.Module):
    """
    Aggregates a sequence of token embeddings into a single sentence vector.
    Uses learnable attention weights rather than mean pooling to focus on
    semantically significant tokens (subjects, verbs) over stopwords.
    """
    def __init__(self, emb_dim):
        super().__init__()
        self.attention_weights = nn.Linear(emb_dim, 1)

    def forward(self, hidden_states, mask=None):
        # hidden_states: [Batch, Seq_Len, Emb_Dim]

        # Calculate raw scores
        scores = self.attention_weights(hidden_states) # [B, S, 1]

        # Apply mask if provided (e.g., for padding)
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(-1) == 0, -float('inf'))

        weights = torch.softmax(scores, dim=1) # [B, S, 1]

        # Weighted sum
        sentence_emb = torch.sum(hidden_states * weights, dim=1) # [B, Emb_Dim]
        return sentence_emb

class ConceptProjector(nn.Module):
    """
    Projects sentence embeddings into a high-dimensional semantic manifold.
    Uses a non-linear normalization.
    """
    def __init__(self, input_dim, concept_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, input_dim * 2),
            nn.GELU(),
            nn.Linear(input_dim * 2, concept_dim)
        )

    def forward(self, x):
        # Project
        x = self.net(x)
        # L2 Normalize to place on hypersphere (crucial for cosine/contrastive losses)
        x = F.normalize(x, p=2, dim=1)
        return x

In [63]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")
class HierarchicalLLM(nn.Module):
    def __init__(self, gpt_model, pooling_layer, concept_head,
                 alpha=1.0, beta=0.5, gamma=0.5, temperature=0.07):
        super().__init__()
        self.gpt = gpt_model
        self.pooler = pooling_layer
        self.concept_head = concept_head

        # Loss Weights
        self.alpha = alpha # Token
        self.beta = beta   # Sentence
        self.gamma = gamma # Concept
        self.temp = temperature

        # Avoid double-wrapping check
        assert not isinstance(gpt_model, HierarchicalLLM), "Do not double-wrap models"

    def compute_token_loss(self, logits, targets):
        # Standard Causal LM Loss
        # Shift logits and targets
        loss = F.cross_entropy(
            logits[:, :-1, :].reshape(-1, logits.size(-1)),
            targets[:, 1:].reshape(-1)
        )
        return loss

    def compute_sentence_loss(self, gen_emb, target_emb):
        """
        Cosine Semantic Alignment.
        Loss = 1 - CosineSimilarity(Generated, Target)
        Expects L2 normalized inputs (or handles normalization internally).
        """
        # Both are [Batch, Dim]
        # Cosine Similarity = (A . B) / (|A|*|B|)
        # Since outputs of pooling aren't strictly normalized yet, we use cosine_similarity
        sim = F.cosine_similarity(gen_emb, target_emb, dim=-1)
        loss = 1.0 - sim.mean()
        return loss

    def compute_concept_loss(self, concepts, labels):
        """
        Supervised Contrastive Loss (SupCon).
        Aligns concepts with the same class label (e.g., Spam vs Spam).
        """
        batch_size = concepts.shape[0]
        if labels is None:
            return torch.tensor(0.0, device=concepts.device)

        # Similarity matrix [B, B]
        sim_matrix = torch.matmul(concepts, concepts.T) / self.temp

        # Mask for positives (same label)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(concepts.device)

        # Mask out self-contrast
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(concepts.device),
            0
        )
        mask = mask * logits_mask

        # Compute log_prob
        exp_sim = torch.exp(sim_matrix) * logits_mask
        log_prob = sim_matrix - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)

        # Mean log-likelihood over positive pairs
        mean_log_prob_pos = (mask * log_prob).sum(dim=1) / (mask.sum(dim=1) + 1e-8)

        loss = - mean_log_prob_pos.mean()
        return loss

    def forward(self, input_ids, labels=None):
        # 1. Generation Branch (Gradient Flows)
        logits, hidden_states = self.gpt(input_ids, return_hidden_states=True)
        last_hidden = hidden_states[-1] # [B, S, D]

        # Sentence Embedding
        sent_emb = self.pooler(last_hidden)

        # Concept Embedding (Normalized)
        concept_emb = self.concept_head(sent_emb)

        # 2. Target Branch (NO Gradient Flow - Stop Gradient)
        # In a self-supervised setup, the target is the input itself (auto-encoding stability).
        # We detach to prevent collapse (model chasing its own moving tail).
        with torch.no_grad():
            _, target_hidden_states = self.gpt(input_ids, return_hidden_states=True)
            target_sent_emb = self.pooler(target_hidden_states[-1])
            # We treat the frozen representation of the sentence as the "ground truth" meaning

        # 3. Calculate Losses

        # A. Syntax (Token) Loss
        L_token = self.compute_token_loss(logits, input_ids)

        # B. Context (Sentence) Loss
        L_sentence = self.compute_sentence_loss(sent_emb, target_sent_emb.detach())

        # C. Concept (Contrastive) Loss
        # Requires class labels (e.g., spam/ham) to cluster semantics
        L_concept = self.compute_concept_loss(concept_emb, labels)

        # Total Loss
        L_total = (self.alpha * L_token) + (self.beta * L_sentence) + (self.gamma * L_concept)

        return {
            "loss": L_total,
            "l_token": L_token,
            "l_sentence": L_sentence,
            "l_concept": L_concept,
            "logits": logits,
            "concept_emb": concept_emb
        }

In [64]:
def train_hierarchical_model(model, dataloader, optimizer, device, num_epochs=1):
    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        print(f"--- Epoch {epoch+1} ---")
        total_loss = 0

        for batch_idx, (input_ids, labels) in enumerate(dataloader):
            input_ids = input_ids.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward pass through wrapper
            outputs = model(input_ids, labels=labels)

            # Backward
            loss = outputs["loss"]
            loss.backward()

            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()

            total_loss += loss.item()

            if batch_idx % 10 == 0:
                print(f"Batch {batch_idx} | "
                      f"Total: {loss.item():.4f} | "
                      f"Tok: {outputs['l_token'].item():.4f} | "
                      f"Sent: {outputs['l_sentence'].item():.4f} | "
                      f"Con: {outputs['l_concept'].item():.4f}")

    return model

In [65]:
# 1. Configuration (Matching your 355M setup or scaled down)
HIER_CONFIG = {
    "vocab_size": 50257,
    "context_length": 1024,
    "emb_dim": 768,       # Scaled for demo, use 1024 for medium
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": True
}

# 2. Instantiate Base Components
base_gpt = HierarchicalGPTModel(HIER_CONFIG)
pooler = AttentionPooling(emb_dim=HIER_CONFIG["emb_dim"])
projector = ConceptProjector(input_dim=HIER_CONFIG["emb_dim"], concept_dim=2048)

# 3. Wrap Logic
hierarchical_model = HierarchicalLLM(
    gpt_model=base_gpt,
    pooling_layer=pooler,
    concept_head=projector,
    alpha=1.0,  # Focus on syntax
    beta=0.5,   # Enforce sentence consistency
    gamma=0.1   # Shape clusters (gentle pressure)
)

# 4. Optimizer
optimizer = torch.optim.AdamW(hierarchical_model.parameters(), lr=5e-5, weight_decay=0.1)

# 5. Runtime Assertion Checks
# Ensure shapes match expectations before training loop
dummy_input = torch.randint(0, 50257, (2, 128))
dummy_labels = torch.tensor([0, 1]) # Example labels
try:
    hierarchical_model.eval()
    with torch.no_grad():
        out = hierarchical_model(dummy_input, dummy_labels)

    print("Sanity Check Passed:")
    print(f"Logits Shape: {out['logits'].shape}") # [2, 128, 50257]
    print(f"Concept Embed Shape: {out['concept_emb'].shape}") # [2, 2048]
    print(f"Total Loss: {out['loss'].item()}")
except Exception as e:
    print(f"Sanity Check Failed: {e}")

# 6. Run Training (assuming train_loader and device defined in your notebook)
# train_hierarchical_model(hierarchical_model, train_loader, optimizer, device)

Sanity Check Passed:
Logits Shape: torch.Size([2, 128, 50257])
Concept Embed Shape: torch.Size([2, 2048])
Total Loss: 10.973592758178711


In [66]:
# ============================================================================
# HIERARCHICAL LLM TRAINING LOOP (Multi-Level Loss)
# ============================================================================

import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("="*70)
print("TRAINING HIERARCHICAL LLM FOR SPAM CLASSIFICATION")
print("="*70)


print(f"\nModel: HierarchicalLLM (Token + Sentence + Concept losses)")
print(f"Device: {device}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

# Create classifier on top of concept embeddings
concept_dim = 2048  # From ConceptProjector
num_classes = 2
hierarchical_classifier = torch.nn.Linear(concept_dim, num_classes).to(device)

print(f"Classifier: Linear({concept_dim}, {num_classes})")

# Move the hierarchical_model to the correct device
hierarchical_model.to(device)

# Optimizer for hierarchical model + classifier
hierarchical_optimizer = torch.optim.AdamW(
    list(hierarchical_model.parameters()) + list(hierarchical_classifier.parameters()),
    lr=3e-5,
    weight_decay=0.01
)

print("\nOptimizer: AdamW (lr=3e-5)")
print("="*70)

# Training function for HierarchicalLLM
def train_epoch_hierarchical(
    model,
    classifier,
    loader,
    optimizer,
    device
):
    model.train()
    classifier.train()

    totals = {
        "total": 0.0,
        "token": 0.0,
        "sentence": 0.0,
        "concept": 0.0,
        "cls": 0.0
    }

    for input_ids, labels in loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass - HierarchicalLLM returns dictionary
        outputs = model(input_ids, labels=labels)

        # Extract losses and embeddings
        loss = outputs["loss"]  # Combined hierarchical loss
        l_token = outputs["l_token"]
        l_sentence = outputs["l_sentence"]
        l_concept = outputs["l_concept"]
        concept_emb = outputs["concept_emb"]

        # Classification loss on concept embeddings
        cls_logits = classifier(concept_emb)
        cls_loss = F.cross_entropy(cls_logits, labels)

        # Total loss: hierarchical losses + classification
        full_loss = loss + cls_loss

        full_loss.backward()
        optimizer.step()

        totals["total"] += full_loss.item()
        totals["token"] += l_token.item()
        totals["sentence"] += l_sentence.item()
        totals["concept"] += l_concept.item()
        totals["cls"] += cls_loss.item()

    n = len(loader)
    return {k: v / n for k, v in totals.items()}


# Evaluation function for HierarchicalLLM
@torch.no_grad()
def evaluate_hierarchical(
    model,
    classifier,
    loader,
    device
):
    model.eval()
    classifier.eval()

    all_preds = []
    all_labels = []

    totals = {
        "token": 0.0,
        "sentence": 0.0,
        "concept": 0.0,
        "cls": 0.0
    }

    for input_ids, labels in loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(input_ids, labels=labels)

        # Extract components
        concept_emb = outputs["concept_emb"]

        totals["token"] += outputs["l_token"].item()
        totals["sentence"] += outputs["l_sentence"].item()
        totals["concept"] += outputs["l_concept"].item()

        # Classification
        cls_logits = classifier(concept_emb)
        cls_loss = F.cross_entropy(cls_logits, labels)
        preds = torch.argmax(cls_logits, dim=-1)

        totals["cls"] += cls_loss.item()

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    # Metrics
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    acc = (all_preds == all_labels).mean()

    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average="binary", zero_division=0
    )
    cm = confusion_matrix(all_labels, all_preds)

    n = len(loader)
    return {
        "token_loss": totals["token"] / n,
        "sentence_loss": totals["sentence"] / n,
        "concept_loss": totals["concept"] / n,
        "cls_loss": totals["cls"] / n,
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": cm
    }


# Training loop
epochs = 5
print(f"\nTraining for {epochs} epochs...")
print("="*70)

for epoch in range(epochs):
    # Training phase
    train_metrics = train_epoch_hierarchical(
        hierarchical_model,
        hierarchical_classifier,
        train_loader,
        hierarchical_optimizer,
        device
    )

    # Validation phase
    val_metrics = evaluate_hierarchical(
        hierarchical_model,
        hierarchical_classifier,
        val_loader,
        device
    )

    # Print epoch results
    print(f"\nEpoch {epoch+1}/{epochs}")
    print(f"{'─'*70}")
    print(f"  Train Total Loss:     {train_metrics['total']:.4f}")
    print(f"    ├─ Token Loss:      {train_metrics['token']:.4f}")
    print(f"    ├─ Sentence Loss:   {train_metrics['sentence']:.4f}")
    print(f"    ├─ Concept Loss:    {train_metrics['concept']:.4f}")
    print(f"    └─ Classifier Loss: {train_metrics['cls']:.4f}")
    print(f"  Val Accuracy:         {val_metrics['accuracy']*100:.2f}%")
    print(f"  Val Precision:        {val_metrics['precision']:.4f}")
    print(f"  Val Recall:           {val_metrics['recall']:.4f}")
    print(f"  Val F1 Score:         {val_metrics['f1']:.4f}")

print(f"\n{'='*70}")
print("HIERARCHICAL TRAINING COMPLETE!")
print("="*70)

# Final evaluation on test set
print("\nEvaluating on test set...")
test_metrics = evaluate_hierarchical(
    hierarchical_model,
    hierarchical_classifier,
    test_loader,
    device
)

print(f"\n{'='*70}")
print("FINAL TEST SET RESULTS (Hierarchical LLM)")
print("="*70)
print(f"  Token Loss:      {test_metrics['token_loss']:.4f}")
print(f"  Sentence Loss:   {test_metrics['sentence_loss']:.4f}")
print(f"  Concept Loss:    {test_metrics['concept_loss']:.4f}")
print(f"  Classifier Loss: {test_metrics['cls_loss']:.4f}")
print(f"  Test Accuracy:   {test_metrics['accuracy']*100:.2f}%")
print(f"  Test Precision:  {test_metrics['precision']:.4f}")
print(f"  Test Recall:     {test_metrics['recall']:.4f}")
print(f"  Test F1 Score:   {test_metrics['f1']:.4f}")
print(f"\nConfusion Matrix:")
print(test_metrics['confusion_matrix'])
print("="*70)



print(" Hierarchical LLM training complete!")


TRAINING HIERARCHICAL LLM FOR SPAM CLASSIFICATION

Model: HierarchicalLLM (Token + Sentence + Concept losses)
Device: cuda
Train batches: 130
Val batches: 19
Classifier: Linear(2048, 2)

Optimizer: AdamW (lr=3e-5)

Training for 5 epochs...

Epoch 1/5
──────────────────────────────────────────────────────────────────────
  Train Total Loss:     3.4428
    ├─ Token Loss:      2.6042
    ├─ Sentence Loss:   0.0011
    ├─ Concept Loss:    1.9049
    └─ Classifier Loss: 0.6476
  Val Accuracy:         73.83%
  Val Precision:        0.9545
  Val Recall:           0.5316
  Val F1 Score:         0.6829

Epoch 2/5
──────────────────────────────────────────────────────────────────────
  Train Total Loss:     2.6711
    ├─ Token Loss:      1.9336
    ├─ Sentence Loss:   0.0026
    ├─ Concept Loss:    1.8830
    └─ Classifier Loss: 0.5479
  Val Accuracy:         87.92%
  Val Precision:        1.0000
  Val Recall:           0.7722
  Val F1 Score:         0.8714

Epoch 3/5
───────────────────────────

In [67]:
class SpamClassifier(nn.Module):
    def __init__(self, concept_dim):
        super().__init__()
        # FIXED: Input should be concept_dim, output should be num_classes (2)
        self.linear = nn.Linear(concept_dim, 2)

    def forward(self, concept_emb):
        return self.linear(concept_emb)


In [68]:
classifier = SpamClassifier(concept_dim=num_classes).to(device)

In [69]:
optimizer = torch.optim.AdamW(
    list(model.parameters()) + list(classifier.parameters()),
    lr=3e-5
)


In [70]:
def train_epoch(
    model,
    classifier,
    loader,
    optimizer,
    device
):
    """
    Training function compatible with standard GPTModel.
    GPTModel.forward(input_ids) returns logits, not a dictionary.
    """
    model.train()
    classifier.train()

    totals = {
        "total": 0.0,
        "cls": 0.0
    }

    for input_ids, labels in loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass - GPTModel returns logits directly
        logits = model(input_ids)  # Shape: [batch, seq_len, vocab_size]

        # Use last token for classification
        last_token_logits = logits[:, -1, :]  # [batch, vocab_size]

        # Classification on top of the last token logits
        cls_logits = classifier(last_token_logits)
        cls_loss = F.cross_entropy(cls_logits, labels)

        cls_loss.backward()
        optimizer.step()

        totals["total"] += cls_loss.item()
        totals["cls"] += cls_loss.item()

    n = len(loader)
    return {k: v / n for k, v in totals.items()}


In [71]:
@torch.no_grad()
def evaluate(
    model,
    classifier,
    loader,
    device
):
    """
    Evaluation function compatible with standard GPTModel.
    """
    model.eval()
    classifier.eval()

    all_preds = []
    all_labels = []
    cls_losses = []

    for input_ids, labels in loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        # Forward pass
        logits = model(input_ids)  # [batch, seq_len, vocab_size]
        last_token_logits = logits[:, -1, :]  # [batch, vocab_size]

        # Classification
        cls_logits = classifier(last_token_logits)
        cls_loss = F.cross_entropy(cls_logits, labels)
        preds = torch.argmax(cls_logits, dim=-1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        cls_losses.append(cls_loss.item())

    # Metrics
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    acc = (all_preds == all_labels).mean()

    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average="binary", zero_division=0
    )

    cm = confusion_matrix(all_labels, all_preds)

    return {
        "cls_loss": np.mean(cls_losses),
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": cm
    }


In [72]:
def train_epoch(
    model,
    classifier,
    loader,
    optimizer,
    device,
    alpha=1.0,
    beta=0.5,
    gamma=0.5
):
    model.train()
    classifier.train()

    totals = {
        "total": 0.0,
        "token": 0.0,
        "sentence": 0.0,
        "concept": 0.0,
        "cls": 0.0
    }

    for input_ids, labels in loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        # Targets are just input_ids for the auto-regressive part
        targets = input_ids.clone()

        optimizer.zero_grad()

        # Forward pass returning a DICTIONARY
        outputs = model(input_ids, labels=labels)

        # Extract components from the dictionary
        loss = outputs["loss"]
        lt = outputs["l_token"]
        ls = outputs["l_sentence"]
        lc = outputs["l_concept"]
        concept_g = outputs["concept_emb"]

        # Classification loss (Spam vs Not Spam) on top of the concept embeddings
        cls_logits = classifier(concept_g)
        cls_loss = F.cross_entropy(cls_logits, labels)

        # Combine losses
        full_loss = loss + cls_loss

        full_loss.backward()
        optimizer.step()

        totals["total"] += full_loss.item()
        totals["token"] += lt.item()
        totals["sentence"] += ls.item()
        totals["concept"] += lc.item()
        totals["cls"] += cls_loss.item()

    n = len(loader)
    return {k: v / n for k, v in totals.items()}

In [73]:
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

@torch.no_grad()
def evaluate(
    model,
    classifier,
    loader,
    device
):
    model.eval()
    classifier.eval()

    all_preds = []
    all_labels = []
    # sent_sims = [] # Calculating this manually is redundant if model does it
    token_losses = []
    sentence_losses = []

    concept_embs = []

    for input_ids, labels in loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        # Forward pass (returns dict)
        outputs = model(input_ids, labels=labels)

        # Extract losses directly from model output
        token_losses.append(outputs["l_token"].item())
        sentence_losses.append(outputs["l_sentence"].item())

        concept_g = outputs["concept_emb"]

        # Classification
        cls_logits = classifier(concept_g)
        preds = torch.argmax(cls_logits, dim=-1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        concept_embs.append(concept_g.cpu())

    # Metrics
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    acc = (all_preds == all_labels).mean()

    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average="binary"
    )

    cm = confusion_matrix(all_labels, all_preds)

    # Concept separability analysis
    concept_embs = torch.cat(concept_embs, dim=0)

    # Normalize for cosine similarity
    concept_embs = F.normalize(concept_embs, dim=-1)

    sim_matrix = torch.matmul(concept_embs, concept_embs.T)

    # Create mask for same-label pairs
    # Note: This operation is O(N^2) and might be slow for large validation sets
    # We'll take a subset if necessary, but for small datasets it's fine.
    labels_tensor = torch.tensor(all_labels)
    same = labels_tensor[:, None] == labels_tensor[None, :]

    # Calculate intra-class (same label) and inter-class (different label) similarity
    # We mask out the diagonal (self-similarity) for intra calculation
    n = same.shape[0]
    diag_mask = ~torch.eye(n, dtype=torch.bool)

    intra_mask = same & diag_mask
    inter_mask = ~same

    intra = sim_matrix[intra_mask].mean().item() if intra_mask.any() else 0.0
    inter = sim_matrix[inter_mask].mean().item() if inter_mask.any() else 0.0

    return {
        "token_loss": np.mean(token_losses),
        "sentence_loss": np.mean(sentence_losses), # Lower is better (1 - sim)
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": cm,
        "concept_intra_similarity": intra,
        "concept_inter_similarity": inter
    }

In [80]:
# ============================================================================
# SPAM CLASSIFICATION TRAINING LOOP
# ============================================================================

import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

print("="*70)
print("STARTING SPAM CLASSIFICATION TRAINING")
print("="*70)

model = gpt # Use the loaded GPT2-medium (355M) model
model.to(device)

# Freeze all parameters of the base GPT model initially
for param in model.parameters():
    param.requires_grad = False

# Replace the original output head with a new classification head
# This head will take the embeddings from the final_norm layer (emb_dim) and map to num_classes
model.out_head = torch.nn.Linear(in_features=NEW_CONFIG["emb_dim"], out_features=num_classes)
model.out_head.to(device) # Move the new head to the correct device

# Only the new output head's parameters need gradients
for param in model.out_head.parameters():
    param.requires_grad = True

print(f"\nModel: {model.__class__.__name__}")
print(f"Device: {device}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# Redefine train_epoch function to work with the modified model (no external classifier needed)
def train_epoch(
    model,
    loader,
    optimizer,
    device
):
    """
    Training function compatible with GPTModel where out_head is the classifier.
    """
    model.train()

    totals = {
        "total": 0.0
    }

    for input_ids, labels in loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass - Model's new out_head directly gives classification logits
        logits = model(input_ids)  # Shape: [batch, seq_len, num_classes]

        # Use last token's logits for classification loss
        cls_logits = logits[:, -1, :] # Shape: [batch, num_classes]

        cls_loss = F.cross_entropy(cls_logits, labels)

        cls_loss.backward()
        optimizer.step()

        totals["total"] += cls_loss.item()

    n = len(loader)
    return {k: v / n for k, v in totals.items()}

# Redefine evaluate function to work with the modified model (no external classifier needed)
@torch.no_grad()
def evaluate(
    model,
    loader,
    device
):
    """
    Evaluation function compatible with GPTModel where out_head is the classifier.
    """
    model.eval()

    all_preds = []
    all_labels = []
    cls_losses = []

    for input_ids, labels in loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        # Forward pass
        logits = model(input_ids)  # [batch, seq_len, num_classes]
        cls_logits = logits[:, -1, :] # [batch, num_classes]

        cls_loss = F.cross_entropy(cls_logits, labels)
        preds = torch.argmax(cls_logits, dim=-1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        cls_losses.append(cls_loss.item())

    # Metrics
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    acc = (all_preds == all_labels).mean()

    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average="binary", zero_division=0
    )

    cm = confusion_matrix(all_labels, all_preds)

    return {
        "cls_loss": np.mean(cls_losses),
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": cm
    }


# Training configuration
epochs = 5
print(f"\nTraining for {epochs} epochs...")
print("="*70)

# Initialize optimizer for the trainable parameters (only the new out_head)
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=3e-5,
    weight_decay=0.01
)

# Training loop
for epoch in range(epochs):
    # Training phase
    train_metrics = train_epoch(
        model,
        train_loader,
        optimizer,
        device
    )

    # Validation phase
    val_metrics = evaluate(
        model,
        val_loader,
        device
    )

    # Print epoch results
    print(f"\nEpoch {epoch+1}/{epochs}")
    print(f"{'─'*70}")
    print(f"  Train Loss:      {train_metrics['total']:.4f}")
    print(f"  Val Loss:        {val_metrics['cls_loss']:.4f}")
    print(f"  Val Accuracy:    {val_metrics['accuracy']*100:.2f}%")
    print(f"  Val Precision:   {val_metrics['precision']:.4f}")
    print(f"  Val Recall:      {val_metrics['recall']:.4f}")
    print(f"  Val F1 Score:    {val_metrics['f1']:.4f}")

print(f"\n{'='*70}")
print("TRAINING COMPLETE!")
print("="*70)

# Final evaluation on test set
print("\nEvaluating on test set...")
test_metrics = evaluate(
    model,
    test_loader,
    device
)

print(f"\n{'='*70}")
print("FINAL TEST SET RESULTS")
print("="*70)
print(f"  Test Loss:       {test_metrics['cls_loss']:.4f}")
print(f"  Test Accuracy:   {test_metrics['accuracy']*100:.2f}%")
print(f"  Test Precision:  {test_metrics['precision']:.4f}")
print(f"  Test Recall:     {test_metrics['recall']:.4f}")
print(f"  Test F1 Score:   {test_metrics['f1']:.4f}")
print(f"\nConfusion Matrix:")
print(test_metrics['confusion_matrix'])
print("="*70)

# Save the trained model and classifier (optional)
print("\n💾 Saving trained model and classifier...")
torch.save({
    'epoch': epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'test_accuracy': test_metrics['accuracy'],
    'test_f1': test_metrics['f1']
}, 'spam_classifier_checkpoint.pth')
print("✓ Checkpoint saved to 'spam_classifier_checkpoint.pth'")

print("\n🎉 All done!")


STARTING SPAM CLASSIFICATION TRAINING

Model: GPTModel
Device: cuda
Train batches: 130
Val batches: 19
Test batches: 38

Training for 5 epochs...

Epoch 1/5
──────────────────────────────────────────────────────────────────────
  Train Loss:      0.7117
  Val Loss:        0.6847
  Val Accuracy:    52.35%
  Val Precision:   0.5278
  Val Recall:      0.9620
  Val F1 Score:    0.6816

Epoch 2/5
──────────────────────────────────────────────────────────────────────
  Train Loss:      0.6709
  Val Loss:        0.6652
  Val Accuracy:    76.51%
  Val Precision:   0.7973
  Val Recall:      0.7468
  Val F1 Score:    0.7712

Epoch 3/5
──────────────────────────────────────────────────────────────────────
  Train Loss:      0.6418
  Val Loss:        0.6524
  Val Accuracy:    57.72%
  Val Precision:   0.7857
  Val Recall:      0.2785
  Val F1 Score:    0.4112

Epoch 4/5
──────────────────────────────────────────────────────────────────────
  Train Loss:      0.6205
  Val Loss:        0.6433
  Val 

In [83]:
# ============================================================================
# TEST INDIVIDUAL MESSAGES (Inference)
# ============================================================================

def classify_message(text, model, tokenizer, device, max_length=120, pad_token_id=50256):
    """Classify a single message as spam or not spam.
    The model's out_head is expected to be the classification layer.
    """
    model.eval()

    # Tokenize
    input_ids = tokenizer.encode(text)

    # Truncate or pad
    if len(input_ids) > max_length:
        input_ids = input_ids[:max_length]
    else:
        input_ids = input_ids + [pad_token_id] * (max_length - len(input_ids))

    # Convert to tensor
    input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)

    # Predict
    with torch.no_grad():
        # model(input_tensor) directly returns classification logits due to modified out_head
        logits = model(input_tensor)
        cls_logits = logits[:, -1, :] # Logits for the last token are the classification scores
        prediction = torch.argmax(cls_logits, dim=-1).item()
        probabilities = torch.softmax(cls_logits, dim=-1)[0]

    label = "SPAM" if prediction == 1 else "NOT SPAM"
    confidence = probabilities[prediction].item() * 100

    return label, confidence

# Test with example messages
print("="*70)
print("TESTING SPAM CLASSIFIER ON INDIVIDUAL MESSAGES")
print("="*70)

test_messages = [
    "Hi, let's catch up for coffee tomorrow!",
    "URGENT: Your account will be closed! Click here now to verify: http://fake-link.com",
    "Meeting scheduled for 3pm on Tuesday in conference room B",
    "Congratulations! You've won $1,000,000! Claim your prize now!!!",
    "Can you send me the report when you get a chance?",
    "FREE MONEY! No purchase necessary! Limited time offer! Act now!"
]

for i, message in enumerate(test_messages, 1):
    # Note: 'classifier' argument is removed as the model itself performs classification
    label, confidence = classify_message(
        message, model, tokenizer, device
    )

    print(f"\n{i}. Message: \"{message[:60]}{'...' if len(message) > 60 else ''}\"")
    print(f"   Prediction: {label} ({confidence:.1f}% confidence)")

print("\n" + "="*70)


TESTING SPAM CLASSIFIER ON INDIVIDUAL MESSAGES

1. Message: "Hi, let's catch up for coffee tomorrow!"
   Prediction: NOT SPAM (68.5% confidence)

2. Message: "URGENT: Your account will be closed! Click here now to verif..."
   Prediction: NOT SPAM (62.9% confidence)

3. Message: "Meeting scheduled for 3pm on Tuesday in conference room B"
   Prediction: NOT SPAM (67.5% confidence)

4. Message: "Congratulations! You've won $1,000,000! Claim your prize now..."
   Prediction: NOT SPAM (66.0% confidence)

5. Message: "Can you send me the report when you get a chance?"
   Prediction: NOT SPAM (66.7% confidence)

6. Message: "FREE MONEY! No purchase necessary! Limited time offer! Act n..."
   Prediction: NOT SPAM (67.1% confidence)

