In [None]:
"""
Test script for dataset.py

Tests:
1. Dataset creation with filtering
2. Dataset statistics
3. Single sample loading
4. Batch iteration
5. Train/Val split
6. Speed benchmark
7. Data integrity checks
"""

import torch
import time
from pathlib import Path
from dataset import (
    DiffusionBarrierDataset,
    collate_fn,
    create_dataloaders
)
from config import Config


def print_section(title):
    """Print section header"""
    print("\n" + "="*70)
    print(title)
    print("="*70)


def test_1_dataset_creation():
    """Test 1: Dataset creation with filtering"""
    print_section("TEST 1: DATASET CREATION")
    
    config = Config()
    
    # Check if CSV exists
    if not Path(config.csv_path).exists():
        print(f"\n✗ SKIPPED: CSV not found at {config.csv_path}")
        return None
    
    print(f"\nCreating dataset from: {config.csv_path}")
    
    # Create dataset with filtering
    dataset = DiffusionBarrierDataset(
        config.csv_path,
        config,
        min_barrier=config.min_barrier,
        max_barrier=config.max_barrier
    )
    
    print(f"\n✓ Dataset created:")
    print(f"  Total samples: {len(dataset)}")
    
    assert len(dataset) > 0, "Dataset should not be empty"
    
    print(f"\n✓ TEST 1 PASSED")
    return dataset, config


def test_2_dataset_statistics(dataset):
    """Test 2: Dataset statistics"""
    print_section("TEST 2: DATASET STATISTICS")
    
    if dataset is None:
        print("\n✗ SKIPPED: No dataset")
        return
    
    stats = dataset.get_statistics()
    
    print(f"\nBarrier statistics:")
    print(f"  Count: {stats['count']}")
    print(f"  Min: {stats['min']:.3f} eV")
    print(f"  Max: {stats['max']:.3f} eV")
    print(f"  Mean: {stats['mean']:.3f} eV")
    print(f"  Median: {stats['median']:.3f} eV")
    print(f"  Std: {stats['std']:.3f} eV")
    
    # Validate statistics
    assert stats['min'] >= 0, "Minimum barrier should be positive"
    assert stats['max'] > stats['min'], "Max should be greater than min"
    assert stats['mean'] > 0, "Mean should be positive"
    assert stats['std'] >= 0, "Std should be non-negative"
    
    print(f"\n✓ TEST 2 PASSED")


def test_3_single_sample(dataset):
    """Test 3: Single sample loading"""
    print_section("TEST 3: SINGLE SAMPLE LOADING")
    
    if dataset is None:
        print("\n✗ SKIPPED: No dataset")
        return None
    
    print("\nLoading first sample...")
    
    start = time.time()
    initial_graph, final_graph, barrier = dataset[0]
    elapsed = (time.time() - start) * 1000
    
    print(f"\n✓ Sample loaded in {elapsed:.2f}ms")
    
    print(f"\nInitial graph:")
    print(f"  Nodes: {initial_graph.num_nodes}")
    print(f"  Node features: {initial_graph.x.shape}")
    print(f"  Edges: {initial_graph.edge_index.shape[1]}")
    print(f"  Label: {initial_graph.y.item():.4f} eV")
    
    print(f"\nFinal graph:")
    print(f"  Nodes: {final_graph.num_nodes}")
    print(f"  Node features: {final_graph.x.shape}")
    print(f"  Edges: {final_graph.edge_index.shape[1]}")
    print(f"  Label: {final_graph.y.item():.4f} eV")
    
    print(f"\nBarrier: {barrier:.4f} eV")
    
    # Validate graphs
    assert initial_graph.num_nodes > 0, "Graph should have nodes"
    assert final_graph.num_nodes > 0, "Graph should have nodes"
    assert initial_graph.num_nodes == final_graph.num_nodes, "Same number of nodes"
    assert hasattr(initial_graph, 'y'), "Graph should have label"
    assert hasattr(final_graph, 'y'), "Graph should have label"
    assert torch.isclose(initial_graph.y, torch.tensor([barrier], dtype=torch.float32))
    assert torch.isclose(final_graph.y, torch.tensor([barrier], dtype=torch.float32))
    
    print(f"\n✓ TEST 3 PASSED")
    return (initial_graph, final_graph)


def test_4_collate_function(dataset):
    """Test 4: Collate function (batching)"""
    print_section("TEST 4: COLLATE FUNCTION")
    
    if dataset is None:
        print("\n✗ SKIPPED: No dataset")
        return
    
    print("\nCreating batch manually...")
    
    # Get 4 samples
    batch_size = 4
    samples = [dataset[i] for i in range(min(batch_size, len(dataset)))]
    
    print(f"  Loaded {len(samples)} samples")
    
    # Apply collate
    start = time.time()
    initial_batch, final_batch, barriers = collate_fn(samples)
    elapsed = (time.time() - start) * 1000
    
    print(f"\n✓ Batch created in {elapsed:.2f}ms")
    
    print(f"\nBatch structure:")
    print(f"  Initial batch:")
    print(f"    Total nodes: {initial_batch.num_nodes}")
    print(f"    Num graphs: {initial_batch.num_graphs}")
    print(f"    Node features: {initial_batch.x.shape}")
    print(f"    Edges: {initial_batch.edge_index.shape[1]}")
    
    print(f"\n  Final batch:")
    print(f"    Total nodes: {final_batch.num_nodes}")
    print(f"    Num graphs: {final_batch.num_graphs}")
    print(f"    Node features: {final_batch.x.shape}")
    print(f"    Edges: {final_batch.edge_index.shape[1]}")
    
    print(f"\n  Barriers: {barriers.shape}")
    print(f"    Values: {barriers.tolist()}")
    
    # Validate batch
    assert initial_batch.num_graphs == len(samples), "Correct number of graphs"
    assert final_batch.num_graphs == len(samples), "Correct number of graphs"
    assert barriers.shape[0] == len(samples), "Correct number of barriers"
    assert initial_batch.num_nodes == final_batch.num_nodes, "Same total nodes"
    
    print(f"\n✓ TEST 4 PASSED")


def test_5_dataloaders(config):
    """Test 5: DataLoader creation"""
    print_section("TEST 5: DATALOADER CREATION")
    
    if config is None:
        print("\n✗ SKIPPED: No config")
        return None, None
    
    print("\nCreating dataloaders...")
    
    train_loader, val_loader = create_dataloaders(
        config,
        val_split=0.2
    )
    
    print(f"\n✓ Dataloaders created")
    
    # Validate
    assert train_loader is not None, "Train loader should exist"
    assert len(train_loader) > 0, "Train loader should have batches"
    
    if val_loader is not None:
        assert len(val_loader) > 0, "Val loader should have batches"
    
    print(f"\n✓ TEST 5 PASSED")
    return train_loader, val_loader


def test_6_iteration(train_loader, val_loader):
    """Test 6: Iteration through batches"""
    print_section("TEST 6: BATCH ITERATION")
    
    if train_loader is None:
        print("\n✗ SKIPPED: No train loader")
        return
    
    print("\nIterating through train loader...")
    
    for i, (initial_batch, final_batch, barriers) in enumerate(train_loader):
        print(f"\nBatch {i+1}:")
        print(f"  Initial: {initial_batch.num_graphs} graphs, {initial_batch.num_nodes} nodes")
        print(f"  Final: {final_batch.num_graphs} graphs, {final_batch.num_nodes} nodes")
        print(f"  Barriers: {barriers.shape}")
        print(f"    Range: [{barriers.min():.3f}, {barriers.max():.3f}] eV")
        
        # Validate
        assert initial_batch.num_graphs == barriers.shape[0], "Batch size mismatch"
        assert final_batch.num_graphs == barriers.shape[0], "Batch size mismatch"
        
        if i >= 2:  # Only test first 3 batches
            break
    
    if val_loader is not None:
        print("\nIterating through val loader...")
        
        for i, (initial_batch, final_batch, barriers) in enumerate(val_loader):
            print(f"\nVal Batch {i+1}:")
            print(f"  Initial: {initial_batch.num_graphs} graphs")
            print(f"  Final: {final_batch.num_graphs} graphs")
            print(f"  Barriers: {barriers.shape}")
            
            if i >= 1:  # Only test first 2 batches
                break
    
    print(f"\n✓ TEST 6 PASSED")


def test_7_speed_benchmark(train_loader):
    """Test 7: Speed benchmark"""
    print_section("TEST 7: SPEED BENCHMARK")
    
    if train_loader is None:
        print("\n✗ SKIPPED: No train loader")
        return
    
    n_batches = min(10, len(train_loader))
    
    print(f"\nLoading {n_batches} batches...")
    
    start = time.time()
    for i, (initial_batch, final_batch, barriers) in enumerate(train_loader):
        if i >= n_batches:
            break
    elapsed = time.time() - start
    
    avg_time = (elapsed / n_batches) * 1000
    
    print(f"\n✓ Benchmark complete:")
    print(f"  Total time: {elapsed:.2f}s")
    print(f"  Average per batch: {avg_time:.1f}ms")
    print(f"  Batches per second: {n_batches/elapsed:.1f}")
    
    print(f"\n✓ TEST 7 PASSED")


def test_8_data_integrity(train_loader):
    """Test 8: Data integrity checks"""
    print_section("TEST 8: DATA INTEGRITY")
    
    if train_loader is None:
        print("\n✗ SKIPPED: No train loader")
        return
    
    print("\nChecking data integrity...")
    
    all_barriers = []
    all_node_counts = []
    all_feature_dims = []
    
    for initial_batch, final_batch, barriers in train_loader:
        # Collect statistics
        all_barriers.extend(barriers.tolist())
        all_node_counts.append(initial_batch.num_nodes / initial_batch.num_graphs)
        all_feature_dims.append(initial_batch.x.shape[1])
        
        # Check consistency
        assert initial_batch.x.shape[1] == final_batch.x.shape[1], "Feature dims match"
        assert initial_batch.num_nodes == final_batch.num_nodes, "Node counts match"
        assert torch.all(barriers > 0), "All barriers positive"
    
    print(f"\n✓ Integrity checks:")
    print(f"  Total samples checked: {len(all_barriers)}")
    print(f"  Barrier range: [{min(all_barriers):.3f}, {max(all_barriers):.3f}] eV")
    print(f"  Avg nodes per graph: {sum(all_node_counts)/len(all_node_counts):.1f}")
    print(f"  Feature dimension: {all_feature_dims[0]} (consistent: {len(set(all_feature_dims)) == 1})")
    
    # Validate
    assert len(set(all_feature_dims)) == 1, "All graphs should have same feature dim"
    assert all(b > 0 for b in all_barriers), "All barriers should be positive"
    
    print(f"\n✓ TEST 8 PASSED")


def main():
    """Run all tests"""
    print("\n" + "="*70)
    print("DATASET TEST SUITE")
    print("="*70)
    
    start_time = time.time()
    
    try:
        # Run tests
        dataset, config = test_1_dataset_creation()
        test_2_dataset_statistics(dataset)
        test_3_single_sample(dataset)
        test_4_collate_function(dataset)
        train_loader, val_loader = test_5_dataloaders(config)
        test_6_iteration(train_loader, val_loader)
        test_7_speed_benchmark(train_loader)
        test_8_data_integrity(train_loader)
        
        # Summary
        elapsed = time.time() - start_time
        print("\n" + "="*70)
        print(f"ALL TESTS PASSED ({elapsed:.1f}s)")
        print("="*70 + "\n")
        
        return 0
        
    except Exception as e:
        print("\n" + "="*70)
        print("TEST FAILED")
        print("="*70)
        print(f"\nError: {e}")
        import traceback
        traceback.print_exc()
        print("\n" + "="*70 + "\n")
        return 1


if __name__ == "__main__":
    exit(main())


DATASET TEST SUITE

TEST 1: DATASET CREATION

Creating dataset from: database_navi.csv
✓ Detected elements from database: ['Mo', 'Nb', 'O', 'Ta', 'W']
Building template graph...
✓ Template created: 127 nodes, 1764 edges
✓ Node features: 12 (pos: 3, one-hot: 5, props: 4)
Dataset loaded: 419 samples
  Removed 779 samples (barrier filtering)

✓ Dataset created:
  Total samples: 419

✓ TEST 1 PASSED

TEST 2: DATASET STATISTICS

Barrier statistics:
  Count: 419
  Min: 0.107 eV
  Max: 14.917 eV
  Mean: 3.671 eV
  Median: 0.891 eV
  Std: 4.475 eV

✓ TEST 2 PASSED

TEST 3: SINGLE SAMPLE LOADING

Loading first sample...

✓ Sample loaded in 86.87ms

Initial graph:
  Nodes: 127
  Node features: torch.Size([127, 12])
  Edges: 1764
  Label: 0.5891 eV

Final graph:
  Nodes: 127
  Node features: torch.Size([127, 12])
  Edges: 1764
  Label: 0.5891 eV

Barrier: 0.5891 eV

✓ TEST 3 PASSED

TEST 4: COLLATE FUNCTION

Creating batch manually...
  Loaded 4 samples

✓ Batch created in 1.28ms

Batch structur

  structure = parser.parse_structures(primitive=False)[0]



Batch 1:
  Initial: 32 graphs, 4064 nodes
  Final: 32 graphs, 4064 nodes
  Barriers: torch.Size([32])
    Range: [0.125, 14.869] eV

Batch 2:
  Initial: 32 graphs, 4064 nodes
  Final: 32 graphs, 4064 nodes
  Barriers: torch.Size([32])
    Range: [0.171, 14.917] eV

Batch 3:
  Initial: 32 graphs, 4064 nodes
  Final: 32 graphs, 4064 nodes
  Barriers: torch.Size([32])
    Range: [0.137, 12.559] eV

Iterating through val loader...

Val Batch 1:
  Initial: 32 graphs
  Final: 32 graphs
  Barriers: torch.Size([32])

Val Batch 2:
  Initial: 32 graphs
  Final: 32 graphs
  Barriers: torch.Size([32])

✓ TEST 6 PASSED

TEST 7: SPEED BENCHMARK

Loading 10 batches...

✓ Benchmark complete:
  Total time: 7.24s
  Average per batch: 724.1ms
  Batches per second: 1.4

✓ TEST 7 PASSED

TEST 8: DATA INTEGRITY

Checking data integrity...

✓ Integrity checks:
  Total samples checked: 335
  Barrier range: [0.107, 14.917] eV
  Avg nodes per graph: 127.0
  Feature dimension: 12 (consistent: True)

✓ TEST 8 PA

: 