# Fine-tune GPT Model

- We'll be finetuning the LLM on a specific target task, such as classifying text.

In [1]:
%load_ext watermark
%watermark -v -p numpy,pandas,polars,torch,lightning --conda

Python implementation: CPython
Python version       : 3.11.8
IPython version      : 8.22.2

numpy    : 1.26.4
pandas   : 2.2.1
polars   : 0.20.18
torch    : 2.2.2
lightning: 2.2.1

conda environment: torch_p11



In [2]:
# Built-in library
from pathlib import Path
import re
import json
from typing import Any, Optional, Union
import logging
import warnings

# Standard imports
import numpy as np
import numpy.typing as npt
from pprint import pprint
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "info": "#76FF7B",
        "warning": "#FBDDFE",
        "error": "#FF0000",
    }
)
console = Console(theme=custom_theme)

# Visualization
import matplotlib.pyplot as plt


# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

warnings.filterwarnings("ignore")


# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F

In [4]:
GPT_CONFIG_124M: dict[str, Any] = {
    "vocab_size": 50_257,
    "context_length": 1_024,
    "emb_dim": 768,
    "n_heads": 12,  # Number of attention heads
    "n_layers": 12,
    "drop_rate": 0.1,  # Dropout rate
    "qkv_bias": False,
}

### Dataset Download

In [5]:
from urllib import request
import zipfile
import os


def download_and_unzip_spam_data(
    url: str, zip_path: str, extracted_path: str, data_file_path: Path
) -> None:
    if data_file_path.exists():
        console.print(
            f"{str(data_file_path)!r} already exists. Skipping download and extraction."
        )
        return None

    with request.urlopen(url) as response:
        with open(zip_path, "wb") as out_file:
            out_file.write(response.read())

    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)

    original_file_path: Path = Path(extracted_path) / "SMSSpamCollection"
    os.rename(original_file_path, data_file_path)
    console.print(f"File downloaded and saved as {data_file_path!r}")

In [6]:
url: str = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path: str = "sms_spam_collection.zip"
extracted_path: str = "sms_spam_collection"
data_file_path: Path = Path(extracted_path) / "SMSSpamCollection.tsv"


download_and_unzip_spam_data(
    url=url,
    zip_path=zip_path,
    extracted_path=extracted_path,
    data_file_path=data_file_path,
)

In [7]:
df: pl.DataFrame = pl.read_csv(
    source=data_file_path,
    separator="\t",
    has_header=False,
).rename({"column_1": "Label", "column_2": "Text"})

print(f"{df.shape[0]:,} rows")
df.head()

5,278 rows


Label,Text
str,str
"""ham""","""Go until juron…"
"""ham""","""Ok lar... Joki…"
"""spam""","""Free entry in …"
"""ham""","""U dun say so e…"
"""ham""","""Nah I don't th…"


In [8]:
df.group_by("Label").agg(pl.len())

Label,len
str,u32
"""spam""",697
"""ham""",4581


In [9]:
seed: int = 123
sample_size: int = int(df.filter(pl.col("Label").eq("spam")).shape[0] * 1.2)
print(f"sample_size: {sample_size:,}")
spam: pl.DataFrame = df.filter(pl.col("Label").eq("spam"))
ham: pl.DataFrame = df.filter(pl.col("Label").eq("ham")).sample(
    n=sample_size, seed=seed
)
data: pl.DataFrame = pl.concat([spam, ham], how="vertical").sample(
    seed=seed, fraction=1, shuffle=True
)
print(f"Data shape: {data.shape[0]:,} rows")

data.head()

sample_size: 836
Data shape: 1,533 rows


Label,Text
str,str
"""spam""","""T-Mobile custo…"
"""ham""","""How will I cre…"
"""spam""","""important info…"
"""ham""","""I love to give…"
"""ham""","""We stopped to …"


In [10]:
data.group_by("Label").agg(pl.len())

Label,len
str,u32
"""ham""",836
"""spam""",697


In [11]:
# Encode the labels
data = data.with_columns(
    Label=pl.when(pl.col("Label").eq("ham")).then(pl.lit(0)).otherwise(pl.lit(1))
)

console.print(data.head())

In [12]:
## Split the data into tran, validation and test sets