In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import asdict

import py3Dmol
import numpy as np

from proteome import protein
from proteome.constants import residue_constants
from proteome.models.design.proteinmpnn import config
from proteome.models.design.proteinmpnn.modeling import ProteinMPNNForSequenceDesign
from proteome.models.folding.omegafold.modeling import OmegaFoldForFolding

In [3]:
designer = ProteinMPNNForSequenceDesign("ca_only_model-20", random_seed=37)
folder = OmegaFoldForFolding()

Downloading: "https://github.com/dauparas/ProteinMPNN/raw/main/ca_model_weights/v_48_020.pt" to /home/conradry71/.cache/torch/hub/checkpoints/ca_only_model-20.pt
100%|██████████████████████████████████████████████████| 6.32M/6.32M [00:00<00:00, 78.5MB/s]


In [4]:
with open("5L33.pdb", mode="r") as f:
    gt_pdb = f.read()

In [5]:
ca_only = True
target_protein = protein.from_pdb_string(gt_pdb, ca_only=ca_only, backbone_only=(not ca_only))
chain_length = len(target_protein.aatype)
num_aa = residue_constants.restype_num + 1  # add 1 for X

target_protein_dict = asdict(target_protein)
target_protein_dict["aatype"] = np.zeros_like(target_protein_dict["aatype"])

target_structure = protein.DesignableProtein(
    design_mask=np.ones(chain_length),
    design_aatype_mask=np.zeros([chain_length, num_aa], np.int32),
    pssm_coef=np.zeros(chain_length),
    pssm_bias=np.zeros([chain_length, num_aa]),
    pssm_log_odds=10000.0 * np.ones([chain_length, num_aa]),
    bias_per_residue=np.zeros([chain_length, num_aa]),
    **target_protein_dict,
)

In [17]:
sequence, score = designer.design_sequence(target_structure, config.InferenceConfig())
print(f"Sequence: {sequence} with global_score {score}")

Sequence: MLSPEEAIALDFIKALEKRDPELMEKVVGPDTELEVNGKKFKGDEIVEFVKKLKEKGVKVKLESYEWVGDKYVYKLKVEKNGKEKEVKVTIEVEDGKIKKVKIEIE with global_score 1.0129684209823608


In [18]:
predicted_protein, confidence = folder.fold(sequence)
result_pdb = protein.to_pdb(predicted_protein)

In [15]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(result_pdb)

color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7f57606a1ae0>

In [16]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(gt_pdb)

color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7f58753b6e30>