In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import py3Dmol
import torch
from proteome import protein
from proteome.models.folding.esm import pretrained
from proteome.models.folding.openfold.utils.feats import atom14_to_atom37

In [3]:
sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH'

In [4]:
model = pretrained.esmfold_v0()
model = model.half()
model = model.eval()
model.set_chunk_size(512)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
model = model.cuda()

In [6]:
output = model.infer([sequence])

In [7]:
final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
output = {k: v.to("cpu").numpy() for k, v in output.items()}
final_atom_positions = final_atom_positions.cpu().numpy()
final_atom_mask = output["atom37_atom_exists"]

In [8]:
for k,v in output.items():
    output[k] = v.squeeze()

In [9]:
predicted_protein = protein.Protein(
    aatype=output["aatype"],
    atom_positions=final_atom_positions.squeeze(),
    atom_mask=final_atom_mask.squeeze(),
    residue_index=output["residue_index"] + 1,
    b_factors=output["plddt"],
    chain_index=output["chain_index"] if "chain_index" in output else None,
)

In [10]:
output["plddt"]

array([[55.22, 53.84, 55.06, ..., 47.97, 38.34, 53.22],
       [62.34, 61.12, 62.34, ..., 52.  , 47.16, 52.62],
       [72.75, 71.1 , 72.9 , ..., 57.1 , 55.28, 52.1 ],
       ...,
       [67.56, 67.4 , 67.5 , ..., 51.56, 52.97, 56.3 ],
       [68.44, 68.3 , 68.44, ..., 55.12, 58.94, 53.66],
       [60.94, 60.34, 60.66, ..., 52.4 , 46.72, 52.88]], dtype=float16)

In [11]:
result_pdb = protein.to_pdb(predicted_protein)

In [12]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(result_pdb)
color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7fa6f3c29ff0>