In [23]:
#!wget http://files.ipd.uw.edu/pub/RFdiffusion/6f5902ac237024bdd0c176cb93063dc4/Base_ckpt.pt

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import re
import os, time, pickle
import torch
from omegaconf import OmegaConf
from omegaconf import DictConfig
import hydra
import logging
import numpy as np
import random
import glob

from proteome.models.design.rfdiffusion.util import writepdb_multi, writepdb
from proteome.models.design.rfdiffusion.inference import utils as iu

In [3]:
from proteome.models.design.rfdiffusion import config
from dataclasses import asdict

In [4]:
def make_deterministic(seed=0):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [5]:
device = torch.cuda.current_device()

In [6]:
make_deterministic()

In [46]:
cfg = config.RFDiffusionConfig(inference=config.InferenceConfig(input_pdb="1qys.pdb", model_directory_path="."), contigmap=config.ContigMap(contigs="90-91"))
conf = DictConfig(asdict(cfg))

In [47]:
sampler = iu.sampler_selector(conf)

Reading models from .
This is inf_conf.ckpt_path
./Base_ckpt.pt
Assembling -model, -diffuser and -preprocess configs from checkpoint
USING MODEL CONFIG: self._conf[model][n_extra_block] = 4
USING MODEL CONFIG: self._conf[model][n_main_block] = 32
USING MODEL CONFIG: self._conf[model][n_ref_block] = 4
USING MODEL CONFIG: self._conf[model][d_msa] = 256
USING MODEL CONFIG: self._conf[model][d_msa_full] = 64
USING MODEL CONFIG: self._conf[model][d_pair] = 128
USING MODEL CONFIG: self._conf[model][d_templ] = 64
USING MODEL CONFIG: self._conf[model][n_head_msa] = 8
USING MODEL CONFIG: self._conf[model][n_head_pair] = 4
USING MODEL CONFIG: self._conf[model][n_head_templ] = 4
USING MODEL CONFIG: self._conf[model][d_hidden] = 32
USING MODEL CONFIG: self._conf[model][d_hidden_templ] = 32
USING MODEL CONFIG: self._conf[model][p_drop] = 0.15
USING MODEL CONFIG: self._conf[model][SE3_param_full] = {'num_layers': 1, 'num_channels': 32, 'num_degrees': 2, 'n_heads': 4, 'div': 4, 'l0_in_features': 8, '

In [48]:
sum([p.numel() for p in sampler.model.parameters()])

59808046

In [49]:
design_startnum = sampler.inf_conf.design_startnum
if sampler.inf_conf.design_startnum == -1:
    existing = glob.glob(sampler.inf_conf.output_prefix + "*.pdb")
    indices = [-1]
    for e in existing:
        print(e)
        m = re.match(".*_(\d+)\.pdb$", e)
        print(m)
        if not m:
            continue
        m = m.groups()[0]
        indices.append(int(m))
    design_startnum = max(indices) + 1

In [50]:
for i_des in range(design_startnum, design_startnum + sampler.inf_conf.num_designs):
    if conf.inference.deterministic:
        make_deterministic(i_des)

    start_time = time.time()
    out_prefix = f"{sampler.inf_conf.output_prefix}_{i_des}"

    x_init, seq_init = sampler.sample_init()
    denoised_xyz_stack = []
    px0_xyz_stack = []
    seq_stack = []
    plddt_stack = []

    x_t = torch.clone(x_init)
    seq_t = torch.clone(seq_init)
    # Loop over number of reverse diffusion time steps.
    for t in range(int(sampler.t_step_input), sampler.inf_conf.final_step - 1, -1):
        px0, x_t, seq_t, plddt = sampler.sample_step(
            t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step
        )
        px0_xyz_stack.append(px0)
        denoised_xyz_stack.append(x_t)
        seq_stack.append(seq_t)
        plddt_stack.append(plddt[0])  # remove singleton leading dimension

    # Flip order for better visualization in pymol
    denoised_xyz_stack = torch.stack(denoised_xyz_stack)
    denoised_xyz_stack = torch.flip(
        denoised_xyz_stack,
        [
            0,
        ],
    )
    px0_xyz_stack = torch.stack(px0_xyz_stack)
    px0_xyz_stack = torch.flip(
        px0_xyz_stack,
        [
            0,
        ],
    )

    # For logging -- don't flip
    plddt_stack = torch.stack(plddt_stack)

    # Save outputs
    os.makedirs(os.path.dirname(out_prefix), exist_ok=True)
    final_seq = seq_stack[-1]

    # Output glycines, except for motif region
    final_seq = torch.where(
        torch.argmax(seq_init, dim=-1) == 21, 7, torch.argmax(seq_init, dim=-1)
    )  # 7 is glycine

    bfacts = torch.ones_like(final_seq.squeeze())
    # make bfact=0 for diffused coordinates
    bfacts[torch.where(torch.argmax(seq_init, dim=-1) == 21, True, False)] = 0
    # pX0 last step
    out = f"{out_prefix}.pdb"

    # Now don't output sidechains
    writepdb(
        out,
        denoised_xyz_stack[0, :, :4],
        final_seq,
        sampler.binderlen,
        chain_idx=sampler.chain_idx,
        bfacts=bfacts,
    )

    # run metadata
    trb = dict(
        config=OmegaConf.to_container(sampler._conf, resolve=True),
        plddt=plddt_stack.cpu().numpy(),
        device=torch.cuda.get_device_name(torch.cuda.current_device())
        if torch.cuda.is_available()
        else "CPU",
        time=time.time() - start_time,
    )
    if hasattr(sampler, "contig_map"):
        for key, value in sampler.contig_map.get_mappings().items():
            trb[key] = value
    with open(f"{out_prefix}.trb", "wb") as f_out:
        pickle.dump(trb, f_out)

    if sampler.inf_conf.write_trajectory:
        # trajectory pdbs
        traj_prefix = (
            os.path.dirname(out_prefix) + "/traj/" + os.path.basename(out_prefix)
        )
        os.makedirs(os.path.dirname(traj_prefix), exist_ok=True)

        out = f"{traj_prefix}_Xt-1_traj.pdb"
        writepdb_multi(
            out,
            denoised_xyz_stack,
            bfacts,
            final_seq.squeeze(),
            use_hydrogens=False,
            backbone_only=False,
            chain_ids=sampler.chain_idx,
        )

        out = f"{traj_prefix}_pX0_traj.pdb"
        writepdb_multi(
            out,
            px0_xyz_stack,
            bfacts,
            final_seq.squeeze(),
            use_hydrogens=False,
            backbone_only=False,
            chain_ids=sampler.chain_idx,
        )

With this beta schedule (linear schedule, beta_0 = 0.04, beta_T = 0.28), alpha_bar_T = 0.00013696050154976547
With this beta schedule (linear schedule, beta_0 = 0.04, beta_T = 0.28), alpha_bar_T = 0.00013696050154976547
With this beta schedule (linear schedule, beta_0 = 0.04, beta_T = 0.28), alpha_bar_T = 0.00013696050154976547
With this beta schedule (linear schedule, beta_0 = 0.04, beta_T = 0.28), alpha_bar_T = 0.00013696050154976547
With this beta schedule (linear schedule, beta_0 = 0.04, beta_T = 0.28), alpha_bar_T = 0.00013696050154976547
With this beta schedule (linear schedule, beta_0 = 0.04, beta_T = 0.28), alpha_bar_T = 0.00013696050154976547
With this beta schedule (linear schedule, beta_0 = 0.04, beta_T = 0.28), alpha_bar_T = 0.00013696050154976547
With this beta schedule (linear schedule, beta_0 = 0.04, beta_T = 0.28), alpha_bar_T = 0.00013696050154976547
With this beta schedule (linear schedule, beta_0 = 0.04, beta_T = 0.28), alpha_bar_T = 0.00013696050154976547
With this 

In [51]:
import py3Dmol

In [52]:
with open("1qys.pdb", mode="r") as f:
    init_pdb = f.read()

In [53]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(init_pdb)

color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7f5c5d3dc7c0>

In [56]:
with open("./samples/design_9.pdb", mode="r") as f:
    designed_pdb = f.read()

In [57]:
PLDDT_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(designed_pdb)

color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}

style['stick'] = {}

view.setStyle({'model': -1}, style)
view.zoomTo()

<py3Dmol.view at 0x7f5c5d3de710>