# Approach 1: Entity-First NER

Two-step approach: Extract entities first, then parse them into claims.

## Overview
- **Step 1**: Train RoBERTa-based NER to extract entities (BRAND, PHONE, URL, EMAIL, etc.)
- **Step 2**: Parse extracted entities into structured claims using rules
- **Advantages**: Entities are concrete, clear intermediate representation, reusable
- **Use Case**: When you need explicit entity extraction for other purposes

## Entity Types
- BRAND (Amazon, PayPal, IRS, etc.)
- PHONE (phone numbers)
- URL (links)
- EMAIL (email addresses)
- AMOUNT (monetary amounts)
- DATE (time references)
- ACCOUNT (account numbers/references)

## Setup Instructions
1. Upload `entity_annotations_2000.json` to Colab
2. Run all cells in order
3. Model extracts entities, then parses to claims

## 1. Environment Setup

In [None]:
# Install required packages
!pip install -q transformers datasets accelerate seqeval scikit-learn torch

In [None]:
# Import libraries
import json
import torch
import numpy as np
import re
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Tuple
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification
)
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
import warnings
warnings.filterwarnings('ignore')

print(f" PyTorch version: {torch.__version__}")
print(f" CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   Device: {torch.cuda.get_device_name(0)}")

In [None]:
from google.colab import files
from tqdm.auto import tqdm

print("Please upload '
entity_annotations_2000.json
'")
uploaded = files.upload()
data_file = list(uploaded.keys())[0]
print(f"Uploaded: {data_file}")

## 2. Define Entity Types

In [None]:
# Define entity types
ENTITY_TYPES = [
    'BRAND',      # Company/organization names
    'PHONE',      # Phone numbers
    'URL',        # Web links
    'EMAIL',      # Email addresses
    'AMOUNT',     # Money/prizes
    'DATE',       # Time references
    'ACCOUNT',    # Account numbers/IDs
    'PERSON',     # Person names
    'LOCATION'    # Places
]

# Create BIO labels
labels = ['O']  # Outside
for entity_type in ENTITY_TYPES:
    labels.append(f'B-{entity_type}')
    labels.append(f'I-{entity_type}')

label2id = {label: idx for idx, label in enumerate(labels)}
id2label = {idx: label for label, idx in label2id.items()}

print(f"Total labels: {len(labels)}")
print(f"Entity types: {len(ENTITY_TYPES)}")
print(f"Labels: {labels[:10]}...")

## 3. Data Loading (Same as Approach 2)

In [None]:
def convert_to_bio_format(text, entity_spans):
    """
    Convert text and entity spans to BIO format
    """
    words = text.split()
    bio_labels = ['O'] * len(words)
    
    char_pos = 0
    
    for word_idx, word in enumerate(words):
        word_start = text.find(word, char_pos)
        if word_start == -1:
            continue
            
        word_end = word_start + len(word)
        char_pos = word_end
        
        for span in entity_spans:
            span_start = span['start']
            span_end = span['end']
            entity_label = span['label']
            
            if not (word_end <= span_start or word_start >= span_end):
                if word_start <= span_start < word_end:
                    bio_labels[word_idx] = f'B-{entity_label}'
                else:
                    if word_idx > 0 and bio_labels[word_idx-1] in [f'B-{entity_label}', f'I-{entity_label}']:
                        bio_labels[word_idx] = f'I-{entity_label}'
                    else:
                        bio_labels[word_idx] = f'B-{entity_label}'
                break
    
    return words, bio_labels

def load_entity_data(json_file):
    """Load entity annotations"""
    with open(json_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    examples = []
    
    for entry in data:
        text = entry['data']['text']
        
        if not entry.get('annotations') or len(entry['annotations']) == 0:
            continue
        
        annotations = entry['annotations'][0]
        
        entity_spans = []
        if 'result' in annotations and annotations['result']:
            for result in annotations['result']:
                value = result.get('value', {})
                labels_list = value.get('labels', [])
                
                if labels_list:
                    entity_spans.append({
                        'text': value.get('text', ''),
                        'start': value.get('start', 0),
                        'end': value.get('end', 0),
                        'label': labels_list[0]
                    })
        
        tokens, bio_labels = convert_to_bio_format(text, entity_spans)
        
        examples.append({
            'id': entry.get('id'),
            'text': text,
            'tokens': tokens,
            'labels': bio_labels,
            'entity_spans': entity_spans
        })
    
    return examples

# Load data
print("Loading data...")
examples = load_entity_data(data_file)
print(f" Loaded {len(examples)} examples")

# Show example
print("\n First example with entities:")
for ex in examples[:5]:
    if ex['entity_spans']:
        print(f"  Text: {ex['text'][:60]}...")
        print(f"  Entities: {len(ex['entity_spans'])}")
        for span in ex['entity_spans'][:3]:
            print(f"    - {span['label']:10} : '{span['text']}'")
        break

In [None]:
# Split data with stratification (balanced ham/smish)
# UPDATED: Match Approach 5 - 80/20 split, no validation set
example_labels = []
for ex in examples:
    # SMISH = has entities, HAM = no entities
    is_smish = len(ex.get('entity_spans', [])) > 0
    example_labels.append('SMISH' if is_smish else 'HAM')

# Stratified split to maintain ham/smish balance
train_examples, test_examples = train_test_split(
    examples, 
    test_size=0.20,  # Same as Approach 5
    random_state=42,
    stratify=example_labels
)

print(f"Dataset split:")
print(f"  Train: {len(train_examples)} examples")
print(f"  Test:  {len(test_examples)} examples")

# Count ham/smish distribution
from collections import Counter
train_labels = []
test_labels = []
for ex in train_examples:
    is_smish = len(ex.get('entity_spans', [])) > 0
    train_labels.append('SMISH' if is_smish else 'HAM')
for ex in test_examples:
    is_smish = len(ex.get('entity_spans', [])) > 0
    test_labels.append('SMISH' if is_smish else 'HAM')

train_dist = Counter(train_labels)
test_dist = Counter(test_labels)
print(f"\nTrain distribution:")
print(f"  HAM:   {train_dist['HAM']} ({train_dist['HAM']/len(train_examples)*100:.1f}%)")
print(f"  SMISH: {train_dist['SMISH']} ({train_dist['SMISH']/len(train_examples)*100:.1f}%)")
print(f"\nTest distribution:")
print(f"  HAM:   {test_dist['HAM']} ({test_dist['HAM']/len(test_examples)*100:.1f}%)")
print(f"  SMISH: {test_dist['SMISH']} ({test_dist['SMISH']/len(test_examples)*100:.1f}%)")

# Count labels
all_labels = []
for ex in train_examples:
    all_labels.extend(ex['labels'])

## 4. Tokenization and Model Training (Same as Approach 2)

In [None]:
# Load tokenizer
MODEL_NAME = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, add_prefix_space=True)

# Tokenization function (same as Approach 2)
def tokenize_and_align_labels(examples, max_length=128):
    tokenized_inputs = tokenizer(
        [ex['text'] for ex in examples],
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_offsets_mapping=True,
        is_split_into_words=False
    )
    
    aligned_labels = []
    
    for i, example in enumerate(examples):
        word_labels = example['labels']
        text = example['text']
        offset_mapping = tokenized_inputs['offset_mapping'][i]
        
        char_labels = ['O'] * len(text)
        char_pos = 0
        
        for word, label in zip(example['tokens'], word_labels):
            word_start = text.find(word, char_pos)
            if word_start != -1:
                word_end = word_start + len(word)
                for j in range(word_start, word_end):
                    char_labels[j] = label
                char_pos = word_end
        
        labels = []
        for start, end in offset_mapping:
            if start == 0 and end == 0:
                labels.append(-100)
            else:
                if start < len(char_labels):
                    labels.append(label2id.get(char_labels[start], 0))
                else:
                    labels.append(0)
        
        aligned_labels.append(labels)
    
    tokenized_inputs.pop('offset_mapping')
    tokenized_inputs['labels'] = aligned_labels
    
    return tokenized_inputs

# Tokenize - UPDATED: No validation set
print("Tokenizing...")
train_tokenized = tokenize_and_align_labels(train_examples)
test_tokenized = tokenize_and_align_labels(test_examples)
print("Done")

# Create datasets
from torch.utils.data import Dataset

class NERDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __len__(self):
        return len(self.encodings['input_ids'])
    
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

train_dataset = NERDataset(train_tokenized)
test_dataset = NERDataset(test_tokenized)

print(f"Created datasets: Train={len(train_dataset)}, Test={len(test_dataset)}")

In [None]:
# Load model
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id
)

print(f" Model loaded")

In [None]:
# Metrics
def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=2)
    
    true_labels = []
    pred_labels = []
    
    for prediction, label in zip(predictions, labels):
        true_label = []
        pred_label = []
        
        for p, l in zip(prediction, label):
            if l != -100:
                true_label.append(id2label[l])
                pred_label.append(id2label[p])
        
        true_labels.append(true_label)
        pred_labels.append(pred_label)
    
    return {
        "precision": precision_score(true_labels, pred_labels),
        "recall": recall_score(true_labels, pred_labels),
        "f1": f1_score(true_labels, pred_labels),
    }

In [None]:
# Training arguments - IMPROVED (Match Approach 5)
training_args = TrainingArguments(
    output_dir="./entity-ner-model",
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    learning_rate=2e-5,  # Slightly lower for better convergence
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=15,  # More epochs for better learning
    weight_decay=0.01,
    warmup_ratio=0.2,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    push_to_hub=False,
    report_to="none",
    fp16=True,
    gradient_accumulation_steps=2,
    lr_scheduler_type="cosine"  # Better learning rate schedule
)

data_collator = DataCollatorForTokenClassification(tokenizer)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

print("Trainer initialized")

In [None]:
# Train
print("="*60)
print("TRAINING ENTITY NER MODEL")
print("="*60)
trainer.train()
print("\nTraining complete!")

## 5. Evaluation

In [None]:
# Evaluate
print(" Evaluating...")
results = trainer.predict(test_dataset)
predictions = np.argmax(results.predictions, axis=2)

true_labels = []
pred_labels = []

for prediction, label in zip(predictions, results.label_ids):
    true_label = []
    pred_label = []
    
    for p, l in zip(prediction, label):
        if l != -100:
            true_label.append(id2label[l])
            pred_label.append(id2label[p])
    
    true_labels.append(true_label)
    pred_labels.append(pred_label)

print("\n" + "="*60)
print("CLASSIFICATION REPORT")
print("="*60)
print(classification_report(true_labels, pred_labels))

In [None]:
# Save detailed test results
print("\nSaving detailed test results...")

test_results = []

for example in test_examples:
    text = example['text']
    gt_entities = example['entity_spans']
    
    # Get model predictions
    pred_entities = extract_entities(text, model, tokenizer, id2label)
    
    # Calculate entity-level metrics - RELAXED MATCHING (>50% overlap)
    matched_gt = set()
    matched_pred = set()
    
    for i, pred in enumerate(pred_entities):
        best_match = None
        best_overlap = 0
        
        for j, gt in enumerate(gt_entities):
            if j in matched_gt:
                continue
            
            # Check overlap
            overlap_start = max(pred['start'], gt['start'])
            overlap_end = min(pred['end'], gt['end'])
            overlap_len = max(0, overlap_end - overlap_start)
            
            # Calculate overlap percentage (relative to GT length)
            gt_len = gt['end'] - gt['start']
            overlap_ratio = overlap_len / gt_len if gt_len > 0 else 0
            
            # Require >50% overlap AND correct type
            if overlap_ratio > 0.5 and pred['type'] == gt['label']:
                if overlap_ratio > best_overlap:
                    best_match = j
                    best_overlap = overlap_ratio
        
        if best_match is not None:
            matched_gt.add(best_match)
            matched_pred.add(i)
    
    matched_entities = len(matched_gt)
    
    # Build result entry
    result_entry = {
        'id': example.get('id'),
        'text': text,
        'ground_truth': {
            'num_entities': len(gt_entities),
            'entities': [
                {
                    'type': entity['label'],
                    'text': entity['text'],
                    'start': entity['start'],
                    'end': entity['end']
                }
                for entity in gt_entities
            ]
        },
        'prediction': {
            'num_entities': len(pred_entities),
            'entities': [
                {
                    'type': entity['type'],
                    'text': entity['text'],
                    'start': entity['start'],
                    'end': entity['end'],
                    'confidence': round(entity['confidence'], 3)
                }
                for entity in pred_entities
            ]
        },
        'evaluation': {
            'entity_count_match': len(gt_entities) == len(pred_entities),
            'entity_count_diff': len(pred_entities) - len(gt_entities),
            'matched_entities': matched_entities,
            'precision': len(matched_pred) / len(pred_entities) if len(pred_entities) > 0 else 0,
            'recall': matched_entities / len(gt_entities) if len(gt_entities) > 0 else (1 if len(pred_entities) == 0 else 0)
        }
    }
    
    test_results.append(result_entry)

# Save to Google Drive
from google.colab import drive
import os

try:
    drive.mount('/content/drive')
except:
    pass

save_dir = '/content/drive/MyDrive/sms_claim_models/approach1_entity_ner'
os.makedirs(save_dir, exist_ok=True)

results_path = f"{save_dir}/test_results_detailed.json"
with open(results_path, 'w', encoding='utf-8') as f:
    json.dump(test_results, f, indent=2, ensure_ascii=False)

# Calculate aggregate metrics
total = len(test_results)
total_gt_entities = sum(len(r['ground_truth']['entities']) for r in test_results)
total_pred_entities = sum(len(r['prediction']['entities']) for r in test_results)
total_matched = sum(r['evaluation']['matched_entities'] for r in test_results)

overall_precision = total_matched / total_pred_entities if total_pred_entities > 0 else 0
overall_recall = total_matched / total_gt_entities if total_gt_entities > 0 else 0
overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0

count_correct = sum(1 for r in test_results if r['evaluation']['entity_count_match'])

summary = {
    'approach': 'approach1_entity_first_ner',
    'evaluation_method': 'relaxed_matching_50pct_overlap',
    'total_test_examples': total,
    'total_ground_truth_entities': total_gt_entities,
    'total_predicted_entities': total_pred_entities,
    'total_matched_entities': total_matched,
    'metrics': {
        'entity_count_accuracy': round(count_correct / total, 3),
        'precision': round(overall_precision, 3),
        'recall': round(overall_recall, 3),
        'f1_score': round(overall_f1, 3)
    },
    'entity_statistics': {
        'avg_entities_per_message_gt': round(total_gt_entities / total, 2),
        'avg_entities_per_message_pred': round(total_pred_entities / total, 2),
        'messages_with_entities_gt': sum(1 for r in test_results if len(r['ground_truth']['entities']) > 0),
        'messages_with_entities_pred': sum(1 for r in test_results if len(r['prediction']['entities']) > 0)
    }
}

summary_path = f"{save_dir}/test_results_summary.json"
with open(summary_path, 'w', encoding='utf-8') as f:
    json.dump(summary, f, indent=2)

print(f"\n{'='*70}")
print("TEST RESULTS SAVED - RELAXED ENTITY MATCHING (>50% overlap)")
print(f"{'='*70}")
print(f"Detailed results: {results_path}")
print(f"Summary: {summary_path}")
print(f"\nEntity-Level Metrics (Relaxed):")
print(f"  Total GT entities: {total_gt_entities}")
print(f"  Total Predicted entities: {total_pred_entities}")
print(f"  Matched entities: {total_matched}")
print(f"\n  Precision: {summary['metrics']['precision']:.1%}")
print(f"  Recall: {summary['metrics']['recall']:.1%}")
print(f"  F1 Score: {summary['metrics']['f1_score']:.1%}")
print(f"\n  Exact count match: {summary['metrics']['entity_count_accuracy']:.1%}")
print(f"\nNote: Accepts matches with >50% overlap with ground truth")
print(f"{'='*70}")

## 6. Entity Extraction + Claim Parsing

In [None]:
def extract_entities(text, model, tokenizer, id2label):
    """
    Extract entities from text
    """
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=128,
        return_offsets_mapping=True
    )
    
    offset_mapping = inputs.pop('offset_mapping')[0]
    
    # Move inputs to same device as model
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
    
    predictions = torch.argmax(outputs.logits, dim=2)[0]
    probabilities = torch.softmax(outputs.logits, dim=2)[0]
    
    entities = []
    current_entity = None
    
    for idx, (pred, prob, (start, end)) in enumerate(zip(predictions, probabilities, offset_mapping)):
        if start == 0 and end == 0:
            continue
        
        label = id2label[pred.item()]
        confidence = prob[pred].item()
        
        if label.startswith('B-'):
            if current_entity:
                entities.append(current_entity)
            
            current_entity = {
                'type': label[2:],
                'start': start.item(),
                'end': end.item(),
                'confidence': confidence
            }
        
        elif label.startswith('I-') and current_entity:
            if label[2:] == current_entity['type']:
                current_entity['end'] = end.item()
                current_entity['confidence'] = (current_entity['confidence'] + confidence) / 2
        
        elif label == 'O' and current_entity:
            entities.append(current_entity)
            current_entity = None
    
    if current_entity:
        entities.append(current_entity)
    
    for entity in entities:
        entity['text'] = text[entity['start']:entity['end']]
    
    return entities

def parse_entities_to_claims(entities, text):
    """
    Parse extracted entities into structured claims
    This is a rule-based approach
    """
    claims = []
    
    # IDENTITY_CLAIM: BRAND entities
    for entity in entities:
        if entity['type'] == 'BRAND':
            claims.append({
                'type': 'IDENTITY_CLAIM',
                'text': entity['text'],
                'evidence': f"Claims to be from {entity['text']}",
                'confidence': entity['confidence']
            })
    
    # ACTION_CLAIM: PHONE or URL entities
    for entity in entities:
        if entity['type'] in ['PHONE', 'URL']:
            claims.append({
                'type': 'ACTION_CLAIM',
                'text': entity['text'],
                'evidence': f"Requests action via {entity['type'].lower()}",
                'confidence': entity['confidence']
            })
    
    # FINANCIAL_CLAIM: AMOUNT entities
    for entity in entities:
        if entity['type'] == 'AMOUNT':
            claims.append({
                'type': 'FINANCIAL_CLAIM',
                'text': entity['text'],
                'evidence': f"Mentions money: {entity['text']}",
                'confidence': entity['confidence']
            })
    
    # ACCOUNT_CLAIM: ACCOUNT entities
    for entity in entities:
        if entity['type'] == 'ACCOUNT':
            claims.append({
                'type': 'ACCOUNT_CLAIM',
                'text': entity['text'],
                'evidence': f"References account: {entity['text']}",
                'confidence': entity['confidence']
            })
    
    # URGENCY_CLAIM: Check for urgency keywords
    urgency_keywords = ['urgent', 'now', 'immediately', 'asap', 'today', 'expires', '24 hours']
    text_lower = text.lower()
    for keyword in urgency_keywords:
        if keyword in text_lower:
            claims.append({
                'type': 'URGENCY_CLAIM',
                'text': keyword,
                'evidence': f"Uses urgency language: '{keyword}'",
                'confidence': 0.8
            })
            break
    
    return claims

## 7. Complete Pipeline Demo

In [None]:
# Test complete pipeline
test_messages = [
    "Your Amazon package is delayed. Click here urgently to reschedule delivery.",
    "URGENT: Your PayPal account suspended. Call 0800-123-456 now to verify.",
    "You've won £5000! Visit www.claim-prize.com immediately.",
]

print(" TWO-STEP PIPELINE: Entity Extraction → Claim Parsing\n")
print("="*70)

for i, msg in enumerate(test_messages, 1):
    print(f"\n{i}. Message: {msg}")
    print("-"*70)
    
    # Step 1: Extract entities
    entities = extract_entities(msg, model, tokenizer, id2label)
    print(f"\n   STEP 1 - Entities Extracted:")
    if entities:
        for entity in entities:
            print(f"     - {entity['type']:10} : '{entity['text']}' (conf: {entity['confidence']:.2f})")
    else:
        print("     (no entities found)")
    
    # Step 2: Parse to claims
    claims = parse_entities_to_claims(entities, msg)
    print(f"\n   STEP 2 - Claims Parsed:")
    if claims:
        for claim in claims:
            print(f"     - {claim['type']:20} : {claim['evidence']}")
    else:
        print("     (no claims generated)")
    
    print("="*70)

## 8. Save Model

In [None]:
# Mount Google Drive (if not already mounted)
from google.colab import drive
import os

try:
    drive.mount('/content/drive')
except:
    print("Drive already mounted")

# Save to Google Drive
save_dir = '/content/drive/MyDrive/sms_claim_models/approach1_entity_ner'
os.makedirs(save_dir, exist_ok=True)

print(f"\nSaving model to Google Drive: {save_dir}")

model_path = f"{save_dir}/final_model"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

with open(f"{save_dir}/label_mappings.json", "w") as f:
    json.dump({
        'label2id': label2id,
        'id2label': {int(k): v for k, v in id2label.items()},
        'entity_types': ENTITY_TYPES
    }, f, indent=2)

print(f"✓ Model saved to: {save_dir}")
print(f"  - Model files: {model_path}/")
print(f"  - Label mappings: {save_dir}/label_mappings.json")
print("\nYou can access it from Google Drive!")


## 9. Results Summary

In [None]:
print("="*60)
print("APPROACH 1: ENTITY-FIRST NER")
print("="*60)
print(f"Two-step pipeline:")
print(f"  1. Extract entities (BRAND, PHONE, URL, etc.)")
print(f"  2. Parse entities → structured claims")
print(f"\nAdvantages:")
print(f"   Entities are concrete and well-defined")
print(f"   Clear intermediate representation")
print(f"   Reusable entity extraction")
print(f"\nTest Metrics:")
print(f"  Precision: {precision_score(true_labels, pred_labels):.3f}")
print(f"  Recall:    {recall_score(true_labels, pred_labels):.3f}")
print(f"  F1 Score:  {f1_score(true_labels, pred_labels):.3f}")
print("="*60)