In [None]:
%%bash
if [ ! -d params ]; then
  # get code
  pip -q install git+https://github.com/hunarbatra/ColabDesign.git
  # download params
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
  for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1
  do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done
fi

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign import mk_afdesign_model, clear_mem
from IPython.display import HTML
from google.colab import files
import numpy as np

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

# Backbone Atom loss

In [None]:
clear_mem()
af_model = mk_afdesign_model(protocol="fixbb", debug=True)
af_model.prep_inputs(pdb_filename=get_pdb("7LWV"), chain="A", length=993)

print("length",  af_model._len)
print("weights", af_model.opt["weights"])

length 993
weights {'con': 0.0, 'dgram_cce': 1.0, 'fape': 0.0, 'helix': 0.0, 'pae': 0.01, 'plddt': 0.01, 'rmsd': 0.0}


In [None]:
# af_model.restart(mode="gumbel",seed=0)
# af_model.set_opt(soft=True)
# af_model.run(backprop=False)

In [None]:
print("length",  af_model._len)

length 993


In [None]:
from colabdesign.af.alphafold.common import residue_constants
import jax.numpy as jnp

def custom_loss(inputs, outputs, opt):
  positions = outputs["structure_module"]["final_atom_positions"]
  ca = positions[:,residue_constants.atom_order["CA"]]
  c = positions[:,residue_constants.atom_order["C"]]
  n = positions[:,residue_constants.atom_order["CA"]]
  center_ca = ca.mean(0) # column wise mean
  center_c = c.mean(0)
  center_n = n.mean(0)
  rg_ca = jnp.square(ca - center_ca).sum(-1).mean() # Sum of (predicted - original) ^2
  rg_c = jnp.square(c - center_c).sum(-1).mean()
  rg_n = jnp.square(n - center_n).sum(-1).mean()
  rg = (rg_ca + rg_c + rg_n) / 3
  return {"rg":rg}

In [None]:
bias = np.loadtxt("bias.txt")

In [None]:
clear_mem()
af_model = mk_afdesign_model(protocol="fixbb",
                             debug=False,
                             crop_len=128,
                             loss_callback=custom_loss) # add our custom loss

af_model.prep_inputs(pdb_filename=get_pdb("7LWV"), chain="A")
af_model._len = 993
print("weights", af_model.opt["weights"])

weights {'con': 0.0, 'dgram_cce': 1.0, 'fape': 0.0, 'helix': 0.0, 'pae': 0.01, 'plddt': 0.01, 'rmsd': 0.0}


In [None]:
af_model.restart(mode="gumbel", seed=0)

af_model.opt["bias"] = bias # ESM-1B bias
af_model.opt["num_recycles"] = 1
af_model.opt["weights"]["rg"] = 0.1 # add our loss to weights (so we can later control it)

af_model.design_3stage(50,50,10)

1 models [3] recycles 1 hard 0.00 soft 0.02 temp 1.00 seqid 0.06 loss 156.16 plddt 0.39 pae 0.72 dgram_cce 2.42 ptm 0.12 rmsd 25.82 rg 1537.29
['MLLLLLLLPLVFSQLLNLTLLLLLLLNPNSNTLSTFLLNSGTNYVVSNNSSGSSSSSSSSNGTSTDLLYLLNITLLFLLLLLLLSSTFLLNSDNSDSGTYYLVNNSSSNNSNSYLLLYDLETGETEGSLDSPNSSSTILSFFTLLGGLYILVSLLNLDSSTNTTLSNSSVSLLFDLDTLNITLTLSTSVSVENSTDSTSNTSNTSNLSLTSVLSSNSVSSTSLPDLSSPTNSTSNDPSSPSNSVSLLLLSLLLNLILLLLLSSSSSSSNRLVLVFLNSNGTLVSVIELSLLSSNSSNLDLSNFDDSNTVLLLLLGTLLLLLLLLLLYLCCCMNGSGSTSDGSNTSCGSDSDGNDGSGGPGGGSSSSSGSDSSNSGNSNSDGSNTSNSGGSGNPGNGSTLNSGTVVVLLAGTTVTLIVFTFTLVSGSSGSSSNGSGTLPVLSTSSSGNGNVSVYDDDDGTVGTLVLTGGLLPSSGTTVTGTSGSSGNSNGGNSSGNNGGPSNGSNGDDGGNDNDNNSDGGGGTDVGTIFVLLNGSTYTVTVSPVTYSATLDSTGGVLVTGDLLLFVFTLGVGVLINPVTGELTSGSGNGGSSGGTTNSTSGTAGGVTTTTSDSGGTGTPVVVPLGGTNGGGTVSSSSSGTLVVSSVTPSSGTLLVVLGTIVLLMASLSSSALSNLLSSTCCCCCSSNSGSGSSSSPSPSSSIISLLNISISIPSNSNYPLKSDLIKLNLVSVNIDCSKYVCGKNKKCRKLLKKYGSACSTIDSALSSLQLEILILLLALLSSLSPTSSLPVIIDLSSFNLSSLLSTIPTPSGRSFIEDLLFNKTVVSDQGFLKDYNDCVKGLKIKDLICAQYYNGI

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()

['MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHVISGTNGTKRFDNPVLPFNDGVYFASTEKSNIIRGWIFGTTLDSKTQSLLIVNAATNVVIKVCEFQCCNDPFLGVYHKNNKSWMESEFRVYSSANNCTFEYVSQPFLMDLEGKQGNFKNLREFVFKNIDGYFKIYSKHTPINLVDDLPQGFSALEPLVDLPIGINITRFQTLLALHRSYLTPGDSSSGWTAGAAAYYVGYLQPRTFLLKYNENGTITDAVDCALDPLSETKCTLKSFTVEKGIYQTSNFRVQPTESIVRFPNITNRFVKLVLSPLVSSQCANLIQRFVIVPYRSNRGSFCGDSGIRFSTYYSTQTQKLLELDNRFWFFNGKCTNGLTIRNPFCLHKNSVFYSVGVEPARIFYSVLDAPALGSKPFFLLLLCNIAVNNAKKACQKISTPPDFRDLVKKHCLKPIENSAAQVSEAADSWTNKKVQLNLKRFQGLAFLFRSLLDEANIKDVGVVKGPEDRIKFLLCFLFQRVQWVISEQKDLFYVSTVLRIPQKKKTKHTQDLGDLIVANLNNAFASKYAIYYQWDSEYLAENLRTNLLVVCASKNDCTLDSGGTPPKTEGGTQIFSTEWDAYNLQSFARVLKVITANSMFSKLVLLLLCSYQSGNLITRRQSYTNSDTRFYRNCQKPKSNFRKPVTQTQNNPEFPFNTESWVHSSHNSTGTTGRDFFFNLVFLGFNEGTKSVMKEFSEFLGWYFGTSEDYLQLVINQNSTNKVVQFEEFYQSFCDLLFTNHKKVRMEVVIESLQSMEKVPEPFLVNTRPTLEYLELPDSITVFLRVFKQGGIQKIFIWTNPPATILVNLANFEHLETLRIEPAGFAACLNFIVLLAYHDKSYLQDDKYGLSTSGLAYCTVTGLITKAVTSCIKESSYIDLACVRADQDSASPEFATDCQEARKGTQSTSDNVVPPGEFAREFPLIETCQ']