<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/af/examples/peptide_binder_design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AfDesign - peptide binder design
For a given protein target and protein binder length, generate/hallucinate a protein binder sequence AlphaFold thinks will bind to the target structure. To do this, we maximize number of contacts at the interface and maximize pLDDT of the binder.

**WARNING**
1.   This notebook is in active development and was designed for demonstration purposes only.
2.   Using AfDesign as the only "loss" function for design might be a bad idea, you may find adversarial sequences (aka. sequences that trick AlphaFold).

In [1]:
import os
print(os.getcwd())
os.chdir('/usr/users/fatma.chafra01/ColabDesign')
print(os.getcwd())
#can't find colabdesign script if not in ColabDesign main directory

/home/mpg01/MBPC/fatma.chafra01/ColabDesign/af/examples
/home/mpg01/MBPC/fatma.chafra01/ColabDesign


In [2]:
#@title **setup**

#if not os.path.isdir("params"):
  # get code
  # os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  # for debugging
  #os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  #os.system("mkdir params")
  #os.system("apt-get install aria2 -qq")
  #os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  #os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.shared.utils import copy_dict
from colabdesign.af.alphafold.common import residue_constants

from IPython.display import HTML
#from google.colab import files
import numpy as np


In [3]:
#@title **prep inputs**
import re
#@markdown ---
#@markdown **target info**
pdb = "8EE2" #@param {type:"string"}
#@markdown - enter PDB code or UniProt code (to fetch AlphaFoldDB model) or leave blink to upload your own
target_chain = "A" #@param {type:"string"}
target_hotspot = "" #@param {type:"string"}
if target_hotspot == "": target_hotspot = None
# specifies positions in binder (chain B) that should remain fixed during
# redesign

pos = "" #@param {type:"string"}
pos = re.sub("[^0-9,]", "", pos)
if pos == "": pos = None
#@markdown - restrict loss to predefined positions on target (eg. "1-10,12,15")
target_flexible = False #@param {type:"boolean"}
#@markdown - allow backbone of target structure to be flexible

#@markdown ---
#@markdown **binder info**
binder_len = 102 #@param {type:"integer"}
#@markdown - length of binder to hallucination
binder_seq = "" #@param {type:"string"}
binder_seq = re.sub("[^A-Z]", "", binder_seq.upper())
if len(binder_seq) > 0:
  binder_len = len(binder_seq)
else:
  binder_seq = None
#@markdown - if defined, will initialize design with this sequence

binder_chain = "C" #@param {type:"string"}
if binder_chain == "": binder_chain = None
#@markdown - if defined, supervised loss is used (binder_len is ignored)
# tried with fix_pos but apparently only available for fixbb and partial hallucination ones
# fix_pos = "1-27, 35-53, 59-99, 112-121" #@param {type:"string"}
# fix_pos = re.sub("[^0-9,]", "", fix_pos)
# if fix_pos == "": fix_pos = None
fix_seq = True #@param {type:"boolean"}
#@markdown - When set to True, it maintains the original sequence at the fixed positions specified by fix_pos.
#@markdown ---
#@markdown **model config**
use_multimer = True #@param {type:"boolean"}
#@markdown - use alphafold-multimer for design
num_recycles = 3 #@param ["0", "1", "3", "6"] {type:"raw"}
num_models = "5" #@param ["1", "2", "3", "4", "5", "all"]
num_models = 5 if num_models == "all" else int(num_models)
#@markdown - number of trained models to use during optimization


x = {"pdb_filename":pdb,
     "chain":target_chain,
     "binder_len":binder_len,
     "binder_chain":binder_chain,
     "hotspot":target_hotspot,
     "use_multimer":use_multimer,
     "rm_target_seq":target_flexible,
     "fix_seq": fix_seq}

# x["pdb_filename"] = get_pdb(x["pdb_filename"])
# instead of using the get_pdb function, using the modified pdb file
x["pdb_filename"] = '/usr/users/fatma.chafra01/ColabDesign/8ee2.pdb'

#if "x_prev" not in dir() or x != x_prev:
clear_mem()
#model = mk_afdesign_model(protocol="binder",
#                          use_multimer=x["use_multimer"],
#                          num_recycles=num_recycles,
#                          recycle_mode="sample")
#model.prep_inputs(**x,
#                  ignore_missing=False)
x_prev = copy_dict(x)


In [4]:
# can't find params directory if not in af/examples directory! gives a warning cannot find model_1_v3 etc.
print(os.getcwd())
os.chdir('/home/mpg01/MBPC/fatma.chafra01/ColabDesign/af/examples')
print(os.getcwd())
# after Chris's suggestions simplified model initialization (warnings appeared for the first time, why?)
model = mk_afdesign_model(protocol="binder",
                          use_multimer=x["use_multimer"],
                          num_recycles=num_recycles,
                          recycle_mode="sample",use_templates=True)
model.prep_inputs(pdb_filename='/usr/users/fatma.chafra01/ColabDesign/8ee2.pdb', target_chain="A",  binder_chain='C', rm_template_ic=True,
                  ignore_missing=False, num_models = num_models)
print("target length:", model._target_len)
print("binder length:", model._binder_len)
binder_len = model._binder_len

/home/mpg01/MBPC/fatma.chafra01/ColabDesign
/home/mpg01/MBPC/fatma.chafra01/ColabDesign/af/examples
target length: 143
binder length: 117


In [7]:
print(model.key())
model.set_seed(123324)
print(model.key())

[4033073092 1506481687]
[3036458098 2370171162]


In [8]:
# model weights
print('weights', model.opt["weights"].keys())
for keys in model.opt["weights"].keys():
    print(f'{keys}', model.opt["weights"][keys])
print(dir(model))
print(model._inputs)

weights dict_keys(['con', 'dgram_cce', 'exp_res', 'fape', 'helix', 'i_con', 'i_pae', 'pae', 'plddt', 'rmsd', 'seq_ent'])
con 0.0
dgram_cce 1.0
exp_res 0.0
fape 0.0
helix 0.0
i_con 0.0
i_pae 0.0
pae 0.0
plddt 0.0
rmsd 0.0
seq_ent 0.0
['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_args', '_binder_len', '_callbacks', '_cfg', '_design_mcmc', '_fix_pos', '_get_loss', '_get_model', '_get_model_nums', '_get_seq', '_inputs', '_k', '_len', '_lengths', '_loss_binder', '_loss_fixbb', '_loss_hallucination', '_loss_partial', '_loss_unsupervised', '_model', '_model_names', '_model_params', '_mutate', '_norm_seq_grad', '_num', '_opt', '_optimizer', '_params', '_pdb', '_prep_binder', '_prep_features', '_pre

In [9]:
# I don't know if the numbering changed because of the edit that I made to the pdb file (aa sequence is still there although no structure is associated with some parts!)
# turns out that it didn't change
# after modiying the pdb file, the aa seq got truncated so had to change the intervals from (1, 27), (35,53), (59,99), (112,121) to:
fix_pos = [(19, 27), (35,53), (59,99), (112,127)]
#seq = 'MAEVQLVESGGGLVQPGGSLRLSCTTSTSLFSITTMGWYRQAPGKQRELVASIKRGGGTNYADSMKGRFTISRDNARNTVFLEMNNLTTEDTAVYYCNAAILAYTGEVTNYWGQGTQVTV'
seq = "MAEVQLVESGGGLVQPGGSLRLSCTTSTSLFSITTMGWYRQAPGKQRELVASIKRGGGTNYADSMKGRFTISRDNARNTVFLEMNNLTTEDTAVYYCNAAILAYTGEVTNYWGQGTQVTVSSGQAGQ"
# make sure to also delete the last part that is not in the structure from the original sequence because it will not be present in the later generated sequence
print('seq length', len(seq))
# instead trying to make a bias matrix as suggested here: https://github.com/sokrypton/ColabDesign/issues/107
print(model._binder_len)
bias = np.zeros((model._binder_len,20))
for item in fix_pos:
  start = item[0] -1
  end = item[1] -1
  print(start, end)
  while start <= end:
    aa = seq[start]
    print(start, aa)
    start += 1
    # because the index changed once the pdb file was truncated
    bias[start-19,residue_constants.restype_order[str(aa)]] = 1e8
    print(f'bias added to:{start-19} as {aa}')
print(bias)

seq length 127
117
18 26
18 S
bias added to:0 as S
19 L
bias added to:1 as L
20 R
bias added to:2 as R
21 L
bias added to:3 as L
22 S
bias added to:4 as S
23 C
bias added to:5 as C
24 T
bias added to:6 as T
25 T
bias added to:7 as T
26 S
bias added to:8 as S
34 52
34 T
bias added to:16 as T
35 M
bias added to:17 as M
36 G
bias added to:18 as G
37 W
bias added to:19 as W
38 Y
bias added to:20 as Y
39 R
bias added to:21 as R
40 Q
bias added to:22 as Q
41 A
bias added to:23 as A
42 P
bias added to:24 as P
43 G
bias added to:25 as G
44 K
bias added to:26 as K
45 Q
bias added to:27 as Q
46 R
bias added to:28 as R
47 E
bias added to:29 as E
48 L
bias added to:30 as L
49 V
bias added to:31 as V
50 A
bias added to:32 as A
51 S
bias added to:33 as S
52 I
bias added to:34 as I
58 98
58 T
bias added to:40 as T
59 N
bias added to:41 as N
60 Y
bias added to:42 as Y
61 A
bias added to:43 as A
62 D
bias added to:44 as D
63 S
bias added to:45 as S
64 M
bias added to:46 as M
65 K
bias added to:47 as K


In [10]:
# check whether the bias matrix makes sense by randomly printing out a row that has to contain a serine (pos 19 according to prev numbering, now pos 19 - 19 so 0)
# I don't know the order of the one hot but it is not exactly the alphabetically ordered classical one hot
# (because for example valine is at the very end of the columns instead of tyrosine)
print(bias[0,15])
# 120 is V so at pos 101
print(bias[75,:])
# 119 is T so at pos 100
print(bias[100,:])
def find_rows_without_value(arr, value):
    # Check each row for the presence of the value
    rows_with_value = np.any(np.isclose(arr, value), axis=1)

    # Get the indices of rows that don't have the value
    rows_without_value = np.where(~rows_with_value)[0]

    return rows_without_value

print("Rows without 1.0e+08:", find_rows_without_value(bias, 1.0e+08))

100000000.0
[0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00
 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 1.e+08]
[0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00
 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 1.e+08 0.e+00 0.e+00 0.e+00]
Rows without 1.0e+08: [  9  10  11  12  13  14  15  35  36  37  38  39  81  82  83  84  85  86
  87  88  89  90  91  92 109 110 111 112 113 114 115 116]


In [7]:
model.restart(seq=binder_seq, seed=12332)
model.set_seq(bias=bias)
# i_pae and i_con 2.0, rest 1.0
model.opt["weights"].update({"i_con": 2.0, "i_pae": 2.0, "con": 0.5, "dgram_cce": 10.0, "exp_res": 1.0, "fape": 0.1, "pae": 10.0, "plddt": 10.0, "rmsd": 20.0, "seq_ent": 1.0})
print('----------')
for keys in model.opt["weights"].keys():
  print(f'{keys}', model.opt["weights"][keys])
print('----------')
print(model.opt.keys())
# bias is mpt a key of model.opt so don't know how to access it
print(model._inputs["bias"])
# has the bias matrix associated with the model!

----------
con 0.5
dgram_cce 10.0
exp_res 1.0
fape 0.1
helix 0.0
i_con 2.0
i_pae 2.0
pae 10.0
plddt 10.0
rmsd 20.0
seq_ent 1.0
----------
dict_keys(['alpha', 'con', 'dropout', 'fape_cutoff', 'hard', 'i_con', 'learning_rate', 'norm_seq_grad', 'num_models', 'num_recycles', 'pssm_hard', 'sample_models', 'soft', 'temp', 'template', 'weights'])
[[0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
 [0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
 [0.e+00 1.e+08 0.e+00 ... 0.e+00 0.e+00 0.e+00]
 ...
 [0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
 [0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
 [0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]]


In [None]:
#@title **run AfDesign**
from scipy.special import softmax

optimizer = "pssm_semigreedy" #@param ["pssm_semigreedy", "3stage", "semigreedy", "pssm", "logits", "soft", "hard"]
#@markdown - `pssm_semigreedy` - uses the designed PSSM to bias semigreedy opt. (Recommended)
#@markdown - `3stage` - gradient based optimization (GD) (logits → soft → hard)
#@markdown - `pssm` - GD optimize (logits → soft) to get a sequence profile (PSSM).
#@markdown - `semigreedy` - tries X random mutations, accepts those that decrease loss
#@markdown - `logits` - GD optimize logits inputs (continious)
#@markdown - `soft` - GD optimize softmax(logits) inputs (probabilities)
#@markdown - `hard` - GD optimize one_hot(logits) inputs (discrete)

#@markdown WARNING: The output sequence from `pssm`,`logits`,`soft` is not one_hot. To get a valid sequence use the other optimizers, or redesign the output backbone with another protocol like ProteinMPNN.

#@markdown ----
#@markdown #### advanced GD settings
GD_method = "adam" #@param ["adabelief", "adafactor", "adagrad", "adam", "adamw", "fromage", "lamb", "lars", "noisy_sgd", "dpsgd", "radam", "rmsprop", "sgd", "sm3", "yogi"]
learning_rate = 0.01 #@param {type:"raw"}
norm_seq_grad = True #@param {type:"boolean"}
dropout = True #@param {type:"boolean"}

# model.restart(seq=binder_seq)
# print(model._opt["weights"])
model.set_optimizer(optimizer=GD_method,
                    learning_rate=learning_rate,
                    norm_seq_grad=norm_seq_grad)
models = model._model_names[:num_models]

flags = {"num_recycles":num_recycles,
         "models":models,
         "dropout":dropout}

if optimizer == "3stage":
  model.design_3stage(120, 60, 10, **flags)
  pssm = softmax(model._tmp["seq_logits"],-1)

if optimizer == "pssm_semigreedy":
  #model.design_pssm_semigreedy(120, 32, **flags)
  model.design_pssm_semigreedy(1, 5, verbose=1, **flags)
  pssm = softmax(model._tmp["seq_logits"],1)

if optimizer == "semigreedy":
  model.design_pssm_semigreedy(0, 32, **flags)
  pssm = None

if optimizer == "pssm":
  model.design_logits(120, e_soft=1.0, num_models=1, ramp_recycles=True, **flags)
  model.design_soft(32, num_models=1, **flags)
  flags.update({"dropout":False,"save_best":True})
  model.design_soft(10, num_models=num_models, **flags)
  pssm = softmax(model.aux["seq"]["logits"],-1)

O = {"logits":model.design_logits,
     "soft":model.design_soft,
     "hard":model.design_hard}

if optimizer in O:
  O[optimizer](120, num_models=1, ramp_recycles=True, **flags)
  flags.update({"dropout":False,"save_best":True})
  O[optimizer](10, num_models=num_models, **flags)
  pssm = softmax(model.aux["seq"]["logits"],-1)

model.save_pdb(f"8ee2_{model.protocol}_use_templates_true_rm_template_ic_true_w_bias_num_recyles_3_soft_1_hard_5_sol1.pdb")


Stage 1: running (logits → soft)
1 models [2] recycles 3 hard 0 soft 1 temp 1 seqid 0.01 loss 6321.34 seq_ent 0.82 pae 0.70 i_pae 0.75 exp_res 0.07 con 3.83 i_con 4.55 dgram_cce 526.66 fape 252.57 plddt 0.32 ptm 0.54 i_ptm 0.17 rmsd 50.12
Running semigreedy optimization...
2 models [1] recycles 3 hard 1 soft 0 temp 1 seqid 0.01 loss 6295.93 seq_ent 2.71 pae 0.72 i_pae 0.78 exp_res 0.04 con 4.10 i_con 4.65 dgram_cce 524.11 fape 253.79 plddt 0.29 ptm 0.54 i_ptm 0.15 rmsd 49.97
3 models [2] recycles 3 hard 1 soft 0 temp 1 seqid 0.01 loss 6474.33 seq_ent 2.71 pae 0.67 i_pae 0.77 exp_res 0.03 con 3.55 i_con 4.61 dgram_cce 539.79 fape 252.91 plddt 0.34 ptm 0.53 i_ptm 0.16 rmsd 51.13
4 models [4] recycles 3 hard 1 soft 0 temp 1 seqid 0.01 loss 6353.92 seq_ent 2.71 pae 0.67 i_pae 0.74 exp_res 0.01 con 3.94 i_con 4.49 dgram_cce 529.92 fape 252.41 plddt 0.30 ptm 0.55 i_ptm 0.19 rmsd 50.03
5 models [0] recycles 3 hard 1 soft 0 temp 1 seqid 0.01 loss 6334.59 seq_ent 2.71 pae 0.73 i_pae 0.83 exp_re

In [22]:
#@title display hallucinated protein {run: "auto"}
color = "pLDDT" #@param ["chain", "pLDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
color_HP = False #@param {type:"boolean"}
animate = True #@param {type:"boolean"}
model.plot_pdb(show_sidechains=show_sidechains,
               show_mainchains=show_mainchains,
               color=color, color_HP=color_HP, animate=animate)

In [24]:
HTML(model.animate(dpi=100))

In [None]:
model.save_pdb(f"{model.protocol}.pdb")
model.get_seqs()

In [None]:
#@markdown ### Amino acid probabilties
import plotly.express as px
alphabet = "ACDEFGHIKLMNPQRSTVWY"
if "pssm" in dir() and pssm is not None:
  fig = px.imshow(pssm.mean(0).T,
                  labels=dict(x="positions", y="amino acids", color="probability"),
                  y=residue_constants.restypes,
                  zmin=0,
                  zmax=1,
                  template="simple_white",
                )
  fig.update_xaxes(side="top")
  fig.show()

In [None]:
# log
model._tmp["best"]["aux"]["log"]