# Approach 2: Claim-Phrase NER

Train a RoBERTa-based NER model to directly extract claim phrases from SMS messages.

## Overview
- **Model**: RoBERTa-base for Token Classification
- **Task**: Extract 12 types of claims using BIO tagging
- **Labels**: IDENTITY_CLAIM, DELIVERY_CLAIM, FINANCIAL_CLAIM, ACCOUNT_CLAIM, URGENCY_CLAIM, ACTION_CLAIM, VERIFICATION_CLAIM, SECURITY_CLAIM, REWARD_CLAIM, LEGAL_CLAIM, SOCIAL_CLAIM, CREDENTIALS_CLAIM
- **Advantages**: Direct semantic capture, robust to variations, handles implicit claims

## Setup Instructions
1. Upload `claim_annotations_2000.json` to Colab
2. Run all cells in order
3. Model will be saved to Google Drive (optional)

## 1. Environment Setup

In [None]:
# Install required packages
!pip install -q transformers datasets accelerate seqeval scikit-learn torch

In [None]:
# Import libraries
import json
import torch
import numpy as np
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Tuple
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification
)
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
import warnings
warnings.filterwarnings('ignore')

print(f"‚úÖ PyTorch version: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   Device: {torch.cuda.get_device_name(0)}")

In [None]:
# Mount Google Drive (optional - for saving models)
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# Upload data file
from google.colab import files

print("üìÅ Please upload 'claim_annotations_2000.json'")
uploaded = files.upload()
data_file = list(uploaded.keys())[0]
print(f"‚úÖ Uploaded: {data_file}")

## 2. Define Claim Types and Labels

In [None]:
# Define the 12 claim types
CLAIM_TYPES = [
    'IDENTITY_CLAIM',      # "We are Amazon/PayPal/IRS"
    'DELIVERY_CLAIM',      # "Your package is delayed/stuck"
    'FINANCIAL_CLAIM',     # "You won $5000 / Prize available"
    'ACCOUNT_CLAIM',       # "Your account is suspended/locked"
    'URGENCY_CLAIM',       # "Act now / Expires tonight"
    'ACTION_CLAIM',        # "Click here / Call immediately"
    'VERIFICATION_CLAIM',  # "Verify your identity / Confirm details"
    'SECURITY_CLAIM',      # "Suspicious activity / Unauthorized access"
    'REWARD_CLAIM',        # "Loyalty bonus / Cashback available"
    'LEGAL_CLAIM',         # "Legal action / Tax penalty / Court summons"
    'SOCIAL_CLAIM',        # "Friend/family needs help"
    'CREDENTIALS_CLAIM'    # "Update password / Reset PIN"
]

# Create BIO labels
labels = ['O']  # Outside any claim
for claim_type in CLAIM_TYPES:
    labels.append(f'B-{claim_type}')  # Beginning
    labels.append(f'I-{claim_type}')  # Inside

label2id = {label: idx for idx, label in enumerate(labels)}
id2label = {idx: label for label, idx in label2id.items()}

print(f"Total labels: {len(labels)}")
print(f"\nLabel structure:")
print(f"  - O (outside): 1 label")
print(f"  - B-/I- tags: {len(CLAIM_TYPES)} √ó 2 = {len(CLAIM_TYPES)*2} labels")
print(f"\nFirst 10 labels: {labels[:10]}")

## 3. Data Loading and Preprocessing

In [None]:
def convert_to_bio_format(text, claim_spans):
    """
    Convert text and claim spans to BIO format at word level
    
    Args:
        text: The SMS message text
        claim_spans: List of {'text', 'start', 'end', 'label'} dicts
    
    Returns:
        tokens, labels: Lists of words and their BIO labels
    """
    # Split into words
    words = text.split()
    bio_labels = ['O'] * len(words)
    
    # Track character position
    char_pos = 0
    
    for word_idx, word in enumerate(words):
        # Find word position in text
        word_start = text.find(word, char_pos)
        if word_start == -1:
            continue
            
        word_end = word_start + len(word)
        char_pos = word_end
        
        # Check if word overlaps with any claim span
        for span in claim_spans:
            span_start = span['start']
            span_end = span['end']
            claim_label = span['label']
            
            # Check overlap
            if not (word_end <= span_start or word_start >= span_end):
                # Word overlaps with claim
                # Use B- if word starts the claim, otherwise I-
                if word_start <= span_start < word_end:
                    bio_labels[word_idx] = f'B-{claim_label}'
                else:
                    # Check if previous word was also in this claim
                    if word_idx > 0 and bio_labels[word_idx-1] in [f'B-{claim_label}', f'I-{claim_label}']:
                        bio_labels[word_idx] = f'I-{claim_label}'
                    else:
                        bio_labels[word_idx] = f'B-{claim_label}'
                break
    
    return words, bio_labels

# Test the BIO conversion
test_text = "Your Amazon package is delayed. Click here urgently."
test_spans = [
    {'text': 'Amazon', 'start': 5, 'end': 11, 'label': 'IDENTITY_CLAIM'},
    {'text': 'package is delayed', 'start': 12, 'end': 30, 'label': 'DELIVERY_CLAIM'},
    {'text': 'Click here', 'start': 32, 'end': 42, 'label': 'ACTION_CLAIM'},
    {'text': 'urgently', 'start': 43, 'end': 51, 'label': 'URGENCY_CLAIM'}
]

test_words, test_labels = convert_to_bio_format(test_text, test_spans)
print("Test BIO conversion:")
for word, label in zip(test_words, test_labels):
    print(f"  {word:15} -> {label}")

In [None]:
# Load and convert annotations
def load_claim_data(json_file):
    """Load claim annotations and convert to NER format"""
    with open(json_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    examples = []
    
    for entry in data:
        text = entry['data']['text']
        
        # Check if has annotations
        if not entry.get('annotations') or len(entry['annotations']) == 0:
            continue
        
        annotations = entry['annotations'][0]
        
        # Extract claim spans
        claim_spans = []
        if 'result' in annotations and annotations['result']:
            for result in annotations['result']:
                value = result.get('value', {})
                labels_list = value.get('labels', [])
                
                if labels_list:
                    claim_spans.append({
                        'text': value.get('text', ''),
                        'start': value.get('start', 0),
                        'end': value.get('end', 0),
                        'label': labels_list[0]
                    })
        
        # Convert to BIO format
        tokens, bio_labels = convert_to_bio_format(text, claim_spans)
        
        examples.append({
            'id': entry.get('id'),
            'text': text,
            'tokens': tokens,
            'labels': bio_labels,
            'claim_spans': claim_spans
        })
    
    return examples

# Load data
print("Loading data...")
examples = load_claim_data(data_file)
print(f"‚úÖ Loaded {len(examples)} examples")

# Show first example
print("\nüìù First example:")
ex = examples[0]
print(f"  Text: {ex['text'][:80]}...")
print(f"  Tokens: {ex['tokens'][:5]}...")
print(f"  Labels: {ex['labels'][:5]}...")
print(f"  Claims: {len(ex['claim_spans'])} spans")

In [None]:
# Split data
train_examples, test_examples = train_test_split(examples, test_size=0.15, random_state=42)
train_examples, val_examples = train_test_split(train_examples, test_size=0.176, random_state=42)  # 0.15/0.85

print(f"Dataset split:")
print(f"  Train: {len(train_examples)} examples")
print(f"  Val:   {len(val_examples)} examples")
print(f"  Test:  {len(test_examples)} examples")

# Count labels
from collections import Counter
all_labels = []
for ex in train_examples:
    all_labels.extend(ex['labels'])

label_counts = Counter(all_labels)
print(f"\nüìä Label distribution in training set:")
for label, count in sorted(label_counts.items(), key=lambda x: -x[1])[:15]:
    print(f"  {label:25} : {count:5} tokens")

## 4. Tokenization and Dataset Preparation

In [None]:
# Load tokenizer
MODEL_NAME = "roberta-base"  # or "distilroberta-base" for faster training
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, add_prefix_space=True)

print(f"‚úÖ Loaded tokenizer: {MODEL_NAME}")
print(f"   Vocab size: {tokenizer.vocab_size}")

In [None]:
def tokenize_and_align_labels(examples, max_length=128):
    """
    Tokenize text and align labels with subword tokens
    """
    tokenized_inputs = tokenizer(
        [ex['text'] for ex in examples],
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_offsets_mapping=True,
        is_split_into_words=False
    )
    
    aligned_labels = []
    
    for i, example in enumerate(examples):
        word_labels = example['labels']
        text = example['text']
        offset_mapping = tokenized_inputs['offset_mapping'][i]
        
        # Create character-level label map
        char_labels = ['O'] * len(text)
        char_pos = 0
        
        for word, label in zip(example['tokens'], word_labels):
            word_start = text.find(word, char_pos)
            if word_start != -1:
                word_end = word_start + len(word)
                for j in range(word_start, word_end):
                    char_labels[j] = label
                char_pos = word_end
        
        # Align with subword tokens
        labels = []
        for start, end in offset_mapping:
            if start == 0 and end == 0:
                # Special token
                labels.append(-100)
            else:
                # Use label of first character
                if start < len(char_labels):
                    labels.append(label2id.get(char_labels[start], 0))
                else:
                    labels.append(0)  # O label
        
        aligned_labels.append(labels)
    
    # Remove offset_mapping (not needed for training)
    tokenized_inputs.pop('offset_mapping')
    tokenized_inputs['labels'] = aligned_labels
    
    return tokenized_inputs

# Tokenize datasets
print("Tokenizing datasets...")
train_tokenized = tokenize_and_align_labels(train_examples)
val_tokenized = tokenize_and_align_labels(val_examples)
test_tokenized = tokenize_and_align_labels(test_examples)
print("‚úÖ Tokenization complete")

In [None]:
# Create PyTorch datasets
from torch.utils.data import Dataset

class NERDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
        
    def __len__(self):
        return len(self.encodings['input_ids'])
    
    def __getitem__(self, idx):
        return {
            key: torch.tensor(val[idx]) 
            for key, val in self.encodings.items()
        }

train_dataset = NERDataset(train_tokenized)
val_dataset = NERDataset(val_tokenized)
test_dataset = NERDataset(test_tokenized)

print(f"‚úÖ Created PyTorch datasets")
print(f"   Train: {len(train_dataset)}")
print(f"   Val: {len(val_dataset)}")
print(f"   Test: {len(test_dataset)}")

## 5. Model Training

In [None]:
# Load model
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id
)

print(f"‚úÖ Loaded model: {MODEL_NAME}")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Define metrics
def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=2)
    
    # Convert to label strings (ignore -100)
    true_labels = []
    pred_labels = []
    
    for prediction, label in zip(predictions, labels):
        true_label = []
        pred_label = []
        
        for p, l in zip(prediction, label):
            if l != -100:
                true_label.append(id2label[l])
                pred_label.append(id2label[p])
        
        true_labels.append(true_label)
        pred_labels.append(pred_label)
    
    # Compute metrics
    return {
        "precision": precision_score(true_labels, pred_labels),
        "recall": recall_score(true_labels, pred_labels),
        "f1": f1_score(true_labels, pred_labels),
    }

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./claim-ner-model",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    push_to_hub=False,
    report_to="none"
)

# Data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

print("‚úÖ Trainer initialized")

In [None]:
# Train the model
print("üöÄ Starting training...")
trainer.train()
print("‚úÖ Training complete!")

## 6. Evaluation

In [None]:
# Evaluate on test set
print("üìä Evaluating on test set...")
results = trainer.predict(test_dataset)
predictions = np.argmax(results.predictions, axis=2)

# Convert to label strings
true_labels = []
pred_labels = []

for prediction, label in zip(predictions, results.label_ids):
    true_label = []
    pred_label = []
    
    for p, l in zip(prediction, label):
        if l != -100:
            true_label.append(id2label[l])
            pred_label.append(id2label[p])
    
    true_labels.append(true_label)
    pred_labels.append(pred_label)

# Print detailed classification report
print("\n" + "="*60)
print("CLASSIFICATION REPORT")
print("="*60)
print(classification_report(true_labels, pred_labels))

## 7. Inference Examples

In [None]:
def extract_claims(text, model, tokenizer, id2label):
    """
    Extract claims from a text message
    """
    # Tokenize
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=128,
        return_offsets_mapping=True
    )
    
    offset_mapping = inputs.pop('offset_mapping')[0]
    
    # Predict
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
    
    predictions = torch.argmax(outputs.logits, dim=2)[0]
    probabilities = torch.softmax(outputs.logits, dim=2)[0]
    
    # Extract claims
    claims = []
    current_claim = None
    
    for idx, (pred, prob, (start, end)) in enumerate(zip(predictions, probabilities, offset_mapping)):
        if start == 0 and end == 0:
            continue
        
        label = id2label[pred.item()]
        confidence = prob[pred].item()
        
        if label.startswith('B-'):
            # Start new claim
            if current_claim:
                claims.append(current_claim)
            
            current_claim = {
                'type': label[2:],
                'start': start.item(),
                'end': end.item(),
                'confidence': confidence
            }
        
        elif label.startswith('I-') and current_claim:
            # Continue current claim
            if label[2:] == current_claim['type']:
                current_claim['end'] = end.item()
                current_claim['confidence'] = (current_claim['confidence'] + confidence) / 2
        
        elif label == 'O' and current_claim:
            # End current claim
            claims.append(current_claim)
            current_claim = None
    
    if current_claim:
        claims.append(current_claim)
    
    # Add text to claims
    for claim in claims:
        claim['text'] = text[claim['start']:claim['end']]
    
    return claims

# Test with examples
test_messages = [
    "Your Amazon package is delayed. Click here urgently to reschedule delivery.",
    "URGENT: Your PayPal account has been suspended. Verify your identity now to avoid legal action.",
    "Congratulations! You've won ¬£5000. Call 0800-123-456 to claim your prize today.",
    "Hi, are we still meeting for lunch?"
]

print("üîç Testing claim extraction:\n")
for i, msg in enumerate(test_messages, 1):
    print(f"\n{i}. Message: {msg}")
    claims = extract_claims(msg, model, tokenizer, id2label)
    
    if claims:
        print(f"   Found {len(claims)} claims:")
        for claim in claims:
            print(f"     - {claim['type']:20} : '{claim['text']}' (conf: {claim['confidence']:.2f})")
    else:
        print("   ‚úÖ No claims detected (likely HAM)")

## 8. Save Model

In [None]:
# Save model locally
model.save_pretrained("./claim-ner-final")
tokenizer.save_pretrained("./claim-ner-final")

# Save label mappings
import json
with open("./claim-ner-final/label_mappings.json", "w") as f:
    json.dump({
        'label2id': label2id,
        'id2label': {int(k): v for k, v in id2label.items()},
        'claim_types': CLAIM_TYPES
    }, f, indent=2)

print("‚úÖ Model saved to ./claim-ner-final/")

# Download model (optional)
# !zip -r claim-ner-final.zip ./claim-ner-final
# from google.colab import files
# files.download('claim-ner-final.zip')

## 9. Results Summary

In [None]:
# Print final summary
print("="*60)
print("TRAINING SUMMARY - APPROACH 2: CLAIM-PHRASE NER")
print("="*60)
print(f"Model: {MODEL_NAME}")
print(f"Training examples: {len(train_dataset)}")
print(f"Validation examples: {len(val_dataset)}")
print(f"Test examples: {len(test_dataset)}")
print(f"\nNumber of claim types: {len(CLAIM_TYPES)}")
print(f"Total labels (BIO): {len(labels)}")
print(f"\nTest Metrics:")
print(f"  Precision: {precision_score(true_labels, pred_labels):.3f}")
print(f"  Recall:    {recall_score(true_labels, pred_labels):.3f}")
print(f"  F1 Score:  {f1_score(true_labels, pred_labels):.3f}")
print("="*60)