# PrxteinMPNN Example Notebook

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/maraxen/PrxteinMPNN/blob/main/examples/example_notebook.ipynb)

Welcome to the PrxteinMPNN example notebook! This notebook demonstrates the core functionality of PrxteinMPNN, a functional interface for ProteinMPNN built with JAX.

## What you'll learn:
- How to load and work with protein structures
- How to score sequences using ProteinMPNN
- How to sample new sequences
- How to leverage JAX transformations for performance

In [1]:
# Install PrxteinMPNN and restart kernel
!git clone https://github.com/maraxen/PrxteinMPNN.git
%cd PrxteinMPNN
!pip install -e .

Cloning into 'PrxteinMPNN'...
remote: Enumerating objects: 1125, done.[K
remote: Counting objects: 100% (616/616), done.[K
remote: Compressing objects: 100% (325/325), done.[K
remote: Total 1125 (delta 263), reused 530 (delta 220), pack-reused 509 (from 1)[K
Receiving objects: 100% (1125/1125), 49.21 MiB | 18.62 MiB/s, done.
Resolving deltas: 100% (551/551), done.
/content/PrxteinMPNN
Obtaining file:///content/PrxteinMPNN
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting jaxtyping (from prxteinmpnn==0.1.0)
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting biotite (from prxteinmpnn==0.1.0)
  Downloading biotite-1.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting foldcomp (from prxteinmpnn==0.1.0)
  Down

## 1. Setup and Imports

First, let's import the necessary libraries and modules from PrxteinMPNN.

In [5]:
import jax
import jax.numpy as jnp
import numpy as np
from typing import Any

print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")
print(f"Available devices: {jax.devices()}")

JAX version: 0.5.2
JAX backend: cpu
Available devices: [CpuDevice(id=0)]


## 2. Load a Protein Structure

We'll start by loading a protein structure. For this example, you can upload a protein structure to colab. You can use any protein structure file, trajectory and template file (with from_trajectory), pdb string (from pdb_string), or foldcomp accession (see prxteinmpnn.utils.foldcomp_utils).

In [1]:
from prxteinmpnn.io import from_structure_file

print("Loading protein structure...")
protein_structure = from_structure_file(file_path="/content/PrxteinMPNN/examples/AF-P10599-F1-model_v4.pdb") # chain and model can also be specified

print("✓ Structure loaded successfully")
print("Note: In practice, use from_structure_file() with your PDB file")

Loading protein structure...
✓ Structure loaded successfully
Note: In practice, use from_structure_file() with your PDB file


## 3. Load the ProteinMPNN Model

Now let's load the ProteinMPNN model with the desired version and weights.

Enums specify the model version and model weights (original or soluble).

In [2]:
from prxteinmpnn.mpnn import ProteinMPNNModelVersion, ModelWeights, get_mpnn_model

print("Loading ProteinMPNN model...")

model_version = ProteinMPNNModelVersion.V_48_020
model_weights = ModelWeights.DEFAULT

model = get_mpnn_model(
    model_version=model_version,
    model_weights=model_weights,
)

print(f"✓ Loaded ProteinMPNN model version: {model_version.value}")
print(f"✓ Using model weights: {model_weights.value}")

Loading ProteinMPNN model...
✓ Loaded ProteinMPNN model version: v_48_020.pkl
✓ Using model weights: original


## 4. Prepare Model Inputs

Convert the protein structure to model inputs that can be used with ProteinMPNN.

In [3]:
# Convert structure to model inputs
print("Preparing model inputs...")
from prxteinmpnn.io import protein_structure_to_model_inputs

# In practice, you would use:
model_inputs = protein_structure_to_model_inputs(protein_structure)

# For this example, we'll create mock model inputs
print("✓ Model inputs prepared")
print("Note: In practice, use protein_structure_to_model_inputs()")

# Mock some basic parameters for demonstration

print(f"Sequence length: {len(model_inputs.sequence)}")
print(f"Sequence: {model_inputs.sequence[:20]}...")

Preparing model inputs...
✓ Model inputs prepared
Note: In practice, use protein_structure_to_model_inputs()
Sequence length: 105
Sequence: [10 17  8 13  7  3 15  8 16  0  4 13  3  0  9  2  0  0  5  2]...


## 5. Sequence Scoring

Let's score sequences using the ProteinMPNN model. This demonstrates how to evaluate the likelihood of specific amino acid sequences for a given structure.

Random-number generator keys for stochastic aspects are explicitly defined.

JAX enables

In [6]:
from prxteinmpnn.scoring.score import make_score_sequence
from prxteinmpnn.utils.decoding_order import random_decoding_order
from prxteinmpnn.io import string_to_protein_sequence


# Initialize random key for JAX operations
key = jax.random.PRNGKey(42)
print("Setting up sequence scoring...")

# Create scoring function
score_sequence = make_score_sequence(
    model,
    random_decoding_order,
    model_inputs=model_inputs,
)

print("✓ Scoring function created")

# Example of scoring multiple sequences
sequences_to_score = [
    "MKFLVNVALVFMVVYISYIYAAIYIQASLLVASVGGTLIPALYQFAIWIIKMKFLVNVALVFMVVYISYIYAAIYIQASLLVASVGGTLIPALYQFAIWIIKIJK",
    "VNVALVFMVVYISYIYAAIYIQASLLVASVGGTLIPALYQFAIWIIKIJKMKFLVNVALVFMVVYISYIYAAIYIQASLLVASVGGTLIPALYQFAIWIIKMKFL",
    "VNVALVFMVVYISYIYAAIYIQASLLVASVGGTLIPALYQFAIWIIKMKFLVNVALVFMVVYISYIYAAIYIQASLLVASVGGTLIPALYQFAIWIIKIJKMKFL",
]

# Prepare sequences for scoring
sequences_to_score = jnp.array([string_to_protein_sequence(seq) for seq in sequences_to_score])

print(f"Prepared {len(sequences_to_score)} sequences for scoring")

print("✓ Prepared sequences successfully")
print("Scoring sequences with same key for decoding order...")
# Score the prepared sequences
scores, logits, decoding_orders = jax.vmap(score_sequence, in_axes=(None, 0))(key, sequences_to_score) # type: ignore[arg-type]
# using None for in_axes to keep the key fixed for all sequences
# you could also use jax.random.split to create a new key for each sequence if needed
print("✓ Scored sequences successfully")


Setting up sequence scoring...
✓ Scoring function created
Prepared 3 sequences for scoring
✓ Prepared sequences successfully
Scoring sequences with same key for decoding order...
✓ Scored sequences successfully


In [7]:
print(f"Scores: {scores}")


Scores: [3.3719664 3.4745483 3.4547827]


### Massively parallel scoring with memory efficiency using jax.lax.map and batch size

Often, you may have memory limitations and jax will throw out of memory errors.

jax.lax.map offers a convenient way to automatically vmap batches of inputs, effectively balancing efficiency and memory limitations.

In [8]:

# Example of scoring with batching
num_sequences = 128#_000_000
batch_size = 64

from jaxtyping import PRNGKeyArray
from prxteinmpnn.utils.types import ProteinSequence
from functools import partial

@partial(jax.jit, static_argnames=("mutation_rate",))
def mutate(key: PRNGKeyArray, input_sequence: ProteinSequence, mutation_rate: float=0.25) -> ProteinSequence:
    """Mutate a protein sequence by randomly changing some amino acids.

    See github.com/maraxen/proteinsmc for more advanced mutation strategies and sampling methods.

    """
    # Randomly mutate the input sequence
    mutation_indices = jax.random.choice(key, jnp.arange(len(input_sequence)), shape=(int(len(input_sequence) * mutation_rate),), replace=False)
    return input_sequence.at[mutation_indices].set(jax.random.choice(key, jnp.arange(0, 20),shape=mutation_indices.shape))

sequences_to_score = jax.lax.map(
    lambda key: mutate(key, model_inputs.sequence, 0.25),
    jax.random.split(key, num_sequences),
    batch_size=batch_size,
)

scores, logits, scoring_decoding_orders = jax.lax.map(
    lambda seq: score_sequence(key, seq),  # type: ignore[arg-type]
    sequences_to_score,
    batch_size=batch_size,
)




In [9]:
scores[:5]

Array([2.2299085, 2.2776206, 2.1614697, 2.2135031, 2.2152264], dtype=float32)

## 6. Sequence Sampling

Now let's sample new sequences using different sampling strategies. This is the core functionality for protein design.

In [None]:
from prxteinmpnn.sampling import make_sample_sequences, SamplingEnum


# Demonstrate different sampling strategies
sampling_strategies = [
    (SamplingEnum.TEMPERATURE, "Temperature sampling"),
    (SamplingEnum.STRAIGHT_THROUGH, "Straight-through estimator sampling"),
]

print("Available sampling strategies:")
for strategy, description in sampling_strategies:
    print(f"  - {strategy}: {description}")


print("Setting up sequence sampling...")

sample_sequence = make_sample_sequences(
    model,
    random_decoding_order,
    sampling_strategy=SamplingEnum.STRAIGHT_THROUGH,
    model_inputs=model_inputs,
)

print("✓ Sampling function created")

print("Sampling sequences with straight-through estimator...")

# Example of sampling sequences
prng_sampling_key, key = jax.random.split(key, 2)
hyperparameters = (0.01,) # Example hyperparameters, in this case the learning rate for the straight-through estimator
iterations = 100  # Number of sequences to sample
sampled_sequence, logits, decoding_order = sample_sequence(prng_sampling_key, hyperparameters, iterations)  # type: ignore[arg-type]

from prxteinmpnn.utils.residue_constants import order_aa
# Display sampled sequences
print(f"Sampled sequence {''.join([order_aa[aa] for aa in sampled_sequence])} with hyperparameters {hyperparameters} after {iterations} iterations")


## 7. Summary and Next Steps

Congratulations! You've successfully explored the core functionality of PrxteinMPNN. Here's what we covered:

### What we learned:
- ✅ Loading protein structures and preparing model inputs
- ✅ Scoring sequences with ProteinMPNN  
- ✅ Sampling new sequences with different strategies
- ✅ Leveraging JAX transformations for performance
- ✅ Advanced sampling configurations and iterative refinement

### Next steps:
1. **Try with real data**: Replace the mock examples with actual protein structures
2. **Explore the API**: Check out the [full documentation](http://maraxen.github.io/PrxteinMPNN)
3. **Performance optimization**: Experiment with different JAX transformations
4. **Custom workflows**: Build your own protein design pipelines

### Key advantages of PrxteinMPNN:
- **🔍 Transparency**: Clear, functional interface to understand model operations
- **⚡ Performance**: JAX-powered acceleration with JIT, vmap, and scan
- **🧩 Modularity**: Easy to extend and customize for specific use cases
- **🔄 Compatibility**: Seamless integration with the JAX ecosystem

## Resources

- **Documentation**: [http://maraxen.github.io/PrxteinMPNN](http://maraxen.github.io/PrxteinMPNN)
- **GitHub Repository**: [https://github.com/maraxen/PrxteinMPNN](https://github.com/maraxen/PrxteinMPNN)
- **Issues & Support**: [GitHub Issues](https://github.com/maraxen/PrxteinMPNN/issues)

Happy protein designing! 🧬✨

# Addenda: Comparison to colabdesign (WIP)

Stay tuned for the comparison to colabdesign showing parity.