In [None]:
"""
Test Script for Template Graph Builder
Tests workflow without subjective ratings.
"""

import time
import pandas as pd
import torch
from pathlib import Path
from template_graph_builder import TemplateGraphBuilder
from config import Config


def print_section(title):
    """Print section header"""
    print("\n" + "="*70)
    print(title)
    print("="*70)


def compute_position_difference_with_pbc(pos_initial, pos_final, cell_size):
    """
    Compute position difference considering PBC.
    
    Parameters:
    -----------
    pos_initial : torch.Tensor [N, 3]
    pos_final : torch.Tensor [N, 3]
    cell_size : float
        Size of supercell in Angstrom
    
    Returns:
    --------
    diff : torch.Tensor [N, 3]
        Position difference with minimum image convention
    """
    diff = pos_final - pos_initial
    
    # Minimum image convention
    diff = torch.where(diff > cell_size/2, diff - cell_size, diff)
    diff = torch.where(diff < -cell_size/2, diff + cell_size, diff)
    
    return diff


def test_1_template_creation():
    """Test 1: Template creation"""
    print_section("TEST 1: TEMPLATE CREATION")
    
    config = Config()
    builder = TemplateGraphBuilder(config)
    
    print(f"\nTemplate properties:")
    print(f"  Nodes: {builder.template_num_nodes}")
    print(f"  Edges: {builder.template_edge_index.shape[1]}")
    print(f"  Edge distances: [{builder.template_edge_attr.min():.2f}, {builder.template_edge_attr.max():.2f}] Å")
    
    # Validation
    expected_nodes = config.supercell_size ** 3 * 2 - 1
    assert builder.template_num_nodes == expected_nodes
    assert builder.template_edge_attr.min() > 0
    assert builder.template_edge_attr.max() <= config.cutoff_radius
    
    print(f"\n✓ TEST 1 PASSED")
    return builder, config


def test_2_element_detection(builder):
    """Test 2: Element detection"""
    print_section("TEST 2: ELEMENT DETECTION")
    
    print(f"\nDetected elements: {builder.elements}")
    print(f"Node feature size: {3 + len(builder.elements) + 4}")
    
    print(f"\nElement mapping:")
    for elem, idx in builder.element_to_idx.items():
        props = builder.atomic_properties[elem]
        print(f"  {elem} → {idx}: radius={props['atomic_radius']:.2f}, mass={props['atomic_mass']:.2f}")
    
    print(f"\n✓ TEST 2 PASSED")


def test_3_graph_pair_with_labels(builder, config):
    """Test 3: Graph pair building with required labels"""
    print_section("TEST 3: GRAPH PAIR WITH LABELS")
    
    if not Path(config.csv_path).exists():
        print(f"\n✗ SKIPPED: CSV not found")
        return None
    
    df = pd.read_csv(config.csv_path)
    if len(df) == 0:
        print("\n✗ SKIPPED: No data")
        return None
    
    sample = df.iloc[0]
    structure_folder = Path(sample['structure_folder'])
    initial_cif = structure_folder / "initial_relaxed.cif"
    final_cif = structure_folder / "final_relaxed.cif"
    
    if not initial_cif.exists() or not final_cif.exists():
        print(f"\n✗ SKIPPED: CIF not found")
        return None
    
    print(f"\nTesting: {structure_folder.name}")
    print(f"Composition: {sample['composition_string']}")
    
    # Get barrier from CSV
    backward_barrier = sample['backward_barrier_eV']
    print(f"Backward barrier: {backward_barrier:.4f} eV")
    
    # Build graph pair
    start = time.time()
    initial_graph, final_graph = builder.build_pair_graph(
        str(initial_cif),
        str(final_cif),
        backward_barrier=backward_barrier
    )
    elapsed = (time.time() - start) * 1000
    
    print(f"\nGraph pair built in {elapsed:.2f}ms")
    
    # Validate structure
    print(f"\nInitial graph:")
    print(f"  Nodes: {initial_graph.num_nodes}")
    print(f"  Node features: {initial_graph.x.shape}")
    print(f"  Edges: {initial_graph.edge_index.shape[1]}")
    print(f"  Label: {initial_graph.y.item():.4f} eV")
    
    print(f"\nFinal graph:")
    print(f"  Nodes: {final_graph.num_nodes}")
    print(f"  Node features: {final_graph.x.shape}")
    print(f"  Edges: {final_graph.edge_index.shape[1]}")
    print(f"  Label: {final_graph.y.item():.4f} eV")
    
    # Validate
    expected_features = 3 + len(builder.elements) + 4
    assert initial_graph.x.shape[1] == expected_features
    assert final_graph.x.shape[1] == expected_features
    assert initial_graph.num_nodes == builder.template_num_nodes
    assert final_graph.num_nodes == builder.template_num_nodes
    
    # Labels must exist and be identical (FIX: use float32!)
    assert hasattr(initial_graph, 'y'), "Initial graph missing label"
    assert hasattr(final_graph, 'y'), "Final graph missing label"
    assert torch.isclose(initial_graph.y, torch.tensor([backward_barrier], dtype=torch.float32))
    assert torch.isclose(final_graph.y, torch.tensor([backward_barrier], dtype=torch.float32))
    assert torch.equal(initial_graph.y, final_graph.y), "Labels must be identical"
    
    # Connectivity must be identical (template-based)
    assert torch.equal(initial_graph.edge_index, final_graph.edge_index)
    
    print(f"\n✓ TEST 3 PASSED")
    return (initial_cif, final_cif, backward_barrier)


def test_4_position_differences_with_pbc(builder, test_data, config):
    """Test 4: Position differences with PBC correction"""
    print_section("TEST 4: POSITION DIFFERENCES (PBC)")
    
    if test_data is None:
        print("\n✗ SKIPPED: No test data")
        return
    
    initial_cif, final_cif, backward_barrier = test_data
    
    # Build graphs
    initial_graph, final_graph = builder.build_pair_graph(
        str(initial_cif),
        str(final_cif),
        backward_barrier=backward_barrier
    )
    
    pos_initial = initial_graph.x[:, :3]
    pos_final = final_graph.x[:, :3]
    
    # Without PBC correction
    naive_diff = (pos_final - pos_initial).abs()
    naive_total = naive_diff.sum().item()
    naive_max = naive_diff.max().item()
    
    print(f"\nWithout PBC correction:")
    print(f"  Total difference: {naive_total:.2f}")
    print(f"  Max displacement: {naive_max:.2f} Å")
    
    # With PBC correction
    cell_size = config.supercell_size * config.lattice_parameter
    pbc_diff = compute_position_difference_with_pbc(pos_initial, pos_final, cell_size).abs()
    pbc_total = pbc_diff.sum().item()
    pbc_max = pbc_diff.max().item()
    pbc_moved = (pbc_diff.sum(dim=1) > 0.5).sum().item()  # Changed to 0.5 Å
    
    print(f"\nWith PBC correction:")
    print(f"  Cell size: {cell_size:.2f} Å")
    print(f"  Total difference: {pbc_total:.2f}")
    print(f"  Max displacement: {pbc_max:.2f} Å")
    print(f"  Atoms moved (>0.5 Å): {pbc_moved}/{initial_graph.num_nodes}")
    
    # Detailed movement analysis
    print(f"\nMovement analysis:")
    pbc_distances = pbc_diff.sum(dim=1)
    max_idx = pbc_distances.argmax()
    print(f"  Max moving atom: index {max_idx}, displacement {pbc_distances[max_idx]:.2f} Å")
    
    # Count atoms by displacement range
    small = (pbc_distances < 0.1).sum().item()
    medium = ((pbc_distances >= 0.1) & (pbc_distances < 0.5)).sum().item()
    large = (pbc_distances >= 0.5).sum().item()
    
    print(f"  Small (<0.1 Å): {small} atoms")
    print(f"  Medium (0.1-0.5 Å): {medium} atoms")
    print(f"  Large (≥0.5 Å): {large} atoms")
    
    # Validate
    assert pbc_total > 0, "Positions must differ"
    assert pbc_max < cell_size, "Max displacement must be less than cell size"
    assert large >= 1, "At least 1 atom must move significantly (≥0.5 Å)"
    
    # Element types unchanged
    one_hot_start = 3
    one_hot_end = 3 + len(builder.elements)
    one_hot_diff = (initial_graph.x[:, one_hot_start:one_hot_end] - 
                   final_graph.x[:, one_hot_start:one_hot_end]).abs().sum().item()
    print(f"  One-hot difference: {one_hot_diff:.2f}")
    assert one_hot_diff == 0, "Element types must not change"
    
    print(f"\n✓ TEST 4 PASSED")


def test_5_multiple_samples(builder, config):
    """Test 5: Multiple samples"""
    print_section("TEST 5: MULTIPLE SAMPLES")
    
    if not Path(config.csv_path).exists():
        print(f"\n✗ SKIPPED: CSV not found")
        return
    
    df = pd.read_csv(config.csv_path)
    n_test = min(20, len(df))
    
    if n_test == 0:
        print("\n✗ SKIPPED: No data")
        return
    
    print(f"\nTesting {n_test} samples")
    
    successful = 0
    failed = 0
    
    for i in range(n_test):
        sample = df.iloc[i]
        structure_folder = Path(sample['structure_folder'])
        initial_cif = structure_folder / "initial_relaxed.cif"
        final_cif = structure_folder / "final_relaxed.cif"
        
        if not initial_cif.exists() or not final_cif.exists():
            continue
        
        try:
            backward_barrier = sample['backward_barrier_eV']
            
            initial_graph, final_graph = builder.build_pair_graph(
                str(initial_cif),
                str(final_cif),
                backward_barrier=backward_barrier
            )
            
            # Validate
            assert initial_graph.num_nodes == builder.template_num_nodes
            assert final_graph.num_nodes == builder.template_num_nodes
            assert initial_graph.x.shape[1] == 3 + len(builder.elements) + 4
            assert hasattr(initial_graph, 'y')
            assert hasattr(final_graph, 'y')
            assert torch.equal(initial_graph.y, final_graph.y)
            
            successful += 1
        except Exception as e:
            print(f"  Sample {i} failed: {e}")
            failed += 1
    
    print(f"\nResults:")
    print(f"  Successful: {successful}")
    print(f"  Failed: {failed}")
    
    assert failed == 0, f"{failed} samples failed"
    
    print(f"\n✓ TEST 5 PASSED")


def test_6_speed_benchmark(builder, test_data):
    """Test 6: Speed benchmark"""
    print_section("TEST 6: SPEED BENCHMARK")
    
    if test_data is None:
        print("\n✗ SKIPPED: No test data")
        return
    
    initial_cif, final_cif, backward_barrier = test_data
    
    n_iterations = 100
    
    print(f"\nBuilding {n_iterations} graph pairs")
    
    start = time.time()
    for _ in range(n_iterations):
        _ = builder.build_pair_graph(
            str(initial_cif),
            str(final_cif),
            backward_barrier=backward_barrier
        )
    elapsed = (time.time() - start) * 1000
    
    avg_time = elapsed / n_iterations
    
    print(f"  Total: {elapsed:.1f}ms")
    print(f"  Average: {avg_time:.2f}ms per pair")
    print(f"  Rate: {1000/avg_time:.0f} pairs/second")
    
    print(f"\n✓ TEST 6 PASSED")


def test_7_label_requirements(builder, config):
    """Test 7: Label is required"""
    print_section("TEST 7: LABEL REQUIREMENTS")
    
    if not Path(config.csv_path).exists():
        print(f"\n✗ SKIPPED: CSV not found")
        return
    
    df = pd.read_csv(config.csv_path)
    if len(df) == 0:
        print("\n✗ SKIPPED: No data")
        return
    
    sample = df.iloc[0]
    structure_folder = Path(sample['structure_folder'])
    initial_cif = structure_folder / "initial_relaxed.cif"
    
    if not initial_cif.exists():
        print(f"\n✗ SKIPPED: CIF not found")
        return
    
    print("\nTesting that label is required")
    
    # Try to build graph without label (should fail)
    try:
        _ = builder.cif_to_graph(str(initial_cif))
        print("  ✗ Should have raised TypeError (missing barrier)")
        assert False, "Label should be required"
    except TypeError as e:
        if "barrier" in str(e):
            print(f"  ✓ Correctly requires barrier parameter")
        else:
            raise
    
    # Build with label (should work)
    backward_barrier = sample['backward_barrier_eV']
    graph = builder.cif_to_graph(str(initial_cif), barrier=backward_barrier)
    
    assert hasattr(graph, 'y'), "Graph must have label"
    # FIX: Compare float32 with float32
    assert torch.isclose(graph.y, torch.tensor([backward_barrier], dtype=torch.float32)).item()
    
    print(f"  ✓ Graph with label: y={graph.y.item():.4f} eV")
    
    print(f"\n✓ TEST 7 PASSED")


def test_8_template_immutability(builder, test_data):
    """Test 8: Template is not modified"""
    print_section("TEST 8: TEMPLATE IMMUTABILITY")
    
    if test_data is None:
        print("\n✗ SKIPPED: No test data")
        return
    
    initial_cif, final_cif, backward_barrier = test_data
    
    # Save original template
    original_edge_index = builder.template_edge_index.clone()
    original_edge_attr = builder.template_edge_attr.clone()
    
    # Build graph and modify it
    initial_graph, final_graph = builder.build_pair_graph(
        str(initial_cif),
        str(final_cif),
        backward_barrier=backward_barrier
    )
    
    initial_graph.edge_index[0, 0] = 999
    initial_graph.edge_attr[0] = 999.0
    
    # Check template unchanged
    assert torch.equal(builder.template_edge_index, original_edge_index), \
        "Template edge_index was modified"
    assert torch.equal(builder.template_edge_attr, original_edge_attr), \
        "Template edge_attr was modified"
    
    print("\nTemplate remains unchanged after graph modification")
    
    print(f"\n✓ TEST 8 PASSED")


def main():
    """Run all tests"""
    print("\n" + "="*70)
    print("TEMPLATE GRAPH BUILDER TEST SUITE")
    print("="*70)
    
    start_time = time.time()
    
    try:
        builder, config = test_1_template_creation()
        test_2_element_detection(builder)
        test_data = test_3_graph_pair_with_labels(builder, config)
        test_4_position_differences_with_pbc(builder, test_data, config)
        test_5_multiple_samples(builder, config)
        test_6_speed_benchmark(builder, test_data)
        test_7_label_requirements(builder, config)
        test_8_template_immutability(builder, test_data)
        
        elapsed = time.time() - start_time
        print(f"\n" + "="*70)
        print(f"ALL TESTS PASSED ({elapsed:.1f}s)")
        print("="*70 + "\n")
        
    except Exception as e:
        print(f"\n" + "="*70)
        print("TEST FAILED")
        print("="*70)
        print(f"\nError: {e}")
        import traceback
        traceback.print_exc()
        print("\n" + "="*70 + "\n")
        return 1
    
    return 0


if __name__ == "__main__":
    exit(main())