In [None]:
#@title Install SALAD and utilities
!pip install -q git+https://github.com/mjendrusch/flexcraft.git

In [None]:
#@title Download parameters
!wget https://zenodo.org/records/14711580/files/salad_params.tar.gz
!tar -xzf salad_params.tar.gz
!mkdir pmpnn_params
!cd pmpnn_params
!wget https://github.com/sokrypton/ColabDesign/raw/refs/heads/main/colabdesign/mpnn/weights/v_48_030.pkl
!curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params

In [None]:
#@title Run salad -> ProteinMPNN -> AF2
#@markdown ### salad settings:
#@markdown num_aa is a ":"-separated string of protein chain lengths.
#@markdown E.g. "100:50" will produce a heterodimer with length 50 and a length 100 monomers.
num_aa = '100' #@param {type:"string"}
num_designs = 10 #@param {type:"raw"}
config = 'default_ve_scaled' #@param ["default_vp", "default_ve_scaled"] {type:"string"}

config_timescale = dict(
    default_vp = "cosine(t)",
    default_ve_scaled = "ve(t)"
)

import os
import time
from copy import deepcopy

import jax
import jax.numpy as jnp

import haiku as hk

import salad.inference as si

from flexcraft.sequence.mpnn import make_pmpnn
from flexcraft.sequence.sample import *
from flexcraft.sequence.aa_codes import PMPNN_AA_CODE, AF2_AA_CODE, decode_sequence
from flexcraft.structure.af import *
from flexcraft.utils import Keygen, parse_options, strip_aa, data_from_protein

# here, we implement a denoising step
def model_step(config):
    # get the configuration and set eval to true
    # this turns off any model components that
    # are only used during training.
    config = deepcopy(config)
    config.eval = True
    # salad is built on top of haiku, which means
    # that any functions using salad modules need
    # to be hk.transform-ed before use.
    @hk.transform
    def step(data, prev):
        # instantiate a noise generator
        noise = si.StructureDiffusionNoise(config)
        # and a denoising model
        predict = si.StructureDiffusionPredict(config)
        # we can edit the structure before noise
        # is applied here:
        ...
        # apply noise
        data.update(noise(data))
        # and edit the noised structure here:
        ...
        # run model
        out, prev = predict(data, prev)
        # or the output of the model here:
        ...
        return out, prev
    # return the pure apply function generated
    # by haiku from our step
    return step.apply


opt = dict(
    num_aa = num_aa,
    num_designs = num_designs,
    num_samples = 10,
    config = config,
    out_path = "outputs/",
    salad_params = f"params/{config}-200k.jax",
    pmpnn_params = "v_48_030.pkl",
    af2_params = "./",
    af2_model = "model_1_ptm",
    temperature = 0.1,
    center = "True",
    seed = 42,
)
os.makedirs(opt["out_path"], exist_ok=True)

key = Keygen(opt["seed"])
pmpnn = jax.jit(make_pmpnn(opt["pmpnn_params"], eps=0.05))

# make salad model
salad_config, salad_params = si.make_salad_model(opt["config"], opt["salad_params"])

# initialize salad data and prev from the num_aa specification
num_aa, resi, chain, is_cyclic, cyclic_mask = si.parse_num_aa(opt["num_aa"])
salad_data, init_prev = si.data.from_config(
    salad_config,
    num_aa=num_aa,
    residue_index=resi,
    chain_index=chain,
    cyclic_mask=cyclic_mask)

salad_step = model_step(salad_config)

# make AF2 model
af2_params = get_model_haiku_params(
    model_name=opt["af2_model"],
    data_dir=opt["af2_params"], fuse=True)
af2_config = model_config(opt["af2_model"])
af2_config.model.global_config.use_dgram = False
af2 = jax.jit(make_predict(make_af2(af2_config), num_recycle=2))

# build a sampler object for sampling
sampler = si.Sampler(
    salad_step,
    out_steps=400,
    timescale=config_timescale[config],
)
# run a loop with num_design steps
print("Starting design...")
for idx in range(num_designs):
    # generate a structure in each step
    start = time.time()
    design = sampler(salad_params, key(), salad_data, init_prev)
    print(f"Design {idx} in {time.time() - start:.2f} s")
    # convert it to ProteinMPNN input
    pmpnn_data = data_from_protein(si.data.to_protein(design))
    pmpnn_data = strip_aa(pmpnn_data)
    step_0_logits = pmpnn(key(), pmpnn_data)["logits"]
    center = step_0_logits.mean(axis=0)
    # ensure one TRP/W residue
    p_W = jax.nn.softmax(step_0_logits - center)[:, PMPNN_AA_CODE.index("W")]
    best_W = jnp.argmax(p_W, axis=0)
    # set up logit transform
    transform = transform_logits([
        toggle_transform(
            center_logits(center=center), use=True),
        scale_by_temperature(temperature=opt["temperature"]),
        forbid("C", PMPNN_AA_CODE),
        norm_logits
    ])
    pmpnn_sampler = sample(pmpnn, logit_transform=transform)
    for ids in range(opt["num_samples"]):
        pmpnn_data = strip_aa(pmpnn_data)
        # ensure one TRP/W residue (at the most probable position)
        pmpnn_data["aa"][best_W] = PMPNN_AA_CODE.index("W")
        # sample all other residues
        result, log_p = pmpnn_sampler(key(), pmpnn_data)
        pmpnn_data["aa"] = reindex_aatype(result["aa"], PMPNN_AA_CODE, AF2_AA_CODE)
        sequence = decode_sequence(pmpnn_data["aa"], AF2_AA_CODE)
        af_data = make_af_data(pmpnn_data)
        result = AFResult(inputs=af_data,
                          result=af2(af2_params, key(), af_data))
        plddt = result.plddt.mean()
        print(idx, ids, sequence, plddt)
        if plddt > 0.8:
            result.save_pdb(f"{opt['out_path']}/result_{idx}.pdb")
            break


In [None]:
#@title Visualize designs
import py3Dmol
import glob
from IPython.display import HTML

view = py3Dmol.view(width=800, viewergrid=(num_designs // 5, 5), linked=True)
for idx, name in enumerate(os.listdir(opt["out_path"])):
  if not name.endswith(".pdb"):
    continue
  viewer_id = (idx // 5, idx % 5)
  view.addModel(open(f"{opt['out_path']}/{name}", 'r').read(),'pdb', viewer=viewer_id)
view.setBackgroundColor('#444444')
if chain.max() > 0:
  view.setStyle({'chain': 'A'}, {'cartoon': {'color':'#882288'}})
  view.setStyle({'chain': 'B'}, {'cartoon': {'color':'#22EEEE'}})
  view.setStyle({'chain': 'C'}, {'cartoon': {'color':'#EEEE22'}})
  view.setStyle({'chain': 'D'}, {'cartoon': {'color':'#222222'}})
else:
  view.setStyle({'cartoon': {'color':'spectrum'}})

#view.setStyle({'resn': 'TRP'}, {'sphere': {'scale': 1.0}})
view.zoom(0.15)
view.show()

In [None]:
#@title Download structures
!zip -r results.zip outputs
from google.colab import files
files.download("/content/results.zip")