In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#!wget https://files.ipd.uw.edu/pub/RoseTTAFold/weights.tar.gz
#!tar xfz weights.tar.gz

In [3]:
import os
import sys
import time
from collections import namedtuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
import py3Dmol

In [4]:
import proteome.models.folding.rosettafold.util as util

In [5]:
#from proteome.models.folding.rosettafold.ffindex import *
from proteome.models.folding.rosettafold.kinematics import xyz_to_t2d
from proteome.models.folding.rosettafold.parsers import parse_a3m, read_templates
from proteome.models.folding.rosettafold.rosettafoldmodel import RoseTTAFold
from proteome.models.folding.rosettafold.trfold import TRFold
from proteome.models.folding.rosettafold.config import RoseTTAFoldConfig, TRFoldConfig

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def extend(a, b, c, L, A, D):
    """
    input:  3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
    output: 4th coord
    """
    N = lambda x: x / np.sqrt(np.square(x).sum(-1, keepdims=True) + 1e-8)
    bc = N(b - c)
    n = N(np.cross(b - a, bc))
    m = [bc, np.cross(n, bc), n]
    d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
    return c + sum([m * d for m, d in zip(m, d)])

In [7]:
if torch.cuda.is_available():
    device = torch.cuda.current_device()

In [8]:
model = RoseTTAFold(RoseTTAFoldConfig()).to(device)
state_dict = torch.load("weights/RoseTTAFold_e2e.pt")["model_state_dict"]
msg = model.load_state_dict(state_dict)
model = model.to(device)
model = model.eval()

In [9]:
msa = parse_a3m("t000.a3m")
N, L = msa.shape
xyz_t = torch.full((1, L, 3, 3), np.nan).float()
t1d = torch.zeros((1, L, 3)).float()
t0d = torch.zeros((1, 3)).float()

In [10]:
msa = torch.tensor(msa).long().view(1, -1, L)
idx_pdb = torch.arange(L).long().view(1, L)
seq = msa[:, 0]

# template features
xyz_t = xyz_t.float().unsqueeze(0)
t1d = t1d.float().unsqueeze(0)
t0d = t0d.float().unsqueeze(0)
t2d = xyz_to_t2d(xyz_t, t0d)

In [12]:
#for pname,param in model.named_parameters():
#    if "se3" in pname:
#        print(pname)
#        break

Is there a way to adopt and unify with the Alphafold/Openfold chunking so that we don't have to do the ugly cropping steps or at the very least we only have 1 method in the codebase for doing it?

In [13]:
msa = msa[:, :1000].to(device)
seq = msa[:, 0]

idx_pdb = idx_pdb.to(device)
t1d = t1d[:, :10].to(device)
t2d = t2d[:, :10].to(device)

In [14]:
with torch.no_grad():
    prob_s, xyz, lddt = model(
        msa, seq, idx_pdb, t1d=t1d, t2d=t2d, refine=False
    )

  assert input.numel() == input.storage().size(), (


In [15]:
prob_trF = []
for prob in prob_s:
    prob = prob.reshape(-1, L, L).permute(1, 2, 0).cpu().numpy()
    prob = torch.tensor(prob).permute(2, 0, 1).to(device)
    prob += 1e-8
    prob = prob / torch.sum(prob, dim=0)[None]
    prob_trF.append(prob)

In [16]:
xyz = xyz[0, :, 1]
trf = TRFold(TRFoldConfig(), device)
xyz = trf.fold(xyz, prob_trF, batch=15, lr=0.1, nsteps=200)
xyz = xyz.detach().cpu().numpy()

In [17]:
# add O and Cb
N = xyz[:, 0, :]
CA = xyz[:, 1, :]
C = xyz[:, 2, :]
O = extend(np.roll(N, -1, axis=0), CA, C, 1.231, 2.108, -3.142)
xyz = np.concatenate((xyz, O[:, None, :]), axis=1)

In [22]:
#write_pdb(seq[0], xyz, idx_pdb[0], Bfacts=lddt[0], prefix="./result")

In [30]:
with open("./result.pdb", mode="r") as f:
    relaxed_pdb = f.read()

In [33]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(relaxed_pdb)
color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7f5d4a4b6140>