# Exercise: Use a foundation model to build a spam email classifier

A foundation model serves as a fundamental building block for potentially endless applications. One application we will explore in this exercise is the development of a spam email classifier using only the prompt. By leveraging the capabilities of a foundation model, this project aims to accurately identify and filter out unwanted and potentially harmful emails, enhancing user experience and security.

## Steps

1. Identify and gather relevant data
2. Build and evaluate the spam email classifier
3. Build an improved classifier?

## Step 1: Identify and gather relevant data

To train and test the spam email classifier, you will need a dataset of emails that are labeled as spam or not spam. It is important to identify and gather a suitable dataset that represents a wide range of spam and non-spam emails.

In [None]:
!pip install --upgrade datasets==3.2.0 huggingface-hub==0.27.1

In [None]:
# Find a spam dataset at https://huggingface.co/datasets and load it using the datasets library

from datasets import load_dataset

dataset = load_dataset("sms_spam", split=["train"])[0]

for entry in dataset.select(range(3)):
    sms = entry["sms"]
    label = entry["label"]
    print(f"label={label}, sms={sms}")

Those labels could be easier to read. Let's create some functions to convert numerical ids to labels.

In [None]:
# Convenient dictionaries to convert between labels and ids
id2label = {0: "NOT SPAM", 1: "SPAM"}
label2id = {"NOT SPAM": 0, "SPAM": 1}

for entry in dataset.select(range(3)):
    sms = entry["sms"]
    label_id = entry["label"]
    print(f"label={id2label[label_id]}, sms={sms}")

## Step 2: Build and evaluate the spam email classifier

Using the foundation model and the prepared dataset, you can create a spam email classifier.

Let's write a prompt that will ask the model to classify 15 message as either "spam" or "not spam". For easier parsing, we can ask the LLM to respond in JSON.

In [None]:
# Let's start with this helper function that will help us format sms messages
# for the LLM.
def get_sms_messages_string(dataset, item_numbers, include_labels=False):
    sms_messages_string = ""
    for item_number, entry in zip(item_numbers, dataset.select(item_numbers)):
        sms = entry["sms"]
        label_id = entry["label"]

        if include_labels:
            sms_messages_string += (
                f"{item_number} (label={id2label[label_id]}) -> {sms}\n"
            )
        else:
            sms_messages_string += f"{item_number} -> {sms}\n"

    return sms_messages_string


print(get_sms_messages_string(dataset, range(3), include_labels=True))

Now let's write a bit of code that will produce your prompt. Your prompt should include a few SMS message to be labelled as well as instructions for the LLM.

Some LLMs will also format the output for you as JSON if you ask them, e.g. "Respond in JSON format."

In [None]:
# Replace <MASK> with your code

# Get a few messages and format them as a string
sms_messages_string = get_sms_messages_string(dataset, range(7, 15))

# Construct a query to send to the LLM including the sms messages.
# Ask it to respond in JSON format.
query = <MASK>

print(query)

In [None]:
# Replace <MASK> with your LLMs response

response = <MASK>

In [None]:
# Estimate the accuracy of your classifier by comparing your responses to the labels in the dataset


def get_accuracy(response, dataset, original_indices):
    correct = 0
    total = 0

    for entry_number, prediction in response.items():
        if int(entry_number) not in original_indices:
            continue

        label_id = dataset[int(entry_number)]["label"]
        label = id2label[label_id]

        # If the prediction from the LLM matches the label in the dataset
        # we increment the number of correct predictions.
        # (Since LLMs do not always produce the same output, we use the
        # lower case version of the strings for comparison)
        if prediction.lower() == label.lower():
            correct += 1

        # increment the total number of predictions
        total += 1

    try:
        accuracy = correct / total
    except ZeroDivisionError:
        print("No matching results found!")
        return

    return round(accuracy, 2)


print(f"Accuracy: {get_accuracy(response, dataset, range(7, 15))}")

That's not bad! (Assuming you used an LLM capable of handling this task)

Surely it won't be correct for every example we throw at it, but it's a great start, especially for not giving it any examples or training data.

We can see that the model is able to distinguish between spam and non-spam messages with a high degree of accuracy. This is a great example of how a foundation model can be used to build a spam email classifier.

## Step 3: Build an improved classifier?

If you provide the LLM with some examples for how to complete a task, it will sometimes improve its performance. Let's try that out here.

In [None]:
# Replace <MASK> with your code that constructs a query to send to the LLM

# Get a few labelled messages and format them as a string
sms_messages_string_w_labels = get_sms_messages_string(
    dataset, range(54, 60), include_labels=True
)

# Get a few unlabelled messages and format them as a string
sms_messages_string_no_labels = get_sms_messages_string(dataset, range(7, 15))


# Construct a query to send to the LLM including the labelled messages
# as well as the unlabelled messages. Ask it to respond in JSON format
query = <MASK>

print(query)

Paste in your response from the LLM below:

In [None]:
# Replace <MASK> with your LLMs response

response = <MASK>

Let's check the accuracy now

In [None]:
# What's the accuracy?

print(f"Accuracy: {get_accuracy(response, dataset, range(7,15)):.2f}")

If there are any misclassified items, let's view them.

In [None]:
# Show the messages that were misclassified, if you have any


def print_misclassified_messages(response, dataset):
    for entry_number, prediction in response.items():
        label_id = dataset[int(entry_number)]["label"]
        label = id2label[label_id]

        if prediction.lower() != label.lower():
            sms = dataset[int(entry_number)]["sms"]
            print("---")
            print(f"Message: {sms}")
            print(f"Label: {label}")
            print(f"Prediction: {prediction}")


print_misclassified_messages(response, dataset)

Interesting (if there were any mistakes). What do you think is going on?