In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import py3Dmol
import torch
from proteome import protein
from proteome.models.folding.esm import pretrained
from proteome.models.folding.openfold.utils.feats import atom14_to_atom37

In [3]:
sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH'

In [4]:
model = pretrained.esmfold_v1()
model = model.half()
model = model.eval()
model.set_chunk_size(512)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
model = model.cuda()

In [6]:
output = model.infer([sequence])

In [7]:
final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
output = {k: v.to("cpu").numpy() for k, v in output.items()}
final_atom_positions = final_atom_positions.cpu().numpy()
final_atom_mask = output["atom37_atom_exists"]

In [8]:
for k,v in output.items():
    output[k] = v.squeeze()

In [9]:
predicted_protein = protein.Protein(
    aatype=output["aatype"],
    atom_positions=final_atom_positions.squeeze(),
    atom_mask=final_atom_mask.squeeze(),
    residue_index=output["residue_index"] + 1,
    b_factors=output["plddt"],
    chain_index=output["chain_index"] if "chain_index" in output else None,
)

In [12]:
output["plddt"]

array([[57.38, 61.47, 61.  , ..., 61.9 , 39.34, 52.2 ],
       [71.56, 71.56, 73.1 , ..., 71.  , 49.6 , 55.62],
       [70.6 , 70.94, 73.44, ..., 70.9 , 47.8 , 55.2 ],
       ...,
       [63.1 , 63.56, 63.56, ..., 60.5 , 43.9 , 50.53],
       [64.25, 65.  , 64.6 , ..., 65.5 , 44.25, 52.25],
       [59.47, 64.06, 58.53, ..., 60.1 , 44.34, 52.75]], dtype=float16)

In [10]:
result_pdb = protein.to_pdb(predicted_protein)

In [11]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(result_pdb)
color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7feadd44aec0>