# NLP 2025-26 -- LLM Project: Email Routing Agents

**Objective:** Develop three LLM-based agents to automatically route customer support emails to the appropriate department.

**Departments:** Technical Support, Customer Service, Billing and Payments, Sales and Pre-Sales, General Inquiry

**Agents:**
1. Routing with prompting using GPT-2 (frozen model)
2. Routing with fine-tuning (LoRA) on GPT-2
3. Routing with discriminative classifier using DistilBERT

## 0. Setup and Imports

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import tracemalloc
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from transformers import (
    GPT2Tokenizer, GPT2LMHeadModel,
    DistilBertTokenizer, DistilBertForSequenceClassification,
)
from tqdm import tqdm

from datapreparation import load_and_prepare_data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Data Loading and Exploration

In [None]:
train_ds, val_ds, test_ds, label_list, label2id, id2label = load_and_prepare_data()

print(f"Train size: {len(train_ds)}")
print(f"Validation size: {len(val_ds)}")
print(f"Test size: {len(test_ds)}")
print(f"Labels: {label_list}")

In [None]:
# Visualize label distribution
from collections import Counter

fig, axes = plt.subplots(1, 3, figsize=(18, 5))
for ax, (name, split) in zip(axes, [("Train", train_ds), ("Validation", val_ds), ("Test", test_ds)]):
    counts = Counter(split["queue"])
    ax.bar(range(len(counts)), counts.values())
    ax.set_xticks(range(len(counts)))
    ax.set_xticklabels(counts.keys(), rotation=45, ha="right")
    ax.set_title(f"{name} Set Distribution")
    ax.set_ylabel("Count")
plt.tight_layout()
plt.show()

In [None]:
# Inspect a few examples
for i in range(3):
    print(f"--- Example {i+1} ---")
    print(f"Subject: {train_ds[i]['subject']}")
    print(f"Body: {train_ds[i]['body'][:200]}...")
    print(f"Department: {train_ds[i]['queue']}")
    print()

---
## 2. Agent 1: Routing with Prompting (Frozen GPT-2)

Use a pretrained GPT-2 model without updating weights. The model receives an instruction-style prompt and must output the department name.

In [ ]:
def build_prompt(subject, body, max_body_chars=500):
    """Build an instruction-style prompt for email routing."""
    body_truncated = body[:max_body_chars].strip()
    prompt = (
        "Classify the following customer support email into exactly one department.\n"
        "Departments: Technical Support, Customer Service, Billing and Payments, Sales and Pre-Sales, General Inquiry\n\n"
        f"Subject: {subject}\n"
        f"Body: {body_truncated}\n\n"
        "Department:"
    )
    return prompt


def score_labels_gpt2(model, tokenizer, prompt, label_list, device):
    """Score each label by computing log-likelihood of the label tokens given the prompt."""
    prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    scores = {}
    for label in label_list:
        label_ids = tokenizer.encode(" " + label, add_special_tokens=False)
        full_ids = torch.cat([prompt_ids, torch.tensor([label_ids], device=device)], dim=1)

        with torch.no_grad():
            outputs = model(full_ids)
            logits = outputs.logits  # (1, seq_len, vocab_size)

        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

        # Sum log-probs of each label token at the positions where they appear
        total_log_prob = 0.0
        prompt_len = prompt_ids.shape[1]
        for i, token_id in enumerate(label_ids):
            total_log_prob += log_probs[0, prompt_len + i - 1, token_id].item()

        # Normalize by number of tokens to avoid length bias
        scores[label] = total_log_prob / len(label_ids)

    return max(scores, key=scores.get)


def evaluate_gpt2_prompting(model, tokenizer, dataset, label_list, device, desc="Evaluating"):
    """Run prompting-based evaluation on a dataset split. Returns predictions, true labels, elapsed time, peak memory."""
    model.eval()
    predictions = []
    true_labels = []

    tracemalloc.start()
    start_time = time.time()

    for example in tqdm(dataset, desc=desc):
        prompt = build_prompt(example["subject"], example["body"])
        pred = score_labels_gpt2(model, tokenizer, prompt, label_list, device)
        predictions.append(pred)
        true_labels.append(example["queue"])

    elapsed = time.time() - start_time
    _, peak_memory = tracemalloc.get_traced_memory()
    tracemalloc.stop()

    return predictions, true_labels, elapsed, peak_memory

### 2.1 GPT-2 Prompting

In [None]:
# Load GPT-2 (frozen)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
gpt2_model.eval()

# Quick sanity check on one example
sample_prompt = build_prompt(test_ds[0]["subject"], test_ds[0]["body"])
sample_pred = score_labels_gpt2(gpt2_model, gpt2_tokenizer, sample_prompt, label_list, device)
print(f"Sample prediction: {sample_pred}")
print(f"True label: {test_ds[0]['queue']}")

In [None]:
# Evaluate GPT-2 on the full test set
gpt2_preds, gpt2_true, gpt2_time, gpt2_mem = evaluate_gpt2_prompting(
    gpt2_model, gpt2_tokenizer, test_ds, label_list, device, desc="GPT-2 Prompting"
)

gpt2_acc = accuracy_score(gpt2_true, gpt2_preds)
print(f"\nGPT-2 Prompting Accuracy: {gpt2_acc:.4f}")
print(f"Time: {gpt2_time:.1f}s | Peak Memory: {gpt2_mem / 1e6:.1f} MB")
print("\nClassification Report:")
print(classification_report(gpt2_true, gpt2_preds, target_names=label_list))

### 2.2 DistilGPT-2 Prompting

In [None]:
# Load DistilGPT-2 (frozen) â€” reuses same tokenizer (GPT-2 tokenizer)
distilgpt2_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
distilgpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2").to(device)
distilgpt2_model.eval()

# Free GPT-2 memory
del gpt2_model
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Evaluate DistilGPT-2 on the full test set
distilgpt2_preds, distilgpt2_true, distilgpt2_time, distilgpt2_mem = evaluate_gpt2_prompting(
    distilgpt2_model, distilgpt2_tokenizer, test_ds, label_list, device, desc="DistilGPT-2 Prompting"
)

distilgpt2_acc = accuracy_score(distilgpt2_true, distilgpt2_preds)
print(f"\nDistilGPT-2 Prompting Accuracy: {distilgpt2_acc:.4f}")
print(f"Time: {distilgpt2_time:.1f}s | Peak Memory: {distilgpt2_mem / 1e6:.1f} MB")
print("\nClassification Report:")
print(classification_report(distilgpt2_true, distilgpt2_preds, target_names=label_list))

---
## 3. Agent 2: Routing with Fine-Tuning (LoRA on GPT-2)

Fine-tune GPT-2 using LoRA, then evaluate with the same prompt as Agent 1.

*Implementation coming in feature/agent2-finetuning*

---
## 4. Agent 3: Routing with DistilBERT Classifier

Fine-tune DistilBERT for sequence classification over the 5 department labels.

*Implementation coming in feature/agent3-classifier*

---
## 5. Results Comparison

Compare all agents on the test set in terms of accuracy, computational time, and memory usage.

*Implementation coming in feature/comparison*