In [1]:
import torch
import yaml
from ml_collections import ConfigDict
from argparse import ArgumentParser
import os

from protein_ebm.model.r3_diffuser import R3Diffuser
from protein_ebm.data.protein_utils import residues_to_features, plot_protein_frame, restypes, restype_order, restype_num
from protein_ebm.model.ebm import ProteinEBM
from protein_ebm.model.boltz_utils import center_random_augmentation
import numpy as np



with open("../protein_ebm/config/base_pretrain.yaml", 'r') as f:
    config = yaml.safe_load(f)
    
config = ConfigDict(config)

# Create models
diffuser = R3Diffuser(config.diffuser)
model = ProteinEBM(config.model, diffuser).cuda()


# Load checkpoint
ckpt = torch.load("../weights/model_1_frozen_1m_md.pt", weights_only=False)



model.load_state_dict({k[len("model."):]: v for k, v in ckpt['state_dict'].items() if k.startswith('model')})

<All keys matched successfully>

In [70]:
from Bio.PDB import PDBParser
from Bio.PDB.Polypeptide import is_aa

#Load the structure
parser = PDBParser(QUIET=True)
structure = parser.get_structure("my_structure", "../eval_data/decoys/natives/2chf.pdb")

chain = [c for c in structure.get_chains()][0]
atom_positions, atom_mask, aatype, residue_idx = residues_to_features([r for r in chain.get_residues() if is_aa(r)])
nres = atom_positions.shape[0]

residue_mask = torch.ones(nres) 
aatype = torch.tensor(aatype, dtype=torch.long)
ca_coords = center_random_augmentation(atom_positions[...,1,:].unsqueeze(0), torch.ones([1, nres])).view([1,-1,3])

  aatype = torch.tensor(aatype, dtype=torch.long)


In [71]:
%matplotlib notebook
plot_protein_frame(atom_positions, atom_mask, cartoon=True)

<py3Dmol.view at 0x14b2649abc50>

In [74]:
num_t = 300
t_max = 1.0
min_t = .01
reverse_steps = np.linspace(min_t, t_max, num_t)[::-1]
dt = (t_max - min_t)/num_t


r_noisy, trans_score = diffuser.forward_marginal(
    ca_coords.numpy(),
    t_max)

r_noisy = torch.tensor(r_noisy, dtype=torch.float)

pos_t = r_noisy.cpu().numpy()
all_pos = [pos_t]

align_steps=False
self_condition=False
aux_score=False
prev0 = torch.zeros(pos_t.shape).cuda()

with torch.no_grad():
    for t in reverse_steps:
        if t > min_t:
            print(t)
            # Set timestep features

            input_feats = {
                'r_noisy': torch.tensor(pos_t, dtype=torch.float).cuda(),
                'aatype': aatype.unsqueeze(0).cuda(),
                'mask': residue_mask.unsqueeze(0).cuda(),
                'residue_idx': residue_idx.unsqueeze(0).cuda(),
                't': torch.tensor([t], dtype=torch.float).cuda(),
                'sc_ca_t' : prev0,
                'atom_mask' : atom_mask.unsqueeze(0).cuda()
            }
            
            #Get model predictions
            out = model.compute_score(input_feats)

            if aux_score:
                score = out['r_update_aux'].cpu().numpy()
            else:
                score = out['trans_score'].cpu().numpy()

            # Reverse diffusion step
            pos_t = diffuser.reverse(
                x_t=pos_t,
                score_t=score,
                mask=residue_mask.unsqueeze(0).numpy(),
                t=t,
                dt=dt,
                center=False,
                noise_scale=1.0,
            )

            # rotate
            pos_t, prev0_new = center_random_augmentation(torch.tensor(pos_t, dtype=torch.float).cuda(), torch.ones([1, nres]).cuda(), rotate=False, second_coords=out['pred_coords'].reshape([1,-1,3]),return_second_coords=True)
            pos_t = pos_t.cpu().numpy()
            
            if self_condition:
                prev0 = prev0_new

            
            all_pos.append(pos_t)

all_pos.append(out['pred_coords'].detach().cpu().numpy())


1.0
0.9966889632107023
0.9933779264214047
0.990066889632107
0.9867558528428093
0.9834448160535116
0.9801337792642141
0.9768227424749164
0.9735117056856187
0.9702006688963211
0.9668896321070234
0.9635785953177257
0.960267558528428
0.9569565217391304
0.9536454849498327
0.9503344481605351
0.9470234113712375
0.9437123745819398
0.9404013377926421
0.9370903010033445
0.9337792642140468
0.9304682274247491
0.9271571906354514
0.9238461538461539
0.9205351170568562
0.9172240802675585
0.9139130434782609
0.9106020066889632
0.9072909698996655
0.9039799331103678
0.9006688963210702
0.8973578595317726
0.8940468227424749
0.8907357859531773
0.8874247491638796
0.8841137123745819
0.8808026755852842
0.8774916387959866
0.8741806020066889
0.8708695652173913
0.8675585284280937
0.864247491638796
0.8609364548494983
0.8576254180602007
0.854314381270903
0.8510033444816053
0.8476923076923076
0.84438127090301
0.8410702341137124
0.8377591973244147
0.8344481605351171
0.8311371237458194
0.8278260869565217
0.824515050167

In [75]:
%matplotlib notebook
non_ca = torch.tensor(all_pos[-1],dtype=torch.float) +  out['sidechain_coords'].squeeze().cpu()
pos_allatom = torch.cat([non_ca[...,:1,:], torch.tensor(all_pos[-1], dtype=torch.float), non_ca[...,1:,:]], dim=-2)

plot_protein_frame(pos_allatom[0], atom_mask, cartoon=True) # remove batch dim

<py3Dmol.view at 0x14b26461ab50>