<a href="https://colab.research.google.com/github/eelliiff/ColabDock/blob/main/ColabDock_ipynb_adl%C4%B1_not_defterinin_kopyas%C4%B1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ColabDock
Inverting structure prediction model for protein-protein docking with experimental restraints



In [None]:
#@title Download AlphaFold2 and AF-Multimer params
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/eelliiff/ColabDock.git@main")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

os.system("git clone -b main https://github.com/eelliiff/ColabDock.git")
os.system("cp -r ./ColabDock/protein/4HFF ./")
os.system("rm -r ./ColabDock")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np
import re
import ml_collections
from ipywidgets import widgets, HBox
from IPython.display import display

from colabdock.utils import prep_path
from colabdock.model import ColabDock

CPU times: user 4.91 ms, sys: 1.02 ms, total: 5.93 ms
Wall time: 851 ms


In [None]:
config = {}

#@title Input structures
#@markdown - docking template
#@markdown -- upload your docking template and set the path
template = './4HFF/PDB/MTCH.pdb' #@param {type:"string"}

#@markdown - native structures, used to calculate RMSD. If not provided, set it to None.
native = 'None' #@param ["None"] {type:"string",allow-input:true}

#@markdown - docking chains
chains = "A,B" #@param {type:"string"}

#@markdown - fixed relative positions between chains. If not provided, set it to None,
#@markdown which means the relative positions between chains in the provided template are ignored.
#@markdown -- example:<br />
#@markdown ['A,B', 'C,D']<br />
#@markdown The relative position of chain A and B is fixed, also that of chain C and D.
fixed_chains = None #@param ["None"] {type:"raw",allow-input:true}

# check the inputs
config['chains'] = chains
if not os.path.exists(template):
  raise Exception('Please upload the docking template and set the correct path!')
config['template'] = template
if native == 'None':
  config['native'] = None
else:
  if not os.path.exists(native):
    raise Exception('Please upload the native structure and set the correct path!')
  else:
    config['native'] = native

if fixed_chains is None:
  config['fixed_chains'] = None
elif type(fixed_chains) is not list:
  raise Exception('Please set fixed_chains according to the descriptive information!')
else:
  config['fixed_chains'] = fixed_chains

## Restraints settings
You can provide 1v1, 1vN, MvN, and (or) the repulsive restraints. If no restraints provided, ColabDock will only minimize the distogram, pLDDT, and ipAE losses.



In [None]:
#@title Threshold
#@markdown - Threshold of the restraints, between 2Å and 22Å.
res_thres = 8.0 #@param [8.0] {type:"raw",allow-input:true}

#@markdown - Threshold of the repulsive restraints, between 2Å and 22Å.<br />
#@markdown Repulsive restraints means the distance of two residues is above the given threshold.
rep_thres = 12.0 #@param [12.0] {type:"raw",allow-input:true}

# check the inputs
if type(res_thres) is not float:
  raise Exception('Please set res_thres according to the descriptive information!')
else:
  config['res_thres'] = res_thres

if type(rep_thres) is not float:
  raise Exception('Please set rep_thres according to the descriptive information!')
else:
  config['rep_thres'] = rep_thres

In [None]:
#@title 1v1 restraints
#@markdown - 1v1 restraints<br />
#@markdown -- description:<br />
#@markdown The distance between two residues is below a given threshold.<br />
#@markdown If there is no such restraints, set it to None.<br />
#@markdown If you have multiple 1v1 restraints, list them in a [].<br />
#@markdown The order number in a 1v1 restraint refers to the residue in the complex sequence.
#@markdown The complex sequence is concatenated by the chain sequences and the order is determined by the "docking chains" provided above.
#@markdown This is the same for the remaining types of restraints. The order number starts from 1.<br />
#@markdown -- example:<br />
#@markdown [[78,198],[20,50]]<br />
#@markdown The distance between 78th and 198th residue is below a given threshold,
#@markdown as well as the distance between 20th and 50th residue.
rest_1v1 = "None" #@param ["None"] {type:"string",allow-input:true}

# check the inputs
if rest_1v1 == "None":
  config['rest_1v1'] = None
else:
  try:
    a = rest_1v1.split(',')
    a = [int(re.sub('\[|\]', '', ia)) for ia in a]
    a = [[a[2*i], a[2*i+1]] for i in range(len(a)//2)]
  except:
    raise Exception('Please set rest_1v1 according to the descriptive information!')
  config['rest_1v1'] = a



In [None]:
#@title 1vN restraints
#@markdown - 1vN restraints<br />
#@markdown -- description:<br />
#@markdown The distance between one residue and a residue set is below a given threshold.<br />
#@markdown If there is no such restraints, set to None.<br />
#@markdown If you have multiple 1v1 restraints, list them in [].<br />
#@markdown The order number starts from 1.<br />
#@markdown -- example:<br />
#@markdown [36,(160-170,178,190)]<br />
#@markdown the distance between the 36th residue and at least a residue from 160th to 170th, 178th, and 190th is below a given threshold.<br />
rest_1vN = "None" #@param ["None"] {type:"string",allow-input:true}

# check the inputs
if rest_1vN == "None":
  config['rest_1vN'] = None
else:
  try:
    a = rest_1vN
    rest_1vN = []
    for irest in re.findall('\[[0-9]*,\([0-9\-,]*\)\]', a):
      irest = irest[1:-1]
      idx = irest.split(',')
      ind_1 = int(re.sub('\[', '', idx[0]))
      ind_N = []
      for ind in idx[1:]:
        if '-' in ind:
          ind = ind.split('-')
          start = int(re.sub('\(|\)', '', ind[0]))
          stop = int(re.sub('\(|\)', '', ind[1]))
          ind_N.extend(list(range(start, stop+1)))
        else:
          ind = int(re.sub('\(|\)', '', ind))
          ind_N.append(ind)
      rest_1vN.append([ind_1, ind_N])
  except:
    raise Exception('Please set rest_1vN according to the descriptive information!')
  config['rest_1vN'] = rest_1vN

In [None]:
#@title MvN restraints
#@markdown - MvN restraints<br />
#@markdown -- description:<br />
#@markdown contain several 1vN restraints, and only a specific number of them are satisfied.<br />
#@markdown If there is no such restraints, set to None.<br />
#@markdown If you have multiple MvN restraints, list them in [].<br />
#@markdown The order number starts from 1.<br />
#@markdown -- example:<br />
#@markdown [[10,(160-170)],[78,(160-170)],[120,(160-170)],2]<br />
#@markdown 2 of the 3 given 1vN restraints should be satisfied.<br />
rest_MvN = "None" #@param ["None"] {type:"string",allow-input:true}

# check the inputs
if rest_MvN == "None":
  config['rest_MvN'] = None
else:
  try:
    a = rest_MvN
    rest_MvN = []
    for irest in re.findall('\[(?:\[[0-9]*,\([0-9\-,]*\)\],)+[0-9]+\]', a):
      irest = irest[1:-1]
      MvN_num = int(irest.split(',')[-1])
      irest_MvN = []
      for rest_1vN in re.findall('\[[0-9]*,\([0-9\-,]*\)\]', irest):
        rest_1vN = rest_1vN[1:-1]
        idx = rest_1vN.split(',')
        ind_1 = int(re.sub('\[', '', idx[0]))
        ind_N = []
        for ind in idx[1:]:
          if '-' in ind:
            ind = ind.split('-')
            start = int(re.sub('\(|\)', '', ind[0]))
            stop = int(re.sub('\(|\)', '', ind[1]))
            ind_N.extend(list(range(start, stop+1)))
          else:
            ind = int(re.sub('\(|\)', '', ind))
            ind_N.append(ind)
        irest_MvN.append([ind_1, ind_N])
      irest_MvN.append(MvN_num)
      rest_MvN.append(irest_MvN)
  except:
    raise Exception('Please set rest_MvN according to the descriptive information!')
  config['rest_MvN'] = rest_MvN

In [None]:
#@title Repulsive restraints
#@markdown - repulsive restraints<br />
#@markdown -- description:<br />
#@markdown The distance between two residues is above a given threshold.<br />
#@markdown If there is no such restraints, set to None.<br />
#@markdown If you have multiple repulsive restraints, list them in [].<br />
#@markdown The order number starts from 1.<br />
#@markdown -- example:<br />
#@markdown [[78,198],[20,50]]<br />
rest_rep = "None" #@param ["None"] {type:"string",allow-input:true}

# check the inputs
if rest_rep == "None":
  config['rest_rep'] = None
else:
  try:
    a = rest_rep.split(',')
    a = [int(re.sub('\[|\]', '', ia)) for ia in a]
    a = [[a[i], a[i+1]] for i in range(len(a)//2)]
  except:
    raise Exception('Please set rest_rep according to the descriptive information!')
  config['rest_rep'] = a

## Other settings

In [None]:
#@title Computational settings
#@markdown - Use AF2-Multimer or AF2
#@markdown -- True for AF2-Multimer, False for AF2
use_multimer = False #@param ["True", "False"] {type:"raw"}
config['use_multimer'] = use_multimer

#@markdown - path to save the results
save_path = './results' #@param {type:"string"}
config['save_path'] = save_path

#@markdown - Segment based optimization
#@markdown -- Setting to None is suggested. If out of memory error is encountered in the generation stage, consider setting it to 200.
#@markdown But this may lead to degenerated performance. For more details, please refer to the paper.
crop_len = None #@param ["None", 200] {type:"raw",allow-input:true}
config['crop_len'] = crop_len

#@markdown - Rounds
#@markdown -- Large rounds can achive better performance but lead to longer time.
rounds = 2 #@param [1,5,10] {type:"raw",allow-input:true}
config['rounds'] = rounds

#@markdown - Steps
#@markdown -- The number of backpropogations in each round.
#@markdown -- If in segment based optimization, set to larger value, for example 150. Otherwise, setting to 50 is enough.
steps = 50 #@param [50, 150] {type:"raw",allow-input:true}
config['steps'] = steps

#@markdown - save_every_n_step
#@markdown -- Save one conformtion in every save_every_n_step step.
#@markdown Useful in segment based optimization, since the number of steps is larger
#@markdown and saving conformations in every step will take too much time.
#@markdown If in segment based optimization, set to larger value, for example 3. Otherwise, setting to 1 is OK.
save_every_n_step = 1 #@param [1, 3] {type:"raw",allow-input:true}
config['save_every_n_step'] = save_every_n_step

#@markdown - bfloat
#@markdown -- Use AF2 or AF-Multimer in bfloat mode. Turning this on can save GPU memory and time.
bfloat = True #@param ["True", "False"] {type:"raw"}
config['bfloat'] = bfloat

config['data_dir'] = './params'

In [None]:
#@title Advanced settings
#@markdown - The weights of each chain in the complex. Run this cell and set using the
#@markdown displayed sliders.
#@markdown -- If you allow the structures of certain chains in the final docking structure
#@markdown different from those in the input template, to better satisfy the given restraints,
#@markdown you can set this parameter.
#@markdown -- Each chain has a value between 0 and 1. With this value increasing,
#@markdown the structure of the chain in the generation stage is getting similar
#@markdown to that in the input template.
#@markdown -- Normally, if your input template is accurate, leave it as the default value.
chains_lst = [c.strip() for c in chains.split(",")]
slider_lst = []
for ichain in chains_lst:
  islider = widgets.FloatSlider(
    value=1.00,
    min=0.00,
    max=1.00,
    step=0.01,
    description=f'Chain {ichain}',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f'
  )
  slider_lst.append(islider)

ui = widgets.HBox(slider_lst)
display(ui)

HBox(children=(FloatSlider(value=1.0, continuous_update=False, description='Chain A', max=1.0, step=0.01), Flo…

## Run and display

In [None]:
#@title Run Docking
if sum([islider.value for islider in slider_lst]) == len(slider_lst):
  chain_weights = None
else:
  chain_weights = {}
  for ith in range(len(chains_lst)):
    chain_weights[chains_lst[ith]] = slider_lst[ith].value
config['chain_weights'] = chain_weights

config_ml = ml_collections.ConfigDict(config)
save_path = config_ml.save_path
prep_path(save_path)
######################################################################################
# template and native structure
######################################################################################
template_r = config_ml.template
native_r = config_ml.native
chains = config_ml.chains
template = {'pdb_path': template_r,
       'chains': chains}
native = {'pdb_path': native_r,
      'chains': chains}
fixed_chains = config_ml.fixed_chains

######################################################################################
# experimental restraints
######################################################################################
rest_MvN_r = config_ml.rest_MvN
rest_non_r = config_ml.rest_rep
rest_1vN_r = config_ml.rest_1vN
rest_1v1_r = config_ml.rest_1v1
# print restraints
print_str = f'restraints:\n'
if rest_1v1_r is None:
  print_str += '\tno 1v1 restraints provided.\n'
else:
  print_str += f'\t1v1 restraints:\n\t\t{rest_1v1_r}\n'

if rest_1vN_r is None:
  print_str += '\tno 1vN restraints provided.\n'
else:
  print_str += f'\t1vN restraints:\n\t\t{rest_1vN_r}\n'

if rest_MvN_r is None:
  print_str += '\tno MvN restraints provided.\n'
else:
  print_str += f'\tMvN restraints:\n\t\t{rest_MvN_r}\n'

if rest_non_r is None:
  print_str += '\tno repulsive restraints provided.\n'
else:
  print_str += f'\trepulsive restraints:\n\t\t{rest_non_r}\n'

# 1v1
if rest_1v1_r is not None:
  if type(rest_1v1_r[0]) is not list:
    rest_1v1_r = [rest_1v1_r]
  rest_1v1 = np.array(rest_1v1_r) - 1
else:
  rest_1v1 = None

# 1vN
if rest_1vN_r is not None:
  if type(rest_1vN_r[0]) is not list:
    rest_1vN_r = [rest_1vN_r]
  rest_1vN = []
  for irest_1vN in rest_1vN_r:
    rest_1vN.append([irest_1vN[0] - 1, np.array(irest_1vN[1]) - 1])
else:
  rest_1vN = None

# MvN
if rest_MvN_r is not None:
  if type(rest_MvN_r[-1]) is not list:
    rest_MvN_r = [rest_MvN_r]
  rest_MvN = []
  for irest_MvN in rest_MvN_r:
    irest = []
    for irest_1vN in irest_MvN[:-1]:
      irest.append([irest_1vN[0] - 1, np.array(irest_1vN[1]) - 1])
    irest.append(irest_MvN[-1])
    rest_MvN.append(irest)
else:
  rest_MvN = None

# repulsive
if rest_non_r is not None:
  if type(rest_non_r[0]) is not list:
    rest_non_r = [rest_non_r]
  rest_non = np.array(rest_non_r) - 1
else:
  rest_non = None

restraints = {'1v1': rest_1v1,
        '1vN': rest_1vN,
        'MvN': rest_MvN,
        'non': rest_non}

res_thres = config_ml.res_thres
non_thres = config_ml.rep_thres

######################################################################################
# optimization parameters
######################################################################################
rounds = config_ml.rounds
crop_len = config_ml.crop_len
step_num = config_ml.steps
save_every_n_step = config_ml.save_every_n_step
data_dir = config_ml.data_dir
bfloat = config_ml.bfloat
use_multimer = config_ml.use_multimer

######################################################################################
# chain weights
######################################################################################
chain_weights = config_ml.chain_weights

######################################################################################
# print setting
######################################################################################
print_str += '\nOptimization losses include:\n\t'
if rest_1v1 is not None:
    print_str += '1v1 restraint loss, '
if rest_1vN is not None:
    print_str += '1vN restraint loss, '
if rest_MvN is not None:
    print_str += 'MvN restraint loss, '
if rest_non is not None:
    print_str += 'repulsive restraint loss, '
print_str += 'distogram loss, pLDDT, and ipAE.\n'

if chain_weights:
  print_str += f'\nChain weights:\n\t'
  for ik, iv in chain_weights.items():
    print_str += f'{ik}:{iv:.2f}\t'

######################################################################################
# start docking
######################################################################################
dock_model = ColabDock(template,
             restraints,
             save_path,
             data_dir,
             structure_gt=native,
             crop_len=crop_len,
             fixed_chains=fixed_chains,
             chain_weights=chain_weights,
             round_num=rounds,
             step_num=step_num,
             bfloat=bfloat,
             res_thres=res_thres,
             non_thres=non_thres,
             save_every_n_step=save_every_n_step,
             use_multimer=use_multimer)
dock_model.setup()
if dock_model.crop_len is not None:
    print_str += 'Colabdock will work in segment based mode.'
print(print_str)
print('\nStart optimization')
dock_model.dock_rank()

restraints:
	no 1v1 restraints provided.
	no 1vN restraints provided.
	no MvN restraints provided.
	no repulsive restraints provided.

Optimization losses include:
	distogram loss, pLDDT, and ipAE.


Start optimization
1 models [0] recycles 0 hard 0 soft 0 temp 1 loss 1.966 plddt 0.084 i_pae 0.412 dgram_cce 1.278
2 models [0] recycles 0 hard 0 soft 0 temp 1 loss 1.774 plddt 0.064 i_pae 0.397 dgram_cce 1.152
3 models [0] recycles 0 hard 0 soft 0 temp 1 loss 1.627 plddt 0.050 i_pae 0.388 dgram_cce 1.055
4 models [1] recycles 0 hard 0 soft 0 temp 1 loss 1.601 plddt 0.046 i_pae 0.400 dgram_cce 1.037
5 models [0] recycles 0 hard 0 soft 0 temp 1 loss 1.532 plddt 0.039 i_pae 0.384 dgram_cce 0.993
6 models [1] recycles 0 hard 0 soft 0 temp 1 loss 1.499 plddt 0.036 i_pae 0.392 dgram_cce 0.971
7 models [1] recycles 0 hard 0 soft 0 temp 1 loss 1.432 plddt 0.033 i_pae 0.389 dgram_cce 0.927
8 models [0] recycles 0 hard 0 soft 0 temp 1 loss 1.415 plddt 0.031 i_pae 0.377 dgram_cce 0.916
9 models [1] 

100%|██████████| 50/50 [05:03<00:00,  6.07s/it]


infer epoch 2


100%|██████████| 50/50 [04:42<00:00,  5.66s/it]


1st_best structure:
	iptm: 0.818, 0 out of 0 restraints are satisfied.
2nd_best structure:
	iptm: 0.817, 0 out of 0 restraints are satisfied.
3rd_best structure:
	iptm: 0.814, 0 out of 0 restraints are satisfied.
4th_best structure:
	iptm: 0.811, 0 out of 0 restraints are satisfied.
5th_best structure:
	iptm: 0.812, 0 out of 0 restraints are satisfied.


In [4]:
#@title Display the best structure {run: "auto"}

from string import ascii_uppercase,ascii_lowercase
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patheffects
import py3Dmol

pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

alphabet_list = list(ascii_uppercase+ascii_lowercase)


rank_num = 1 #@param ["1", "2", "3", "4", "5"] {type:"raw"}
color = "rainbow" #@param ["chain", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}


def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color="chain"):
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  pdb_file = f'{config_ml.save_path}/docked/1st_best.pdb'
  print(pdb_file)
  view.addModel(open(pdb_file,'r').read(),'pdb')

  if color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    chains = len(config.template['chains'].split(','))
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view

show_pdb(rank_num, show_sidechains, show_mainchains, color).show()


ModuleNotFoundError: No module named 'py3Dmol'

In [3]:
pdb_file = f'{config_ml.save_path}/docked/1st_best.pdb'

NameError: name 'config_ml' is not defined