# Approach 4: Contrastive Classification (Ham vs Smish)

Train a RoBERTa-based binary classifier to detect if an SMS is legitimate (HAM) or a scam (SMISH).

## Overview
- **Model**: RoBERTa-base for Sequence Classification
- **Task**: Binary Classification (HAM vs SMISH)
- **Use Case**: First-line defense - quick ham/smish detection before claim extraction
- **Advantages**: Simple, fast, direct answer to "is this a scam?"

## Key Difference
This is the **only approach with direct classification**. Other approaches focus on **claim extraction**.

## Setup Instructions
1. Upload `claim_annotations_2000.json` to Colab
2. Run all cells in order
3. Model will classify messages as HAM or SMISH

## 1. Environment Setup

In [None]:
# Install required packages
!pip install -q transformers datasets accelerate scikit-learn torch

In [None]:
# Import libraries
import json
import torch
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
import warnings
warnings.filterwarnings('ignore')

print(f"‚úÖ PyTorch version: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   Device: {torch.cuda.get_device_name(0)}")

In [None]:
# Upload data file
from google.colab import files

print("üìÅ Please upload 'claim_annotations_2000.json'")
uploaded = files.upload()
data_file = list(uploaded.keys())[0]
print(f"‚úÖ Uploaded: {data_file}")

## 2. Data Loading

In [None]:
# Load data
with open(data_file, 'r', encoding='utf-8') as f:
    data = json.load(f)

print(f"‚úÖ Loaded {len(data)} examples")

# Extract texts and labels
texts = []
labels = []

for entry in data:
    text = entry['data']['text']
    
    # Determine if HAM or SMISH
    # HAM: no annotations or empty results
    # SMISH: has claim annotations
    is_ham = True
    
    if 'annotations' in entry and len(entry['annotations']) > 0:
        annotations = entry['annotations'][0]
        if 'result' in annotations and len(annotations['result']) > 0:
            is_ham = False  # Has claims = SMISH
    
    # Can also check metadata
    if 'meta' in entry:
        if entry['meta'].get('label') == 'ham':
            is_ham = True
    
    texts.append(text)
    labels.append(0 if is_ham else 1)  # 0=HAM, 1=SMISH

# Show distribution
from collections import Counter
label_dist = Counter(labels)

print(f"\nüìä Dataset distribution:")
print(f"   HAM (0):   {label_dist[0]} messages ({label_dist[0]/len(labels)*100:.1f}%)")
print(f"   SMISH (1): {label_dist[1]} messages ({label_dist[1]/len(labels)*100:.1f}%)")

# Show examples
print(f"\nüìù Example HAM messages:")
ham_examples = [t for t, l in zip(texts, labels) if l == 0][:3]
for i, ex in enumerate(ham_examples, 1):
    print(f"   {i}. {ex[:80]}...")

print(f"\n‚ö†Ô∏è  Example SMISH messages:")
smish_examples = [t for t, l in zip(texts, labels) if l == 1][:3]
for i, ex in enumerate(smish_examples, 1):
    print(f"   {i}. {ex[:80]}...")

In [None]:
# Split data (stratified to maintain class balance)
train_texts, test_texts, train_labels, test_labels = train_test_split(
    texts, labels, test_size=0.15, random_state=42, stratify=labels
)

train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_texts, train_labels, test_size=0.176, random_state=42, stratify=train_labels
)

print(f"Dataset split:")
print(f"  Train: {len(train_texts)} examples")
print(f"  Val:   {len(val_texts)} examples")
print(f"  Test:  {len(test_texts)} examples")

## 3. Tokenization

In [None]:
# Load tokenizer
MODEL_NAME = "roberta-base"  # or "distilroberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print(f"‚úÖ Loaded tokenizer: {MODEL_NAME}")

# Tokenize
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=128)
val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=128)
test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length=128)

print(f"‚úÖ Tokenization complete")

In [None]:
# Create PyTorch datasets
from torch.utils.data import Dataset

class SMSDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

train_dataset = SMSDataset(train_encodings, train_labels)
val_dataset = SMSDataset(val_encodings, val_labels)
test_dataset = SMSDataset(test_encodings, test_labels)

print(f"‚úÖ Created PyTorch datasets")

## 4. Model Training

In [None]:
# Load model
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
    id2label={0: "HAM", 1: "SMISH"},
    label2id={"HAM": 0, "SMISH": 1}
)

print(f"‚úÖ Loaded model: {MODEL_NAME}")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Define metrics
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='weighted')
    
    return {
        'accuracy': acc,
        'f1': f1
    }

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./ham-smish-classifier",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    warmup_ratio=0.1,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    push_to_hub=False,
    report_to="none"
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

print("‚úÖ Trainer initialized")

In [None]:
# Train the model
print("üöÄ Starting training...")
trainer.train()
print("‚úÖ Training complete!")

## 5. Evaluation

In [None]:
# Evaluate on test set
print("üìä Evaluating on test set...")
predictions = trainer.predict(test_dataset)
preds = predictions.predictions.argmax(-1)

# Classification report
print("\n" + "="*60)
print("CLASSIFICATION REPORT")
print("="*60)
print(classification_report(test_labels, preds, target_names=["HAM", "SMISH"]))

# Confusion matrix
cm = confusion_matrix(test_labels, preds)
print("\nConfusion Matrix:")
print("                Predicted")
print("              HAM    SMISH")
print(f"Actual HAM    {cm[0][0]:4}   {cm[0][1]:4}")
print(f"       SMISH  {cm[1][0]:4}   {cm[1][1]:4}")

## 6. Inference Examples

In [None]:
def classify_message(text, model, tokenizer):
    """
    Classify a message as HAM or SMISH
    """
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
    
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
    
    probs = torch.softmax(outputs.logits, dim=1)[0]
    pred = torch.argmax(probs).item()
    
    return {
        'prediction': 'HAM' if pred == 0 else 'SMISH',
        'confidence': probs[pred].item(),
        'ham_prob': probs[0].item(),
        'smish_prob': probs[1].item()
    }

# Test with examples
test_messages = [
    "Your Amazon package is delayed. Click here urgently to reschedule.",
    "URGENT: Your PayPal account suspended. Verify identity now.",
    "You've won ¬£5000! Call 0800-123-456 to claim your prize.",
    "Hi, are we still meeting for lunch today?",
    "Your appointment is confirmed for tomorrow at 3pm.",
    "FINAL NOTICE: Tax debt must be paid immediately or face legal action."
]

print("üîç Testing message classification:\n")
print("="*70)
for i, msg in enumerate(test_messages, 1):
    result = classify_message(msg, model, tokenizer)
    
    emoji = "‚úÖ" if result['prediction'] == 'HAM' else "‚ö†Ô∏è"
    print(f"\n{i}. {emoji} {result['prediction']} (confidence: {result['confidence']:.2%})")
    print(f"   Message: {msg[:60]}...")
    print(f"   Probabilities: HAM={result['ham_prob']:.2%}, SMISH={result['smish_prob']:.2%}")
    print("-"*70)

## 7. Save Model

In [None]:
# Save model
model.save_pretrained("./ham-smish-classifier-final")
tokenizer.save_pretrained("./ham-smish-classifier-final")

print("‚úÖ Model saved to ./ham-smish-classifier-final/")

# Download model (optional)
# !zip -r ham-smish-classifier.zip ./ham-smish-classifier-final
# from google.colab import files
# files.download('ham-smish-classifier.zip')

## 8. Results Summary

In [None]:
# Print final summary
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

acc = accuracy_score(test_labels, preds)
prec = precision_score(test_labels, preds, average='weighted')
rec = recall_score(test_labels, preds, average='weighted')
f1 = f1_score(test_labels, preds, average='weighted')

print("="*60)
print("TRAINING SUMMARY - APPROACH 4: CONTRASTIVE CLASSIFICATION")
print("="*60)
print(f"Model: {MODEL_NAME}")
print(f"Task: Binary Classification (HAM vs SMISH)")
print(f"\nDataset:")
print(f"  Train: {len(train_dataset)} examples")
print(f"  Val:   {len(val_dataset)} examples")
print(f"  Test:  {len(test_dataset)} examples")
print(f"\nTest Metrics:")
print(f"  Accuracy:  {acc:.3f}")
print(f"  Precision: {prec:.3f}")
print(f"  Recall:    {rec:.3f}")
print(f"  F1 Score:  {f1:.3f}")
print("="*60)

# Class-specific metrics
ham_prec = precision_score(test_labels, preds, pos_label=0)
ham_rec = recall_score(test_labels, preds, pos_label=0)
smish_prec = precision_score(test_labels, preds, pos_label=1)
smish_rec = recall_score(test_labels, preds, pos_label=1)

print(f"\nPer-Class Performance:")
print(f"  HAM:   Precision={ham_prec:.3f}, Recall={ham_rec:.3f}")
print(f"  SMISH: Precision={smish_prec:.3f}, Recall={smish_rec:.3f}")
print("="*60)