# 03 - Dataset Preparation: Creating Training Data

**Goal**: Transform decoded transactions into instruction-tuning format ready for fine-tuning.

In this notebook, you'll learn:
- How to extract structured intents from decoded transactions
- How to format data in Alpaca instruction-tuning format
- How to split data into train/validation/test sets with stratification
- How to validate dataset quality
- How to visualize dataset characteristics

**Prerequisites**: Completed `02-data-extraction.ipynb`, have decoded transaction data

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

import json
import sys
from pathlib import Path
from collections import Counter
from typing import Any

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Import project modules
from eth_finetuning.dataset.intent_extraction import extract_intent, extract_intents_batch
from eth_finetuning.dataset.templates import format_training_example, format_training_examples_batch
from eth_finetuning.dataset.preparation import prepare_dataset, validate_data

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

print("✓ Imports successful")
print(f"✓ Project root: {project_root}")

## Loading Decoded Transaction Data

First, let's load the decoded transactions we created in the previous notebook.

In [None]:
# Load decoded transactions from previous notebook
decoded_file = project_root / "data" / "processed" / "decoded_transactions.json"

if decoded_file.exists():
    with open(decoded_file, 'r') as f:
        decoded_txs = json.load(f)
    print(f"✓ Loaded {len(decoded_txs)} decoded transactions from {decoded_file.name}")
else:
    # Fallback: create sample data for demonstration
    print("⚠️  No decoded transactions file found")
    print("   Creating sample data for demonstration...")
    
    decoded_txs = [
        {
            "tx_hash": "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
            "action": "transfer",
            "protocol": "ethereum",
            "from": "0x742d35Cc6634C0532925a3b844Bc9e7595f0",
            "to": "0x1f9840a85d5aF5bf1D1762F925BDADdC4201F984",
            "amount_wei": "1000000000000000000",
            "amount_eth": "1.0",
            "status": "success",
        },
        {
            "tx_hash": "0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890",
            "action": "transfer",
            "protocol": "erc20",
            "token_address": "0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48",
            "token_symbol": "USDC",
            "from": "0x742d35Cc6634C0532925a3b844Bc9e7595f0",
            "to": "0x1f9840a85d5aF5bf1D1762F925BDADdC4201F984",
            "amount": "1000.0",
            "decimals": 6,
            "status": "success",
        },
    ]
    print(f"✓ Created {len(decoded_txs)} sample transactions")

# Display summary
protocols = [tx.get('protocol', 'unknown') for tx in decoded_txs]
protocol_counts = Counter(protocols)

print("\nProtocol Distribution:")
for protocol, count in protocol_counts.most_common():
    print(f"  {protocol:15s}: {count:3d} transactions")

## Understanding Intent Extraction

### What is an Intent?

An **intent** is a structured representation of what a transaction does:

```json
{
  "action": "transfer",          // What action: transfer, swap, approve
  "assets": ["ETH"],             // Which assets involved
  "protocol": "ethereum",         // Which protocol
  "outcome": "success",           // Did it succeed?
  "amounts": ["1.0"]              // How much
}
```

### Why Extract Intents?

Intents provide a **simplified, consistent** format that:
- Makes training data more uniform across protocols
- Focuses on semantic meaning, not implementation details
- Reduces noise from irrelevant technical fields
- Makes it easier for models to learn patterns

## Extracting Intents from Decoded Transactions

Let's see how to extract intents from our decoded transactions.

In [None]:
print("INTENT EXTRACTION EXAMPLES")
print("=" * 80)

# Extract intent from first transaction
if decoded_txs:
    first_tx = decoded_txs[0]
    
    print("\nInput (Decoded Transaction):")
    print(json.dumps(first_tx, indent=2))
    
    # Extract intent
    intent = extract_intent(first_tx)
    
    print("\nOutput (Extracted Intent):")
    print(json.dumps(intent, indent=2))
    
    print("\n✓ Intent extraction successful")
    print(f"\nKey observations:")
    print(f"  • Action simplified to: {intent['action']}")
    print(f"  • Assets extracted: {intent['assets']}")
    print(f"  • Amounts normalized: {intent['amounts']}")
    print(f"  • Outcome captured: {intent['outcome']}")

### Batch Intent Extraction

Now let's extract intents from all our decoded transactions.

In [None]:
print("EXTRACTING INTENTS FROM ALL TRANSACTIONS")
print("=" * 80)

# Extract intents in batch
intents = extract_intents_batch(decoded_txs)

print(f"\n✓ Extracted {len(intents)} intents from {len(decoded_txs)} transactions")

if len(intents) < len(decoded_txs):
    failed_count = len(decoded_txs) - len(intents)
    print(f"⚠️  {failed_count} transaction(s) failed intent extraction")

# Display a few examples
print("\nExample intents:")
for i, intent in enumerate(intents[:3], 1):
    print(f"\n{i}. Intent:")
    print(f"   Action:   {intent['action']}")
    print(f"   Protocol: {intent['protocol']}")
    print(f"   Assets:   {', '.join(intent['assets'])}")
    print(f"   Amounts:  {', '.join(intent['amounts'])}")
    print(f"   Outcome:  {intent['outcome']}")

## Instruction-Tuning Format (Alpaca)

### What is Instruction-Tuning?

Instruction-tuning teaches models to follow instructions through examples of:
- **Instruction**: What task to perform
- **Input**: The data to process
- **Output**: The expected result

### Alpaca Format

```json
{
  "instruction": "Extract the structured intent from this Ethereum transaction.",
  "input": "{...transaction data...}",
  "output": "{...intent JSON...}"
}
```

### Why This Format?

- ✓ Standard format used by many fine-tuning frameworks
- ✓ Clear separation of task, input, and expected output
- ✓ Compatible with HuggingFace Trainer
- ✓ Easy to understand and debug

## Formatting Training Examples

Let's convert our decoded transactions and intents into training examples.

In [None]:
print("FORMATTING TRAINING EXAMPLES")
print("=" * 80)

# Format first transaction as example
if decoded_txs and intents:
    example = format_training_example(decoded_txs[0], intents[0])
    
    print("\nTraining Example Structure:")
    print("\n1. INSTRUCTION:")
    print(example['instruction'])
    
    print("\n2. INPUT (transaction data):")
    input_preview = example['input'][:200] + '...' if len(example['input']) > 200 else example['input']
    print(input_preview)
    
    print("\n3. OUTPUT (intent):")
    output_preview = example['output'][:200] + '...' if len(example['output']) > 200 else example['output']
    print(output_preview)
    
    print("\n✓ Training example formatted successfully")
    print(f"\nExample statistics:")
    print(f"  Instruction length: {len(example['instruction'])} chars")
    print(f"  Input length:       {len(example['input'])} chars")
    print(f"  Output length:      {len(example['output'])} chars")

### Batch Formatting

Now format all transactions into training examples.

In [None]:
# Format all examples
training_examples = format_training_examples_batch(decoded_txs, intents)

print(f"✓ Formatted {len(training_examples)} training examples")

# Analyze example lengths
lengths = {
    'instruction': [len(ex['instruction']) for ex in training_examples],
    'input': [len(ex['input']) for ex in training_examples],
    'output': [len(ex['output']) for ex in training_examples],
}

print("\nLength Statistics (characters):")
for field, values in lengths.items():
    print(f"\n{field.upper()}:")
    print(f"  Mean:   {sum(values) / len(values):.1f}")
    print(f"  Min:    {min(values)}")
    print(f"  Max:    {max(values)}")
    print(f"  Median: {sorted(values)[len(values)//2]}")

## Data Validation

Before splitting, let's validate data quality to ensure:
- No missing critical fields
- Addresses are properly formatted
- Amounts are numeric
- Protocols are valid

In [None]:
print("DATA VALIDATION")
print("=" * 80)

try:
    # Run validation
    validate_data(decoded_txs, intents)
    print("\n✓ All validation checks passed")
    
    print("\nValidation checks performed:")
    print("  ✓ No null values in critical fields")
    print("  ✓ Addresses properly checksummed")
    print("  ✓ Amounts are numeric")
    print("  ✓ Protocols are valid")
    print("  ✓ Intents match decoded transactions")
    
except ValueError as e:
    print(f"\n✗ Validation failed: {e}")
    print("\nPlease fix data quality issues before proceeding")

## Train/Validation/Test Split

### Why Split?

- **Training set (70%)**: Used to train the model
- **Validation set (15%)**: Used to tune hyperparameters and monitor training
- **Test set (15%)**: Used for final evaluation (never seen during training)

### Stratification

We'll use **stratified splitting** to ensure each split has a similar distribution of protocols.
This prevents the model from being biased toward certain protocols.

In [None]:
print("DATASET SPLITTING")
print("=" * 80)

# Define output directory
output_dir = project_root / "data" / "datasets"

# Prepare dataset with stratified split
split_counts = prepare_dataset(
    decoded_txs=decoded_txs,
    output_dir=output_dir,
    split_ratios=(0.7, 0.15, 0.15),
    stratify_by_protocol=True,
)

print("\n✓ Dataset prepared and saved")
print(f"\nSplit statistics:")
print(f"  Training:   {split_counts['train']:3d} examples ({split_counts['train']/sum(split_counts.values())*100:.1f}%)")
print(f"  Validation: {split_counts['validation']:3d} examples ({split_counts['validation']/sum(split_counts.values())*100:.1f}%)")
print(f"  Test:       {split_counts['test']:3d} examples ({split_counts['test']/sum(split_counts.values())*100:.1f}%)")
print(f"  Total:      {sum(split_counts.values()):3d} examples")

print(f"\nFiles saved to: {output_dir}")
print(f"  • train.jsonl")
print(f"  • validation.jsonl")
print(f"  • test.jsonl")

## Visualizing Dataset Characteristics

Let's visualize the characteristics of our prepared dataset.

In [None]:
# Load the prepared datasets
train_data = []
val_data = []
test_data = []

with open(output_dir / "train.jsonl", 'r') as f:
    train_data = [json.loads(line) for line in f]

with open(output_dir / "validation.jsonl", 'r') as f:
    val_data = [json.loads(line) for line in f]

with open(output_dir / "test.jsonl", 'r') as f:
    test_data = [json.loads(line) for line in f]

print(f"✓ Loaded datasets for visualization")
print(f"  Train:      {len(train_data)} examples")
print(f"  Validation: {len(val_data)} examples")
print(f"  Test:       {len(test_data)} examples")

In [None]:
# Visualize split distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Split sizes
splits = ['Train', 'Validation', 'Test']
counts = [len(train_data), len(val_data), len(test_data)]
colors = ['steelblue', 'coral', 'lightgreen']

axes[0].bar(splits, counts, color=colors)
axes[0].set_title('Dataset Split Sizes', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Number of Examples')
axes[0].set_xlabel('Split')

# Add value labels on bars
for i, (split, count) in enumerate(zip(splits, counts)):
    axes[0].text(i, count + max(counts)*0.02, str(count), 
                ha='center', va='bottom', fontweight='bold')

# Plot 2: Length distribution
all_lengths = {
    'Train': [len(ex['input']) + len(ex['output']) for ex in train_data],
    'Val': [len(ex['input']) + len(ex['output']) for ex in val_data],
    'Test': [len(ex['input']) + len(ex['output']) for ex in test_data],
}

axes[1].boxplot([all_lengths['Train'], all_lengths['Val'], all_lengths['Test']], 
                labels=splits)
axes[1].set_title('Example Length Distribution', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Characters (Input + Output)')
axes[1].set_xlabel('Split')

plt.tight_layout()
plt.show()

print("\n📊 Visualization complete")

## Protocol Distribution Across Splits

Let's verify that stratification worked correctly.

In [None]:
# Extract protocols from outputs (intents)
def get_protocols(data):
    protocols = []
    for example in data:
        try:
            intent = json.loads(example['output'])
            protocols.append(intent.get('protocol', 'unknown'))
        except:
            protocols.append('unknown')
    return Counter(protocols)

train_protocols = get_protocols(train_data)
val_protocols = get_protocols(val_data)
test_protocols = get_protocols(test_data)

# Create comparison DataFrame
protocol_names = set(train_protocols.keys()) | set(val_protocols.keys()) | set(test_protocols.keys())
comparison = []

for protocol in protocol_names:
    comparison.append({
        'Protocol': protocol,
        'Train': train_protocols.get(protocol, 0),
        'Validation': val_protocols.get(protocol, 0),
        'Test': test_protocols.get(protocol, 0),
    })

df_protocols = pd.DataFrame(comparison)

# Calculate percentages
df_protocols['Train %'] = (df_protocols['Train'] / df_protocols['Train'].sum() * 100).round(1)
df_protocols['Val %'] = (df_protocols['Validation'] / df_protocols['Validation'].sum() * 100).round(1)
df_protocols['Test %'] = (df_protocols['Test'] / df_protocols['Test'].sum() * 100).round(1)

print("PROTOCOL DISTRIBUTION ACROSS SPLITS")
print("=" * 80)
print(df_protocols.to_string(index=False))

# Check if distributions are similar
print("\n✓ Stratification ensures balanced protocol representation across splits")

## Loading Dataset with HuggingFace

Let's verify that our dataset can be loaded by HuggingFace's `datasets` library.

In [None]:
from datasets import load_dataset

print("LOADING DATASET WITH HUGGINGFACE")
print("=" * 80)

try:
    # Load dataset
    dataset = load_dataset(
        'json',
        data_files={
            'train': str(output_dir / 'train.jsonl'),
            'validation': str(output_dir / 'validation.jsonl'),
            'test': str(output_dir / 'test.jsonl'),
        }
    )
    
    print("\n✓ Dataset loaded successfully with HuggingFace")
    print(f"\nDataset structure:")
    print(dataset)
    
    # Show a sample
    print("\nSample from training set:")
    sample = dataset['train'][0]
    print(f"\nInstruction: {sample['instruction'][:100]}...")
    print(f"Input:       {sample['input'][:100]}...")
    print(f"Output:      {sample['output'][:100]}...")
    
except Exception as e:
    print(f"\n✗ Error loading dataset: {e}")

## Quality Checks

Let's perform final quality checks on our prepared dataset.

In [None]:
print("DATASET QUALITY CHECKS")
print("=" * 80)

all_data = train_data + val_data + test_data

# Check 1: All examples have required fields
required_fields = ['instruction', 'input', 'output']
missing_fields = 0
for example in all_data:
    for field in required_fields:
        if field not in example or not example[field]:
            missing_fields += 1

print(f"\n✓ Check 1: Required fields")
if missing_fields == 0:
    print(f"  All {len(all_data)} examples have all required fields")
else:
    print(f"  ⚠️  {missing_fields} missing field(s) detected")

# Check 2: Output is valid JSON
invalid_json = 0
for example in all_data:
    try:
        json.loads(example['output'])
    except json.JSONDecodeError:
        invalid_json += 1

print(f"\n✓ Check 2: Valid JSON outputs")
if invalid_json == 0:
    print(f"  All {len(all_data)} outputs are valid JSON")
else:
    print(f"  ⚠️  {invalid_json} invalid JSON output(s) detected")

# Check 3: Reasonable length distribution
total_lengths = [len(ex['input']) + len(ex['output']) for ex in all_data]
avg_length = sum(total_lengths) / len(total_lengths)
max_length = max(total_lengths)

print(f"\n✓ Check 3: Length distribution")
print(f"  Average total length: {avg_length:.0f} characters")
print(f"  Maximum total length: {max_length} characters")
if max_length < 2048 * 4:  # Rough char to token ratio
    print(f"  ✓ All examples within reasonable token limits")
else:
    print(f"  ⚠️  Some examples may exceed token limits")

# Check 4: No duplicates
unique_inputs = set(ex['input'] for ex in all_data)
duplicate_count = len(all_data) - len(unique_inputs)

print(f"\n✓ Check 4: Duplicates")
if duplicate_count == 0:
    print(f"  No duplicate examples detected")
else:
    print(f"  ⚠️  {duplicate_count} duplicate example(s) detected")

print("\n" + "=" * 80)
print("✓ Quality checks complete")

## Key Takeaways

✓ **Intent Extraction**: Simplifies complex transaction data into consistent semantic structure

✓ **Instruction Format**: Alpaca format clearly separates task, input, and expected output

✓ **Stratified Splitting**: Ensures balanced protocol distribution across train/val/test

✓ **Data Validation**: Critical for catching quality issues before training

✓ **JSONL Format**: One example per line, compatible with HuggingFace

## Troubleshooting Tips

**Issue**: Validation fails with missing fields
- **Solution**: Check decoded transaction structure, ensure all decoders return required fields

**Issue**: Imbalanced splits (one protocol dominates)
- **Solution**: Use stratified splitting, or collect more data from underrepresented protocols

**Issue**: Examples too long (exceed token limits)
- **Solution**: Truncate input data, focus on essential fields only

**Issue**: Invalid JSON in outputs
- **Solution**: Check intent extraction logic, ensure proper JSON serialization

**Issue**: HuggingFace dataset loading fails
- **Solution**: Verify JSONL format (one JSON object per line), check file paths

## Next Steps

In the next notebook (**04-fine-tuning.ipynb**), we'll learn how to:
- Configure QLoRA for efficient fine-tuning
- Load and prepare the base model
- Execute training with live monitoring
- Manage VRAM usage on consumer GPUs
- Save and checkpoint models

---

**Ready to continue?** → `notebooks/04-fine-tuning.ipynb`