# **Finetuning for Text Classification**

This notebook will focus on the fine-tuning of a pre-trained model to carry out specific tasks, such as text-classification. The alternative i.e. instruction-finetuning will be covered in the next notebook.

In [1]:
from importlib.metadata import version

pkgs = ["matplotlib",  # Plotting library
        "numpy",       # PyTorch & TensorFlow dependency
        "tiktoken",    # Tokenizer
        "torch",       # Deep learning library
        "tensorflow",  # For OpenAI's pretrained weights
        "pandas"       # Dataset loading
       ]
for p in pkgs:
    print(f"{p} version: {version(p)}")

matplotlib version: 3.10.1
numpy version: 2.2.5
tiktoken version: 0.9.0
torch version: 2.7.0
tensorflow version: 2.19.0
pandas version: 2.2.3


## **The Task: Classification of Spam Emails**

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/spam-non-spam.webp" width=500px>

Here, we will be operating with a specific number of class labels i.e. "spam" and "not spam".

## **1. Preparation of the Dataset**

In [None]:
import urllib.request
import zipfile
import os
from pathlib import Path

url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "data/sms_spam_collection.zip"
extracted_path = "data/sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"

def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
    if data_file_path.exists():
        print(f"{data_file_path} already exists. Skipping download and extraction.")
        return

    # Downloading the file
    with urllib.request.urlopen(url) as response:
        with open(zip_path, "wb") as out_file:
            out_file.write(response.read())

    # Unzipping the file
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)

    # Add .tsv file extension
    original_file_path = Path(extracted_path) / "SMSSpamCollection"
    os.rename(original_file_path, data_file_path)
    print(f"File downloaded and saved as {data_file_path}")

try:
    download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e:
    print(f"Primary URL failed: {e}. Trying backup URL...")
    url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip"
    download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path) 

File downloaded and saved as data/sms_spam_collection/SMSSpamCollection.tsv


In [3]:
import pandas as pd

df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
df

Unnamed: 0,Label,Text
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."
...,...,...
5567,spam,This is the 2nd time we have tried 2 contact u...
5568,ham,Will ü b going to esplanade fr home?
5569,ham,"Pity, * was in mood for that. So...any other s..."
5570,ham,The guy did some bitching but I acted like i'd...


In [4]:
# Distributions
df['Label'].value_counts()

Label
ham     4825
spam     747
Name: count, dtype: int64

To make the task of training the LLM faster (solely for demonstration purposes), we will subsample the dataset so that it contains 747 instances from each class. Thereby making the dataset "balanced".

In [5]:
def create_balanced_dataset(df):
    # Count of spam instances
    num_spam = df[df["Label"] == "spam"].shape[0]
    
    # Random sampling of "ham instances" to match the number of "spam" instances
    ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=42)
    
    # Combine ham "subset" with "spam"
    bal_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
    
    return bal_df

balanced_df = create_balanced_dataset(df)
balanced_df["Label"].value_counts()

Label
ham     747
spam    747
Name: count, dtype: int64

In [6]:
# One-Hot Encode the string class labels "ham" and "spam"
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
balanced_df 

Unnamed: 0,Label,Text
3714,0,If i not meeting ü all rite then i'll go home ...
1311,0,"I.ll always be there, even if its just in spir..."
548,0,"Sorry that took so long, omw now"
1324,0,I thk 50 shd be ok he said plus minus 10.. Did...
3184,0,Dunno i juz askin cos i got a card got 20% off...
...,...,...
5537,1,Want explicit SEX in 30 secs? Ring 02073162414...
5540,1,ASKED 3MOBILE IF 0870 CHATLINES INCLU IN FREE ...
5547,1,Had your contract mobile 11 Mnths? Latest Moto...
5566,1,REMINDER FROM O2: To get 2.50 pounds free call...


In [25]:
# Function to randomly divide the dataset into training, validation and test subsets
def random_split(df, train_frac, valid_frac):
    # Shuffle the df
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    
    # Calculate split indices
    train_end = int(len(df) * train_frac)
    valid_end = train_end + int(len(df) * valid_frac)
    
    # Split the dataframe
    train_df = df[:train_end]
    valid_df = df[train_end:valid_end]
    test_df = df[valid_end:]
    
    return train_df, valid_df, test_df

# Test size is 0.2
train_df, valid_df, test_df = random_split(balanced_df, 0.7, 0.1)

train_df.to_csv("data/sms_spam_collection/train.csv", index=None)
valid_df.to_csv("data/sms_spam_collection/valid.csv", index=None)
test_df.to_csv("data/sms_spam_collection/test.csv", index=None)

## **2. Creating the DataLoaders**

- Text messages have variable lengths, and they must either be truncated to the length of the shortest message in the dataset / batch or padded to the length of the longest message in the dataset / batch.
- We will proceed with padding while making use of the `<|endoftext|>` padding token.

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/pad-input-sequences.webp?123" width=640px>

In [17]:
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")
print(tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}))

[50256]


In [None]:
# Identify the longest sequence in the training dataset and add the padding token to others to match sequence length
import torch
from torch.utils.data import Dataset

PATH = "data/sms_spam_collection/"

class SpamDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
        self.data = pd.read_csv(csv_file)
        
        # Pre-tokenize texts
        self.encoded_texts = [
            tokenizer.encode(text) for text in self.data["Text"]
        ]
        
        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length
            # Truncate sequences if they are longer than max_length
            self.encoded_texts = [
                encoded_text[:self.max_length]
                for encoded_text in self.encoded_texts
            ]
            
        # Pad sequences to the longest sequence
        self.encoded_texts = [
            encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
            for encoded_text in self.encoded_texts
        ]
    
    def __getitem__(self, index):
        encoded = self.encoded_texts[index]
        label = self.data.iloc[index]["Label"]
        return (
            torch.tensor(encoded, dtype=torch.long),
            torch.tensor(label, dtype=torch.long)
        )
        
    def __len__(self):
        return len(self.data)
    
    def _longest_encoded_length(self):
        max_length = 0
        #for encoded_text in self.encoded_texts:
        #    encoded_length = len(encoded_text)
        #    if encoded_length > max_length:
        #        max_length = encoded_length
        #return max_length
        # More Pythonic version
        return max(len(encoded_text) for encoded_text in self.encoded_texts)

In [28]:
train_dataset = SpamDataset(
    csv_file=PATH + "/train.csv",
    max_length=None,
    tokenizer=tokenizer
)

print(train_dataset.max_length) # Expect variation due to differences in seed values

109
