In [13]:
import re
import os, time, pickle
import torch
from omegaconf import OmegaConf
import hydra
import logging
from rfdiffusion.util import writepdb_multi, writepdb
from rfdiffusion.inference import utils as iu, symmetry
from hydra.core.hydra_config import HydraConfig
import numpy as np
import random
import glob
import string
from random import randint, choice

In [14]:
with hydra.initialize(version_base=None, config_path="./config/inference", job_name="base"):
    conf = hydra.compose(config_name="base")

contigs = '5/Y471-485/10-30/X95-101/10-30/X246-251/10-30/X265-271/1/X273-278/10-30/X313-322/5/0' #@param {type:"string"}
pdb = './examples/input_pdbs/1DBB_interface_30A.pdb' #@param {type:"string"}
num_designs = 1 #@param ["1", "2", "4", "8", "10", "16", "32"] {type:"raw"}
symmetry = "C4" #@param {type:"string"}

contigs_lst = []
for subcon in contigs.split("/"):
  if "-" in subcon and subcon[0].isdigit():
    length_inpaint = random.randint(int(subcon.split("-")[0]), int(subcon.split("-")[1]))
    contigs_lst.append(str(length_inpaint))
  else:
    contigs_lst.append(subcon)
new_contigs = "/".join(contigs_lst)


conf.inference.asy_motif = True
conf.contigmap.contigs = [new_contigs]
conf.inference.input_pdb = pdb
conf.inference.symmetry = symmetry

sampler = iu.sampler_selector(conf)

Reading models from /home/hychim/software/RFdiffusion_asy/rfdiffusion/inference/../../models
This is inf_conf.ckpt_path
/home/hychim/software/RFdiffusion_asy/rfdiffusion/inference/../../models/Base_ckpt.pt
Assembling -model, -diffuser and -preprocess configs from checkpoint
USING MODEL CONFIG: self._conf[model][n_extra_block] = 4
USING MODEL CONFIG: self._conf[model][n_main_block] = 32
USING MODEL CONFIG: self._conf[model][n_ref_block] = 4
USING MODEL CONFIG: self._conf[model][d_msa] = 256
USING MODEL CONFIG: self._conf[model][d_msa_full] = 64
USING MODEL CONFIG: self._conf[model][d_pair] = 128
USING MODEL CONFIG: self._conf[model][d_templ] = 64
USING MODEL CONFIG: self._conf[model][n_head_msa] = 8
USING MODEL CONFIG: self._conf[model][n_head_pair] = 4
USING MODEL CONFIG: self._conf[model][n_head_templ] = 4
USING MODEL CONFIG: self._conf[model][d_hidden] = 32
USING MODEL CONFIG: self._conf[model][d_hidden_templ] = 32
USING MODEL CONFIG: self._conf[model][p_drop] = 0.15
USING MODEL CONF

In [4]:
#@title Run RFdiffusion
x_init, seq_init = sampler.sample_init()
denoised_xyz_stack = []
px0_xyz_stack = []
seq_stack = []
plddt_stack = []

x_t = torch.clone(x_init)
seq_t = torch.clone(seq_init)

for t in range(int(sampler.t_step_input), sampler.inf_conf.final_step - 1, -1):
  print(t)
  if t != sampler.t_step_input:
      subunit_com_lst = []
      un_mask_1d = [not elem for elem in sampler.contig_map.mask_1d]
      x_t_unmasked = x_t[un_mask_1d]
      unmasked_subunit_len = int(x_t_unmasked.shape[0]/sampler.symmetry.order)
      sub_uppercase = string.ascii_uppercase[:sampler.symmetry.order*2]

      for order in range(sampler.symmetry.order):
          unmasked_start_i = unmasked_subunit_len*order
          unmasked_end_i = unmasked_subunit_len*(order+1)
          subunit_com_lst.append(x_t_unmasked[unmasked_start_i:unmasked_end_i].mean(dim=0))

      subunit_pair_com_lst = [((subunit_com_lst[order-1].nan_to_num()+subunit_com_lst[order].nan_to_num())/2) for order in range(sampler.symmetry.order)]
      subunit_pair_com_lst.append(subunit_pair_com_lst.pop(0)) # rotate list

      for order in range(sampler.symmetry.order):
          interface_id = []

          interface_A = string.ascii_uppercase[order*2]
          interface_B = string.ascii_uppercase[order*2+1]

          for i, ref in enumerate(sampler.contig_map.ref):
              if ref[0] == interface_A or ref[0] == interface_B:
                  interface_id.append(i)

          interface_com = x_t[interface_id].mean(dim=0)
          com_diff = subunit_pair_com_lst[order] - interface_com

          for i, ref in enumerate(sampler.contig_map.ref):
              if (ref[0] == interface_A or ref[0] == interface_B):
                  x_t[i] = x_t[i] + com_diff*0.01
                  
      x_t = x_t.nan_to_num()

  px0, x_t, seq_t, plddt = sampler.sample_step(
      t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step
  )
  px0_xyz_stack.append(px0)

  denoised_xyz_stack.append(x_t)
  seq_stack.append(seq_t)
  plddt_stack.append(plddt[0])  # remove singleton leading dimension


IndentationError: unindent does not match any outer indentation level (<tokenize>, line 45)

In [10]:
target_feats = iu.process_target(sampler.inf_conf.input_pdb, parse_hetatom=True, center=False)

In [94]:
target_feats.keys()

dict_keys(['xyz_27', 'mask_27', 'seq', 'pdb_idx', 'xyz_het', 'info_het'])

In [11]:
target_feats['xyz_27']

tensor([[[-10.9840,  25.0500,   7.2450],
         [-11.0090,  26.1810,   6.2970],
         [-12.3870,  26.4620,   5.6970],
         ...,
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan]],

        [[-12.5890,  27.6670,   5.0820],
         [-13.8760,  27.9990,   4.4560],
         [-14.1790,  26.9830,   3.3130],
         ...,
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan]],

        [[ -6.0100,  22.4990,   7.2450],
         [ -4.6760,  22.3610,   6.7110],
         [ -3.6490,  22.3000,   7.8230],
         ...,
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan]],

        ...,

        [[ -3.2380,  30.3300,  -9.5340],
         [ -4.2570,  29.3840,  -9.9720],
         [ -4.7890,  29.7500, -11.3530],
         ...,
         [     nan,      nan,      nan],
         [     nan,   

In [45]:
target_feats_xyz_27 = target_feats['xyz_27']

In [25]:
sampler.symmetry

<rfdiffusion.inference.symmetry.SymGen at 0x7fbf2a31b970>

In [101]:
rot_x = 0
rot_y = 0
rot_z = 180
rot = np.array([[np.cos(rot_y)*np.cos(rot_z),   np.sin(rot_x)*np.sin(rot_y)*np.cos(rot_z)-np.cos(rot_x)*np.sin(rot_z),  np.cos(rot_x)*np.sin(rot_y)*np.cos(rot_z)+np.sin(rot_x)*np.sin(rot_z)],
                [np.cos(rot_y)*np.sin(rot_z),   np.sin(rot_x)*np.sin(rot_y)*np.sin(rot_z)+np.cos(rot_x)*np.cos(rot_z),  np.cos(rot_x)*np.sin(rot_y)*np.sin(rot_z)-np.sin(rot_x)*np.cos(rot_z)],
                [-np.sin(rot_y)             ,   np.sin(rot_x)*np.cos(rot_y)                                          ,  np.cos(rot_x)*np.cos(rot_y)]], dtype=np.float32)

In [102]:
target_feats_xyz_27[0]

tensor([[-10.9840,  25.0500,   7.2450],
        [-11.0090,  26.1810,   6.2970],
        [-12.3870,  26.4620,   5.6970],
        [-13.2460,  25.5460,   5.7410],
        [-10.0170,  25.9340,   5.1570],
        [ -9.7170,  27.1560,   4.3440],
        [-10.5720,  27.6290,   3.3700],
        [ -8.6600,  27.9990,   4.3590],
        [-10.0500,  28.7130,   2.8220],
        [ -8.8920,  28.9580,   3.4060],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],


In [103]:
torch.einsum('bnj,kj->bnk', target_feats_xyz_27, torch.from_numpy(rot))

tensor([[[ 26.6424,  -6.1916,   7.2450],
         [ 27.5634,  -6.8484,   6.2970],
         [ 28.6132,  -5.9126,   5.6970],
         ...,
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan]],

        [[ 29.6995,  -6.4719,   5.0820],
         [ 30.7357,  -5.6395,   4.4560],
         [ 30.1031,  -4.7887,   3.3130],
         ...,
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan]],

        [[ 21.6219,  -8.6498,   7.2450],
         [ 20.7130,  -9.6360,   6.7110],
         [ 20.0495, -10.4223,   7.8230],
         ...,
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan]],

        ...,

        [[ 26.2368, -15.5572,  -9.5340],
         [ 26.0887, -14.1746,  -9.9720],
         [ 26.7003, -13.9675, -11.3530],
         ...,
         [     nan,      nan,      nan],
         [     nan,   