![ML Handover Hero](assets/hero.png)

# ML Integration: Data Factory Flow 🤖

This notebook demonstrates how to use `synth-pdb` as a high-speed data factory for Training Protein AI models. 

We leverage the `BatchedGenerator` to produce thousands of structures in milliseconds and feed them directly into **PyTorch** and **JAX** with **Zero-Copy** memory handover.

### The Data Factory Workflow
Traditional structural bio tools are optimized for single-file PDB processing. `synth-pdb` is optimized for **tensor throughput**.

![Protein Data Factory Workflow](assets/workflow.png)


In [None]:
# @title Setup & Installation (Run if on Colab) { display-mode: "form" }
import os
import sys

if 'google.colab' in str(get_ipython()):
    print("Running on Google Colab. Installing dependencies...")
    !pip install biotite numba
    
    if not os.path.exists('synth-pdb'):
        !git clone https://github.com/elkins-lab/synth-pdb.git
    
    sys.path.append(os.path.abspath('synth-pdb'))
    print("Setup Complete.")
else:
    print("Running locally. Ensure synth_pdb is in your python path.")

In [None]:
import numpy as np
import time
import matplotlib.pyplot as plt
from synth_pdb.batch_generator import BatchedGenerator

print("Libraries Loaded.")

## 1. High-Speed Generation
We'll generate a batch of 1,000 peptides of length 50. In a traditional serial loop, this would take significant time. In `synth-pdb`, it's a single matrix operation.

In [None]:
# Construct a clean sequence of 49 residues
residues = ["ALA", "GLY", "SER", "LEU", "VAL", "ILE", "MET"] * 7
sequence = "-".join(residues)
n_batch = 1000

generator = BatchedGenerator(sequence, n_batch=n_batch, full_atom=False)

start = time.time()
batch = generator.generate_batch(drift=5.0)
print(f"Generated {n_batch} structures.")


### Benchmark: Serial vs. Batched Generation
Why use `BatchedGenerator`? Below we compare the time to generate 1000 structures one-by-one vs. generating them in a single batch.

In [None]:
from synth_pdb.generator import generate_pdb_content

def run_benchmark(n=100):
    # 1. Traditional Serial Generation
    # We generate structures one-by-one, including building the PDB string.
    # This is how most traditional bio scripts work.
    start_serial = time.time()
    for _ in range(n):
        _ = generate_pdb_content(sequence_str=sequence, minimize_energy=False)
    serial_dt = time.time() - start_serial
    
    # 2. Batched Generation
    # We generate the entire batch at once using vectorized NumPy math.
    start_batched = time.time()
    _ = generator.generate_batch(drift=1.0)
    batched_dt = time.time() - start_batched
    
    return serial_dt, batched_dt

# We run for a smaller subset to keep the notebook responsive
n_test = 100
s_time, b_time = run_benchmark(n_test)

# Scale to show the cost per 1k structures
s_1k = s_time * (1000/n_test)
b_1k = b_time * (1000/n_batch) if n_batch > 0 else b_time # Already generated 1k in previous cell

plt.figure(figsize=(8, 4))
bars = plt.bar(["Traditional Serial", "synth-pdb Batched"], [s_1k, b_time], color=["#ff9999", "#667eea"])
plt.ylabel("Seconds per 1,000 Structures")
plt.title(f"Real-World Performance Comparison")
plt.grid(axis="y", linestyle="--", alpha=0.7)

for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, yval + (s_1k*0.02), f"{yval:.3f}s", ha="center", va="bottom", fontweight="bold")

plt.show()

print(f"Vectorization Speedup: {s_1k / b_time:.1f}x")
print(f"Theoretical throughput: {1000/b_time:.0f} structures/sec")

## 2. PyTorch Handover (Zero-Copy)
PyTorch can "wrap" a NumPy array without copying it. Any change to the NumPy array will be reflected in the Tensor (and vice versa).

In [None]:
try:
    import torch
    
    # Zero-copy handover
    torch_tensor = torch.from_numpy(batch.coords).float()
    
    print("✅ PyTorch Handover successful!")
    print(f"Tensor Device: {torch_tensor.device}")
    print(f"Contiguous in memory: {torch_tensor.is_contiguous()}")
except ImportError:
    print("❌ PyTorch not found. Use 'pip install torch' to see this in action.")

## 3. JAX / MLX Handover
JAX also supports efficient conversion from NumPy.

In [None]:
try:
    import jax.numpy as jnp
    
    jax_array = jnp.array(batch.coords)
    print("✅ JAX Handover successful!")
    print(f"JAX Device: {jax_array.device}")
except ImportError:
    print("❌ JAX not found.")

## 4. Educational Note: Why does this matter?

In deep learning for proteins, the **Data Loading** step is often the bottleneck. If your GPU has to wait for Python loops to calculate coordinates, it sits idle. 

By using `BatchedGenerator`, you can:
1. Keep generation on the CPU/AMX units while the GPU trains.
2. Avoid expensive serialized PDB parsing.
3. Feed thousands of "Hard Decoys" (structures with noise) to help your model learn the energy landscape.

## 4. Visualizing the Data: Structural Ensembles
In ML, we often want to train on "Hard Decoys"—structures that are mostly correct but have physical noise. `BatchedGenerator` can produce these ensembles instantly.

In [None]:
# Visualize the first 10 structures in the batch overlaid
plt.figure(figsize=(10, 6))
for i in range(10):
    # Projecting 3D to 2D for simple matplotlib viz
    plt.plot(batch.coords[i, :, 0], batch.coords[i, :, 1], alpha=0.3, label=f"Model {i}" if i==0 else "")

plt.title("Ensemble Drift: Structural Noise for ML Training")
plt.xlabel("X (Å)")
plt.ylabel("Y (Å)")
plt.legend()
plt.show()

### Interactive 3D Inspection
Use `3Dmol.js` to inspect a sample structure from the batch.

In [None]:
try:
    import py3Dmol
    import numpy as np
    from synth_pdb.batch_generator import BatchedPeptide
    
    # 1. Coordinate Extraction & Zero-Point Removal
    c = batch.coords[0].copy()
    mask = np.any(c != 0, axis=1)
    c_clean = c[mask]
    
    # 2. Precise Bounding-Box Centering
    center = (c_clean.min(axis=0) + c_clean.max(axis=0)) / 2
    c_centered = c_clean - center
    
    # 3. Create view-ready Peptide
    p = BatchedPeptide(
        c_centered[np.newaxis, ...], 
        batch.sequence, 
        np.array(batch.atom_names)[mask].tolist(), 
        np.array(batch.residue_indices)[mask].tolist()
    )
    
    # 4. Rendering
    view = py3Dmol.view(width=800, height=400)
    view.setBackgroundColor("#fdfdfd")
    view.addModel(p.to_pdb(0), "pdb")
    view.setStyle({"stick": {"radius": 0.15}, "cartoon": {"color": "spectrum"}})
    
    view.zoomTo()
    view.center()
    view.zoom(1.2)
    view.show()
    
    print(f"Viewer Ready. Visualizing {len(c_clean)} atoms.")
except ImportError:
    print("py3Dmol not installed.")