# **Email Spam Classification with a Large Language Model (LLM)**
In this notebook, we will:

Load a pre-trained instruction-following model from Hugging Face

Import the spam.csv dataset

Apply prompt engineering to classify each email as either spam or ham (not spam)

Configure the model to output only the class label for each email



In [1]:
!pip install transformers pandas tqdm

from transformers import pipeline
import pandas as pd
from tqdm.auto import tqdm



# **2. Load the LLM and Set Up a Text-to-Text Generation Pipeline**
We’ll utilize the google/flan-t5-base model, a flexible Seq2Seq instruction-following model. The pipeline will be configured to produce a single-token output, either "spam" or "ham."

In [14]:
import pandas as pd

# Load the dataset
try:
    df = pd.read_csv('/content/spam.csv', encoding='latin1')
    print("Dataset loaded successfully. Here's a preview:")
    print(df.head())
    print("\nDataset info:")
    print(df.info())
except FileNotFoundError:
    print("Error: spam.csv not found. Make sure the file is in the same directory as your script.")
    # Exit or handle the error appropriately if the file isn't found
    exit()

Dataset loaded successfully. Here's a preview:
                                                text target
0  Go until jurong point, crazy.. Available only ...    ham
1                      Ok lar... Joking wif u oni...    ham
2  Free entry in 2 a wkly comp to win FA Cup fina...   spam
3  U dun say so early hor... U c already then say...    ham
4  Nah I don't think he goes to usf, he lives aro...    ham

Dataset info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5572 entries, 0 to 5571
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   text    5572 non-null   object
 1   target  5572 non-null   object
dtypes: object(2)
memory usage: 87.2+ KB
None


# **3. Define the Prompt Template and Classification Function**
We design a prompt that explicitly directs the model to return only the class label, and limit the generation to a brief maximum length to ensure concise outputs.

In [16]:
# Function to classify a batch of emails using the loaded LLM
def classify_email_batch(texts):
    predicted_labels = []
    # Loop through each text in the batch
    for text in texts:
        # Prompt Engineering
        prompt = f"Classify the following email as 'ham' or 'spam'. Respond with only 'ham' or 'spam'.\n\nEmail: \"{text}\"\nClassification:"

        try:
            # Step 3b: Perform LLM inference
            llm_output = classifier(prompt)[0]['label']

            # Step 3c: Map LLM output to 'ham' or 'spam' (if necessary)

            if llm_output == 'POSITIVE': # Assuming 'POSITIVE' might relate to ham (or spam depending on context)
                final_label = 'ham'
            elif llm_output == 'NEGATIVE': # Assuming 'NEGATIVE' might relate to spam (or ham)
                final_label = 'spam'
            else:

                final_label = llm_output.lower()

            predicted_labels.append(final_label)

        except Exception as e:
            print(f"Error classifying email: '{text[:50]}...' - {e}")
            predicted_labels.append("error_in_classification")
    return predicted_labels

print("\nBatch classification function defined.")


Batch classification function defined.


# **4. Load the Dataset and Perform Classification**
We’ll load the spam.csv file, classify each email using the model, and append a new column containing the predicted spam or ham label.

In [None]:
# Initialize a list to store all predicted targets
all_predicted_targets = []

# Define batch size (adjust based on your system's memory and LLM's capacity)
batch_size = 1000 # Start with a small batch size, e.g., 1000 or 20000, then increase if stable.

print(f"\nStarting batch processing with batch size: {batch_size}")
print(f"Total emails to process: {len(df)}")

# Define the batch classification function
def classify_email_batch(texts: list[str]) -> list[str]:
    """Classifies a batch of emails using the pre-defined classifier."""
    return [classify_email(text) for text in texts]

# Process the DataFrame in batches
for i in tqdm(range(0, len(df), batch_size)):
    # Get the current batch of texts
    batch_df = df.iloc[i:i + batch_size]
    batch_texts = batch_df['text'].tolist()

    if not batch_texts: # Skip if batch is empty (can happen at the very end)
        continue

    # print(f"Processing emails {i} to {min(i + batch_size, len(df))}...")

    # Get predictions for the current batch using your LLM
    predicted_labels = classify_email_batch(batch_texts)

    # Extend the main list of predictions
    all_predicted_targets.extend(predicted_labels)

print("\nBatch processing complete.")

# Add the predicted labels to the DataFrame
df['predicted'] = all_predicted_targets

display(df.head())


Starting batch processing with batch size: 1000
Total emails to process: 5572


  0%|          | 0/6 [00:00<?, ?it/s]

# **5. Evaluate and Display Sample Results**
We’ll calculate key metrics including accuracy, F1-score, recall, and generate a confusion matrix to assess the classification performance.

In [6]:
# Add the predicted targets back to the original DataFrame
if len(all_predicted_targets) == len(df):
    df['predicted_target'] = all_predicted_targets
    print("\nPredicted targets added to DataFrame. Here's a sample:")
    print(df[['text', 'target', 'predicted_target']].head(10))

    # Basic Evaluation (Optional, but recommended)
    from sklearn.metrics import accuracy_score, classification_report
    df['target_normalized'] = df['target'].apply(lambda x: x.strip().lower())
    df['predicted_target_normalized'] = df['predicted_target'].apply(lambda x: x.strip().lower())

    # Filter out any 'error_in_classification'
    valid_predictions_df = df[df['predicted_target_normalized'] != 'error_in_classification']

    if not valid_predictions_df.empty:
        true_labels = valid_predictions_df['target_normalized']
        pred_labels = valid_predictions_df['predicted_target_normalized']

        # Get unique labels from both true and predicted sets for consistent reporting
        all_labels = sorted(list(set(true_labels.tolist() + pred_labels.tolist())))

        print(f"\nAccuracy: {accuracy_score(true_labels, pred_labels):.4f}")
        print("\nClassification Report:")
        print(classification_report(true_labels, pred_labels, labels=all_labels, zero_division=0))
    else:
        print("\nNo valid predictions to evaluate.")

else:
    print(f"Mismatch in number of predictions ({len(all_predicted_targets)}) and emails ({len(df)}).")
    print("Something went wrong during batch processing.")

print("\nPractical Exercise 1: Email classification setup complete.")

KeyError: 'predicted'