# Approach 5: Pure NER Claim Extraction (No Context Filtering)

## Overview
- Extract claims directly using NER model
- Fixed dataset: Removed 31 near-duplicates
- Grouped rare claims into OTHER_CLAIM
- Fixed word-splitting issues with better BIO alignment

## Improvements:
1. Resolved 112 overlapping spans
2. Better tokenization alignment (fixes "Hur", "ry", "!" splitting)
3. Uses balanced 2000 dataset (1000 HAM, 1000 SMISH)
4. No duplicates

## 1. Environment Setup

In [None]:
# Install required packages
!pip install -q transformers datasets accelerate seqeval scikit-learn torch

In [None]:
# Import libraries
import json
import torch
import numpy as np
from pathlib import Path
from typing import List, Dict, Tuple
from collections import defaultdict, Counter
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification
)
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score, accuracy_score
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
save_dir = '/content/drive/MyDrive/sms_claim_models/approach5_pure_ner'
os.makedirs(save_dir, exist_ok=True)
print(f"Models will be saved to: {save_dir}")

In [None]:
# Upload dataset
from google.colab import files

print("Please upload 'claim_annotations_2000_balanced.json'")
uploaded = files.upload()
data_file = list(uploaded.keys())[0]
print(f"Uploaded: {data_file}")

## 2. Define Claim Types and Labels

In [None]:
# Group rare claims to handle data scarcity
CLAIM_TYPES = [
    'ACTION_CLAIM',
    'URGENCY_CLAIM',
    'REWARD_CLAIM',
    'FINANCIAL_CLAIM',
    'ACCOUNT_CLAIM',
    'DELIVERY_CLAIM',
    'VERIFICATION_CLAIM',
    'OTHER_CLAIM'  # Groups: SECURITY, IDENTITY, CREDENTIALS, LEGAL, SOCIAL
]

RARE_CLAIMS = ['SECURITY_CLAIM', 'IDENTITY_CLAIM', 'CREDENTIALS_CLAIM', 'LEGAL_CLAIM', 'SOCIAL_CLAIM']

def normalize_claim_type(claim_type):
    return 'OTHER_CLAIM' if claim_type in RARE_CLAIMS else claim_type

# Create BIO labels
labels = ['O']
for claim_type in CLAIM_TYPES:
    labels.append(f'B-{claim_type}')
    labels.append(f'I-{claim_type}')

label2id = {label: idx for idx, label in enumerate(labels)}
id2label = {idx: label for label, idx in label2id.items()}

print(f"Total labels: {len(labels)}")
print(f"Claim types: {CLAIM_TYPES}")

## 3. Data Loading with Overlap Resolution

KEY FIX: Better BIO tagging to prevent word splitting

In [None]:
def resolve_overlapping_spans(claim_spans):
    """Resolve overlapping claim spans by keeping longer ones"""
    if not claim_spans:
        return []
    
    sorted_spans = sorted(claim_spans, key=lambda x: (x['start'], -len(x['text'])))
    resolved = []
    
    for span in sorted_spans:
        has_overlap = False
        for existing in resolved:
            if not (span['end'] <= existing['start'] or span['start'] >= existing['end']):
                if len(span['text']) > len(existing['text']):
                    resolved.remove(existing)
                    resolved.append(span)
                has_overlap = True
                break
        
        if not has_overlap:
            resolved.append(span)
    
    return sorted(resolved, key=lambda x: x['start'])

def convert_to_bio_format(text, claim_spans):
    """
    Convert text and claim spans to BIO format
    FIX: Use character-level labeling to prevent word splitting
    """
    claim_spans = resolve_overlapping_spans(claim_spans)
    
    # Create character-level labels
    char_labels = ['O'] * len(text)
    
    for span in claim_spans:
        claim_label = normalize_claim_type(span['label'])
        start = span['start']
        end = span['end']
        
        # Mark first character as B-, rest as I-
        if start < len(char_labels):
            char_labels[start] = f'B-{claim_label}'
        
        for i in range(start + 1, min(end, len(char_labels))):
            char_labels[i] = f'I-{claim_label}'
    
    # Convert to word-level for compatibility
    words = text.split()
    word_labels = []
    char_pos = 0
    
    for word in words:
        word_start = text.find(word, char_pos)
        if word_start == -1:
            word_labels.append('O')
            continue
        
        # Use label of first character of word
        word_labels.append(char_labels[word_start])
        char_pos = word_start + len(word)
    
    return words, word_labels, char_labels

print("BIO conversion functions loaded with improved alignment")

In [None]:
def load_claim_data(json_file):
    """Load claim annotations with overlap resolution"""
    with open(json_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    examples = []
    overlaps_resolved = 0
    
    for entry in data:
        text = entry['data']['text']
        
        if not entry.get('annotations') or len(entry['annotations']) == 0:
            continue
        
        annotations = entry['annotations'][0]
        claim_spans = []
        
        if 'result' in annotations and annotations['result']:
            for result in annotations['result']:
                value = result.get('value', {})
                labels_list = value.get('labels', [])
                
                if labels_list:
                    claim_spans.append({
                        'text': value.get('text', ''),
                        'start': value.get('start', 0),
                        'end': value.get('end', 0),
                        'label': labels_list[0]
                    })
        
        original_count = len(claim_spans)
        claim_spans = resolve_overlapping_spans(claim_spans)
        if len(claim_spans) < original_count:
            overlaps_resolved += (original_count - len(claim_spans))
        
        tokens, bio_labels, char_labels = convert_to_bio_format(text, claim_spans)
        
        examples.append({
            'id': entry.get('id'),
            'text': text,
            'tokens': tokens,
            'labels': bio_labels,
            'char_labels': char_labels,
            'claim_spans': claim_spans
        })
    
    print(f"Resolved {overlaps_resolved} overlapping spans")
    return examples

print("Loading data...")
examples = load_claim_data(data_file)
print(f"Loaded {len(examples)} examples")

## 4. Data Splitting

In [None]:
# Stratified split
example_labels = []
for ex in examples:
    is_smish = len(ex.get('claim_spans', [])) > 0
    example_labels.append('SMISH' if is_smish else 'HAM')

train_examples, test_examples = train_test_split(
    examples,
    test_size=0.20,
    random_state=42,
    stratify=example_labels
)

print(f"Dataset split:")
print(f"  Train: {len(train_examples)}")
print(f"  Test: {len(test_examples)}")

## 5. Tokenization with Fixed Alignment

KEY FIX: Better subword token alignment using character-level labels

In [None]:
# Load tokenizer
MODEL_NAME = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, add_prefix_space=True)
print(f"Loaded tokenizer: {MODEL_NAME}")

In [None]:
def tokenize_and_align_labels(examples, max_length=128):
    """
    Tokenize and align labels using character-level mapping
    FIX: Uses char_labels directly to prevent word splitting issues
    """
    tokenized_inputs = tokenizer(
        [ex['text'] for ex in examples],
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_offsets_mapping=True,
        is_split_into_words=False
    )
    
    aligned_labels = []
    
    for i, example in enumerate(examples):
        offset_mapping = tokenized_inputs['offset_mapping'][i]
        char_labels = example['char_labels']
        
        labels = []
        for start, end in offset_mapping:
            if start == 0 and end == 0:
                # Special token
                labels.append(-100)
            else:
                # Use label of first character in this token
                if start < len(char_labels):
                    label = char_labels[start]
                    labels.append(label2id.get(label, 0))
                else:
                    labels.append(0)
        
        aligned_labels.append(labels)
    
    tokenized_inputs.pop('offset_mapping')
    tokenized_inputs['labels'] = aligned_labels
    
    return tokenized_inputs

print("Tokenizing datasets...")
train_tokenized = tokenize_and_align_labels(train_examples)
test_tokenized = tokenize_and_align_labels(test_examples)
print("Tokenization complete")

In [None]:
# Create PyTorch datasets
from torch.utils.data import Dataset

class NERDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __len__(self):
        return len(self.encodings['input_ids'])
    
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

train_dataset = NERDataset(train_tokenized)
test_dataset = NERDataset(test_tokenized)

print(f"Created datasets: Train={len(train_dataset)}, Test={len(test_dataset)}")

## 6. Model Training

In [None]:
# Load model
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id
)

print(f"Loaded model: {MODEL_NAME}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Metrics
def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=2)
    
    true_labels = []
    pred_labels = []
    
    for prediction, label in zip(predictions, labels):
        true_label = []
        pred_label = []
        
        for p, l in zip(prediction, label):
            if l != -100:
                true_label.append(id2label[l])
                pred_label.append(id2label[p])
        
        true_labels.append(true_label)
        pred_labels.append(pred_label)
    
    return {
        "precision": precision_score(true_labels, pred_labels),
        "recall": recall_score(true_labels, pred_labels),
        "f1": f1_score(true_labels, pred_labels),
    }

In [None]:
# Training arguments - IMPROVED for better accuracy
training_args = TrainingArguments(
    output_dir="./claim-ner-model",
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    learning_rate=2e-5,  # Slightly lower for better convergence
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=15,  # More epochs for better learning
    weight_decay=0.01,
    warmup_ratio=0.2,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    push_to_hub=False,
    report_to="none",
    fp16=True,
    gradient_accumulation_steps=2,
    lr_scheduler_type="cosine"  # Better learning rate schedule
)

data_collator = DataCollatorForTokenClassification(tokenizer)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

print("Trainer initialized")

In [None]:
# Train
print("="*60)
print("TRAINING NER MODEL")
print("="*60)
trainer.train()
print("\nTraining complete!")

## 7. Claim Extraction Function with Fixed Merging

In [None]:
def extract_claims_with_ner(text, model, tokenizer, id2label, confidence_threshold=0.5):
    """
    Extract claims using NER model
    IMPROVEMENTS:
    - Better merging to prevent splitting
    - Confidence threshold filtering
    - Post-processing to remove weak/invalid claims
    """
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=128,
        return_offsets_mapping=True
    )
    
    offset_mapping = inputs.pop('offset_mapping')[0]
    
    # Move to correct device
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
    
    predictions = torch.argmax(outputs.logits, dim=2)[0]
    probabilities = torch.softmax(outputs.logits, dim=2)[0]
    
    # Build claims with better merging
    claims = []
    current_claim = None
    
    for idx, (pred, prob, (start, end)) in enumerate(zip(predictions, probabilities, offset_mapping)):
        if start == 0 and end == 0:
            continue
        
        label = id2label[pred.item()]
        confidence = prob[pred].item()
        
        if label.startswith('B-'):
            # Save previous claim
            if current_claim:
                claims.append(current_claim)
            
            # Start new claim
            current_claim = {
                'type': label[2:],
                'start': start.item(),
                'end': end.item(),
                'confidence': confidence,
                'token_count': 1
            }
        
        elif label.startswith('I-') and current_claim:
            # Continue current claim if same type
            if label[2:] == current_claim['type']:
                # Extend end position (merge tokens)
                current_claim['end'] = end.item()
                current_claim['token_count'] += 1
                # Average confidence
                current_claim['confidence'] = (
                    current_claim['confidence'] * (current_claim['token_count'] - 1) + confidence
                ) / current_claim['token_count']
        
        elif label == 'O':
            # End current claim
            if current_claim:
                claims.append(current_claim)
                current_claim = None
    
    # Don't forget last claim
    if current_claim:
        claims.append(current_claim)
    
    # Extract text and apply quality filters
    filtered_claims = []
    for claim in claims:
        claim['text'] = text[claim['start']:claim['end']].strip()
        claim.pop('token_count')
        
        # FILTER 1: Confidence threshold (removes weak predictions like "To" at 0.415)
        if claim['confidence'] < confidence_threshold:
            continue
        
        # FILTER 2: Minimum length (avoid single/double letter claims)
        if len(claim['text']) < 3:
            continue
        
        # FILTER 3: Avoid pure stopwords/prepositions
        stopwords = {'to', 'a', 'an', 'the', 'of', 'in', 'on', 'at', 'by', 'for', 'with', 'from'}
        if claim['text'].lower() in stopwords:
            continue
        
        # FILTER 4: Must contain meaningful content (at least one alphanumeric)
        if not any(c.isalnum() for c in claim['text']):
            continue
        
        filtered_claims.append(claim)
    
    return filtered_claims

print("Claim extraction function loaded with confidence filtering and post-processing")

## 8. Evaluation

In [None]:
# Evaluate on test set
print("Evaluating NER model on test set...")
results = trainer.predict(test_dataset)
predictions = np.argmax(results.predictions, axis=2)

true_labels = []
pred_labels = []

for prediction, label in zip(predictions, results.label_ids):
    true_label = []
    pred_label = []
    
    for p, l in zip(prediction, label):
        if l != -100:
            true_label.append(id2label[l])
            pred_label.append(id2label[p])
    
    true_labels.append(true_label)
    pred_labels.append(pred_label)

print("\n" + "="*70)
print("CLAIM-LEVEL NER PERFORMANCE")
print("="*70)
print(classification_report(true_labels, pred_labels))

print(f"\nOverall Token Accuracy: {accuracy_score(true_labels, pred_labels):.3f}")
print("="*70)

In [None]:
# Test extraction on samples
import random
random.seed(42)

print("\n" + "="*70)
print("TEST SET EXAMPLES - Claim Extraction Analysis")
print("="*70)

ham_examples = [ex for ex in test_examples if len(ex['claim_spans']) == 0]
smish_examples = [ex for ex in test_examples if len(ex['claim_spans']) > 0]

sampled = random.sample(ham_examples, min(3, len(ham_examples))) + \
          random.sample(smish_examples, min(10, len(smish_examples)))

for idx, example in enumerate(sampled, 1):
    text = example['text']
    gt_claims = example['claim_spans']
    
    print(f"\n{'='*70}")
    print(f"Example {idx}")
    print(f"{'='*70}")
    print(f"Message: {text}")
    
    print(f"\nGround Truth:")
    if gt_claims:
        print(f"  Label: SMISH ({len(gt_claims)} claims)")
        for i, claim in enumerate(gt_claims, 1):
            print(f"  {i}. {claim['label']:20} : '{claim['text']}'")
            print(f"     Position: [{claim['start']}:{claim['end']}]")
    else:
        print(f"  Label: HAM (no claims)")
    
    pred_claims = extract_claims_with_ner(text, model, tokenizer, id2label)
    
    print(f"\nModel Prediction:")
    if pred_claims:
        print(f"  Extracted {len(pred_claims)} claims:")
        for i, claim in enumerate(pred_claims, 1):
            print(f"  {i}. {claim['type']:20} : '{claim['text']}'")
            print(f"     Position: [{claim['start']}:{claim['end']}]")
            print(f"     Confidence: {claim['confidence']:.3f}")
    else:
        print(f"  No claims extracted")
    
    print(f"\nAnalysis:")
    if len(gt_claims) == len(pred_claims):
        print(f"  Claim count: MATCH ({len(gt_claims)} claims)")
    else:
        print(f"  Claim count: MISMATCH (GT: {len(gt_claims)}, Pred: {len(pred_claims)})")

print("\n" + "="*70)

In [None]:
# Save detailed test results - PROPER NER EVALUATION
print("\nSaving detailed test results...")

test_results = []

for example in test_examples:
    text = example['text']
    gt_claims = example['claim_spans']
    
    # Get model predictions
    pred_claims = extract_claims_with_ner(text, model, tokenizer, id2label)
    
    # Calculate claim-level metrics
    matched_claims = 0
    for pred in pred_claims:
        for gt in gt_claims:
            # Check for overlap
            overlap_start = max(pred['start'], gt['start'])
            overlap_end = min(pred['end'], gt['end'])
            if overlap_end > overlap_start:
                # Has overlap - check type match
                gt_type = normalize_claim_type(gt['label'])
                if pred['type'] == gt_type:
                    matched_claims += 1
                    break
    
    # Build result entry
    result_entry = {
        'id': example.get('id'),
        'text': text,
        'ground_truth': {
            'num_claims': len(gt_claims),
            'claims': [
                {
                    'type': claim['label'],
                    'text': claim['text'],
                    'start': claim['start'],
                    'end': claim['end']
                }
                for claim in gt_claims
            ]
        },
        'prediction': {
            'num_claims': len(pred_claims),
            'claims': [
                {
                    'type': claim['type'],
                    'text': claim['text'],
                    'start': claim['start'],
                    'end': claim['end'],
                    'confidence': round(claim['confidence'], 3)
                }
                for claim in pred_claims
            ]
        },
        'evaluation': {
            'claim_count_match': len(gt_claims) == len(pred_claims),
            'claim_count_diff': len(pred_claims) - len(gt_claims),
            'matched_claims': matched_claims,
            'precision': matched_claims / len(pred_claims) if len(pred_claims) > 0 else 0,
            'recall': matched_claims / len(gt_claims) if len(gt_claims) > 0 else (1 if len(pred_claims) == 0 else 0)
        }
    }
    
    test_results.append(result_entry)

# Save to JSON
results_path = f"{save_dir}/test_results_detailed.json"
with open(results_path, 'w', encoding='utf-8') as f:
    json.dump(test_results, f, indent=2, ensure_ascii=False)

# Calculate NER-specific metrics
total = len(test_results)
total_gt_claims = sum(len(r['ground_truth']['claims']) for r in test_results)
total_pred_claims = sum(len(r['prediction']['claims']) for r in test_results)
total_matched = sum(r['evaluation']['matched_claims'] for r in test_results)

# Aggregate precision and recall
overall_precision = total_matched / total_pred_claims if total_pred_claims > 0 else 0
overall_recall = total_matched / total_gt_claims if total_gt_claims > 0 else 0
overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0

count_correct = sum(1 for r in test_results if r['evaluation']['claim_count_match'])

summary = {
    'total_test_examples': total,
    'total_ground_truth_claims': total_gt_claims,
    'total_predicted_claims': total_pred_claims,
    'total_matched_claims': total_matched,
    'metrics': {
        'claim_count_accuracy': round(count_correct / total, 3),
        'precision': round(overall_precision, 3),
        'recall': round(overall_recall, 3),
        'f1_score': round(overall_f1, 3)
    },
    'claim_statistics': {
        'avg_claims_per_message_gt': round(total_gt_claims / total, 2),
        'avg_claims_per_message_pred': round(total_pred_claims / total, 2),
        'messages_with_claims_gt': sum(1 for r in test_results if len(r['ground_truth']['claims']) > 0),
        'messages_with_claims_pred': sum(1 for r in test_results if len(r['prediction']['claims']) > 0)
    }
}

# Save summary
summary_path = f"{save_dir}/test_results_summary.json"
with open(summary_path, 'w', encoding='utf-8') as f:
    json.dump(summary, f, indent=2)

print(f"\n{'='*70}")
print("TEST RESULTS SAVED - NER EVALUATION")
print(f"{'='*70}")
print(f"Detailed results: {results_path}")
print(f"Summary: {summary_path}")
print(f"\nClaim-Level NER Metrics:")
print(f"  Total GT claims: {total_gt_claims}")
print(f"  Total Predicted claims: {total_pred_claims}")
print(f"  Matched claims: {total_matched}")
print(f"\n  Precision: {summary['metrics']['precision']:.1%}")
print(f"  Recall: {summary['metrics']['recall']:.1%}")
print(f"  F1 Score: {summary['metrics']['f1_score']:.1%}")
print(f"\n  Exact count match: {summary['metrics']['claim_count_accuracy']:.1%}")
print(f"{'='*70}")


In [None]:
# Flexible evaluation with partial credit
print("\n" + "="*70)
print("FLEXIBLE EVALUATION WITH PARTIAL CREDIT")
print("="*70)

flexible_results = []

for example in test_examples:
    text = example['text']
    gt_claims = example['claim_spans']
    pred_claims = extract_claims_with_ner(text, model, tokenizer, id2label)
    
    # Scoring system
    score_details = {
        'id': example.get('id'),
        'text': text[:80] + '...' if len(text) > 80 else text,
        'gt_count': len(gt_claims),
        'pred_count': len(pred_claims),
        'scores': {
            'exact_match': 0,        # Perfect claim match
            'partial_overlap': 0,    # Overlapping position but different boundaries
            'type_match': 0,         # Same type but different position
            'valid_detection': 0,    # Model found valid claim not in GT
        },
        'penalties': {
            'wrong_type': 0,         # Detected claim but wrong type
            'false_positive': 0,     # Detected claim where there isn't one
            'missed_claim': 0        # Failed to detect GT claim
        }
    }
    
    matched_gt = set()
    matched_pred = set()
    
    # Match predictions to ground truth
    for pred_idx, pred in enumerate(pred_claims):
        best_match = None
        best_score = 0
        
        for gt_idx, gt in enumerate(gt_claims):
            if gt_idx in matched_gt:
                continue
            
            # Calculate overlap
            overlap_start = max(pred['start'], gt['start'])
            overlap_end = min(pred['end'], gt['end'])
            overlap_len = max(0, overlap_end - overlap_start)
            
            gt_len = gt['end'] - gt['start']
            pred_len = pred['end'] - pred['start']
            
            # IoU (Intersection over Union)
            union_len = gt_len + pred_len - overlap_len
            iou = overlap_len / union_len if union_len > 0 else 0
            
            if iou > best_score:
                best_score = iou
                best_match = (gt_idx, gt, iou)
        
        if best_match:
            gt_idx, gt, iou = best_match
            
            # Check type match
            gt_type = normalize_claim_type(gt['label'])
            pred_type = pred['type']
            
            if iou >= 0.7 and gt_type == pred_type:
                # Exact match (high overlap + correct type)
                score_details['scores']['exact_match'] += 1
                matched_gt.add(gt_idx)
                matched_pred.add(pred_idx)
            
            elif iou >= 0.3:
                # Partial overlap
                if gt_type == pred_type:
                    score_details['scores']['partial_overlap'] += 0.5
                    matched_gt.add(gt_idx)
                    matched_pred.add(pred_idx)
                else:
                    score_details['penalties']['wrong_type'] += 1
                    matched_pred.add(pred_idx)
            
            elif gt_type == pred_type:
                # Same type but different location
                score_details['scores']['type_match'] += 0.3
                matched_pred.add(pred_idx)
        
        else:
            # No overlap with any GT claim
            # Check if this could be a valid claim that's missing from GT
            # (e.g., confidence > 0.8 suggests model is confident)
            if pred['confidence'] > 0.8:
                score_details['scores']['valid_detection'] += 0.2
            else:
                score_details['penalties']['false_positive'] += 1
            matched_pred.add(pred_idx)
    
    # Unmatched GT claims are missed
    score_details['penalties']['missed_claim'] = len(gt_claims) - len(matched_gt)
    
    # Calculate final score
    total_positive = sum(score_details['scores'].values())
    total_negative = sum(score_details['penalties'].values())
    
    # Normalize score (0-1 scale)
    max_possible = max(len(gt_claims), len(pred_claims), 1)
    normalized_score = max(0, (total_positive - total_negative * 0.5) / max_possible)
    
    score_details['total_score'] = round(normalized_score, 3)
    score_details['performance_category'] = (
        'excellent' if normalized_score >= 0.8 else
        'good' if normalized_score >= 0.6 else
        'fair' if normalized_score >= 0.4 else
        'poor'
    )
    
    flexible_results.append(score_details)

# Save flexible evaluation
flexible_path = f"{save_dir}/test_results_flexible.json"
with open(flexible_path, 'w', encoding='utf-8') as f:
    json.dump(flexible_results, f, indent=2, ensure_ascii=False)

# Calculate aggregate statistics
avg_score = sum(r['total_score'] for r in flexible_results) / len(flexible_results)
excellent = sum(1 for r in flexible_results if r['performance_category'] == 'excellent')
good = sum(1 for r in flexible_results if r['performance_category'] == 'good')
fair = sum(1 for r in flexible_results if r['performance_category'] == 'fair')
poor = sum(1 for r in flexible_results if r['performance_category'] == 'poor')

total_exact = sum(r['scores']['exact_match'] for r in flexible_results)
total_partial = sum(r['scores']['partial_overlap'] for r in flexible_results)
total_type = sum(r['scores']['type_match'] for r in flexible_results)
total_valid = sum(r['scores']['valid_detection'] for r in flexible_results)

print(f"\nFlexible Evaluation Results:")
print(f"{'='*70}")
print(f"Average Score: {avg_score:.3f}")
print(f"\nPerformance Distribution:")
print(f"  Excellent (≥0.8): {excellent} ({excellent/len(flexible_results)*100:.1f}%)")
print(f"  Good (≥0.6):      {good} ({good/len(flexible_results)*100:.1f}%)")
print(f"  Fair (≥0.4):      {fair} ({fair/len(flexible_results)*100:.1f}%)")
print(f"  Poor (<0.4):      {poor} ({poor/len(flexible_results)*100:.1f}%)")
print(f"\nPositive Scores:")
print(f"  Exact matches:        {total_exact:.1f} (full points)")
print(f"  Partial overlaps:     {total_partial:.1f} (0.5 points each)")
print(f"  Type matches:         {total_type:.1f} (0.3 points each)")
print(f"  Valid detections:     {total_valid:.1f} (0.2 points each)")
print(f"\nSaved to: {flexible_path}")
print("="*70)

# Show examples of each category
print("\nExample Cases:")
print("="*70)

for category in ['excellent', 'good', 'fair', 'poor']:
    examples = [r for r in flexible_results if r['performance_category'] == category]
    if examples:
        ex = examples[0]
        print(f"\n{category.upper()} Performance (score: {ex['total_score']}):")
        print(f"  Text: {ex['text']}")
        print(f"  GT claims: {ex['gt_count']}, Predicted: {ex['pred_count']}")
        print(f"  Exact: {ex['scores']['exact_match']}, Partial: {ex['scores']['partial_overlap']}")
        print(f"  Missed: {ex['penalties']['missed_claim']}, FP: {ex['penalties']['false_positive']}")


## 9. Save Model

In [None]:
# Save model
print("\nSaving model to Google Drive...")

model_path = f"{save_dir}/final_model"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

import pandas as pd

config = {
    'approach': 'approach5_pure_ner_fixed',
    'training_date': str(pd.Timestamp.now()),
    'model_name': MODEL_NAME,
    'num_train_examples': len(train_dataset),
    'claim_types': CLAIM_TYPES,
    'label_mappings': {'label2id': label2id, 'id2label': {int(k): v for k, v in id2label.items()}},
    'improvements': [
        'Fixed word splitting with character-level labels',
        'Better token merging in extraction',
        'Resolved overlapping spans',
        'Grouped rare claims',
        'Used balanced deduplicated dataset'
    ]
}

with open(f"{save_dir}/config.json", "w") as f:
    json.dump(config, f, indent=2)

print(f"Model saved to: {save_dir}")
print("Ready to use!")