In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import glob
import os
import pickle
import sys

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

import py3Dmol
from proteome.models.design.protein_seq_des import atoms
import proteome.models.design.protein_seq_des.models as models
import proteome.models.design.protein_seq_des.sampler as sampler

  from rosetta import *


PyRosetta-4 2023 [Rosetta PyRosetta4.conda.linux.cxx11thread.serialization.CentOS.python310.Release 2023.27+release.e3ce6ea9faf661ae8fa769511e2a9b8596417e58 2023-07-07T12:00:46] retrieved from: http://www.pyrosetta.org
(C) Copyright Rosetta Commons Member Institutions. Created in JHU by Sergey Lyskov and PyRosetta Team.


In [3]:
from pyrosetta.rosetta.protocols.denovo_design.filters import ExposedHydrophobicsFilterCreator
from pyrosetta.rosetta.protocols.simple_filters import (
    BuriedUnsatHbondFilterCreator, 
    PackStatFilterCreator,
)

In [4]:
from proteome.models.design.protein_seq_des import config

In [5]:
nic = len(atoms.atoms) + 1 + 21

In [6]:
def load_model(model, use_cuda=True, nic=len(atoms.atoms)):
    classifier = models.seqPred(nic=nic)
    if use_cuda:
        classifier.cuda()
    if use_cuda:
        state = torch.load(model)
    else:
        state = torch.load(model, map_location="cpu")
    for k in state.keys():
        if "module" in k:
            print("MODULE")
            classifier = nn.DataParallel(classifier)
        break
    if use_cuda:
        classifier.load_state_dict(torch.load(model))
    else:
        classifier.load_state_dict(torch.load(model, map_location="cpu"))
    return classifier

def load_models(model_list, use_cuda=True, nic=len(atoms.atoms)):
    classifiers = []
    for model in model_list:
        classifier = load_model(model, use_cuda=use_cuda, nic=nic)
        classifiers.append(classifier)
    return classifiers

In [7]:
from proteome.utils.hub_utils import load_state_dict_from_gdrive_zip

In [8]:
state = load_state_dict_from_gdrive_zip(
    "https://drive.google.com/u/0/uc?id=1X66RLbaA2-qTlJLlG9TI53cao8gaKnEt",
    extract_member="models/conditional_model_0.pt"
)

In [9]:
classifier = models.seqPred(nic=nic)
msg = classifier.load_state_dict(state)
classifier = classifier.cuda()
classifier = classifier.eval()

In [10]:
state = load_state_dict_from_gdrive_zip(
    "https://drive.google.com/u/0/uc?id=1X66RLbaA2-qTlJLlG9TI53cao8gaKnEt",
    extract_member="models/baseline_model.pt"
)

In [11]:
init_model = models.seqPred(nic=6)
msg = init_model.load_state_dict(state)
init_model = init_model.cuda()
init_model = init_model.eval()

In [12]:
cfg = config.SamplerConfig(pdb="5L33.pdb")

In [13]:
design_sampler = sampler.Sampler(
    cfg, [classifier], init_models=[init_model], use_cuda=True
)

In [14]:
# initialize sampler
design_sampler.init()

1       9       17      25      33      41      49      57      65      73      
HMPEEEKAARLFIEALEKGDPELMRKVISPDTRMEDNGREFTGDEVVEYVKEIQKRGEQWHLRRYTKEGNSWRFEVQVDN
LLLHHHHHHHHHHHHHHLLLHHHHHHHLLLLLEEEELLEEEEHHHHHHHHHHHHHHLLEEEEEEEEEELLEEEEEEEEEE


81      89      97      105     
NGQTEQWEVQIEVRNGRIKRVTITHV
LLEEEEEEEEEEEELLEEEEEEEELL




In [15]:
best_rosetta_energy = np.inf
best_energy = np.inf

In [16]:
design_sampler.init_seq()

In [17]:
# n_iters in config
logmeans = np.zeros(int(2500))
rosettas = np.zeros(int(2500))

In [18]:
import datetime

In [19]:
# run design
with torch.no_grad():
    for i in tqdm(range(1, int(2500)), desc="running design"):
        # save log_p_means and rosettas
        logmeans[i] = design_sampler.log_p_mean
        rosettas[i] = design_sampler.rosetta_energy

        if design_sampler.log_p_mean < best_energy:
            now = datetime.datetime.now()
            ts = now.strftime("%Y-%m-%d-%H-%M-%S")
            design_sampler.pose.dump_pdb(
                #log.log_path
                "./"
                + "/"
                + "results"
                + "/"
                + "curr_best_log_p_%s.pdb" % ts
            )
            best_energy = design_sampler.log_p_mean

        if design_sampler.rosetta_energy < best_rosetta_energy:
            now = datetime.datetime.now()
            ts = now.strftime("%Y-%m-%d-%H-%M-%S")
            design_sampler.pose.dump_pdb(
                #log.log_path
                "./"
                + "/"
                + "results"
                + "/"
                + "curr_best_rosetta_energy_%s.pdb" % ts
            )
            best_rosetta_energy = design_sampler.rosetta_energy

        # save intermediate models -- comment out if desired
        if (i == 1) or (i % 10 == 0) or (i == 2500 - 1):
            now = datetime.datetime.now()
            ts = now.strftime("%Y-%m-%d-%H-%M-%S")
            design_sampler.pose.dump_pdb(
                #log.log_path
                "./"
                + "/"
                + "results"
                + "/"
                + "curr_%s_%s.pdb" % (i, ts)
            )

        #log.advance_iteration()

running design: 100%|████████████████████████████████████| 2499/2499 [00:00<00:00, 27634.48it/s]


In [20]:
design_sampler.seq

'HMPEEEKAARLFIEALEKGDPELMRKVISPDTRMEDNGREFTGDEVVEYVKEIQKRGEQWHLRRYTKEGNSWRFEVQVDNNGQTEQWEVQIEVRNGRIKRVTITHV'

In [21]:
design_sampler.pose.sequence()

'PIDSEYAESLQILSALETTDPSEIHAKVKEKTKVKELGTEYQSDEVVEYITRFRAQGISYVLHHFIKRGDMIVIEIRISYTGDTLTIRLEIRVTQGAVQAINVMEL'

In [22]:
from proteome.models.folding.omegafold.modeling import OmegaFoldForFolding
from proteome import protein

In [23]:
folder = OmegaFoldForFolding()

In [24]:
predicted_protein, confidence = folder.fold(design_sampler.pose.sequence())
folded_pdb = protein.to_pdb(predicted_protein)

In [25]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(folded_pdb)

color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7fc228126020>