## Imports and jax caching

In [1]:
%reload_ext autoreload
%autoreload 2

from mber_protocols.stable.VHH_binder_design.config import ModelConfig, LossConfig, TrajectoryConfig, EnvironmentConfig, TemplateConfig, EvaluationConfig
from mber_protocols.stable.VHH_binder_design.template import TemplateModule
from mber_protocols.stable.VHH_binder_design.trajectory import TrajectoryModule
from mber_protocols.stable.VHH_binder_design.evaluation import EvaluationModule
from mber_protocols.stable.VHH_binder_design.state import DesignState, TemplateData

import jax
import os
jax.config.update("jax_compilation_cache_dir", os.path.expanduser("~/.jax/jax_cache")) # model loading is sped up a lot by caching!

  from pkg_resources import resource_filename


## Set up design state and configs

In [2]:
# Create configuration objects
template_config = TemplateConfig()
model_config = ModelConfig()
loss_config = LossConfig()
trajectory_config = TrajectoryConfig()
evaluation_config = EvaluationConfig()
environment_config = EnvironmentConfig(
    af_params_dir='~/.mber/af_params',
    device='cuda:0',
)

design_state = DesignState(
        template_data=TemplateData(
            target_id="Q9NZQ7",  # Uniprot ID for PDL1
            target_name="PDL1",
            region="A:18-132",  # Specific region to consider
            target_hotspot_residues="A54,A56,A66,A115",  # Specify hotspot explicitly
            masked_binder_seq="EVQLVESGGGLVQPGGSLRLSCAASG*********WFRQAPGKEREF***********NADSVKGRFTISRDNAKNTLYLQMNSLRAEDTAVYYC************WGQGTLVTVSS" # Example VHH framework with masked CDRs
        )
    )

## Set up and run template module

In [3]:
template_module = TemplateModule(
        template_config=template_config,
        environment_config=environment_config,
        verbose=True,
    )

In [4]:
template_module.setup(design_state)
design_state = template_module.run(design_state)
template_module.teardown(design_state)

[TemplateModule] TemplateModule initialized with logging
[TemplateModule] Initializing ESM2 model on cuda:0
[TemplateModule] Initialize ESM2 model took 1.91 seconds
[TemplateModule] Initializing template folding model: nbb2
[TemplateModule] Initialize folding model took 0.97 seconds
[TemplateModule] TemplateModule._setup_models completed in 2.89 seconds
[TemplateModule] TemplateModule._ensure_template_data completed in 0.00 seconds
[TemplateModule] TemplateModule.setup completed in 2.89 seconds
[TemplateModule] Starting template preparation
[TemplateModule] TemplateModule._process_target completed in 0.48 seconds
[TemplateModule] Using provided hotspots: A54,A56,A66,A115
[TemplateModule] TemplateModule._process_hotspots completed in 0.00 seconds
[TemplateModule] Creating truncation of target structure
[TemplateModule] TemplateModule._create_truncation completed in 0.05 seconds
[TemplateModule] Processing masked binder sequence
[TemplateModule] Generating position-specific bias
[Templat

In [5]:
display(design_state)

DesignState:
  template_data: TemplateData:
      target_id: Q9NZQ7
      target_name: PDL1
      region: A:18-132
      target_hotspot_residues: A54,A56,A66,A115
      masked_binder_seq: EVQLVESGGGLVQPG...****WGQGTLVTVSS (118 chars)
      include_surrounding_context: False
      target_source: None
      template_pdb: [PDB data: 2702 lines]
      full_target_pdb: [PDB data: 922 lines]
      target_chain: A
      binder_chain: H
      binder_len: 118
      binder_seq: EVQLVESGGGLVQPG...YYDYWGQGTLVTVSS (118 chars)
      binder_bias: array(shape=(118, 20), dtype=float64)
      template_preparation_complete: True
      logs: [44 items]
      timings: {'timings': {'Initialize folding model': 0.974374532699585}, 'setup': 2.889068841934204, 'run': 4.623076438903809, 'teardown': 0.2080519199371338}
  trajectory_data: TrajectoryData:
      seed: None
      trajectory_name: None
      metrics: None
      best_pdb: None
      final_seqs: None
      updated_bias: None
      pssm_logits: None
    

## Set up and run trajectory module

In [6]:
trajectory_module = TrajectoryModule(
    model_config=model_config,
    loss_config=loss_config,
    trajectory_config=trajectory_config,
    environment_config=environment_config,
)

In [7]:
trajectory_module.setup(design_state=design_state)
design_state = trajectory_module.run(design_state=design_state)
trajectory_module.teardown(design_state=design_state)

[TrajectoryModule] TrajectoryModule initialized with logging
[TrajectoryModule] Initializing AlphaFold model for design
[TrajectoryModule] Initialize AF model took 4.06 seconds
[TrajectoryModule] TrajectoryModule._setup_models completed in 4.06 seconds
[TrajectoryModule] Setup models took 4.06 seconds
[TrajectoryModule] Generated random seed for trajectory: 8341692
[TrajectoryModule] Set trajectory name: PDL1_8341692
[TrajectoryModule] TrajectoryModule._setup_trajectory completed in 0.12 seconds
[TrajectoryModule] Setup trajectory took 0.12 seconds
[TrajectoryModule] Configuring optimizer: schedule_free_sgd with lr=0.4
[TrajectoryModule] Configured optimizer: schedule_free_sgd
[TrajectoryModule] TrajectoryModule._setup_optimizer completed in 0.00 seconds
[TrajectoryModule] Setup optimizer took 0.00 seconds
[TrajectoryModule] Setup loss took 0.00 seconds
[TrajectoryModule] Precompiling model for faster execution...
1 models [3] recycles 3 hard 0 soft 0 temp 1 seqid 0.97 loss 3.83 seq_en

In [8]:
display(design_state)

DesignState:
  template_data: TemplateData:
      target_id: Q9NZQ7
      target_name: PDL1
      region: A:18-132
      target_hotspot_residues: A54,A56,A66,A115
      masked_binder_seq: EVQLVESGGGLVQPG...****WGQGTLVTVSS (118 chars)
      include_surrounding_context: False
      target_source: None
      template_pdb: [PDB data: 2702 lines]
      full_target_pdb: [PDB data: 922 lines]
      target_chain: A
      binder_chain: H
      binder_len: 118
      binder_seq: EVQLVESGGGLVQPG...YYDYWGQGTLVTVSS (118 chars)
      binder_bias: array(shape=(118, 20), dtype=float64)
      template_preparation_complete: True
      logs: [44 items]
      timings: {'timings': {'Initialize folding model': 0.974374532699585}, 'setup': 28.857595682144165, 'run': 4.623076438903809, 'teardown': 0.2080519199371338}
  trajectory_data: TrajectoryData:
      seed: 8341692
      trajectory_name: PDL1_8341692
      metrics: [101 steps, last: con=1.352, dgram_cce=558.679, exp_res=0.011...]
      best_pdb: [PDB dat

In [9]:
# Examine the trajectory animation

from IPython.display import HTML
HTML(design_state.trajectory_data.animated_trajectory)

## Set up and run evaluation module

In [10]:
evaluation_module = EvaluationModule(
    model_config=model_config,
    loss_config=loss_config,
    evaluation_config=evaluation_config,
    environment_config=environment_config,
)

In [11]:
evaluation_module.setup(design_state=design_state)
design_state = evaluation_module.run(design_state=design_state)
evaluation_module.teardown(design_state=design_state)

[EvaluationModule] EvaluationModule initialized with logging
[EvaluationModule] Initializing AlphaFold model for complex evaluation


[EvaluationModule] Initialize AF complex model took 2.62 seconds
[EvaluationModule] Initializing ESM model for sequence evaluation
[EvaluationModule] Initialize ESM model took 0.52 seconds
[EvaluationModule] Initializing monomer folding model: nbb2
[EvaluationModule] Initialize monomer folding model took 0.76 seconds
[EvaluationModule] EvaluationModule._setup_models completed in 3.89 seconds
[EvaluationModule] Setup models took 3.89 seconds
[EvaluationModule] EvaluationModule._setup_evaluation completed in 0.04 seconds
[EvaluationModule] Setup evaluation parameters took 0.04 seconds
[EvaluationModule] Setup loss functions took 0.00 seconds
[EvaluationModule] EvaluationModule.setup completed in 3.93 seconds
[EvaluationModule] Starting evaluation of 10 binder sequences
[EvaluationModule] Evaluating binder sequence 1/10
predict models [0] recycles 3 hard 1 soft 0 temp 1 seqid 0.86 loss 0.49 seq_ent 2.71 pae 0.17 i_pae 0.18 con 1.43 i_con 2.21 plddt 0.90 ptm 0.81 i_ptm 0.69 rmsd 47.24 hbon

In [12]:
display(design_state)

DesignState:
  template_data: TemplateData:
      target_id: Q9NZQ7
      target_name: PDL1
      region: A:18-132
      target_hotspot_residues: A54,A56,A66,A115
      masked_binder_seq: EVQLVESGGGLVQPG...****WGQGTLVTVSS (118 chars)
      include_surrounding_context: False
      target_source: None
      template_pdb: [PDB data: 2702 lines]
      full_target_pdb: [PDB data: 922 lines]
      target_chain: A
      binder_chain: H
      binder_len: 118
      binder_seq: EVQLVESGGGLVQPG...YYDYWGQGTLVTVSS (118 chars)
      binder_bias: array(shape=(118, 20), dtype=float64)
      template_preparation_complete: True
      logs: [44 items]
      timings: {'timings': {'Initialize folding model': 0.974374532699585}, 'setup': 3.9297103881835938, 'run': 4.623076438903809, 'teardown': 0.2080519199371338}
  trajectory_data: TrajectoryData:
      seed: 8341692
      trajectory_name: PDL1_8341692
      metrics: [101 steps, last: con=1.352, dgram_cce=558.679, exp_res=0.011...]
      best_pdb: [PDB dat

## Save the design state to a directory

In [13]:
design_state.to_dir('./example_outputs/PDL1_example_design_state')

<Figure size 1475x400 with 0 Axes>