In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from typing import Any
from jax import config
print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")
print(f"Available devices: {jax.devices()}")

JAX version: 0.7.0
JAX backend: cpu
Available devices: [CpuDevice(id=0)]


In [2]:
from prxteinmpnn.io import from_structure_file

print("Loading protein structure...")
protein_structure = from_structure_file(file_path="AF-P10599-F1-model_v4.pdb") # chain and model can also be specified

print("✓ Structure loaded successfully")
print("Note: In practice, use from_structure_file() with your PDB file")

Loading protein structure...


  coords_37 = jnp.zeros((num_residues, 37, 3), dtype=jnp.float64)


✓ Structure loaded successfully
Note: In practice, use from_structure_file() with your PDB file


In [3]:
from prxteinmpnn.mpnn import ProteinMPNNModelVersion, ModelWeights, get_mpnn_model

print("Loading ProteinMPNN model...")

model_version = ProteinMPNNModelVersion.V_48_020
model_weights = ModelWeights.DEFAULT

model = get_mpnn_model(
    model_version=model_version,
    model_weights=model_weights,
)

print(f"✓ Loaded ProteinMPNN model version: {model_version.value}")
print(f"✓ Using model weights: {model_weights.value}")

Loading ProteinMPNN model...
✓ Loaded ProteinMPNN model version: v_48_020.pkl
✓ Using model weights: original


In [4]:
# Convert structure to model inputs
print("Preparing model inputs...")
from prxteinmpnn.io import protein_structure_to_model_inputs

# In practice, you would use:
model_inputs = protein_structure_to_model_inputs(protein_structure)

# For this example, we'll create mock model inputs
print("✓ Model inputs prepared")
print("Note: In practice, use protein_structure_to_model_inputs()")

# Mock some basic parameters for demonstration

print(f"Sequence length: {len(model_inputs.sequence)}")
print(f"Sequence: {model_inputs.sequence[:20]}...")

Preparing model inputs...
✓ Model inputs prepared
Note: In practice, use protein_structure_to_model_inputs()
Sequence length: 105
Sequence: [10 17  8 13  7  3 15  8 16  0  4 13  3  0  9  2  0  0  5  2]...


In [5]:
from prxteinmpnn.sampling import make_sample_sequences, SamplingConfig, SamplingEnum
from prxteinmpnn.utils.decoding_order import random_decoding_order
from prxteinmpnn.utils.residue_constants import order_aa
from prxteinmpnn.utils.aa_convert import mpnn_to_af
import optax
import numpy as np

prng_sampling_key, key = jax.random.split(jax.random.key(0), 2) 


# Demonstrate different sampling strategies
sampling_strategies = [
    (SamplingEnum.TEMPERATURE, "Temperature sampling"),
    (SamplingEnum.STRAIGHT_THROUGH, "Straight-through estimator sampling"),
]

print("Available sampling strategies:")
for strategy, description in sampling_strategies:
    print(f"  - {strategy}: {description}")


print("Setting up sequence sampling...")


print("Temperature based sampling")

temp_config = SamplingConfig(
    sampling_strategy=SamplingEnum.TEMPERATURE,
    temperature=0.1,
)

sample_seq_temp = make_sample_sequences(
    model,
    random_decoding_order,
    config=temp_config,
    model_inputs=model_inputs,
)


Available sampling strategies:
  - SamplingEnum.TEMPERATURE: Temperature sampling
  - SamplingEnum.STRAIGHT_THROUGH: Straight-through estimator sampling
Setting up sequence sampling...
Temperature based sampling


In [17]:
from prxteinmpnn.sampling import make_sample_sequences, SamplingConfig, SamplingEnum
from prxteinmpnn.utils.decoding_order import random_decoding_order
from prxteinmpnn.utils.residue_constants import order_aa
from prxteinmpnn.utils.aa_convert import mpnn_to_af
import optax
import numpy as np

prng_sampling_key, key = jax.random.split(jax.random.key(0), 2) 


# Demonstrate different sampling strategies
sampling_strategies = [
    (SamplingEnum.TEMPERATURE, "Temperature sampling"),
    (SamplingEnum.STRAIGHT_THROUGH, "Straight-through estimator sampling"),
]

print("Available sampling strategies:")
for strategy, description in sampling_strategies:
    print(f"  - {strategy}: {description}")


print("Setting up sequence sampling...")


print("STE based sampling")

ste_config = SamplingConfig(
    sampling_strategy=SamplingEnum.STRAIGHT_THROUGH,
    learning_rate=0.5,
    iterations=100,
)

sample_seq_ste = make_sample_sequences(
    model,
    random_decoding_order,
    config=ste_config,
    model_inputs=model_inputs,
)


Available sampling strategies:
  - SamplingEnum.TEMPERATURE: Temperature sampling
  - SamplingEnum.STRAIGHT_THROUGH: Straight-through estimator sampling
Setting up sequence sampling...
STE based sampling


In [18]:
sampled_ste_seqs, logits_ste, decoding_order_ste = jax.lax.map(
  lambda x: sample_seq_ste(prng_key=x),
  jax.random.split(prng_sampling_key, 10),
  batch_size=10,
  )
sampled_sequence_temp, logits_temp, decoding_order_temp = jax.lax.map(
  lambda x: sample_seq_temp(prng_key=x),
  jax.random.split(prng_sampling_key, 100),
  batch_size=100,
  )

In [19]:
def seq_id(seq, ref_seq):
    return sum(seq == ref_seq) / len(ref_seq)
from functools import partial
seq_id_to_original = partial(seq_id, ref_seq=model_inputs.sequence)

In [20]:
from prxteinmpnn.scoring.score import make_score_sequence
from prxteinmpnn.utils.decoding_order import random_decoding_order
from prxteinmpnn.io import string_to_protein_sequence


key = jax.random.PRNGKey(42)
print("Setting up sequence scoring...")

score_sequence = make_score_sequence(
    model,
    random_decoding_order,
    model_inputs=model_inputs,
)

print("✓ Scoring function created")

same_key_score = partial(score_sequence, key)

scores, _, _ = jax.lax.map(same_key_score, sampled_ste_seqs, batch_size=10)
print("✓ Scored sequences successfully")
temp_scores, _, _ = jax.lax.map(same_key_score, sampled_sequence_temp, batch_size=100)


Setting up sequence scoring...
✓ Scoring function created
✓ Scored sequences successfully


In [21]:
ste_seq_ids = jax.lax.map(seq_id_to_original, sampled_ste_seqs, batch_size=100)
temp_seq_ids = jax.lax.map(seq_id_to_original, sampled_sequence_temp, batch_size=100)
print("Sampled STE sequences score average:", scores.mean())
print("Sampled STE sequences score standard deviaion:", scores.std())
print("Sampled STE sequences score maximum:", scores.max())
print("Sampled STE sequences score minimum:", scores.min())
print("Sampled STE sequences average sequence identity to original:", ste_seq_ids.mean())
print("Sampled Temperature sequences score average:", temp_scores.mean())
print("Sampled Temperature sequences score standard deviation:", temp_scores.std())
print("Sampled Temperature sequences score maximum:", temp_scores.max())
print("Sampled Temperature sequences score minimum:", temp_scores.min())
print("Sampled Temperature sequences average sequence identity to original:", temp_seq_ids.mean())
print("Original sequence score: ", same_key_score(model_inputs.sequence)[0])


Sampled STE sequences score average: 1.0886233
Sampled STE sequences score standard deviaion: 0.025988365
Sampled STE sequences score maximum: 1.1270864
Sampled STE sequences score minimum: 1.0274435
Sampled STE sequences average sequence identity to original: 0.6
Sampled Temperature sequences score average: 1.150066
Sampled Temperature sequences score standard deviation: 0.028019372
Sampled Temperature sequences score maximum: 1.2120124
Sampled Temperature sequences score minimum: 1.0904613
Sampled Temperature sequences average sequence identity to original: 0.6038096
Original sequence score:  1.5782065


In [9]:
from prxteinmpnn.utils.aa_convert import mpnn_to_af
ste_sampled_sequences = [mpnn_to_af(seq) for seq in sampled_ste_seqs]
ste_sampled_sequences_str = [''.join([order_aa.get(aa.item(), 'UNK') for aa in seq]) for seq in ste_sampled_sequences]
temp_sampled_sequences = [mpnn_to_af(seq) for seq in sampled_sequence_temp]
temp_sampled_sequences_str = [''.join([order_aa.get(aa.item(), 'UNK') for aa in seq]) for seq in temp_sampled_sequences]

In [11]:
cbdesign_samples = ['MVIEVTSLEEYEELLKNAGDKLVVVDFYAPWCGPCKKIKPHFEKLSEKYKDVVFLKVDVNKCPEIAKKEGVTSTPTFVFYKNGKKVDSFSGADKEKLEKKIEELK',
        'MVIEINSLEEFEKALKDAGDKLVVVDFYAPWCGPCKKIKPFFEKLSEKYKDVVFLKVDVNKCPEIAKKMGVKATPTFKFFKNGKLVDSFVGANEKKLEEKIKKLS',
        'MVIEVNSKEEYEELLKNAGDKLVVVDFYAPWCGPCKKIKPYFEKLSEKYKDVIFLKVDVKKCPEIAKEEGVTSTPTFLFFKNGKKVASFSGADKEKLEATIEKLK',
        'MVTEINSLEEFEEALKNAGDKLVVIDFYAKWCGPCKKIKPFFEKLSEEYKDVVFLKVDVEKCPEVAKKLGVKSTPTFVFFKNGKKVDSFSGADEEKLKKKIEELS',
        'MVTKITSLEEFEEALKNAGDKLVVVDFYAKWCGPCKKIKPYFEKLSEEYKDVVFLEVDVDECPEIAKKEGVTATPTFKFFKNGKLVDSFSGADKEKLEKKIEELK',
        'MVKEINSLEEFKAALKAAGDKLVVVDFYAPWCGPCKKIEPYFEELSEKYPDVVFLKVDVNKCPEVAKELGVKSTPTFVFFKNGEKVGSFSGADKEKLEKTIKELS',
        'MVKEINSLEEFEAALKAAGDKLVVVDFYAKWCGPCKKIKPYFEELSEKYKDVVFLKVDVDKCPEIAKKEGVTSTPTFVFYKNGKKVDSFSGADKEKLEKKIEELK',
        'MVTEITSRAEFEAALAAAGDRLVVVDFYAAWCGPCKEIEPHFEALSERYPDVVFLKVDVDACPEVAAACGVTSTPTFLFFRRGELVDRFSGADKEKLEATIERLR',
        'MVTEVNSLEEFEALLAAAGDKLVVVDFYAPWCGPCKKIAPHFEALSERYPDVVFLKVDVDKCPEIAKRCGVKSTPTFLFFKNGELVDRFSGADKAKLEARIEELR',
        'MVTEITSKEEFKKALEDAGDKLVVVDFYAPWCGPCKKIKPFFEELSEKYPDVVFLKVDVNKCPEVAKEMGVTATPTFKFFKNGKLVDSFSGADKKKLEERIKKLS']

cbdesign_sequences = jnp.array([string_to_protein_sequence(seq) for seq in cbdesign_samples])
scores_cbdesign, _, _ = jax.lax.map(same_key_score, cbdesign_sequences, batch_size=10)
print("CBDesign sequences score average:", scores_cbdesign.mean())

CBDesign sequences score average: 1.1324364


In [23]:
print(bin(0x00))
print(bin(0x80))

0b0
0b10000000


In [None]:
print(bin(0x0DC6 ^ 0x0E46))

0b1


In [34]:
0x0E0F

3599

In [37]:
bin(0x8C)

'0b10001100'

In [45]:
0xAC

172

In [46]:
0xD4

212