# Zero-Copy Handover to JAX

This notebook demonstrates how to efficiently transfer data from a `synth_pdb.BatchedPeptide` object to a JAX `DeviceArray`.

### Install Dependencies
If you don't have JAX installed, you can install it using pip:

In [None]:
# !pip install jax jaxlib

In [None]:
import numpy as np
from synth_pdb.generator import BatchedGenerator

try:
    # Note: JAX defaults to float32. We enable float64 if precision is critical,
    # but for Handover demos, we usually stick to defaults.
    import jax.numpy as jnp
except ImportError:
    print("\033[91mError: JAX not found.\033[0m")
    print("JAX is required for this notebook. Please install it using: !pip install jax jaxlib")

In [None]:
# 1. Generate a batch of 10 peptides, each 5 residues long
sequence = "A" * 5
generator = BatchedGenerator(sequence_str=sequence, n_batch=10)
peptide_batch = generator.generate_batch(conformation='alpha')

# 2. Access the underlying contiguous C-ordered numpy array of coordinates
coords_np = peptide_batch.coords

In [None]:
# 3. Create a JAX array from the numpy array
try:
    coords_jax = jnp.asarray(coords_np)
    print(f"Numpy array shape: {coords_np.shape} (dtype: {coords_np.dtype})")
    print(f"JAX array shape: {coords_jax.shape} (dtype: {coords_jax.dtype})")
except NameError:
    print("Skipping JAX conversion due to missing dependency.")

In [None]:
# 4. Verify that the data is the same
try:
    # We use assert_allclose instead of assert_array_equal because JAX often 
    # defaults to float32 even if the source NumPy array is float64.
    np.testing.assert_allclose(coords_np, coords_jax, atol=1e-5)
    print("Verification successful: Data is consistent between Numpy and JAX.")
except NameError:
    print("Skipping verification due to missing dependency.")