# Data Preparation: Entity Extraction with Langextract

This notebook prepares the training and test datasets by extracting named entities from raw Vietnamese news articles using Google's Langextract with Gemini API.

## Features
- ‚úÖ Parallel processing (4-8x faster extraction)
- ‚úÖ Automatic resume on crash
- ‚úÖ Incremental saving (no data loss)
- ‚úÖ Multi-encoding support (UTF-8, Latin-1, CP1252, UTF-16)
- ‚úÖ Progress tracking

## Output
- Individual JSON files: `data/processed/{split}/{category}/json/{article}.json`
- Combined datasets: `langextract_train.json`, `langextract_test.json`
- Finetuning format: `langextract_train_finetuning.jsonl`

## 1. Setup and Imports

In [None]:
import sys
import json
from pathlib import Path
from typing import List
import os

from loguru import logger
from tqdm.notebook import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# Add project root to path
sys.path.append("..")

from src.config import PROCESSED_DATA_DIR, RAW_DATA_DIR, NERLangExtractConfig
from src.langextract_pipeline import LangExtractNERExtractor
from src.services.data_processor import DataProcessorService

## 2. Configuration

In [None]:
# Check if API key is set
if not os.getenv("GEMINI_API_KEY"):
    print("‚ö†Ô∏è GEMINI_API_KEY not found!")
    print("Please set it: export GEMINI_API_KEY='your-key-here'")
    print("Or create a .env file with: GEMINI_API_KEY=your-key-here")
else:
    print("‚úÖ GEMINI_API_KEY is set")

# Configuration
config = NERLangExtractConfig()
print(f"\nConfiguration:")
print(f"  Model: {config.model_id}")
print(f"  Extraction passes: {config.extraction_passes}")
print(f"  Max workers: {config.max_workers}")
print(f"  Max char buffer: {config.max_char_buffer}")

## 3. Explore Available Data

In [None]:
def get_available_categories(split: str) -> List[str]:
    """Get list of available categories in a split."""
    split_path = RAW_DATA_DIR / split
    if not split_path.exists():
        return []
    categories = [d.name for d in split_path.iterdir() if d.is_dir()]
    return sorted(categories)

def count_articles(split: str, category: str) -> int:
    """Count articles in a category."""
    category_path = RAW_DATA_DIR / split / category
    if not category_path.exists():
        return 0
    return len(list(category_path.glob("*.txt")))

# Show available data
print("Available Categories:\n")
for split in ["train", "test"]:
    categories = get_available_categories(split)
    print(f"{split.upper()}:")
    for cat in categories:
        count = count_articles(split, cat)
        print(f"  - {cat}: {count} articles")
    print()

## 4. Check Processing Progress

In [None]:
def count_processed_samples(split: str, category: str, output_dir: Path = PROCESSED_DATA_DIR) -> int:
    """Count how many samples have already been processed."""
    category_output_dir = output_dir / split / category / "json"
    if not category_output_dir.exists():
        return 0
    return len(list(category_output_dir.glob("*.json")))

def show_progress():
    """Show processing progress for all categories."""
    print("Processing Progress:\n")
    for split in ["train", "test"]:
        categories = get_available_categories(split)
        print(f"{split.upper()}:")
        for cat in categories:
            total = count_articles(split, cat)
            processed = count_processed_samples(split, cat)
            percentage = (processed / total * 100) if total > 0 else 0
            status = "‚úÖ" if processed == total else "‚è≥" if processed > 0 else "‚ùå"
            print(f"  {status} {cat}: {processed}/{total} ({percentage:.1f}%)")
        print()

show_progress()

## 5. Test Extraction on Sample Article

In [None]:
# Test on a single article first
def test_single_article(split="train", category="Doi song"):
    """Test extraction on a single article."""
    category_path = RAW_DATA_DIR / split / category
    
    if not category_path.exists():
        print(f"‚ùå Category path not found: {category_path}")
        return
    
    # Load first article
    articles = DataProcessorService.load_articles_from_folder(category_path)
    
    if not articles:
        print("‚ùå No articles found")
        return
    
    article = articles[0]
    print(f"Testing on: {article['file_name']}")
    print(f"Text length: {len(article['text'])} characters")
    print(f"\nText preview:\n{article['text'][:300]}...\n")
    
    # Extract entities
    extractor = LangExtractNERExtractor(config=config)
    print("Extracting entities...")
    
    entities = extractor.extract_entities(
        text=article["text"],
        extraction_passes=2,
        max_workers=4
    )
    
    print("\n‚úÖ Extracted entities:")
    print(f"  Person ({len(entities.get('person', []))}): {entities.get('person', [])}")
    print(f"  Organizations ({len(entities.get('organizations', []))}): {entities.get('organizations', [])}")
    print(f"  Address ({len(entities.get('address', []))}): {entities.get('address', [])}")
    
    total = sum(len(v) for v in entities.values())
    print(f"\nTotal entities: {total}")
    
    return entities

# Run test
test_entities = test_single_article()

## 6. Process Single Article with Saving

In [None]:
def process_single_article(
    article: dict,
    category_name: str,
    split: str,
    output_dir: Path,
    extractor: LangExtractNERExtractor,
    extraction_passes: int = 2,
    max_workers: int = 4
) -> dict:
    """Process a single article and save to individual JSON file."""
    # Create output directory
    category_output_dir = output_dir / split / category_name / "json"
    category_output_dir.mkdir(parents=True, exist_ok=True)
    
    # Generate output filename
    file_stem = Path(article["file_name"]).stem
    output_file = category_output_dir / f"{file_stem}.json"
    
    # Check if already processed
    if output_file.exists():
        try:
            with open(output_file, "r", encoding="utf-8") as f:
                sample = json.load(f)
            return sample
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to load cached result for {article['file_name']}: {e}")
    
    # Extract entities
    entities = extractor.extract_entities(
        text=article["text"],
        extraction_passes=extraction_passes,
        max_workers=max_workers
    )
    
    # Create sample
    sample = {
        "file_name": article["file_name"],
        "category": category_name,
        "text": article["text"],
        "entities": entities
    }
    
    # Save individual result
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(sample, f, ensure_ascii=False, indent=2)
    
    return sample

## 7. Process Category with Parallel Processing

In [None]:
def process_category(
    category_name: str,
    split: str,
    extraction_passes: int = 2,
    max_workers: int = 4,
    output_dir: Path = PROCESSED_DATA_DIR
) -> List[dict]:
    """Process all articles in a category with parallel processing."""
    category_path = RAW_DATA_DIR / split / category_name
    
    if not category_path.exists():
        print(f"‚ùå Category path not found: {category_path}")
        return []
    
    print(f"\n{'='*70}")
    print(f"Processing: {category_name} ({split})")
    print(f"{'='*70}")
    
    # Load articles
    articles = DataProcessorService.load_articles_from_folder(category_path)
    
    if not articles:
        print("‚ùå No valid articles found")
        return []
    
    # Check for already processed samples
    processed_count = count_processed_samples(split, category_name, output_dir)
    if processed_count > 0:
        print(f"‚úÖ Found {processed_count} already processed samples - will skip those")
    
    # Initialize extractor
    extractor = LangExtractNERExtractor(config=config)
    
    # Process articles in parallel
    samples = []
    failed_articles = []
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_article = {
            executor.submit(
                process_single_article,
                article,
                category_name,
                split,
                output_dir,
                extractor,
                extraction_passes,
                1  # Use 1 worker per article since we're parallelizing at article level
            ): article
            for article in articles
        }
        
        # Process completed tasks with progress bar
        for future in tqdm(as_completed(future_to_article), total=len(articles), desc=category_name):
            article = future_to_article[future]
            try:
                sample = future.result()
                samples.append(sample)
            except Exception as e:
                print(f"\n‚ùå Failed to extract from {article['file_name']}: {e}")
                failed_articles.append(article["file_name"])
    
    if failed_articles:
        print(f"\n‚ö†Ô∏è Failed articles ({len(failed_articles)}): {', '.join(failed_articles[:5])}{'...' if len(failed_articles) > 5 else ''}")
    
    print(f"\n‚úÖ Processed {len(samples)}/{len(articles)} articles from {category_name}")
    return samples

## 8. Process Selected Categories

Choose which categories to process. Start with one category to test, then process all.

In [None]:
# Configure processing
CATEGORIES_TO_PROCESS = ["Doi song"]  # Start with one category
# CATEGORIES_TO_PROCESS = get_available_categories("train")  # Or process all

EXTRACTION_PASSES = 2  # More passes = better quality but slower
MAX_WORKERS = 4  # More workers = faster but may hit rate limits

print(f"Configuration:")
print(f"  Categories: {CATEGORIES_TO_PROCESS}")
print(f"  Extraction passes: {EXTRACTION_PASSES}")
print(f"  Max workers: {MAX_WORKERS}")

### 8a. Process Training Data

In [None]:
train_samples = []

for category in CATEGORIES_TO_PROCESS:
    samples = process_category(
        category_name=category,
        split="train",
        extraction_passes=EXTRACTION_PASSES,
        max_workers=MAX_WORKERS
    )
    train_samples.extend(samples)

print(f"\n{'='*70}")
print(f"Total training samples: {len(train_samples)}")
print(f"{'='*70}")

### 8b. Process Test Data

In [None]:
test_samples = []

for category in CATEGORIES_TO_PROCESS:
    samples = process_category(
        category_name=category,
        split="test",
        extraction_passes=EXTRACTION_PASSES,
        max_workers=MAX_WORKERS
    )
    test_samples.extend(samples)

print(f"\n{'='*70}")
print(f"Total test samples: {len(test_samples)}")
print(f"{'='*70}")

## 9. Save Combined Datasets

In [None]:
if train_samples:
    # Deduplicate entities
    train_samples = DataProcessorService.deduplicate_entities(train_samples)
    
    # Validate
    valid_train, invalid_train = DataProcessorService.validate_samples(train_samples)
    print(f"Train: {len(valid_train)} valid, {len(invalid_train)} invalid")
    
    # Save in different formats
    train_json_path = PROCESSED_DATA_DIR / "langextract_train.json"
    DataProcessorService.save_dataset(valid_train, train_json_path)
    
    train_jsonl_path = PROCESSED_DATA_DIR / "langextract_train.jsonl"
    DataProcessorService.save_jsonl(valid_train, train_jsonl_path)
    
    train_finetuning_path = PROCESSED_DATA_DIR / "langextract_train_finetuning.jsonl"
    DataProcessorService.export_for_finetuning(valid_train, train_finetuning_path, format="chat")
    
    print(f"\n‚úÖ Saved training data to:")
    print(f"  - {train_json_path}")
    print(f"  - {train_jsonl_path}")
    print(f"  - {train_finetuning_path}")

if test_samples:
    # Deduplicate entities
    test_samples = DataProcessorService.deduplicate_entities(test_samples)
    
    # Validate
    valid_test, invalid_test = DataProcessorService.validate_samples(test_samples)
    print(f"\nTest: {len(valid_test)} valid, {len(invalid_test)} invalid")
    
    # Save in different formats
    test_json_path = PROCESSED_DATA_DIR / "langextract_test.json"
    DataProcessorService.save_dataset(valid_test, test_json_path)
    
    test_jsonl_path = PROCESSED_DATA_DIR / "langextract_test.jsonl"
    DataProcessorService.save_jsonl(valid_test, test_jsonl_path)
    
    test_finetuning_path = PROCESSED_DATA_DIR / "langextract_test_finetuning.jsonl"
    DataProcessorService.export_for_finetuning(valid_test, test_finetuning_path, format="chat")
    
    print(f"\n‚úÖ Saved test data to:")
    print(f"  - {test_json_path}")
    print(f"  - {test_jsonl_path}")
    print(f"  - {test_finetuning_path}")

## 10. Compute and Display Statistics

In [None]:
import pandas as pd

def display_statistics(samples: List[dict], split_name: str):
    """Display statistics about the dataset."""
    stats = DataProcessorService.compute_statistics(samples)
    
    print(f"\n{'='*70}")
    print(f"{split_name.upper()} STATISTICS")
    print(f"{'='*70}")
    print(f"Total samples: {stats['total_samples']}")
    print(f"Samples with entities: {stats['samples_with_entities']}")
    print(f"Samples without entities: {stats['samples_without_entities']}")
    print(f"\nEntity Counts:")
    print(f"  Person: {stats['entity_counts']['person']}")
    print(f"  Organizations: {stats['entity_counts']['organizations']}")
    print(f"  Address: {stats['entity_counts']['address']}")
    print(f"  Total: {stats['total_entities']}")
    print(f"\nAverages:")
    print(f"  Entities per sample: {stats['avg_entities_per_sample']:.2f}")
    print(f"  Text length: {stats['avg_text_length']:.0f} chars")
    print(f"\nText Length Range:")
    print(f"  Min: {stats['min_text_length']} chars")
    print(f"  Max: {stats['max_text_length']} chars")
    
    return stats

# Display statistics
train_stats = None
test_stats = None

if train_samples:
    train_stats = display_statistics(valid_train, "train")

if test_samples:
    test_stats = display_statistics(valid_test, "test")

# Save combined statistics
if train_stats or test_stats:
    combined_stats = {}
    if train_stats:
        combined_stats["train"] = train_stats
    if test_stats:
        combined_stats["test"] = test_stats
    
    stats_path = PROCESSED_DATA_DIR / "langextract_statistics.json"
    with open(stats_path, "w", encoding="utf-8") as f:
        json.dump(combined_stats, f, ensure_ascii=False, indent=2)
    
    print(f"\n‚úÖ Statistics saved to: {stats_path}")

## 11. Visualize Entity Distribution

In [None]:
import matplotlib.pyplot as plt

def plot_entity_distribution(train_stats, test_stats):
    """Plot entity distribution for train and test sets."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Train distribution
    if train_stats:
        train_counts = train_stats['entity_counts']
        axes[0].bar(train_counts.keys(), train_counts.values(), color=['#3498db', '#e74c3c', '#2ecc71'])
        axes[0].set_title('Training Set - Entity Distribution', fontsize=14, fontweight='bold')
        axes[0].set_ylabel('Count', fontsize=12)
        axes[0].set_xlabel('Entity Type', fontsize=12)
        
        for i, (k, v) in enumerate(train_counts.items()):
            axes[0].text(i, v, str(v), ha='center', va='bottom', fontweight='bold')
    
    # Test distribution
    if test_stats:
        test_counts = test_stats['entity_counts']
        axes[1].bar(test_counts.keys(), test_counts.values(), color=['#3498db', '#e74c3c', '#2ecc71'])
        axes[1].set_title('Test Set - Entity Distribution', fontsize=14, fontweight='bold')
        axes[1].set_ylabel('Count', fontsize=12)
        axes[1].set_xlabel('Entity Type', fontsize=12)
        
        for i, (k, v) in enumerate(test_counts.items()):
            axes[1].text(i, v, str(v), ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

# Plot
if train_stats or test_stats:
    plot_entity_distribution(train_stats, test_stats)

## 12. Sample Inspection

In [None]:
# Show a few sample extractions
def show_samples(samples: List[dict], n: int = 3):
    """Display sample extractions."""
    import random
    
    sample_items = random.sample(samples, min(n, len(samples)))
    
    for i, sample in enumerate(sample_items, 1):
        print(f"\n{'='*70}")
        print(f"Sample {i}: {sample['file_name']}")
        print(f"Category: {sample['category']}")
        print(f"{'='*70}")
        print(f"\nText preview:\n{sample['text'][:300]}...")
        print(f"\nExtracted Entities:")
        print(f"  Person ({len(sample['entities']['person'])}): {sample['entities']['person']}")
        print(f"  Organizations ({len(sample['entities']['organizations'])}): {sample['entities']['organizations']}")
        print(f"  Address ({len(sample['entities']['address'])}): {sample['entities']['address']}")

if train_samples:
    print("\n" + "#" * 70)
    print("TRAINING SAMPLES")
    print("#" * 70)
    show_samples(valid_train, n=3)

## 13. Final Summary

In [None]:
print("\n" + "="*70)
print("DATA PREPARATION COMPLETE")
print("="*70)

print("\nüìÅ Output Files:")
print(f"\nIndividual results:")
for split in ["train", "test"]:
    for category in CATEGORIES_TO_PROCESS:
        json_dir = PROCESSED_DATA_DIR / split / category / "json"
        if json_dir.exists():
            count = len(list(json_dir.glob("*.json")))
            print(f"  - {json_dir}: {count} files")

print(f"\nCombined datasets:")
for file_name in [
    "langextract_train.json",
    "langextract_train.jsonl",
    "langextract_train_finetuning.jsonl",
    "langextract_test.json",
    "langextract_test.jsonl",
    "langextract_test_finetuning.jsonl",
    "langextract_statistics.json"
]:
    file_path = PROCESSED_DATA_DIR / file_name
    if file_path.exists():
        print(f"  ‚úÖ {file_path}")

print("\nüìä Statistics:")
if train_stats:
    print(f"  Train: {train_stats['total_samples']} samples, {train_stats['total_entities']} entities")
if test_stats:
    print(f"  Test: {test_stats['total_samples']} samples, {test_stats['total_entities']} entities")

print("\nüéâ Ready for model training and evaluation!")
print("\nNext steps:")
print("  1. Review the extracted entities for quality")
print("  2. Use langextract_train_finetuning.jsonl for LLM finetuning")
print("  3. Use langextract_test.json for model evaluation")
print("  4. Compare with other NER approaches (RAG, Prompt Engineering)")