# Fine tuning for spam classification

This follows chapter 6 of the book [Build a Large Language Model (From Scratch)](https://www.manning.com/books/build-a-large-language-model-from-scratch).

In [48]:
import import_ipynb
import openai # type:ignore
import gpt # type:ignore
import pandas as pd
import urllib.request
import ssl
import zipfile
import os
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import tiktoken
import time

def get_device() -> torch.device:
    if torch.cuda.is_available(): # type: ignore[attr-defined]
        return torch.device("cuda")
    elif torch.backends.mps.is_available(): # type: ignore[attr-defined]
        return torch.device("mps:0")
    else:
        return torch.device("cpu")

## Download and preprocess the UCI spam data

The fine folks at the University of California at Irvine have provided a nice little data set for SMS spam.
Let's download that and save it in a convenient CSV format.

In [49]:
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extracted_path = "sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"

# NOTE: as of 6/21/25, the UCI archive server is unreachable. I downloaded this file
# manually from a mirror.
def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
    if data_file_path.exists():
        print(f"{data_file_path} already exists. Skipping download and extraction.")
        return
     
    ssl_context = ssl._create_unverified_context()

    with urllib.request.urlopen(url, context=ssl_context) as response:
        with open(zip_path, "wb") as out_file:
            out_file.write(response.read())
    
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)

    original_file_path = Path(extracted_path) / "SMSSpamCollection"
    os.rename(original_file_path, data_file_path)    
    print(f"File downloaded and saved as {data_file_path}")


The data set contains 4825 ham messages and only 747 spam messages. Since we want an equal number of both, we'll have to take 747 ham messages at random and discard the rest.

In [50]:
def create_balanced_dataset(df):
    df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
    num_spam = df[df["Label"] == "spam"].shape[0]
    ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
    balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
    return balanced_df

Now we want to create the following splits:
- 70% for training
- 10% for validation
- 20% for testing

In [51]:
def random_split(df, train_frac, validation_frac):
    df = df.sample(frac=1, random_state=123).reset_index(drop=True)

    train_end = int(len(df) * train_frac)
    validation_end = train_end + int(len(df) * validation_frac)

    train_df = df[:train_end]
    validation_df = df[train_end:validation_end]
    test_df = df[validation_end:]

    return train_df, validation_df, test_df

In [52]:
def save_csv():
    download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
    df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
    balanced_df = create_balanced_dataset(df)
    balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
    train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)
    train_df.to_csv("train.csv", index=None) # type:ignore
    validation_df.to_csv("validation.csv", index=None) # type:ignore
    test_df.to_csv("test.csv", index=None) # type:ignore

# Uncomment if you haven't saved this yet
# save_csv()

## SpamDataset

This class:
1. Pre-tokenizes the texts from the dataset
2. Truncates any sequences that are longer than the maximum length (or the longest text, if no maximum is set).
3. Pads any sequences shorter than the max length.

In [53]:
class SpamDataset(Dataset):
    def __init__(self, csv_file: Path, tokenizer: tiktoken.Encoding, max_length:int|None=None, pad_token_id:int=50256):
        self.data = pd.read_csv(csv_file)

        # Pre-tokenize texts
        self.encoded_texts = [
            tokenizer.encode(text) for text in self.data["Text"]
        ]

        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length
        
            # truncate sequences that are longer than max_length
            self.encoded_texts = [
                encoded_text[:self.max_length]
                for encoded_text in self.encoded_texts
            ]

        # pad the sequences
        self.encoded_texts = [
            encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
            for encoded_text in self.encoded_texts
        ]

    def __getitem__(self, idx):
        encoded = self.encoded_texts[idx]
        label = self.data.iloc[idx]["Label"]
        return (
            torch.tensor(encoded, dtype=torch.long),
            torch.tensor(label, dtype=torch.long)
        )

    def __len__(self):
        return len(self.data)

    def _longest_encoded_length(self):
        return max([len(txt) for txt in self.encoded_texts])

In [54]:
tokenizer = tiktoken.get_encoding("gpt2")
train_dataset = SpamDataset(
    csv_file=Path("train.csv"),
    tokenizer=tokenizer,
)
max_length = train_dataset.max_length
val_dataset = SpamDataset(
    csv_file=Path("validation.csv"),
    max_length=max_length,
    tokenizer=tokenizer,
)
test_dataset = SpamDataset(
    csv_file=Path("test.csv"),
    max_length=max_length,
    tokenizer=tokenizer,
)

In [55]:
num_workers = 0
batch_size = 8

gpt.manual_seed(123)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True,
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True,
)

# Note: I think drop_last is true because the last batch won't have enough messages in it,
# not because the last message might be short.

## Creating the ClassifierGPT class

The ClassifierGPT class wraps and modifies a normal SimplifiedGPT model.
Maybe it would be better to use a function that explicitly modifies its argument?

### What's the difference between this and SimplifiedGPT?

Truthfully, not much: 
- The final output layer has dimensions $\text{context\_length}\times\text{classifications}$, rather than $\text{context\_length}\times\text{vocabulary}$.
- We discard the gradients for the inner layers, since those are adequately trained already.


In [56]:
class ClassifierGPT(nn.Module):
    """Wraps a SimplifiedGPT model and bases a classification model on it.
    Note that the model argument WILL BE MODIFIED."""
    def __init__(self, model: gpt.SimplifiedGPT, classifications:int):
        super().__init__()
        self.model = model
        cfg = model.cfg
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.output = nn.Linear(cfg["emb_dim"], classifications)
        for param in self.model.transformer_blocks[-1].parameters():
            param.requires_grad = True
        for param in self.model.layer_norm.parameters():
            param.requires_grad = True

    def forward(self, in_idx: torch.Tensor) -> torch.Tensor:
        return self.model(in_idx)

In [57]:
gpt124m = openai.load_openai_model(openai.GPT_CONFIG_124M, "124M").model

File already exists and is up-to-date: gpt2/124M/checkpoint
File already exists and is up-to-date: gpt2/124M/encoder.json
File already exists and is up-to-date: gpt2/124M/hparams.json
File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2/124M/model.ckpt.index
File already exists and is up-to-date: gpt2/124M/model.ckpt.meta
File already exists and is up-to-date: gpt2/124M/vocab.bpe
124M model loaded.


In [58]:
clas = ClassifierGPT(gpt124m, 2).to(get_device())

In [102]:
inputs = tokenizer.encode("URGENT!! Your 4* Costa Del Sol Holiday or £5000 await collection. Call 09050090044 Now toClaim. SAE, TC s, POBox334, Stockport, SK38xh, Cost£1.50/pm, Max10mins")
inputs = torch.tensor(inputs).unsqueeze(0).to(get_device())
print(f"Inputs {inputs}")
print(f"Input dimensions: {inputs.shape}")

output = clas(inputs)
print(f"Outputs:\n{output}")
print(f"Output dimensions: {output.shape}")

Inputs tensor([[ 4261,    38,  3525,  3228,  3406,   604,     9, 18133,  4216,  4294,
         22770,   393,  4248, 27641, 25507,  4947,    13,  4889,  7769,  2713,
           405, 12865,  2598,  2735,   284, 44819,    13,   311, 14242,    11,
         17283,   264,    11,   350,  9864,  1140, 31380,    11, 10500,   634,
            11, 14277,  2548,    87,    71,    11,  6446, 14988,    16,    13,
          1120,    14,  4426,    11,  5436,   940, 42951]], device='cuda:0')
Input dimensions: torch.Size([1, 57])
Outputs:
tensor([[[-0.1383, -0.8963],
         [-5.2545, -1.9323],
         [-6.2596, -3.4245],
         [-5.2016, -3.4344],
         [-6.5946, -3.3080],
         [-5.2458, -2.9205],
         [-6.4475, -2.9775],
         [-6.7713, -3.1688],
         [-3.3267, -2.2104],
         [-7.6786, -4.0148],
         [-7.3234, -3.7454],
         [-6.9835, -3.6252],
         [-6.4855, -5.9448],
         [-7.2155, -3.5419],
         [-6.6906, -3.4634],
         [-7.4141, -4.8514],
         [

In [70]:
torch.softmax(output[:, -1, :], dim=-1)

tensor([[0.1321, 0.8679]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [61]:
def calc_accuracy_loader(dataloader: DataLoader, model: ClassifierGPT, device: torch.device, num_batches:int|None=None) -> float:
    model.eval()
    correct_predictions, num_examples = 0, 0

    if num_batches is None:
        num_batches = len(dataloader)
    else:
        num_batches = min(num_batches, len(dataloader))
    for i, (input_batch, target_batch) in enumerate(dataloader):
        if i < num_batches:
            input_batch, target_batch = input_batch.to(device), target_batch.to(device)

            with torch.no_grad():
                logits = model(input_batch)[:, -1, :]
            predicted_labels = torch.argmax(logits, dim=-1)

            num_examples += predicted_labels.shape[0]
            correct_predictions += (predicted_labels == target_batch).sum().item()
        else:
            break
    return correct_predictions / num_examples

In [62]:
def calc_loss_batch(input_batch: torch.Tensor, target_batch: torch.Tensor, model: ClassifierGPT, device: torch.device) -> torch.Tensor:
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)[:, -1, :]
    loss = nn.functional.cross_entropy(logits, target_batch)
    return loss

In [63]:
def calc_loss_loader(dataloader: DataLoader, model: ClassifierGPT, device: torch.device, num_batches:int|None=None) -> float:
    total_loss = 0.
    if len(dataloader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(dataloader)
    else:
        num_batches = min(num_batches, len(dataloader))
    for i, (input_batch, target_batch) in enumerate(dataloader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

In [71]:
with torch.no_grad(): # disable gradient tracking for efficiency right now
    train_loss = calc_loss_loader(train_loader, clas, get_device(), num_batches=5)
    val_loss = calc_loss_loader(val_loader, clas, get_device(), num_batches=5)
    test_loss = calc_loss_loader(test_loader, clas, get_device(), num_batches=5)

print(f"Training loss: {train_loss:.3f}")
print(f"Validation loss: {val_loss:.3f}")
print(f"Test loss: {test_loss:.3f}")

Training loss: 0.088
Validation loss: 0.073
Test loss: 0.326


In [65]:
def evaluate_model(model: ClassifierGPT, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, eval_iter:int):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss

In [66]:
def train_classifier_simple(model: ClassifierGPT, train_loader: DataLoader, val_loader: DataLoader, optimizer, device: torch.device, num_epochs:int, eval_freq:int, eval_iter:int):
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    examples_seen, global_step = 0, -1

    for epoch in range(num_epochs):
        model.train()

        for input_batch, target_batch in train_loader:
            optimizer.zero_grad()
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward()
            optimizer.step()
            examples_seen += input_batch.shape[0]
            global_step += 1

            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")

        train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
        val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
        print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end ="")
        print(f"Validation accuracy: {val_accuracy*100:.2f}%")
        train_accs.append(train_accuracy)
        val_accs.append(val_accuracy)

    return train_losses, val_losses, train_accs, val_accs, examples_seen

In [68]:
def training_run():
    start_time = time.time()
    gpt.manual_seed(123)
    optimizer = torch.optim.AdamW(clas.parameters(), lr=5e-5, weight_decay=0.1)
    num_epochs = 5
    train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
        clas, train_loader, val_loader, optimizer, get_device(), num_epochs=num_epochs, eval_freq=50, eval_iter=5,
    )
    end_time = time.time()
    execution_time_minutes = (end_time - start_time) / 60
    print(f"Training completed in {execution_time_minutes:.2f} minutes.")

training_run()

Ep 1 (Step 000000): Train loss 0.112, Val loss 0.069
Ep 1 (Step 000050): Train loss 0.094, Val loss 0.076
Ep 1 (Step 000100): Train loss 0.052, Val loss 0.170
Training accuracy: 92.50% | Validation accuracy: 100.00%
Ep 2 (Step 000150): Train loss 0.054, Val loss 0.100
Ep 2 (Step 000200): Train loss 0.104, Val loss 0.153
Ep 2 (Step 000250): Train loss 0.153, Val loss 0.095
Training accuracy: 95.00% | Validation accuracy: 100.00%
Ep 3 (Step 000300): Train loss 0.033, Val loss 0.084
Ep 3 (Step 000350): Train loss 0.193, Val loss 0.063
Training accuracy: 97.50% | Validation accuracy: 97.50%
Ep 4 (Step 000400): Train loss 0.210, Val loss 0.023
Ep 4 (Step 000450): Train loss 0.072, Val loss 0.138
Ep 4 (Step 000500): Train loss 0.036, Val loss 0.138
Training accuracy: 97.50% | Validation accuracy: 100.00%
Ep 5 (Step 000550): Train loss 0.090, Val loss 0.042
Ep 5 (Step 000600): Train loss 0.061, Val loss 0.060
Training accuracy: 100.00% | Validation accuracy: 97.50%
Training completed in 0.53 

In [100]:
def model_accuracy(model: ClassifierGPT, device):
    train_accuracy = calc_accuracy_loader(train_loader, model, device)
    val_accuracy = calc_accuracy_loader(val_loader, model, device)
    test_accuracy = calc_accuracy_loader(test_loader, model, device)

    print(f"Training accuracy: {train_accuracy*100:.2f}%")
    print(f"Validation accuracy: {val_accuracy*100:.2f}%")
    print(f"Test accuracy: {test_accuracy*100:.2f}%")

model_accuracy(clas, get_device())

Training accuracy: 98.17%
Validation accuracy: 97.92%
Test accuracy: 95.27%


In [None]:
def classify(text: str, model: ClassifierGPT, tokenizer: tiktoken.Encoding, device: torch.device, max_length:int=0, pad_token_id:int=50256) -> str:
    model.eval()
    input_ids = tokenizer.encode(text)
    supported_context_length = model.model.cfg['context_length']
    if max_length == 0:
        max_length = supported_context_length

    # truncate if too long
    input_ids = input_ids[:min(max_length, supported_context_length)]
    # pad if too short
    input_ids += [pad_token_id] * (max_length - len(input_ids))
    input_tensor = torch.tensor(input_ids, device=device, dtype=torch.long).unsqueeze(0)

    with torch.no_grad():
        logits = model(input_tensor)[:, -1, :]
    print(f"Logits: {logits}")
    predicted_label = torch.argmax(logits, dim=-1).item()

    return "spam" if predicted_label == 1 else "not spam"

In [109]:
max_length = train_dataset.max_length
sample = "hey dude, what up wit it"
classify(sample, clas, tokenizer, get_device(), max_length=max_length)


Logits: tensor([[ 2.7385, -3.7317]], device='cuda:0')


'not spam'