In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#!wget https://files.ipd.uw.edu/pub/RoseTTAFold/weights.tar.gz
#!tar xfz weights.tar.gz

In [3]:
import numpy as np
import py3Dmol
import torch
import torch.nn as nn

In [5]:
from proteome import protein
from proteome.models.folding.rosettafold.kinematics import xyz_to_t2d
from proteome.models.folding.rosettafold.parsers import parse_a3m
from proteome.models.folding.rosettafold.rosettafoldmodel import RoseTTAFold
from proteome.models.folding.rosettafold.trfold import TRFold
from proteome.models.folding.rosettafold.config import RoseTTAFoldConfig, TRFoldConfig

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
if torch.cuda.is_available():
    device = torch.cuda.current_device()

In [9]:
model = RoseTTAFold(RoseTTAFoldConfig()).to(device)
state_dict = torch.load("weights/RoseTTAFold_e2e.pt")["model_state_dict"]
msg = model.load_state_dict(state_dict)
model = model.to(device)
model = model.eval()

In [10]:
msa = parse_a3m("t000.a3m")
N, L = msa.shape
xyz_t = torch.full((1, L, 3, 3), np.nan).float()
t1d = torch.zeros((1, L, 3)).float()
t0d = torch.zeros((1, 3)).float()

In [11]:
msa = torch.tensor(msa).long().view(1, -1, L)
idx_pdb = torch.arange(L).long().view(1, L)
seq = msa[:, 0]

# template features
xyz_t = xyz_t.float().unsqueeze(0)
t1d = t1d.float().unsqueeze(0)
t0d = t0d.float().unsqueeze(0)
t2d = xyz_to_t2d(xyz_t, t0d)

Is there a way to adopt and unify with the Alphafold/Openfold chunking so that we don't have to do the ugly cropping steps or at the very least we only have 1 method in the codebase for doing it?

In [13]:
msa = msa[:, :1000].to(device)
seq = msa[:, 0]

idx_pdb = idx_pdb.to(device)
t1d = t1d[:, :10].to(device)
t2d = t2d[:, :10].to(device)

In [23]:
with torch.no_grad():
    prob_s, xyz, lddt = model(
        msa, seq, idx_pdb, t1d=t1d, t2d=t2d, refine=False
    )

In [25]:
prob_trF = []
for prob in prob_s:
    prob = prob.reshape(-1, L, L).permute(1, 2, 0).cpu().numpy()
    prob = torch.tensor(prob).permute(2, 0, 1).to(device)
    prob += 1e-8
    prob = prob / torch.sum(prob, dim=0)[None]
    prob_trF.append(prob)

In [26]:
xyz = xyz[0, :, 1]
trf = TRFold(TRFoldConfig(), device)
xyz = trf.fold(xyz, prob_trF, batch=15, lr=0.1, nsteps=200)
xyz = xyz.detach().cpu().numpy()

In [179]:
xyzo = protein.add_oxygen_to_atom_positions(xyz)
predicted_protein = protein.Protein(
    atom_positions=xyzo, 
    aatype=seq[0].cpu().numpy(), 
    atom_mask=np.ones_like(xyzo)[..., 0],
    residue_index=idx_pdb[0].cpu().numpy() + 1,
    b_factors=lddt[0].cpu().numpy()[:, None].repeat(4, axis=1),
)

In [181]:
of_pdb = protein.to_pdb(predicted_protein)

In [182]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(of_pdb)
color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7f13bff26b30>

In [106]:
#PLDDT_BANDS = [
#  (0, 50, '#FF7D45'),
#  (50, 70, '#FFDB13'),
#  (70, 90, '#65CBF3'),
#  (90, 100, '#0053D6')
#]
#view = py3Dmol.view(width=800, height=600)
#view.addModelsAsFrames(rf_pdb)
#color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
#style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

#style['stick'] = {}

#view.setStyle({'model': -1}, style)
#view.zoomTo()