# QCML JAX vs PyTorch GPU Performance Test

This notebook tests the JAX implementation against PyTorch on GPU to measure performance improvements.


In [None]:
# Check GPU availability
import torch
print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"PyTorch CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"PyTorch CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

import jax
print(f"JAX devices: {jax.devices()}")
print(f"JAX default backend: {jax.default_backend()}")


In [None]:
# Install required packages
!pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install optax matplotlib numpy scipy


In [None]:
# Clone the repository (update with your actual GitHub URL)
!git clone https://github.com/your-username/qcml_new.git
!cd qcml_new && pip install -e .


In [None]:
# Import required modules
import sys
sys.path.append('qcml_new')
sys.path.append('qcml_new/qcml_fresh')

import numpy as np
import torch
import jax
import jax.numpy as jnp
import time
import matplotlib.pyplot as plt

# Import our implementations
from qcml.matrix_trainer.matrix_trainer import MatrixConfigurationTrainer as PyTorchTrainer
from qcml.matrix_trainer.jax_matrix_trainer import JAXMatrixTrainer, MatrixTrainerConfig
from qcml.manifolds.sphere import SphereManifold
from qcml.manifolds.spiral import SpiralManifold

print("âœ… All imports successful!")


In [None]:
def gpu_performance_test():
    """Run GPU performance comparison between JAX and PyTorch"""
    
    print("ðŸš€ Starting GPU Performance Test")
    print("=" * 50)
    
    # Test parameters
    test_cases = [
        {
            'name': 'sphere_small',
            'manifold': SphereManifold(dimension=3, noise=0.0),
            'n_points': 1000,
            'N': 3, 'D': 3,
            'n_epochs': 100,
            'w_qf': 0.0,
            'learning_rate': 0.001
        },
        {
            'name': 'sphere_large',
            'manifold': SphereManifold(dimension=3, noise=0.0),
            'n_points': 5000,
            'N': 3, 'D': 3,
            'n_epochs': 100,
            'w_qf': 0.0,
            'learning_rate': 0.001
        },
        {
            'name': 'spiral_small',
            'manifold': SpiralManifold(noise=0.0),
            'n_points': 1000,
            'N': 4, 'D': 3,
            'n_epochs': 100,
            'w_qf': 0.0,
            'learning_rate': 0.0005
        }
    ]
    
    results = []
    
    for test_case in test_cases:
        print(f"\nðŸ”¹ Testing: {test_case['name']}")
        print(f"Points: {test_case['n_points']}, N={test_case['N']}, D={test_case['D']}")
        
        # Generate test data
        points = test_case['manifold'].generate_points(n_points=test_case['n_points'])
        
        # Set random seed for reproducibility
        seed = 42
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        # PyTorch test
        print("  PyTorch GPU test...")
        pytorch_trainer = PyTorchTrainer(
            points_np=points,
            N=test_case['N'],
            D=test_case['D'],
            quantum_fluctuation_weight=test_case['w_qf'],
            learning_rate=test_case['learning_rate'],
            torch_seed=seed
        )
        
        start_time = time.time()
        pytorch_history = pytorch_trainer.train_matrix_configuration(
            n_epochs=test_case['n_epochs'],
            batch_size=min(500, test_case['n_points']),
            verbose=False
        )
        pytorch_time = time.time() - start_time
        pytorch_final_loss = pytorch_trainer.forward(torch.tensor(points))['total_loss'].item()
        
        # JAX test
        print("  JAX GPU test...")
        config = MatrixTrainerConfig(
            N=test_case['N'],
            D=test_case['D'],
            quantum_fluctuation_weight=test_case['w_qf'],
            learning_rate=test_case['learning_rate']
        )
        jax_trainer = JAXMatrixTrainer(config)
        
        start_time = time.time()
        jax_history = jax_trainer.train(jnp.array(points), verbose=False)
        jax_time = time.time() - start_time
        
        # Get final loss
        matrices_jax = jnp.stack(jax_trainer.matrices)
        jax_loss_dict = jax_trainer._loss_function(
            matrices_jax, jnp.array(points), test_case['N'], test_case['D'], 0.0, test_case['w_qf']
        )
        jax_final_loss = float(jax_loss_dict['total_loss'])
        
        # Calculate speedup
        speedup = pytorch_time / jax_time if jax_time > 0 else 0
        
        # Store results
        result = {
            'test_case': test_case['name'],
            'n_points': test_case['n_points'],
            'pytorch_time': pytorch_time,
            'jax_time': jax_time,
            'speedup': speedup,
            'pytorch_loss': pytorch_final_loss,
            'jax_loss': jax_final_loss,
            'loss_difference': abs(pytorch_final_loss - jax_final_loss)
        }
        results.append(result)
        
        print(f"    PyTorch: {pytorch_time:.2f}s, Loss: {pytorch_final_loss:.6f}")
        print(f"    JAX:     {jax_time:.2f}s, Loss: {jax_final_loss:.6f}")
        print(f"    Speedup: {speedup:.2f}x")
        print(f"    Loss diff: {result['loss_difference']:.6f}")
    
    return results

# Run the test
results = gpu_performance_test()
