## 8. Phase 1a: Real Partial Charges & VdW (Data Pipeline)

Now that the architecture is validated, we'll replace the *mock* physics with *real* physics.

**Goals:**
1.  Install `openmm` to assign real force field parameters.
2.  Update the data loader to extract per-atom partial charges, sigmas, and epsilons.
3.  Update the physics primitives to include Lennard-Jones (VdW) forces.
4.  Update the `AllAtomModel` and `PhysicsGNN` to use this new, richer edge information.

In [None]:
!pip install openmm pdbfixer --quiet

### 8.1. New Imports & Physics Feature Extractor

First, we'll import `openmm` and `pdbfixer` and write a helper function to extract all-atom physics parameters.

In [None]:
# New imports for this section
import openmm.app as app
import openmm
from pdbfixer import PDBFixer
import numpy as np

def get_all_atom_physics_features(atom_array: bsio.AtomArray) -> Dict[str, jnp.ndarray]:
    """Uses OpenMM to assign AMBER14 charges and VdW params."""
    
    # 1. Use PDBFixer to process the biotite AtomArray
    # We need to write a temporary PDB file for PDBFixer to read
    tmp_pdb_path = "/tmp/temp_for_fixer.pdb"
    bsio.save_structure(tmp_pdb_path, atom_array)
    fixer = PDBFixer(filename=tmp_pdb_path)
    
    # 2. Add missing hydrogens (crucial for correct charge assignment)
    fixer.addMissingHydrogens(7.0)
    
    # 3. Load AMBER14 force field
    forcefield = app.ForceField('amber14-all.atom.xml', 'amber14/tip3pfb.xml')
    
    # 4. Create OpenMM System to get parameters
    system = forcefield.createSystem(fixer.topology, nonbondedMethod=app.NoCutoff)
    
    # 5. Find the NonbondedForce
    nonbonded_force = None
    for force in system.getForces():
        if isinstance(force, openmm.NonbondedForce):
            nonbonded_force = force
            break
    if nonbonded_force is None:
        raise ValueError("Could not find NonbondedForce in OpenMM system.")

    # 6. Extract parameters (charge, sigma, epsilon) for *all* atoms (inc. H)
    all_atom_params = []
    for i in range(system.getNumParticles()):
        charge, sigma, epsilon = nonbonded_force.getParticleParameters(i)
        all_atom_params.append([
            charge._value, 
            sigma._value, 
            epsilon._value
        ])
    all_atom_params = np.array(all_atom_params)

    # 7. Filter back to *only* the original heavy atoms
    # We match the atom names and residue IDs from our original array
    original_atom_indices = []
    omm_topology = fixer.topology
    omm_atoms = list(omm_topology.atoms())
    
    idx_in_omm = 0
    for heavy_atom in atom_array:
        # Find the corresponding atom in the OpenMM topology
        # This is a bit brittle but works for this benchmark
        while True:
            omm_atom = omm_atoms[idx_in_omm]
            if (omm_atom.name == heavy_atom.atom_name and 
                omm_atom.residue.name == heavy_atom.res_name and 
                omm_atom.residue.id == str(heavy_atom.res_id)):
                original_atom_indices.append(idx_in_omm)
                idx_in_omm += 1
                break
            idx_in_omm += 1
            if idx_in_omm >= len(omm_atoms):
                raise RuntimeError("Topology matching failed.")
                
    heavy_atom_params = all_atom_params[original_atom_indices]
    
    # Clean up temp file
    os.remove(tmp_pdb_path)
    
    return {
        "charges": jnp.array(heavy_atom_params[:, 0:1]),   # [N, 1]
        "sigmas": jnp.array(heavy_atom_params[:, 1:2]),    # [N, 1]
        "epsilons": jnp.array(heavy_atom_params[:, 2:3]), # [N, 1]
    }

### 8.2. Update Data Loader

Now we'll create `load_all_atom_data_v2` to use this feature extractor.

In [None]:
def load_all_atom_data_v2(pdb_id: str, d_gae: int, key: jax.random.PRNGKey):
    # --- 1. Download and parse PDB --- 
    print(f"Downloading {pdb_id}...")
    pdb_path = rcsb.fetch(pdb_id, "pdb", target_path="/tmp")
    atom_array = bsio.load_structure(pdb_path)
    
    atom_array = atom_array[atom_array.model_id == 1]
    atom_array = atom_array[(atom_array.hetero == False) & (atom_array.element != 'H')]
    
    positions = jnp.array(atom_array.coord)
    n_atoms = positions.shape[0]
    print(f"Loaded {n_atoms} non-hydrogen atoms.")

    # --- 2. Get REAL Physics Features ---
    print("Assigning real physics parameters with OpenMM...")
    physics_features = get_all_atom_physics_features(atom_array)
    print(f"Charges shape: {physics_features['charges'].shape}")
    print(f"Sigmas shape: {physics_features['sigmas'].shape}")
    print(f"Epsilons shape: {physics_features['epsilons'].shape}")
    
    # --- 3. Mock GAE Feature (z_i) --- 
    # We still mock this, as the GAE is the *next* step
    key, z_key, target_key = jr.split(key, 3)
    mock_z_vectors = jr.normal(z_key, (n_atoms, d_gae))
    
    # Combine all features
    features = {
        "z_vectors": mock_z_vectors,       # [N, d_gae]
        "charges": physics_features["charges"],       # [N, 1]
        "sigmas": physics_features["sigmas"],         # [N, 1]
        "epsilons": physics_features["epsilons"],     # [N, 1]
    }

    # --- 4. Mock Ground Truth Target --- 
    y_target = jr.normal(target_key, (n_atoms, 1))
    
    return positions, features, y_target

# --- Test the new loader ---
key, data_key_v2 = jr.split(key)
positions_v2, features_v2, y_target_v2 = load_all_atom_data_v2("1UBQ", D_GAE, data_key_v2)

print(f"\nMean charge: {jnp.mean(features_v2['charges']):.4f}")

### 8.3. Update Physics Primitives (Add VdW)

We need to add a Lennard-Jones (VdW) force calculation to match our new features.

In [None]:
# We already have `pairwise_coulomb_force`
# Now let's define the LJ force. We'll use the vector form of F = -nabla(U).
def pairwise_lennard_jones_force(r_ij, sigma_i, sigma_j, epsilon_i, epsilon_j, eps=1e-7):
    """Calculates the LJ force vector using Lorentz-Berthelot combining rules."""
    # 1. Combine rules
    sigma = (sigma_i + sigma_j) / 2.0
    epsilon = jnp.sqrt(epsilon_i * epsilon_j)
    
    # 2. Calculate force
    r_mag = space.distance(r_ij)
    r_mag_sq = r_mag**2 + eps
    
    sig_over_r_sq = (sigma**2) / r_mag_sq
    sig_over_r_6 = sig_over_r_sq**3
    sig_over_r_12 = sig_over_r_6**2
    
    # Force magnitude: F = 24 * epsilon * (2 * (sig/r)^12 - (sig/r)^6) / r
    force_mag = 24.0 * epsilon * (2.0 * sig_over_r_12 - sig_over_r_6) / (r_mag + eps)
    
    # Force vector
    force_vec = force_mag * (r_ij / (r_mag + eps))
    return force_vec

# Vectorize it for the graph edges
vmap_lj_force = jax.vmap(
    pairwise_lennard_jones_force,
    in_axes=(0, 0, 0, 0, 0) # r_ij, s_i, s_j, e_i, e_j
)

### 8.4. Update Model Architecture (v2)

The `AllAtomModel` and `PhysicsGNN` must be updated to handle this new, richer edge information.

In [None]:
# The `RBF` and `PhysicsAttentionTransformer` classes are unchanged.
# We only need to update the GNN and the main Model class.

class PhysicsGNN_v2(eqx.Module):
    """ The Jraph-based 'combiner' GNN (v2). """
    gnn_layers: List[jraph.GraphNetwork]

    def __init__(self, d_model: int, d_hidden: int, n_layers: int, key: jax.random.PRNGKey):
        
        def _make_layer(key):
            # This is the update_node_fn, implementing Decision 3 ("Benchmark Cheat")
            def update_node_fn(nodes: jnp.ndarray, 
                               sent_edges: jnp.ndarray, 
                               received_edges: jnp.ndarray, 
                               globals: jnp.ndarray) -> jnp.ndarray:
                """
                nodes: [n_nodes, d_model] (invariant features)
                received_edges: [n_nodes, 6] (sum of incoming forces: 3 for Coulomb, 3 for LJ)
                """
                # Convert equivariant forces to invariant magnitudes
                coulomb_mag = jnp.linalg.norm(received_edges[..., :3], axis=-1, keepdims=True)
                lj_mag = jnp.linalg.norm(received_edges[..., 3:], axis=-1, keepdims=True)
                
                # Concatenate and process with an MLP
                in_features = jnp.concatenate([nodes, coulomb_mag, lj_mag], axis=-1)
                
                # The input dim is now d_model + 2 (one for each force mag)
                mlp = eqx.nn.MLP(in_features=d_model + 2, out_features=d_model, 
                                   width_size=d_hidden, depth=1, key=key)
                return nodes + mlp(in_features) # Residual connection

            return jraph.GraphNetwork(
                update_edge_fn=None,
                update_node_fn=update_node_fn,
                update_global_fn=None
            )
        
        layer_keys = jr.split(key, n_layers)
        self.gnn_layers = [_make_layer(k) for k in layer_keys]
        
    def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        for layer in self.gnn_layers:
            graph = layer(graph)
        return graph

class AllAtomModel_v2(eqx.Module):
    """ The main model (v2) that processes real physics features. """
    transformer: PhysicsAttentionTransformer
    gnn: PhysicsGNN_v2
    prediction_head: eqx.nn.MLP
    
    def __init__(self, d_gae: int, d_model: int, d_hidden: int, n_gnn_layers: int, key: jax.random.PRNGKey):
        tfm_key, gnn_key, head_key = jr.split(key, 3)
        
        self.transformer = PhysicsAttentionTransformer(d_gae, d_model, tfm_key)
        self.gnn = PhysicsGNN_v2(d_model, d_hidden, n_gnn_layers, gnn_key)
        self.prediction_head = eqx.nn.MLP(
            in_size=d_model, 
            out_size=1, 
            width_size=d_hidden, 
            depth=1, 
            key=head_key
        )

    def __call__(self, positions: jnp.ndarray, features: Dict, k: int) -> jnp.ndarray:
        n_atoms = positions.shape[0]
        
        # 1. Get Geometry
        all_displacements, all_distances = get_geometry(positions)
        
        # 2. Get Sparse Graph
        attn_logits, h_atoms = self.transformer(features["z_vectors"], all_distances)
        _, top_k_indices = jax.lax.top_k(attn_logits, k=k)

        # 3. Build Jraph GraphTuple
        senders = jnp.arange(n_atoms).repeat(k)
        receivers = top_k_indices.flatten()
        top_k_displacements = all_displacements[senders, receivers]
        
        # Get all physics features for the edges
        charges = features["charges"].flatten()
        sigmas = features["sigmas"].flatten()
        epsilons = features["epsilons"].flatten()
        
        q_i, q_j = charges[senders], charges[receivers]
        s_i, s_j = sigmas[senders], sigmas[receivers]
        e_i, e_j = epsilons[senders], epsilons[receivers]
        
        # Calculate physical edge features (force vectors)
        edge_coulomb_forces = vmap_coulomb_force(top_k_displacements, q_i, q_j)
        edge_lj_forces = vmap_lj_force(top_k_displacements, s_i, s_j, e_i, e_j)
        
        # Concatenate forces into a single edge feature tensor [N*k, 6]
        all_edge_forces = jnp.concatenate([edge_coulomb_forces, edge_lj_forces], axis=-1)
        
        graph_in = jraph.GraphsTuple(
            nodes=h_atoms,
            edges=all_edge_forces,
            senders=senders,
            receivers=receivers,
            n_node=jnp.array([n_atoms]),
            n_edge=jnp.array([senders.shape[0]]),
            globals=None
        )
        
        # 4. Run GNN
        graph_out = self.gnn(graph_in)
        
        # 5. Predict
        final_atom_reps = graph_out.nodes
        predictions = jax.vmap(self.prediction_head)(final_atom_reps)
        
        return predictions

### 8.5. Re-run Benchmarks with Real Physics

You can now copy-paste your training cells from Benchmarks 1 & 2 (cells #5 and #6) to re-run the tests. Just make sure to:
1.  Use `AllAtomModel_v2` and `PhysicsGNN_v2`.
2.  Use the `positions_v2`, `features_v2`, and `y_target_v2` loaded from the new function.

This will confirm that your architecture is *still* SE(3) invariant even when processing multiple, real physics features. This is the code for a new training cell:

In [None]:
# --- Re-run Benchmark 1 (Static) with Real Physics ---
print("=== Running Benchmarks with REAL Physics Features ===")
key, model_key_v2 = jr.split(key)
model_v2 = AllAtomModel_v2(D_GAE, D_MODEL, D_HIDDEN, N_GNN_LAYERS, model_key_v2)

# Optimizer
optimizer_v2 = optax.adam(LEARNING_RATE)
opt_state_v2 = optimizer_v2.init(eqx.filter(model_v2, eqx.is_array))

# Loss function definition is unchanged, but we'll redefine train_step
# to be sure it's using the v2 model and optimizer
@eqx.filter_jit
def train_step_v2(model, opt_state, pos, feats, y_true):
    (loss, grads) = loss_fn(model, pos, feats, y_true)
    updates, opt_state = optimizer_v2.apply_updates(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

@eqx.filter_jit
def train_step_rotated_v2(model, opt_state, pos, feats, y_true, key):
    R = random_rotation_matrix(key)
    rotated_pos = (R @ pos.T).T
    (loss, grads) = loss_fn(model, rotated_pos, feats, y_true)
    updates, opt_state = optimizer_v2.apply_updates(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

# --- Loop 1: Static --- 
print("Starting Benchmark 1 (Static, v2)...")
static_losses_v2 = []
static_model_v2 = model_v2
static_opt_state_v2 = opt_state_v2

for step in range(STEPS):
    static_model_v2, static_opt_state_v2, loss = train_step_v2(
        static_model_v2, static_opt_state_v2, positions_v2, features_v2, y_target_v2
    )
    static_losses_v2.append(loss)
print(f"Step {STEPS-1}, Final Loss: {static_losses_v2[-1]:.6f}")

# --- Loop 2: Rotated --- 
print("\nStarting Benchmark 2 (Rotated, v2)...")
rotated_losses_v2 = []
rotated_model_v2 = model_v2 # Reset to initial model
rotated_opt_state_v2 = opt_state_v2
key, train_key_v2 = jr.split(key)

for step in range(STEPS):
    train_key_v2, step_key = jr.split(train_key_v2)
    rotated_model_v2, rotated_opt_state_v2, loss = train_step_rotated_v2(
        rotated_model_v2, rotated_opt_state_v2, positions_v2, features_v2, y_target_v2, step_key
    )
    rotated_losses_v2.append(loss)
print(f"Step {STEPS-1}, Final Loss: {rotated_losses_v2[-1]:.6f}")

# --- Plot v2 Results ---
plt.figure(figsize=(12, 6))
plt.title("v2 Overfitting Test Results (Real Physics)", fontsize=16)
plt.plot(static_losses_v2, label=f"Benchmark 1 (Static, Real Physics)")
plt.plot(rotated_losses_v2, label=f"Benchmark 2 (Rotated, Real Physics)", linestyle='--')
plt.xlabel("Training Step")
plt.ylabel("MSE Loss")
plt.yscale("log")
plt.legend()
plt.grid(True, which="both", ls="--", alpha=0.5)
plt.show()

### Next Steps (Phase 1b & 2)

After this, the next logical steps would be:
1.  **Phase 1b:** Build the Graph Autoencoder to replace the `mock_z_vectors`.
2.  **Phase 2:** Replace `y_target` with real pre-training targets from an APBS solver.