In [1]:
#!wget https://github.com/jasonkyuyim/se3_diffusion/raw/master/weights/paper_weights.pth
#!wget https://github.com/jasonkyuyim/se3_diffusion/raw/master/weights/best_weights.pth

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import py3Dmol

import numpy as np
import torch

from proteome import protein
from proteome.constants import residue_constants
from proteome.models.design.se3_diffusion import config
from proteome.models.design.se3_diffusion.sampler import Sampler
from proteome.models.design.se3_diffusion.score_network import ScoreNetwork
from proteome.models.design.se3_diffusion.se3_diffuser import SE3Diffuser
from proteome.models.design.proteinmpnn.modeling import ProteinMPNNForSequenceDesign
from proteome.models.folding.omegafold.modeling import OmegaFoldForFolding

In [4]:
ckpt = torch.load("best_weights.pth", map_location="cpu")

In [5]:
diffuser = SE3Diffuser(config.Diffuser())

In [6]:
model = ScoreNetwork(config.Model(), diffuser)

In [7]:
ckpt_model = {k.replace('module.', ''):v for k,v in ckpt["model"].items()}
model.load_state_dict(ckpt_model, strict=True)

<All keys matched successfully>

In [8]:
model = model.cuda()
model = model.eval()

In [14]:
sampler = Sampler(model, diffuser, config.SamplerConfig(length=40), config.Data())

In [15]:
res = sampler.sample()

In [16]:
atom_positions = res['prot_traj'][0][:, :4]
length, num_atoms, _ = atom_positions.shape

designed_structure = protein.Protein(
    atom_positions=atom_positions,
    aatype=np.array(length * [residue_constants.restype_order_with_x["G"]]),
    atom_mask=np.ones((length, num_atoms)),
    residue_index=np.arange(0, length),
    b_factors=np.ones((length, num_atoms)),
    chain_index=np.zeros((length,), dtype=np.int32),
)

In [17]:
designed_pdb_str = protein.to_pdb(designed_structure)

In [18]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(designed_pdb_str)

color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7fbb30686500>

In [19]:
sequence_designer = ProteinMPNNForSequenceDesign("vanilla_model-20", random_seed=37)
folder = OmegaFoldForFolding()

In [20]:
sequence, score = sequence_designer.design_sequence(designed_structure)
print(f"Sequence: {sequence} with global_score {score}")

Sequence: KLEELELELLGKKLKVKLLKGNVTKEELEKLIKELIEKLK with global_score 1.6189324855804443


In [21]:
predicted_protein, confidence = folder.fold(sequence)
result_pdb = protein.to_pdb(predicted_protein)

In [22]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(result_pdb)

color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7fbb29eff3d0>