# Fine-tune GPT Model

- We'll be finetuning the LLM on a specific target task, such as classifying text.

In [1]:
%load_ext watermark
%watermark -v -p numpy,pandas,polars,torch,lightning --conda

Python implementation: CPython
Python version       : 3.11.8
IPython version      : 8.22.2

numpy    : 1.26.4
pandas   : 2.2.1
polars   : 0.20.18
torch    : 2.2.2
lightning: 2.2.1

conda environment: torch_p11



In [2]:
# Built-in library
from pathlib import Path
import re
import json
from typing import Any, Optional, Union
import logging
import warnings

# Standard imports
import numpy as np
import numpy.typing as npt
from pprint import pprint
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "info": "#76FF7B",
        "warning": "#FBDDFE",
        "error": "#FF0000",
    }
)
console = Console(theme=custom_theme)

# Visualization
import matplotlib.pyplot as plt


# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

warnings.filterwarnings("ignore")


# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F

In [4]:
GPT_CONFIG_124M: dict[str, Any] = {
    "vocab_size": 50_257,
    "context_length": 1_024,
    "emb_dim": 768,
    "n_heads": 12,  # Number of attention heads
    "n_layers": 12,
    "drop_rate": 0.1,  # Dropout rate
    "qkv_bias": False,
}

### Dataset Download

In [5]:
from urllib import request
import zipfile
import os


def download_and_unzip_spam_data(
    url: str, zip_path: str, extracted_path: str, data_file_path: Path
) -> None:
    original_file_path: Path = Path(extracted_path) / "SMSSpamCollection"

    if data_file_path.exists():
        console.print(
            f"{str(data_file_path)!r} already exists. Skipping download and extraction."
        )
        return None

    # Download the file to the specified directory
    with request.urlopen(url) as response:
        with open(zip_path, "wb") as out_file:
            out_file.write(response.read())

    # Unzip the file to the specified directory
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)

    os.rename(original_file_path, data_file_path)
    console.print(f"File downloaded and saved as {data_file_path!r}")
    return None

In [6]:
url: str = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path: str = "sms_spam_collection.zip"
extracted_path: str = "sms_spam_collection"
data_file_path: Path = Path(extracted_path) / "SMSSpamCollection.tsv"


download_and_unzip_spam_data(
    url=url,
    zip_path=zip_path,
    extracted_path=extracted_path,
    data_file_path=data_file_path,
)

In [7]:
df: pl.DataFrame = pl.read_csv(
    source=data_file_path,
    separator="\t",
    has_header=False,
).rename({"column_1": "Label", "column_2": "Text"})

print(f"{df.shape[0]:,} rows")
df.head()

5,278 rows


Label,Text
str,str
"""ham""","""Go until juron…"
"""ham""","""Ok lar... Joki…"
"""spam""","""Free entry in …"
"""ham""","""U dun say so e…"
"""ham""","""Nah I don't th…"


In [8]:
df.group_by("Label").agg(pl.len())

Label,len
str,u32
"""spam""",697
"""ham""",4581


In [9]:
def generate_sample_data(
    data: pl.DataFrame, seed: int = 123, print_shape: bool = False
) -> pl.DataFrame:

    sample_size: int = int(data.filter(pl.col("Label").eq("spam")).shape[0] * 1.2)
    print(f"sample_size: {sample_size:,}")
    spam: pl.DataFrame = data.filter(pl.col("Label").eq("spam"))
    ham: pl.DataFrame = data.filter(pl.col("Label").eq("ham")).sample(
        n=sample_size, seed=seed
    )
    data_df: pl.DataFrame = pl.concat([spam, ham], how="vertical").sample(
        seed=seed, fraction=1, shuffle=True
    )
    if print_shape:
        print(f"Data shape: {data_df.shape[0]:,} rows")

    return data_df

In [10]:
seed: int = 123

data = generate_sample_data(data=df, seed=seed, print_shape=True)
data.head()

sample_size: 836
Data shape: 1,533 rows


Label,Text
str,str
"""spam""","""T-Mobile custo…"
"""ham""","""How will I cre…"
"""spam""","""important info…"
"""ham""","""I love to give…"
"""ham""","""We stopped to …"


In [11]:
data.group_by("Label").agg(pl.len())

Label,len
str,u32
"""spam""",697
"""ham""",836


In [12]:
# Encode the labels
data = data.with_columns(
    Label=pl.when(pl.col("Label").eq("ham")).then(pl.lit(0)).otherwise(pl.lit(1))
)

console.print(data.head())

In [13]:
## Split the data into tran, validation and test sets
from sklearn.model_selection import train_test_split


train_data: pl.DataFrame
val_data: pl.DataFrame
test_data: pl.DataFrame

train_data, test_data = train_test_split(
    data, stratify=data.select("Label"), test_size=0.1, random_state=seed
)
train_data, val_data = train_test_split(
    train_data, stratify=train_data.select("Label"), test_size=0.1, random_state=seed
)

print(f"{train_data.shape = }, {val_data.shape = }, {test_data.shape = }")

train_data.shape = (1241, 2), val_data.shape = (138, 2), test_data.shape = (154, 2)


In [14]:
# Save the data
save_path: Path = Path("../../data/sms_data")
train_data.write_parquet(file=save_path / "train.parquet", use_pyarrow=True)
val_data.write_parquet(file=save_path / "val.parquet", use_pyarrow=True)
test_data.write_parquet(file=save_path / "test.parquet", use_pyarrow=True)

print(train_data.group_by("Label").agg(pl.len()))
print(val_data.group_by("Label").agg(pl.len()))
print(test_data.group_by("Label").agg(pl.len()))

shape: (2, 2)
┌───────┬─────┐
│ Label ┆ len │
│ ---   ┆ --- │
│ i32   ┆ u32 │
╞═══════╪═════╡
│ 1     ┆ 564 │
│ 0     ┆ 677 │
└───────┴─────┘
shape: (2, 2)
┌───────┬─────┐
│ Label ┆ len │
│ ---   ┆ --- │
│ i32   ┆ u32 │
╞═══════╪═════╡
│ 0     ┆ 75  │
│ 1     ┆ 63  │
└───────┴─────┘
shape: (2, 2)
┌───────┬─────┐
│ Label ┆ len │
│ ---   ┆ --- │
│ i32   ┆ u32 │
╞═══════╪═════╡
│ 1     ┆ 70  │
│ 0     ┆ 84  │
└───────┴─────┘


### Create Datasets And Data Loaders

- Pad all the texts to the same length.
- Pad using the index of the pad token.
  - `"<|endoftext|>"` is the padding token.
  - it has an index of 50256 (using tiktoken)

In [33]:
from torch.utils.data import Dataset, DataLoader


class SpamDataset(Dataset):
    def __init__(
        self,
        data: pl.DataFrame,
        tokenizer: Any,
        max_length: int | None = None,
        pad_token: int = 50_256,
    ) -> None:
        self.data = data
        self.encoded_texts: list[int] = [
            tokenizer.encode(text) for text in self.data.select("Text").to_series()
        ]

        if max_length is None:
            self.max_length: int = self._calculate_max_length()
        else:
            assert (
                max_length > 0
            ), "max_length must be a positive integer or None, not a negative integer."
            self.max_length = max_length

        # Truncate text
        self.encoded_texts = [
            tok_ids[: -self.max_length] for tok_ids in self.encoded_texts
        ]
        # Pad Text
        self.encoded_texts = [
            tok_ids + [pad_token] * (self.max_length - len(tok_ids))
            for tok_ids in self.encoded_texts
        ]
        # Targets
        self.targets = self.data.select("Label").to_series()

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        text: Tensor = torch.tensor(self.encoded_texts[idx], dtype=torch.long)
        label: Tensor = torch.tensor(self.targets[idx], dtype=torch.long)

        return (text, label)

    def _calculate_max_length(self) -> int:
        return max([len(tok_ids) for tok_ids in self.encoded_texts])

In [41]:
import tiktoken


torch.manual_seed(123)
tokenizer = tiktoken.get_encoding("gpt2")
# end_of_text: str = "<|endoftext|>"
# tokenizer.encode(text, allowed_special={end_of_text})

train_dataset: Dataset = SpamDataset(data=train_data, tokenizer=tokenizer)
val_dataset: Dataset = SpamDataset(
    data=val_data, tokenizer=tokenizer, max_length=train_dataset.max_length
)
test_dataset: Dataset = SpamDataset(
    data=test_data, tokenizer=tokenizer, max_length=train_dataset.max_length
)


# Create data loaders
batch_size: int = 8
num_workers: int = 0

train_loader: DataLoader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True,
)
val_loader: DataLoader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    drop_last=False,
)
test_loader: DataLoader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    drop_last=False,
)

In [42]:
for inp_batch, target_batch in train_loader:
    print(f"{inp_batch.shape = }")
    print(f"{target_batch.shape = }\n\n")

    break


print(f"{len(train_loader) = }")
print(f"{len(val_loader) = }")
print(f"{len(test_loader) = }")

inp_batch.shape = torch.Size([8, 257])
target_batch.shape = torch.Size([8])


len(train_loader) = 155
len(val_loader) = 18
len(test_loader) = 20


In [44]:
MODEL_CHOICE: str = "gpt2-small (124M)"
INPUT_PROMPT: str = "Every effort moves"
BASE_CONFIG = {
    "vocab_size": 50_257,  # Vocabulary size
    "context_length": 1_024,  # Context length
    "drop_rate": 0.0,  # Dropout rate
    "qkv_bias": True,  # Query-key-value bias
}
model_configs: dict[str, dict] = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
BASE_CONFIG.update(model_configs[MODEL_CHOICE])

assert train_dataset.max_length <= BASE_CONFIG.get("context_length"), (
    f"Dataset length {train_dataset.max_length} exceeds model's context "
    f"length {BASE_CONFIG.get('context_length')}. Reinitialize data sets with "
    f"`max_length={BASE_CONFIG.get('context_length')}`"
)