# 🔤 Tokenization for LLaMA Insurance Fine-tuning

This notebook handles tokenization and data formatting for LLaMA model training:

## What this notebook does:
1. Load and configure the LLaMA tokenizer
2. Format datasets for instruction tuning
3. Tokenize training data with proper attention masks
4. Handle context length and padding
5. Create data loaders for training
6. Validate tokenization quality

**⚠️ Important: Make sure you have access to LLaMA models through Hugging Face**

## 1. Import Libraries and Setup

In [None]:
import os
import json
import torch
from pathlib import Path
from typing import Dict, List, Any, Optional
import warnings
from tqdm.auto import tqdm
import pandas as pd
import numpy as np

# Transformers and datasets
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    TrainingArguments
)
from datasets import Dataset, DatasetDict, load_from_disk
from torch.utils.data import DataLoader

warnings.filterwarnings('ignore')
tqdm.pandas()

print("✅ Libraries imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration and Model Setup

In [None]:
# Configuration
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"  # or "meta-llama/Llama-2-7b-hf"
PROCESSED_DATA_DIR = Path("data/processed")
TOKENIZED_DATA_DIR = Path("data/tokenized")
TOKENIZED_DATA_DIR.mkdir(exist_ok=True)

# Tokenization parameters
MAX_LENGTH = 2048  # Maximum sequence length for training
PADDING_SIDE = "right"  # For causal LM, use right padding
TRUNCATION = True
ADD_EOS_TOKEN = True

# Instruction formatting
INSTRUCTION_TEMPLATE = {
    'system': "You are a helpful AI assistant specialized in insurance and financial services. Provide accurate, helpful, and compliant information.",
    'user_prefix': "[INST]",
    'user_suffix': "[/INST]",
    'assistant_prefix': "",
    'assistant_suffix': "</s>"
}

# Task-specific prompts
TASK_PROMPTS = {
    'CLAIM_CLASSIFICATION': "Classify the following insurance claim into the appropriate category:",
    'POLICY_SUMMARIZATION': "Summarize the following insurance policy document:",
    'FAQ_GENERATION': "Generate frequently asked questions for the following insurance document:",
    'COMPLIANCE_CHECK': "Identify compliance requirements in the following insurance document:",
    'CONTRACT_QA': "Answer the following question based on the insurance document provided:"
}

print(f"Configuration loaded:")
print(f"- Model: {MODEL_NAME}")
print(f"- Max length: {MAX_LENGTH}")
print(f"- Processed data: {PROCESSED_DATA_DIR}")
print(f"- Tokenized output: {TOKENIZED_DATA_DIR}")

## 3. Load and Configure Tokenizer

In [None]:
def setup_tokenizer(model_name: str) -> AutoTokenizer:
    """Load and configure the LLaMA tokenizer"""
    print(f"Loading tokenizer for {model_name}...")
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            use_fast=True,
            trust_remote_code=True,
            padding_side=PADDING_SIDE
        )
        
        # Set special tokens
        if tokenizer.pad_token is None:
            if tokenizer.eos_token is not None:
                tokenizer.pad_token = tokenizer.eos_token
            else:
                tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        
        # Ensure we have all required special tokens
        special_tokens = {
            'bos_token': '<s>',
            'eos_token': '</s>',
            'unk_token': '<unk>',
        }
        
        tokens_to_add = {}
        for token_name, token_value in special_tokens.items():
            if getattr(tokenizer, token_name) is None:
                tokens_to_add[token_name] = token_value
        
        if tokens_to_add:
            tokenizer.add_special_tokens(tokens_to_add)
        
        print(f"✅ Tokenizer loaded successfully")
        print(f"  Vocab size: {len(tokenizer)}")
        print(f"  Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
        print(f"  EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")
        print(f"  BOS token: {tokenizer.bos_token} (ID: {tokenizer.bos_token_id})")
        
        return tokenizer
        
    except Exception as e:
        print(f"❌ Error loading tokenizer: {e}")
        print("Make sure you're authenticated with Hugging Face and have access to LLaMA models")
        raise

# Load tokenizer
tokenizer = setup_tokenizer(MODEL_NAME)

# Test tokenization
test_text = "This is a test of the LLaMA tokenizer for insurance documents."
test_tokens = tokenizer.encode(test_text)
decoded_text = tokenizer.decode(test_tokens)

print(f"\nTokenization test:")
print(f"Original: {test_text}")
print(f"Tokens: {test_tokens} ({len(test_tokens)} tokens)")
print(f"Decoded: {decoded_text}")

## 4. Load Processed Datasets

In [None]:
def load_processed_datasets() -> Dict[str, DatasetDict]:
    """Load processed datasets from disk"""
    datasets = {}
    
    print(f"Loading processed datasets from {PROCESSED_DATA_DIR}...")
    
    # Check for combined dataset first
    combined_dir = PROCESSED_DATA_DIR / "combined"
    if combined_dir.exists():
        print(f"Loading combined dataset...")
        try:
            combined_dataset = DatasetDict()
            for split in ['train', 'validation', 'test']:
                json_file = combined_dir / f"{split}.json"
                if json_file.exists():
                    dataset = Dataset.from_json(str(json_file))
                    combined_dataset[split] = dataset
                    print(f"  {split}: {len(dataset)} examples")
            
            if combined_dataset:
                datasets['combined'] = combined_dataset
                print(f"✅ Combined dataset loaded")
        except Exception as e:
            print(f"❌ Error loading combined dataset: {e}")
    
    # Load individual task datasets
    task_dirs = [d for d in PROCESSED_DATA_DIR.iterdir() if d.is_dir() and d.name != 'combined']
    
    for task_dir in task_dirs:
        task_name = task_dir.name.upper()
        print(f"Loading {task_name} dataset...")
        
        try:
            task_dataset = DatasetDict()
            for split in ['train', 'validation', 'test']:
                json_file = task_dir / f"{split}.json"
                if json_file.exists():
                    dataset = Dataset.from_json(str(json_file))
                    task_dataset[split] = dataset
                    print(f"  {split}: {len(dataset)} examples")
            
            if task_dataset:
                datasets[task_name] = task_dataset
                print(f"✅ {task_name} dataset loaded")
        except Exception as e:
            print(f"❌ Error loading {task_name} dataset: {e}")
    
    return datasets

# Load datasets
datasets = load_processed_datasets()

if not datasets:
    print("❌ No datasets found. Please run 01_data_preprocessing.ipynb first.")
else:
    print(f"\nLoaded {len(datasets)} datasets:")
    for name, dataset_dict in datasets.items():
        total_examples = sum(len(dataset_dict[split]) for split in dataset_dict.keys())
        print(f"  {name}: {total_examples} total examples")
        
        # Show sample from first dataset
        if 'train' in dataset_dict and len(dataset_dict['train']) > 0:
            sample = dataset_dict['train'][0]
            print(f"    Sample keys: {list(sample.keys())}")
            if 'instruction' in sample:
                print(f"    Sample instruction: {sample['instruction'][:100]}...")
        print()

## 5. Format Data for Instruction Tuning

In [None]:
def format_instruction(example: Dict[str, Any]) -> str:
    """Format example into instruction-following format for LLaMA"""
    
    # Get the task-specific prompt
    task_type = example.get('task_type', 'POLICY_SUMMARIZATION')
    task_prompt = TASK_PROMPTS.get(task_type, "Complete the following task:")
    
    # Build the instruction
    if 'instruction' in example and example['instruction']:
        instruction = example['instruction']
    else:
        instruction = task_prompt
    
    # Get input and output
    user_input = example.get('input', '')
    assistant_output = example.get('output', '')
    
    # Handle different task types
    if task_type == 'CONTRACT_QA' and 'question' in example:
        # For Q&A, format with context and question
        context = example.get('context', '')
        question = example.get('question', '')
        user_input = f"Context: {context}\n\nQuestion: {question}"
        assistant_output = example.get('answer', assistant_output)
    
    # Format in LLaMA chat format
    formatted_text = f"{INSTRUCTION_TEMPLATE['user_prefix']} {instruction}\n\n{user_input} {INSTRUCTION_TEMPLATE['user_suffix']} {assistant_output}{INSTRUCTION_TEMPLATE['assistant_suffix']}"
    
    return formatted_text

def format_dataset_for_training(dataset: Dataset) -> Dataset:
    """Format entire dataset for instruction tuning"""
    
    def format_example(example):
        formatted_text = format_instruction(example)
        return {
            'text': formatted_text,
            'task_type': example.get('task_type', 'POLICY_SUMMARIZATION'),
            'original_id': example.get('doc_id', 'unknown')
        }
    
    formatted_dataset = dataset.map(
        format_example,
        remove_columns=[col for col in dataset.column_names if col not in ['task_type', 'doc_id']]
    )
    
    return formatted_dataset

# Test instruction formatting
if datasets and 'combined' in datasets:
    sample_dataset = datasets['combined']['train']
    if len(sample_dataset) > 0:
        sample_example = sample_dataset[0]
        formatted_sample = format_instruction(sample_example)
        
        print("Instruction formatting test:")
        print("=" * 50)
        print(formatted_sample)
        print("=" * 50)
        print(f"Formatted length: {len(formatted_sample)} characters")
        
        # Check tokenization length
        tokens = tokenizer.encode(formatted_sample)
        print(f"Token count: {len(tokens)} tokens")
        
        if len(tokens) > MAX_LENGTH:
            print(f"⚠️ Warning: Example exceeds max length ({MAX_LENGTH} tokens)")
        else:
            print(f"✅ Example fits within max length")

## 6. Tokenize Datasets

In [None]:
def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, List[List[int]]]:
    """Tokenize a batch of examples"""
    
    # Tokenize the text
    tokenized = tokenizer(
        examples['text'],
        truncation=TRUNCATION,
        padding=False,  # We'll pad dynamically during training
        max_length=MAX_LENGTH,
        return_tensors=None,  # Return lists, not tensors
        add_special_tokens=True
    )
    
    # For causal LM, labels are the same as input_ids
    tokenized['labels'] = tokenized['input_ids'].copy()
    
    # Calculate lengths for filtering
    tokenized['length'] = [len(ids) for ids in tokenized['input_ids']]
    
    return tokenized

def tokenize_dataset(dataset: Dataset, dataset_name: str) -> Dataset:
    """Tokenize an entire dataset"""
    print(f"Tokenizing {dataset_name} dataset ({len(dataset)} examples)...")
    
    # First format for instruction tuning
    formatted_dataset = format_dataset_for_training(dataset)
    
    # Then tokenize
    tokenized_dataset = formatted_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=formatted_dataset.column_names,
        desc=f"Tokenizing {dataset_name}"
    )
    
    # Filter out examples that are too long or too short
    def filter_length(example):
        length = example['length']
        return 10 <= length <= MAX_LENGTH  # Filter very short and very long examples
    
    filtered_dataset = tokenized_dataset.filter(filter_length)
    
    print(f"  Original: {len(tokenized_dataset)} examples")
    print(f"  After filtering: {len(filtered_dataset)} examples")
    
    if len(filtered_dataset) > 0:
        lengths = [ex['length'] for ex in filtered_dataset]
        print(f"  Length stats - Min: {min(lengths)}, Max: {max(lengths)}, Avg: {np.mean(lengths):.1f}")
    
    return filtered_dataset

def tokenize_all_datasets(datasets: Dict[str, DatasetDict]) -> Dict[str, DatasetDict]:
    """Tokenize all datasets"""
    tokenized_datasets = {}
    
    for dataset_name, dataset_dict in datasets.items():
        print(f"\nProcessing {dataset_name} dataset...")
        tokenized_dict = DatasetDict()
        
        for split_name, dataset in dataset_dict.items():
            if len(dataset) > 0:
                tokenized_split = tokenize_dataset(dataset, f"{dataset_name}_{split_name}")
                tokenized_dict[split_name] = tokenized_split
            else:
                print(f"  Skipping empty {split_name} split")
        
        if tokenized_dict:
            tokenized_datasets[dataset_name] = tokenized_dict
    
    return tokenized_datasets

# Tokenize all datasets
if datasets:
    print("Starting tokenization process...")
    tokenized_datasets = tokenize_all_datasets(datasets)
    
    print(f"\n✅ Tokenization complete!")
    print(f"Tokenized {len(tokenized_datasets)} datasets:")
    
    for name, dataset_dict in tokenized_datasets.items():
        total_examples = sum(len(dataset_dict[split]) for split in dataset_dict.keys())
        print(f"  {name}: {total_examples} total examples")
        
        # Show sample tokenized example
        if 'train' in dataset_dict and len(dataset_dict['train']) > 0:
            sample = dataset_dict['train'][0]
            print(f"    Sample keys: {list(sample.keys())}")
            print(f"    Input length: {len(sample['input_ids'])} tokens")
            print(f"    First 10 tokens: {sample['input_ids'][:10]}")
else:
    print("❌ No datasets available for tokenization")

## 7. Create Data Collator

In [None]:
class InstructionDataCollator:
    """Custom data collator for instruction tuning"""
    
    def __init__(self, tokenizer, max_length=MAX_LENGTH):
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __call__(self, features):
        # Extract input_ids and labels
        input_ids = [f['input_ids'] for f in features]
        labels = [f['labels'] for f in features]
        
        # Pad sequences
        batch = self.tokenizer.pad(
            {'input_ids': input_ids},
            padding=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Pad labels as well
        labels_batch = self.tokenizer.pad(
            {'input_ids': labels},
            padding=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Replace padding token id's in labels with -100 so they're ignored in loss
        labels_batch['input_ids'][labels_batch['input_ids'] == self.tokenizer.pad_token_id] = -100
        
        batch['labels'] = labels_batch['input_ids']
        
        return batch

# Create data collator
data_collator = InstructionDataCollator(tokenizer, max_length=MAX_LENGTH)

print(f"✅ Data collator created")
print(f"  Max length: {MAX_LENGTH}")
print(f"  Pad token ID: {tokenizer.pad_token_id}")

# Test data collator with a small batch
if tokenized_datasets and 'combined' in tokenized_datasets:
    test_dataset = tokenized_datasets['combined']['train']
    if len(test_dataset) >= 2:
        test_batch = [test_dataset[i] for i in range(2)]
        collated_batch = data_collator(test_batch)
        
        print(f"\nData collator test:")
        print(f"  Batch keys: {list(collated_batch.keys())}")
        print(f"  Input IDs shape: {collated_batch['input_ids'].shape}")
        print(f"  Labels shape: {collated_batch['labels'].shape}")
        print(f"  Attention mask shape: {collated_batch['attention_mask'].shape}")
        
        # Check that padding worked correctly
        print(f"  ✅ Batch created successfully")

## 8. Create Data Loaders

In [None]:
def create_dataloaders(tokenized_datasets: Dict[str, DatasetDict], batch_size: int = 4) -> Dict[str, Dict[str, DataLoader]]:
    """Create PyTorch DataLoaders for training"""
    dataloaders = {}
    
    for dataset_name, dataset_dict in tokenized_datasets.items():
        print(f"Creating DataLoaders for {dataset_name}...")
        dataset_loaders = {}
        
        for split_name, dataset in dataset_dict.items():
            if len(dataset) > 0:
                shuffle = (split_name == 'train')  # Only shuffle training data
                
                dataloader = DataLoader(
                    dataset,
                    batch_size=batch_size,
                    shuffle=shuffle,
                    collate_fn=data_collator,
                    pin_memory=torch.cuda.is_available(),
                    num_workers=0  # Use 0 for Colab compatibility
                )
                
                dataset_loaders[split_name] = dataloader
                print(f"  {split_name}: {len(dataset)} examples, {len(dataloader)} batches")
        
        dataloaders[dataset_name] = dataset_loaders
    
    return dataloaders

# Create dataloaders
if tokenized_datasets:
    # Use small batch size for Colab
    BATCH_SIZE = 2  # Adjust based on GPU memory
    
    print(f"Creating DataLoaders with batch size {BATCH_SIZE}...")
    dataloaders = create_dataloaders(tokenized_datasets, batch_size=BATCH_SIZE)
    
    print(f"\n✅ DataLoaders created for {len(dataloaders)} datasets")
    
    # Test a dataloader
    if 'combined' in dataloaders and 'train' in dataloaders['combined']:
        test_loader = dataloaders['combined']['train']
        
        print(f"\nTesting DataLoader...")
        try:
            batch = next(iter(test_loader))
            print(f"  Batch loaded successfully")
            print(f"  Input IDs shape: {batch['input_ids'].shape}")
            print(f"  Labels shape: {batch['labels'].shape}")
            print(f"  Attention mask shape: {batch['attention_mask'].shape}")
            
            # Check for proper label masking
            masked_labels = (batch['labels'] == -100).sum().item()
            total_labels = batch['labels'].numel()
            print(f"  Masked labels: {masked_labels}/{total_labels} ({masked_labels/total_labels:.1%})")
            
        except Exception as e:
            print(f"  ❌ Error testing DataLoader: {e}")
else:
    print("❌ No tokenized datasets available for DataLoader creation")

## 9. Save Tokenized Datasets

In [None]:
def save_tokenized_datasets(tokenized_datasets: Dict[str, DatasetDict], tokenizer: AutoTokenizer):
    """Save tokenized datasets and tokenizer to disk"""
    
    print(f"Saving tokenized datasets to {TOKENIZED_DATA_DIR}...")
    
    # Save tokenizer
    tokenizer_dir = TOKENIZED_DATA_DIR / "tokenizer"
    tokenizer.save_pretrained(tokenizer_dir)
    print(f"  ✅ Tokenizer saved to {tokenizer_dir}")
    
    # Save each tokenized dataset
    for dataset_name, dataset_dict in tokenized_datasets.items():
        dataset_dir = TOKENIZED_DATA_DIR / dataset_name
        dataset_dir.mkdir(exist_ok=True)
        
        for split_name, dataset in dataset_dict.items():
            split_dir = dataset_dir / split_name
            dataset.save_to_disk(split_dir)
            
            # Also save as JSON for inspection
            json_file = dataset_dir / f"{split_name}.json"
            dataset.to_json(json_file)
        
        print(f"  ✅ {dataset_name} dataset saved to {dataset_dir}")
    
    # Save tokenization metadata
    metadata = {
        'model_name': MODEL_NAME,
        'max_length': MAX_LENGTH,
        'padding_side': PADDING_SIDE,
        'vocab_size': len(tokenizer),
        'pad_token_id': tokenizer.pad_token_id,
        'eos_token_id': tokenizer.eos_token_id,
        'bos_token_id': tokenizer.bos_token_id,
        'datasets': {},
        'task_prompts': TASK_PROMPTS,
        'instruction_template': INSTRUCTION_TEMPLATE
    }
    
    for dataset_name, dataset_dict in tokenized_datasets.items():
        metadata['datasets'][dataset_name] = {
            split_name: len(dataset) for split_name, dataset in dataset_dict.items()
        }
    
    metadata_file = TOKENIZED_DATA_DIR / "tokenization_metadata.json"
    with open(metadata_file, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    
    print(f"  ✅ Metadata saved to {metadata_file}")

def create_tokenization_summary():
    """Create a summary of tokenization results"""
    summary = {
        'tokenization_config': {
            'model_name': MODEL_NAME,
            'max_length': MAX_LENGTH,
            'vocab_size': len(tokenizer),
            'padding_side': PADDING_SIDE
        },
        'datasets': {},
        'total_examples': 0
    }
    
    total_examples = 0
    for dataset_name, dataset_dict in tokenized_datasets.items():
        dataset_stats = {}
        dataset_total = 0
        
        for split_name, dataset in dataset_dict.items():
            count = len(dataset)
            dataset_stats[split_name] = count
            dataset_total += count
        
        dataset_stats['total'] = dataset_total
        summary['datasets'][dataset_name] = dataset_stats
        total_examples += dataset_total
    
    summary['total_examples'] = total_examples
    return summary

# Save tokenized datasets
if tokenized_datasets:
    save_tokenized_datasets(tokenized_datasets, tokenizer)
    
    # Create and display summary
    summary = create_tokenization_summary()
    print(f"\n📊 Tokenization Summary:")
    print(json.dumps(summary, indent=2))
    
    # Save summary
    summary_file = TOKENIZED_DATA_DIR / "tokenization_summary.json"
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)
    
    print(f"\n✅ Tokenization complete!")
    print(f"\nTokenized data saved to: {TOKENIZED_DATA_DIR}")
    print(f"\nNext steps:")
    print(f"1. Review tokenized datasets in {TOKENIZED_DATA_DIR}")
    print(f"2. Run 03_finetuning_lora.ipynb to start training")
    print(f"3. Monitor training progress and adjust hyperparameters as needed")
else:
    print("❌ No tokenized datasets to save")

## 10. Validation and Quality Checks

In [None]:
def validate_tokenization_quality(tokenized_datasets: Dict[str, DatasetDict], tokenizer: AutoTokenizer):
    """Perform quality checks on tokenized data"""
    
    print("🔍 Performing tokenization quality checks...")
    
    issues = []
    
    for dataset_name, dataset_dict in tokenized_datasets.items():
        print(f"\nChecking {dataset_name} dataset:")
        
        for split_name, dataset in dataset_dict.items():
            if len(dataset) == 0:
                issues.append(f"{dataset_name}/{split_name} is empty")
                continue
            
            print(f"  {split_name} split ({len(dataset)} examples):")
            
            # Check sequence lengths
            lengths = [len(ex['input_ids']) for ex in dataset]
            min_len, max_len, avg_len = min(lengths), max(lengths), np.mean(lengths)
            
            print(f"    Length - Min: {min_len}, Max: {max_len}, Avg: {avg_len:.1f}")
            
            if max_len > MAX_LENGTH:
                issues.append(f"{dataset_name}/{split_name} has sequences longer than {MAX_LENGTH}")
            
            if min_len < 10:
                issues.append(f"{dataset_name}/{split_name} has very short sequences (< 10 tokens)")
            
            # Check for special tokens
            sample_ids = dataset[0]['input_ids']
            has_bos = tokenizer.bos_token_id in sample_ids if tokenizer.bos_token_id else True
            has_eos = tokenizer.eos_token_id in sample_ids if tokenizer.eos_token_id else True
            
            if not has_eos:
                issues.append(f"{dataset_name}/{split_name} missing EOS tokens")
            
            # Check label alignment
            sample = dataset[0]
            if len(sample['input_ids']) != len(sample['labels']):
                issues.append(f"{dataset_name}/{split_name} input_ids and labels length mismatch")
            
            # Sample a few examples and decode them
            if len(dataset) >= 3:
                sample_indices = [0, len(dataset)//2, -1]
                for i in sample_indices:
                    example = dataset[i]
                    decoded = tokenizer.decode(example['input_ids'], skip_special_tokens=False)
                    
                    # Check for obvious formatting issues
                    if '[INST]' not in decoded or '[/INST]' not in decoded:
                        issues.append(f"{dataset_name}/{split_name} example {i} missing instruction formatting")
                        break
            
            print(f"    ✅ Basic checks passed")
    
    # Summary
    if issues:
        print(f"\n⚠️ Found {len(issues)} issues:")
        for issue in issues:
            print(f"  - {issue}")
    else:
        print(f"\n✅ All quality checks passed!")
    
    return issues

def show_tokenization_examples(tokenized_datasets: Dict[str, DatasetDict], tokenizer: AutoTokenizer, num_examples: int = 2):
    """Show decoded examples from tokenized datasets"""
    
    print(f"\n📝 Sample tokenized examples:")
    
    for dataset_name, dataset_dict in list(tokenized_datasets.items())[:1]:  # Show first dataset only
        if 'train' in dataset_dict and len(dataset_dict['train']) > 0:
            dataset = dataset_dict['train']
            
            print(f"\n{dataset_name} examples:")
            print("=" * 80)
            
            for i in range(min(num_examples, len(dataset))):
                example = dataset[i]
                
                print(f"\nExample {i+1}:")
                print(f"Length: {len(example['input_ids'])} tokens")
                
                # Decode and show
                decoded = tokenizer.decode(example['input_ids'], skip_special_tokens=False)
                print(f"Decoded text:")
                print(decoded)
                
                # Show first 20 token IDs
                print(f"\nFirst 20 token IDs: {example['input_ids'][:20]}")
                print(f"First 20 tokens: {[tokenizer.decode([tid]) for tid in example['input_ids'][:20]]}")
                print("-" * 40)

# Run quality checks
if tokenized_datasets:
    issues = validate_tokenization_quality(tokenized_datasets, tokenizer)
    
    # Show examples
    show_tokenization_examples(tokenized_datasets, tokenizer, num_examples=2)
    
    if not issues:
        print(f"\n🎉 Tokenization completed successfully!")
        print(f"Ready for training with {sum(sum(len(ds[split]) for split in ds.keys()) for ds in tokenized_datasets.values())} total examples")
else:
    print("❌ No tokenized datasets available for validation")