In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy
import glob
import json
import os
import os.path
import random
import shutil
import subprocess
import sys
import time
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from proteome.models.design.proteinmpnn.featurizer import decode_sequence, get_sequence_scores, tied_featurize
from proteome.models.design.proteinmpnn.model import ProteinMPNN
from proteome.models.design.proteinmpnn import config

from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset, random_split

import py3Dmol
from dataclasses import asdict
from proteome import protein
from proteome.constants import residue_constants
from proteome.models.folding.omegafold.modeling import OmegaFoldForFolding

In [4]:
folder = OmegaFoldForFolding()

In [5]:
with open("example.pdb", mode="r") as f:
    gt_pdb = f.read()

In [6]:
ca_only = True

In [7]:
target_protein = protein.from_pdb_string(gt_pdb, ca_only=ca_only, backbone_only=(not ca_only))
chain_length = len(target_protein.aatype)
num_aa = residue_constants.restype_num + 1  # add 1 for X

target_structure = protein.DesignableProtein(
    design_mask=np.ones(chain_length),
    design_aatype_mask=np.zeros([chain_length, num_aa], np.int32),
    pssm_coef=np.zeros(chain_length),
    pssm_bias=np.zeros([chain_length, num_aa]),
    pssm_log_odds=10000.0 * np.ones([chain_length, num_aa]),
    bias_per_residue=np.zeros([chain_length, num_aa]),
    **asdict(protein.from_pdb_string(gt_pdb, ca_only=ca_only, backbone_only=(not ca_only))),
)

In [8]:
if ca_only:
    url = "https://github.com/dauparas/ProteinMPNN/raw/main/ca_model_weights/v_48_020.pt"
else:
    url = "https://github.com/dauparas/ProteinMPNN/raw/main/vanilla_model_weights/v_48_002.pt"
    
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
state_dict = checkpoint["model_state_dict"]
noise_level_print = checkpoint["noise_level"]
print(f"Training noise level: {noise_level_print}A")

Training noise level: 0.2A


In [9]:
if torch.cuda.is_available():
    device = torch.cuda.current_device()
else:
    device = torch.device("cpu")

In [10]:
model = ProteinMPNN(
    cfg=config.ProteinMPNNConfig() if not ca_only else config.ProteinMPNNCAOnlyConfig()
)

In [11]:
msg = model.load_state_dict(state_dict)
model = model.to(device)
model = model.eval()

In [12]:
seed = 1227
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

alphabet = "ACDEFGHIKLMNPQRSTVWYX"
alphabet_dict = dict(zip(alphabet, range(21)))
omit_aas_np = np.array([aa in [omit_aas] for aa in alphabet]).astype(np.float32)
bias_aas_np = np.zeros(len(alphabet))


NUM_BATCHES = num_seq_per_target // batch_size
BATCH_COPIES = batch_size

test_sum, test_weights = 0.0, 0.0
score_list = []
global_score_list = []
all_probs_list = []
all_log_probs_list = []
S_sample_list = []
native_seqs = []
batch_clones = [copy.deepcopy(target_structure) for i in range(BATCH_COPIES)]

In [13]:
with torch.no_grad():
    tf_out = tied_featurize(
        batch_clones,
        device,
        chain_dict=None,
        ca_only=ca_only,
    )
    
    pssm_log_odds_mask = (
        tf_out.pssm_log_odds_all > pssm_threshold
    ).float()
    
    for temp in temperatures:
        for j in range(NUM_BATCHES):
            randn_2 = torch.randn(tf_out.chain_M.shape, device=device)
            sample_dict = model.sample(
                tf_out.atom_positions,
                randn_2,
                tf_out.sequence,
                tf_out.chain_M,
                tf_out.chain_encoding_all,
                tf_out.residue_idx,
                mask=tf_out.mask,
                temperature=temp,
                omit_AAs_np=omit_aas_np,
                bias_AAs_np=bias_aas_np,
                chain_M_pos=tf_out.chain_M_pos,
                omit_AA_mask=tf_out.omit_AA_mask,
                pssm_coef=tf_out.pssm_coef_all,
                pssm_bias=tf_out.pssm_bias_all,
                pssm_multi=pssm_multi,
                pssm_log_odds_flag=False,
                pssm_log_odds_mask=pssm_log_odds_mask,
                pssm_bias_flag=False,
                bias_by_res=tf_out.bias_by_res_all,
            )
            S_sample = sample_dict["S"]

            log_probs = model(
                tf_out.atom_positions,
                S_sample,
                tf_out.mask,
                tf_out.chain_M * tf_out.chain_M_pos,
                tf_out.residue_idx,
                tf_out.chain_encoding_all,
                randn_2,
                use_input_decoding_order=True,
                decoding_order=sample_dict["decoding_order"],
            )
            mask_for_loss = tf_out.mask * tf_out.chain_M * tf_out.chain_M_pos
            scores = get_sequence_scores(S_sample, log_probs, mask_for_loss)
            scores = scores.cpu().data.numpy()

            global_scores = get_sequence_scores(S_sample, log_probs, tf_out.mask)
            global_scores = global_scores.cpu().data.numpy()

            all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
            all_log_probs_list.append(log_probs.cpu().data.numpy())
            S_sample_list.append(S_sample.cpu().data.numpy())
            
            for b_ix in range(BATCH_COPIES):
                masked_chain_length_list = (
                    tf_out.masked_chain_length_list_list[b_ix]
                )
                masked_list = tf_out.masked_list_list[b_ix]
                seq_recovery_rate = torch.sum(
                    torch.sum(
                        torch.nn.functional.one_hot(tf_out.sequence[b_ix], 21)
                        * torch.nn.functional.one_hot(
                            S_sample[b_ix], 21
                        ),
                        axis=-1,
                    )
                    * mask_for_loss[b_ix]
                ) / torch.sum(mask_for_loss[b_ix])
                seq = decode_sequence(S_sample[b_ix], tf_out.chain_M[b_ix])
                score = scores[b_ix]
                score_list.append(score)
                global_score = global_scores[b_ix]
                global_score_list.append(global_score)

In [14]:
seq

'SAAARIRRALAEARRARRRAEEARRRAEEARARGDLAAARRALAEARRAERRAREAERRAEELRRRLLAPPRR'

In [15]:
predicted_protein, confidence = folder.fold(seq)
result_pdb = protein.to_pdb(predicted_protein)

In [16]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(result_pdb)

color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7f554e7b7a90>