In [None]:
import pado
from pado.math import nm, um
import torch
import matplotlib.pyplot as plt

In [None]:
# Test float32 and MPS acceleration
print("=" * 60)
print("Testing float32 and MPS acceleration")
print("=" * 60)

# Check available devices
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False}")

# Select device
if torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False:
    device = 'mps'
    print(f"\nUsing device: {device}")
elif torch.cuda.is_available():
    device = 'cuda:0'
    print(f"\nUsing device: {device}")
else:
    device = 'cpu'
    print(f"\nUsing device: {device}")

# Test parameters
R = 512
C = 512
pitch = 2 * um
wvl = 660 * nm
dim = (1, 1, R, C)

print(f"\nTest parameters:")
print(f"  Resolution: {R}x{C}")
print(f"  Pitch: {pitch*1e6:.2f} um")
print(f"  Wavelength: {wvl*1e9:.2f} nm")


In [None]:
# Test 1: Create Light with float32 (complex64)
print("\n" + "=" * 60)
print("Test 1: Creating Light with float32 (complex64)")
print("=" * 60)

try:
    # Create light with default dtype (should be complex64/float32)
    light = pado.light.Light(dim, pitch, wvl, device=device)
    
    # Check dtype
    print(f"Light field dtype: {light.field.dtype}")
    print(f"Light field device: {light.field.device}")
    print(f"Light field shape: {light.field.shape}")
    
    # Verify it's complex64 (which uses float32 for real/imag parts)
    assert light.field.dtype == torch.complex64, f"Expected complex64, got {light.field.dtype}"
    assert light.field.device.type == device.split(':')[0], f"Expected device {device}, got {light.field.device}"
    
    print("✓ Light created successfully with float32 (complex64)")
    
except Exception as e:
    print(f"✗ Error creating Light: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# Test 2: Basic operations with float32 on MPS
print("\n" + "=" * 60)
print("Test 2: Basic operations with float32")
print("=" * 60)

try:
    # Set plane light
    light.set_plane_light(theta=5.0)
    print(f"✓ set_plane_light() completed")
    
    # Get amplitude, phase, intensity
    amplitude = light.get_amplitude()
    phase = light.get_phase()
    intensity = light.get_intensity()
    
    print(f"  Amplitude dtype: {amplitude.dtype}, device: {amplitude.device}")
    print(f"  Phase dtype: {phase.dtype}, device: {phase.device}")
    print(f"  Intensity dtype: {intensity.dtype}, device: {intensity.device}")
    
    # Verify dtypes are float32
    assert amplitude.dtype == torch.float32, f"Expected float32, got {amplitude.dtype}"
    assert phase.dtype == torch.float32, f"Expected float32, got {phase.dtype}"
    assert intensity.dtype == torch.float32, f"Expected float32, got {intensity.dtype}"
    
    print("✓ Basic operations completed successfully")
    
except Exception as e:
    print(f"✗ Error in basic operations: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# Test 3: Optical element operations
print("\n" + "=" * 60)
print("Test 3: Optical element operations")
print("=" * 60)

try:
    # Create a lens
    lens = pado.optical_element.RefractiveLens(
        dim=dim,
        pitch=pitch,
        wvl=wvl,
        focal_length=50e-3,
        device=device
    )
    print(f"✓ Lens created, field_change dtype: {lens.field_change.dtype}, device: {lens.field_change.device}")
    
    # Apply lens to light
    light_after_lens = lens.forward(light)
    print(f"✓ Lens forward() completed")
    print(f"  Output field dtype: {light_after_lens.field.dtype}, device: {light_after_lens.field.device}")
    
    # Create an aperture
    aperture = pado.optical_element.Aperture(
        dim=dim,
        pitch=pitch,
        wvl=wvl,
        radius=100e-6,
        device=device
    )
    print(f"✓ Aperture created, field_change dtype: {aperture.field_change.dtype}, device: {aperture.field_change.device}")
    
    # Apply aperture
    light_after_aperture = aperture.forward(light_after_lens)
    print(f"✓ Aperture forward() completed")
    print(f"  Output field dtype: {light_after_aperture.field.dtype}, device: {light_after_aperture.field.device}")
    
    print("✓ Optical element operations completed successfully")
    
except Exception as e:
    print(f"✗ Error in optical element operations: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# Test 4: Propagation with ASM
print("\n" + "=" * 60)
print("Test 4: Propagation with ASM (Angular Spectrum Method)")
print("=" * 60)

try:
    # Create propagator (doesn't take device parameter, uses light's device)
    propagator = pado.propagator.Propagator(mode='ASM')
    print(f"✓ Propagator created")
    
    # Propagate light
    import time
    start_time = time.time()
    light_propagated = propagator.forward(light, z=10e-3)
    elapsed_time = time.time() - start_time
    
    print(f"✓ Propagation completed in {elapsed_time:.4f} seconds")
    print(f"  Output field dtype: {light_propagated.field.dtype}, device: {light_propagated.field.device}")
    print(f"  Output field shape: {light_propagated.field.shape}")
    
    # Verify dtype and device
    assert light_propagated.field.dtype == torch.complex64, f"Expected complex64, got {light_propagated.field.dtype}"
    assert light_propagated.field.device.type == device.split(':')[0], f"Expected device {device}, got {light_propagated.field.device}"
    
    print("✓ Propagation test completed successfully")
    
except Exception as e:
    print(f"✗ Error in propagation: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# Test 5: Performance comparison (if MPS is available)
print("\n" + "=" * 60)
print("Test 5: Performance test")
print("=" * 60)

if device == 'mps':
    try:
        import time
        
        # Warm up
        _ = propagator.forward(light, z=10e-3)
        if hasattr(torch.backends, 'mps'):
            torch.mps.synchronize()
        
        # Benchmark
        num_iterations = 10
        start_time = time.time()
        for _ in range(num_iterations):
            _ = propagator.forward(light, z=10e-3)
        if hasattr(torch.backends, 'mps'):
            torch.mps.synchronize()
        elapsed_time = time.time() - start_time
        
        avg_time = elapsed_time / num_iterations
        print(f"✓ Performance test completed")
        print(f"  Average time per propagation: {avg_time:.4f} seconds")
        print(f"  Total time for {num_iterations} iterations: {elapsed_time:.4f} seconds")
        print(f"  Throughput: {num_iterations/elapsed_time:.2f} propagations/second")
        
    except Exception as e:
        print(f"✗ Error in performance test: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"  Skipping performance test (not using MPS device)")


In [None]:
# Test 6: Verify results are reasonable
print("\n" + "=" * 60)
print("Test 6: Verify results are reasonable")
print("=" * 60)

try:
    # Check that intensity values are reasonable
    intensity = light_propagated.get_intensity()
    intensity_mean = intensity.mean().item()
    intensity_max = intensity.max().item()
    intensity_min = intensity.min().item()
    
    print(f"  Intensity statistics:")
    print(f"    Mean: {intensity_mean:.6f}")
    print(f"    Min: {intensity_min:.6f}")
    print(f"    Max: {intensity_max:.6f}")
    
    # Check that values are finite
    assert torch.isfinite(intensity).all(), "Intensity contains non-finite values"
    assert intensity_min >= 0, "Intensity should be non-negative"
    
    print("✓ Results verification passed")
    
except Exception as e:
    print(f"✗ Error in verification: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "=" * 60)
print("All tests completed!")
print("=" * 60)
