# PrxteinMPNN Training Test Notebook

This notebook tests all possible training paths for PrxteinMPNN:
1. **Models**: Autoregressive (AR) and Diffusion
2. **Features**: Vanilla, Electrostatics, Van der Waals (vdW), Both

It uses a synthetic dataset to verify that the code runs correctly for each configuration.

In [None]:
# Install dependencies (if running in Colab)
!pip install jax jaxlib equinox optax array_record msgpack msgpack-numpy
# Install the package in editable mode (assuming we are in the repo root)
%pip install -e .

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import numpy as np
from prxteinmpnn.model.diffusion_mpnn import DiffusionPrxteinMPNN
from prxteinmpnn.training.diffusion import NoiseSchedule
from prxteinmpnn.training.train_diffusion import train_step as train_step_diff
from prxteinmpnn.training.trainer import train_step as train_step_ar
from prxteinmpnn.model.mpnn import PrxteinMPNN
from prxteinmpnn.utils.types import TrainingMetrics

print("Imports successful!")

In [None]:
# Synthetic Data Generator
def get_batch(batch_size=4, length=50, feature_dim=5):
    key = jax.random.PRNGKey(0)
    coords = jax.random.normal(key, (batch_size, length, 4, 3))
    mask = jnp.ones((batch_size, length))
    residue_index = jnp.tile(jnp.arange(length), (batch_size, 1))
    chain_index = jnp.zeros((batch_size, length), dtype=jnp.int32)
    sequence = jax.random.randint(key, (batch_size, length), 0, 21)
    
    # Physics features: [B, N, Dim]
    # For 'Both', dim would be 10
    phys_features = jax.random.normal(key, (batch_size, length, feature_dim))
    
    return coords, mask, residue_index, chain_index, sequence, phys_features

In [None]:
def run_test(model_type, feature_type):
    print(f"Testing {model_type} with {feature_type} features...")
    
    # Config
    node_features = 64
    edge_features = 64
    hidden_features = 64
    num_layers = 2
    
    # Feature setup
    if feature_type == "vanilla":
        phys_dim = None
        phys_input = None
    elif feature_type == "electrostatics":
        phys_dim = 5
        phys_input = get_batch(feature_dim=5)[-1]
    elif feature_type == "vdw":
        phys_dim = 5
        phys_input = get_batch(feature_dim=5)[-1]
    elif feature_type == "both":
        phys_dim = 10
        phys_input = get_batch(feature_dim=10)[-1]
    
    key = jax.random.PRNGKey(42)
    
    # Model Init
    if model_type == "diffusion":
        model = DiffusionPrxteinMPNN(
            node_features=node_features,
            edge_features=edge_features,
            hidden_features=hidden_features,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            k_neighbors=10,
            physics_feature_dim=phys_dim,
            key=key
        )
    else: # AR
        model = PrxteinMPNN(
            node_features=node_features,
            edge_features=edge_features,
            hidden_features=hidden_features,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            k_neighbors=10,
            physics_feature_dim=phys_dim,
            key=key
        )
        
    optimizer = optax.adam(1e-4)
    opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
    
    # Data
    coords, mask, res_idx, chain_idx, seq, _ = get_batch()
    
    # Train Step
    try:
        if model_type == "diffusion":
            noise_schedule = NoiseSchedule(num_steps=10)
            lr_schedule = optax.constant_schedule(1e-4)
            
            model, opt_state, metrics = train_step_diff(
                model,
                opt_state,
                optimizer,
                coords,
                mask,
                res_idx,
                chain_idx,
                seq,
                key,
                noise_schedule,
                lr_schedule,
                0,
                physics_features=phys_input,
                physics_noise_scale=0.5 if phys_input is not None else 0.0
            )
        else: # AR
            lr_schedule = optax.constant_schedule(1e-4)
            model, opt_state, metrics = train_step_ar(
                model,
                opt_state,
                optimizer,
                coords,
                mask,
                res_idx,
                chain_idx,
                seq,
                key,
                0.1, # label_smoothing
                0, # current_step
                lr_schedule,
                physics_features=phys_input,
            )
            
        print(f"SUCCESS: {model_type} + {feature_type}. Loss: {metrics.loss:.4f}")
        
    except Exception as e:
        print(f"FAILED: {model_type} + {feature_type}. Error: {e}")
        import traceback
        traceback.print_exc()

In [None]:
# Run All Tests
modes = ["ar", "diffusion"]
features = ["vanilla", "electrostatics", "vdw", "both"]

for m in modes:
    for f in features:
        run_test(m, f)