In [2]:
"""
Test script for Oracle
Run this to verify Oracle works correctly
"""

from config import Config
from oracle import Oracle

def test_single_calculation():
    """Test a single NEB calculation"""
    print("="*70)
    print("TEST 1: Single NEB Calculation")
    print("="*70)
    
    config = Config()
    
    # Test composition (quaternary)
    composition = {
        'Mo': 0.25,
        'Nb': 0.25,
        'Ta': 0.25,
        'W': 0.25
    }
    
    print(f"\nTesting with composition: {composition}")
    print("This will take a few minutes...\n")
    
    # Run calculation with automatic cleanup
    with Oracle(config) as oracle:
        success = oracle.calculate(composition)
    
    if success:
        print("\n" + "="*70)
        print("✓✓✓ TEST 1 PASSED ✓✓✓")
        print("="*70)
        print("\nCheck:")
        print("  - neb_database/Mo25Nb25Ta25W25/run_1/")
        print("  - data.csv")
        return True
    else:
        print("\n" + "="*70)
        print("✗✗✗ TEST 1 FAILED ✗✗✗")
        print("="*70)
        return False


def test_multiple_runs():
    """Test multiple runs of same composition"""
    print("\n" + "="*70)
    print("TEST 2: Multiple Runs (same composition)")
    print("="*70)
    
    config = Config()
    
    composition = {
        'Mo': 0.25,
        'Nb': 0.25,
        'Ta': 0.25,
        'W': 0.25
    }
    
    print(f"\nRunning 2 calculations for: {composition}")
    
    with Oracle(config) as oracle:
        for i in range(2):
            print(f"\nRun {i+1}/2:")
            success = oracle.calculate(composition)
            if not success:
                print("\n✗✗✗ TEST 2 FAILED ✗✗✗")
                return False
    
    print("\n" + "="*70)
    print("✓✓✓ TEST 2 PASSED ✓✓✓")
    print("="*70)
    print("\nShould have created:")
    print("  - neb_database/Mo25Nb25Ta25W25/run_1/")
    print("  - neb_database/Mo25Nb25Ta25W25/run_2/")
    print("  - neb_database/Mo25Nb25Ta25W25/run_3/")
    return True


def test_different_composition():
    """Test different composition (binary)"""
    print("\n" + "="*70)
    print("TEST 3: Different Composition (Binary)")
    print("="*70)
    
    config = Config()
    
    composition = {
        'Mo': 0.5,
        'Nb': 0.0,
        'Ta': 0.0,
        'W': 0.5
    }
    
    print(f"\nTesting with composition: {composition}")
    
    with Oracle(config) as oracle:
        success = oracle.calculate(composition)
    
    if success:
        print("\n" + "="*70)
        print("✓✓✓ TEST 3 PASSED ✓✓✓")
        print("="*70)
        print("\nShould have created:")
        print("  - neb_database/Mo50W50/run_1/")
        return True
    else:
        print("\n✗✗✗ TEST 3 FAILED ✗✗✗")
        return False


def check_results():
    """Check and display results"""
    print("\n" + "="*70)
    print("CHECKING RESULTS")
    print("="*70)
    
    import pandas as pd
    from pathlib import Path
    
    csv_path = Path("data.csv")
    
    if not csv_path.exists():
        print("✗ No data.csv found!")
        return
    
    df = pd.read_csv(csv_path)
    print(f"\n✓ CSV found with {len(df)} entries\n")
    print(df.to_string())
    
    # Check file structure for first entry
    if len(df) > 0:
        structure_folder = Path(df.iloc[0]['structure_folder'])
        print(f"\n\nFiles in {structure_folder}:")
        if structure_folder.exists():
            for file in sorted(structure_folder.glob("*")):
                size_kb = file.stat().st_size / 1024
                print(f"  ✓ {file.name:<30} ({size_kb:.1f} KB)")
        else:
            print("  ✗ Folder not found!")


def print_config(config):
    """Print configuration parameters"""
    print("\nCURRENT CONFIGURATION:")
    print("-" * 70)
    
    print("\nCalculator:")
    print(f"  Type:                CHGNet")
    print(f"  Model:               CHGNet pretrained")
    
    print("\nBCC Structure:")
    print(f"  Supercell size:      {config.supercell_size}x{config.supercell_size}x{config.supercell_size}")
    print(f"  Lattice parameter:   {config.lattice_parameter} Å")
    print(f"  Total atoms:         {2 * config.supercell_size**3} (before vacancy)")
    print(f"  Elements:            {', '.join(config.elements)}")
    
    print("\nNEB Parameters:")
    print(f"  Number of images:    {config.neb_images}")
    print(f"  Force convergence:   {config.neb_fmax} eV/Å")
    print(f"  Max steps:           {config.neb_max_steps}")
    print(f"  Spring constant:     {config.neb_spring_constant}")
    print(f"  Climbing image:      {config.neb_climb}")
    
    print("\nRelaxation:")
    print(f"  Force convergence:   {config.relax_fmax} eV/Å")
    print(f"  Max steps:           {config.relax_max_steps}")
    print(f"  Relax cell:          {config.relax_cell}")
    
    print("\nData Storage:")
    print(f"  Database directory:  {config.database_dir}")
    print(f"  CSV path:            {config.csv_path}")
    
    print("-" * 70)


if __name__ == "__main__":
    print("\n" + "="*70)
    print("ORACLE TEST SUITE")
    print("="*70)
    
    # Load and display config
    config = Config()
    print_config(config)
    
    print("\nThis will run 4 NEB calculations.\n")
    
    input("Press ENTER to start tests (or Ctrl+C to cancel)...")
    
    # Run tests
    test1_passed = test_single_calculation()
    
    if test1_passed:
        test2_passed = test_multiple_runs()
        test3_passed = test_different_composition()
        
        # Check results
        check_results()
        
        # Final summary
        print("\n" + "="*70)
        print("TEST SUMMARY")
        print("="*70)
        print(f"Test 1 (Single calculation): {'✓ PASSED' if test1_passed else '✗ FAILED'}")
        print(f"Test 2 (Multiple runs):      {'✓ PASSED' if test2_passed else '✗ FAILED'}")
        print(f"Test 3 (Binary composition): {'✓ PASSED' if test3_passed else '✗ FAILED'}")
        
        if all([test1_passed, test2_passed, test3_passed]):
            print("\n🎉 ALL TESTS PASSED! 🎉")
            print("\nOracle is working correctly!")
        else:
            print("\n⚠️  SOME TESTS FAILED")
    else:
        print("\n⚠️  First test failed, stopping here")
    
    print("\n" + "="*70)


ORACLE TEST SUITE

CURRENT CONFIGURATION:
----------------------------------------------------------------------

Calculator:
  Type:                CHGNet
  Model:               CHGNet pretrained

BCC Structure:
  Supercell size:      4x4x4
  Lattice parameter:   3.2 Å
  Total atoms:         128 (before vacancy)
  Elements:            Mo, Nb, Ta, W

NEB Parameters:
  Number of images:    5
  Force convergence:   0.05 eV/Å
  Max steps:           200
  Spring constant:     0.5
  Climbing image:      True

Relaxation:
  Force convergence:   0.05 eV/Å
  Max steps:           500
  Relax cell:          False

Data Storage:
  Database directory:  neb_database
  CSV path:            data.csv
----------------------------------------------------------------------

This will run 4 NEB calculations.

TEST 1: Single NEB Calculation

Testing with composition: {'Mo': 0.25, 'Nb': 0.25, 'Ta': 0.25, 'W': 0.25}
This will take a few minutes...

Oracle initialized
  Database: neb_database
  CSV: data.csv