In [26]:
%load_ext autoreload
%autoreload 2

PyRosetta-4 2023 [Rosetta PyRosetta4.conda.linux.cxx11thread.serialization.CentOS.python310.Release 2023.27+release.e3ce6ea9faf661ae8fa769511e2a9b8596417e58 2023-07-07T12:00:46] retrieved from: http://www.pyrosetta.org
(C) Copyright Rosetta Commons Member Institutions. Created in JHU by Sergey Lyskov and PyRosetta Team.
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
import glob
import os
import pickle
import sys

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

import py3Dmol
from proteome.models.design.protein_seq_des import atoms
import proteome.models.design.protein_seq_des.models as models
import proteome.models.design.protein_seq_des.sampler as sampler

In [28]:
from pyrosetta.rosetta.protocols.denovo_design.filters import ExposedHydrophobicsFilterCreator
from pyrosetta.rosetta.protocols.simple_filters import (
    BuriedUnsatHbondFilterCreator, 
    PackStatFilterCreator,
)

In [52]:
from proteome.models.design.protein_seq_des import config
from dataclasses import asdict

In [30]:
import Bio.PDB

In [31]:
from proteome import protein

In [34]:
with open("5L33.pdb") as f:
    pdb_str = f.read()

In [40]:
structure = protein.from_pdb_string(pdb_str)

In [53]:
structure = config.DesignableProtein(**asdict(structure), is_tim=False)

In [54]:
nic = len(atoms.atoms) + 1 + 21

In [55]:
def load_model(model, use_cuda=True, nic=len(atoms.atoms)):
    classifier = models.seqPred(nic=nic)
    if use_cuda:
        classifier.cuda()
    if use_cuda:
        state = torch.load(model)
    else:
        state = torch.load(model, map_location="cpu")
    for k in state.keys():
        if "module" in k:
            print("MODULE")
            classifier = nn.DataParallel(classifier)
        break
    if use_cuda:
        classifier.load_state_dict(torch.load(model))
    else:
        classifier.load_state_dict(torch.load(model, map_location="cpu"))
    return classifier

def load_models(model_list, use_cuda=True, nic=len(atoms.atoms)):
    classifiers = []
    for model in model_list:
        classifier = load_model(model, use_cuda=use_cuda, nic=nic)
        classifiers.append(classifier)
    return classifiers

In [56]:
from proteome.utils.hub_utils import load_state_dict_from_gdrive_zip

In [44]:
state = load_state_dict_from_gdrive_zip(
    "https://drive.google.com/u/0/uc?id=1X66RLbaA2-qTlJLlG9TI53cao8gaKnEt",
    extract_member="models/conditional_model_0.pt"
)

In [47]:
classifier = models.SeqPred(nic=nic)
msg = classifier.load_state_dict(state)
classifier = classifier.cuda()
classifier = classifier.eval()

In [48]:
state = load_state_dict_from_gdrive_zip(
    "https://drive.google.com/u/0/uc?id=1X66RLbaA2-qTlJLlG9TI53cao8gaKnEt",
    extract_member="models/baseline_model.pt"
)

In [49]:
init_model = models.SeqPred(nic=6)
msg = init_model.load_state_dict(state)
init_model = init_model.cuda()
init_model = init_model.eval()

In [50]:
cfg = config.SamplerConfig()

In [91]:
design_sampler = sampler.Sampler(
    cfg, structure, [classifier], init_model=init_model
)

In [125]:
pose = protein.to_rosetta_pose(structure)

In [126]:
struct_back = protein.from_rosetta_pose(pose)

In [129]:
# initialize sampler
design_sampler.init()

1       9       17      25      33      41      49      57      65      73      
HMPEEEKAARLFIEALEKGDPELMRKVISPDTRMEDNGREFTGDEVVEYVKEIQKRGEQWHLRRYTKEGNSWRFEVQVDN
LLLHHHHHHHHHHHHHHLLLHHHHHHHLLLLLEEEELLEEEEHHHHHHHHHHHHHHLLEEEEEEEEEELLEEEEEEEEEE


81      89      97      105     
NGQTEQWEVQIEVRNGRIKRVTITHV
LLEEEEEEEEEEEELLEEEEEEEELL




In [130]:
best_rosetta_energy = np.inf
best_energy = np.inf

In [131]:
design_sampler.init_seq()

In [132]:
# n_iters in config
logmeans = np.zeros(int(2500))
rosettas = np.zeros(int(2500))

In [133]:
import datetime

In [134]:
# run design
with torch.no_grad():
    for i in tqdm(range(1, int(2500)), desc="running design"):
        # save log_p_means and rosettas
        logmeans[i] = design_sampler.log_p_mean
        rosettas[i] = design_sampler.rosetta_energy

        if design_sampler.log_p_mean < best_energy:
            now = datetime.datetime.now()
            ts = now.strftime("%Y-%m-%d-%H-%M-%S")
            design_sampler.pose.dump_pdb(
                #log.log_path
                "./"
                + "/"
                + "results"
                + "/"
                + "curr_best_log_p_%s.pdb" % ts
            )
            best_energy = design_sampler.log_p_mean

        if design_sampler.rosetta_energy < best_rosetta_energy:
            now = datetime.datetime.now()
            ts = now.strftime("%Y-%m-%d-%H-%M-%S")
            design_sampler.pose.dump_pdb(
                #log.log_path
                "./"
                + "/"
                + "results"
                + "/"
                + "curr_best_rosetta_energy_%s.pdb" % ts
            )
            best_rosetta_energy = design_sampler.rosetta_energy

        # save intermediate models -- comment out if desired
        if (i == 1) or (i % 10 == 0) or (i == 2500 - 1):
            now = datetime.datetime.now()
            ts = now.strftime("%Y-%m-%d-%H-%M-%S")
            design_sampler.pose.dump_pdb(
                #log.log_path
                "./"
                + "/"
                + "results"
                + "/"
                + "curr_%s_%s.pdb" % (i, ts)
            )

        #log.advance_iteration()

running design: 100%|████████████████████████████████████| 2499/2499 [00:00<00:00, 27548.92it/s]


In [135]:
design_sampler.seq

'HMPEEEKAARLFIEALEKGDPELMRKVISPDTRMEDNGREFTGDEVVEYVKEIQKRGEQWHLRRYTKEGNSWRFEVQVDNNGQTEQWEVQIEVRNGRIKRVTITHV'

In [136]:
design_sampler.pose.sequence()

'SELTRLAHVQEKLEAMALGDPDAIQTVLTQNTEVSANGEQYECNEVALFVNRYKASGIQFHVKQFAFVGDKVRITVMVTLEGKEYQLTAVFSVQDDNVVAIQVIDY'

In [137]:
from proteome.models.folding.omegafold.modeling import OmegaFoldForFolding
from proteome import protein

In [138]:
folder = OmegaFoldForFolding()

In [139]:
predicted_protein, confidence = folder.fold(design_sampler.pose.sequence())
folded_pdb = protein.to_pdb(predicted_protein)

In [140]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(folded_pdb)

color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7fc0c96a2ef0>