# SE3 Diffusion

In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import py3Dmol

import numpy as np
import torch

from proteome import protein
from proteome.constants import residue_constants
from proteome.models.design.se3_diffusion import config
from proteome.models.design.se3_diffusion.modeling import SE3DiffusionForStructureDesign
from proteome.models.design.proteinmpnn.modeling import ProteinMPNNForSequenceDesign
from proteome.models.folding.omegafold.modeling import OmegaFoldForFolding

In [9]:
import string

In [11]:
structure_designer = SE3DiffusionForStructureDesign()

In [12]:
designed_structure = structure_designer.design_structure(
    config.InferenceConfig(length=80)
)

  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


In [13]:
designed_pdb_str = protein.to_pdb(designed_structure)

In [14]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(designed_pdb_str)

color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7f004374a350>

In [15]:
sequence_designer = ProteinMPNNForSequenceDesign("ca_only_model-2", random_seed=37)
folder = OmegaFoldForFolding()

In [16]:
sequence, score = sequence_designer.design_sequence(
    protein.to_ca_only_protein(designed_structure)
)
print(f"Sequence: {sequence} with global_score {score}")

Sequence: KHKIYIITKSKKDAKKLCKELKKFIEKTCKVEGVTFKFFGNKNKKIKVLIKLKNITKECVKKLIKFIKKKKKYKVKVTIE with global_score 1.5385401248931885


In [17]:
predicted_protein, confidence = folder.fold(sequence)
result_pdb = protein.to_pdb(predicted_protein)

In [18]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(result_pdb)

color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7f0053b08580>