In [None]:
"""
Test Notebook for inference.py

Tests all components step by step in one notebook.
Run cells sequentially to test each level.

All parameters come from Config, not from hardcoded values.
"""

# ============================================================================
# SETUP & IMPORTS
# ============================================================================

import pandas as pd
import numpy as np
from pathlib import Path

from config import Config
from oracle import Oracle
from inference import (
    generate_test_compositions,
    select_samples_by_error,
    generate_test_data_with_oracle,
    predict_barriers_for_test_set,
    run_inference_cycle
)

print("✓ All imports successful")


# ============================================================================
# LEVEL 1A: Test Composition Generation
# ============================================================================

print("\n" + "="*70)
print("LEVEL 1A: COMPOSITION GENERATION TEST")
print("="*70)

# Get elements from Config
config = Config()
elements = config.elements

print(f"\nElements from Config: {elements}")

# Generate test compositions
comps = generate_test_compositions(
    elements, 
    n_test=10, 
    strategy=config.al_test_strategy,
    seed=config.al_seed
)

print(f"\n✓ Generated {len(comps)} compositions using '{config.al_test_strategy}' strategy")
print("\nFirst 3 compositions:")
for i, comp in enumerate(comps[:3], 1):
    comp_str = ", ".join(f"{k}={v:.3f}" for k, v in comp.items())
    total = sum(comp.values())
    print(f"  {i}. {comp_str}")
    print(f"     Sum: {total:.6f} (should be 1.0)")

print("\n✓ Level 1A PASSED" if all(abs(sum(c.values()) - 1.0) < 1e-6 for c in comps) else "\n✗ Level 1A FAILED")


# ============================================================================
# LEVEL 1B: Test Error-Weighted Sampling
# ============================================================================

print("\n" + "="*70)
print("LEVEL 1B: ERROR-WEIGHTED SAMPLING TEST")
print("="*70)

# Create fake predictions
fake_predictions = pd.DataFrame({
    'composition': [
        'Mo0.200Nb0.200O0.200Ta0.200W0.200',
        'Mo0.300Nb0.100O0.300Ta0.100W0.200',
        'Mo0.100Nb0.100O0.500Ta0.100W0.200',
        'Mo0.400Nb0.100O0.200Ta0.100W0.200'
    ],
    'oracle_barrier': [1.0, 2.0, 3.0, 4.0],
    'predicted_barrier': [1.2, 2.5, 2.8, 4.3],
    'relative_error': [0.20, 0.25, 0.07, 0.075],
    'absolute_error': [0.2, 0.5, 0.2, 0.3],
    'structure_folder': ['path1', 'path2', 'path3', 'path4']
})

print("\nFake predictions:")
print(fake_predictions[['composition', 'relative_error']])

# Test sampling multiple times to see distribution
print(f"\nRunning 10 sampling iterations using '{config.al_query_strategy}' strategy:")
selection_counts = {i: 0 for i in range(len(fake_predictions))}

for trial in range(10):
    selected = select_samples_by_error(
        fake_predictions.copy(), 
        n_query=2,
        strategy=config.al_query_strategy,
        seed=config.al_seed + trial
    )
    for s in selected:
        idx = fake_predictions[fake_predictions['composition'] == s['composition_str']].index[0]
        selection_counts[idx] += 1

print("\nSelection frequency (out of 10 trials, 2 samples each = 20 total):")
for idx, count in selection_counts.items():
    error = fake_predictions.iloc[idx]['relative_error']
    print(f"  Sample {idx} (error={error:.3f}): selected {count} times")

print("\nExpected: Higher errors (0.25, 0.20) should be selected more often than lower errors (0.075, 0.07)")
print("✓ Level 1B PASSED (check if distribution makes sense)")


# ============================================================================
# LEVEL 2C: Test Oracle Integration (OPTIONAL - SLOW!)
# ============================================================================

print("\n" + "="*70)
print("LEVEL 2C: ORACLE INTEGRATION TEST (OPTIONAL)")
print("="*70)
print("\n⚠️  WARNING: This test calls Oracle 2 times (~10 minutes)")

run_oracle_test = input("Run Oracle test? (yes/no): ").strip().lower() == 'yes'

if run_oracle_test:
    config = Config()
    oracle = Oracle(config)
    
    # Get elements from Config
    elements = config.elements
    
    # Generate 2 test compositions using Config parameters
    test_comps = generate_test_compositions(
        elements=elements,
        n_test=2,
        strategy=config.al_test_strategy,
        seed=config.al_seed
    )
    
    print(f"\nGenerating test data with Oracle ({len(test_comps)} compositions)...")
    print(f"Using elements from Config: {elements}")
    
    test_data = generate_test_data_with_oracle(
        compositions=test_comps,
        oracle=oracle,
        config=config,
        verbose=True
    )
    
    print("\nTest data:")
    print(test_data)
    
    # Save for next test
    test_data.to_csv('temp_test_data.csv', index=False)
    print("\n✓ Test data saved to: temp_test_data.csv")
    print("✓ Level 2C PASSED")
else:
    print("\n⊘ Skipped Oracle test")


# ============================================================================
# LEVEL 2D: Test Model Prediction (requires test data from 2C)
# ============================================================================

print("\n" + "="*70)
print("LEVEL 2D: MODEL PREDICTION TEST")
print("="*70)

test_data_file = Path('temp_test_data.csv')

if test_data_file.exists():
    print("\n✓ Found test data from Level 2C")
    
    # Check if model exists
    model_path = Path(config.checkpoint_dir) / 'best_model.pt'
    
    if model_path.exists():
        config = Config()
        
        test_data = pd.read_csv(test_data_file)
        
        print(f"\nMaking predictions for {len(test_data)} samples...")
        print(f"Using Config elements: {config.elements}")
        
        predictions = predict_barriers_for_test_set(
            model_path=str(model_path),
            test_data=test_data,
            config=config,
            verbose=True
        )
        
        print("\nPredictions:")
        print(predictions[['composition', 'oracle_barrier', 'predicted_barrier', 'relative_error']])
        
        print("\n✓ Level 2D PASSED")
    else:
        print(f"\n⊘ Model not found at: {model_path}")
        print("   Train a model first or adjust path in Config")
else:
    print("\n⊘ No test data found. Run Level 2C first or create temp_test_data.csv manually")


# ============================================================================
# LEVEL 3: Mini-Cycle Test (OPTIONAL - SLOW!)
# ============================================================================

print("\n" + "="*70)
print("LEVEL 3: MINI-CYCLE TEST (OPTIONAL)")
print("="*70)
print(f"\n⚠️  WARNING: This runs a complete inference cycle with n_test={config.al_n_test}")
print(f"   Adjust config.al_n_test for faster testing (currently: {config.al_n_test})")

run_cycle_test = input("Run mini-cycle test? (yes/no): ").strip().lower() == 'yes'

if run_cycle_test:
    # Check if model exists
    model_path = Path(config.checkpoint_dir) / 'best_model.pt'
    
    if model_path.exists():
        config = Config()
        
        # For testing: use small values (can be overridden)
        print("\nCurrent Active Learning Config:")
        print(f"  Elements: {config.elements}")
        print(f"  al_n_test: {config.al_n_test}")
        print(f"  al_n_query: {config.al_n_query}")
        print(f"  al_test_strategy: {config.al_test_strategy}")
        print(f"  al_query_strategy: {config.al_query_strategy}")
        
        # Optional: override for faster testing
        use_small_test = input("\nUse small test values (n_test=5, n_query=2) for faster testing? (yes/no): ").strip().lower() == 'yes'
        if use_small_test:
            config.al_n_test = 5
            config.al_n_query = 2
            print(f"  Overridden: al_n_test={config.al_n_test}, al_n_query={config.al_n_query}")
        
        oracle = Oracle(config)
        
        print("\nRunning inference cycle...")
        
        selected, predictions = run_inference_cycle(
            cycle=0,
            model_path=str(model_path),
            oracle=oracle,
            config=config,
            verbose=True
        )
        
        print("\n" + "="*70)
        print("CYCLE RESULTS")
        print("="*70)
        
        print(f"\n✓ Selected {len(selected)} samples for training:")
        for i, s in enumerate(selected, 1):
            print(f"  {i}. {s['composition_str']}")
            print(f"     Oracle: {s['oracle_barrier']:.3f} eV, Predicted: {s['predicted_barrier']:.3f} eV")
            print(f"     Relative error: {s['relative_error']:.3f}")
        
        print(f"\n✓ Predictions saved to: {config.al_results_dir}/cycle_0_predictions.csv")
        
        print("\n✓ Level 3 PASSED")
    else:
        print(f"\n⊘ Model not found at: {model_path}")
        print("   Train a model first")
else:
    print("\n⊘ Skipped mini-cycle test")


# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "="*70)
print("TEST SUMMARY")
print("="*70)

print("\n✓ Level 1A: Composition Generation - PASSED")
print("✓ Level 1B: Error-Weighted Sampling - PASSED")

if run_oracle_test:
    print("✓ Level 2C: Oracle Integration - PASSED")
else:
    print("⊘ Level 2C: Oracle Integration - SKIPPED")

model_path = Path(config.checkpoint_dir) / 'best_model.pt'
if test_data_file.exists() and model_path.exists():
    print("✓ Level 2D: Model Prediction - PASSED")
else:
    print("⊘ Level 2D: Model Prediction - SKIPPED (missing data or model)")

if run_cycle_test:
    print("✓ Level 3: Mini-Cycle - PASSED")
else:
    print("⊘ Level 3: Mini-Cycle - SKIPPED")

print("\n" + "="*70)
print("\nConfiguration Summary:")
print(f"  Elements: {config.elements}")
print(f"  Test strategy: {config.al_test_strategy}")
print(f"  Query strategy: {config.al_query_strategy}")
print(f"  n_test: {config.al_n_test}")
print(f"  n_query: {config.al_n_query}")
print(f"  Results dir: {config.al_results_dir}")

print("\nNext steps:")
print("1. If all tests passed → Ready for active_learning_loop.py")
print("2. If tests skipped → Run them when Oracle/Model available")
print("3. Adjust Config parameters for production (config.al_n_test, config.al_n_query, etc.)")
print("="*70 + "\n")