In [None]:
%%bash
if [ ! -d params ]; then
  pip -q install git+https://github.com/hunarbatra/ColabDesign.git
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
  for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1
  do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done
fi

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign import mk_afdesign_model, clear_mem
from IPython.display import HTML
from google.colab import files
import numpy as np

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

In [None]:
bias = np.loadtxt("bias.txt")

In [None]:
from colabdesign.af.alphafold.common import residue_constants
import jax.numpy as jnp

def custom_loss(inputs, outputs, opt):
  positions = outputs["structure_module"]["final_atom_positions"]
  ca = positions[:,residue_constants.atom_order["CA"]]
  c = positions[:,residue_constants.atom_order["C"]]
  n = positions[:,residue_constants.atom_order["CA"]]
  center_ca = ca.mean(0) # column wise mean
  center_c = c.mean(0)
  center_n = n.mean(0)
  rg_ca = jnp.square(ca - center_ca).sum(-1).mean() # Sum of (predicted - original) ^2
  rg_c = jnp.square(c - center_c).sum(-1).mean()
  rg_n = jnp.square(n - center_n).sum(-1).mean()
  rg = (rg_ca + rg_c + rg_n) / 3
  return {"rg":rg}

In [None]:

length =  993 
copies =  1

pae = 0.1 
plddt = 0.1 
helix = 0.0 
con = 1.0 

seqsep = 9 
cutoff = "14" 
num = "2" 
binary = False 
if cutoff == "max": cutoff = 21.6875
if num == "max": num = length

opt = {"con":{"seqsep":int(seqsep),"cutoff":float(cutoff),"num":int(num),
              "binary":binary}}
weights = {"con":float(con),"helix":float(helix),
           "pae":float(pae),"plddt":float(plddt)}

if "model" not in dir() or model._len != length or model._copies != copies:
  clear_mem()
  model = mk_afdesign_model(protocol="hallucination", crop_len=128, loss_callback=custom_loss)
  model.prep_inputs(pdb_filename=get_pdb('7LWV'), length=length, copies=copies)

# pre-design with gumbel initialization and softmax activation
model.restart(mode="gumbel",opt=opt,weights=weights)
model.opt["bias"] = bias 
model.opt["num_recycles"] = 1
model.design_soft(10)

# three stage design  
model.restart(seq=model.aux["seq"]["pseudo"],
              keep_history=True, opt=opt,weights=weights)
model.opt["bias"] = bias 
model.opt["num_recycles"] = 1
model.design_3stage(50,50,10)

1 models [2] recycles 1 hard 0 soft 1 temp 1 loss 5.15 plddt 0.54 pae 0.75 con 5.02 ptm 0.12
['IGERVAPRPLDYCMCNFWDIRLKPPWHMYDHKIWANQAEGLGRGAWYNVKRETMMICTKALQKICGREVQKLCKMSVLAEAKEVAQHDWLEPTNEMDKKTIPAPKIGRYILTGGIKYRKDEKVKFSIQITYIPVDISMPQLPSKNVQNRFRYSFYQIWYFWMAYYYAKLGLFLLIVVCWMFDQWGWSDLWQIHARNFRQFPLWITLHYRGLPNFWPDCTRPDYICMRHSFYHTLVPERGYGTSTYLDYYAPLMTPWIRKCPKPGLIWNFDWRWGYKMNANIVWDPVWFQVEWFWMWGFMMMGPQQIHSTMCKPFWRNTYPRPRKNHWTDSSWTTMPMQVVFRMSGYCNDGQNIVNTIHALSKFRFQAPIWMMNESNSHVAVCNYVTMVHPVGMDYHSYSSGMLWCESESLWSLPIVCNHTVQCMIQDWWTMSLSEDVMMHHWHYSYFMFYQVAGWSENRSRNDQFWVFPYVVGEAWAQMEVVWATLMDGRFTEMWEFMHRAPWSAPICISEPMEHTGSFLIVEIPLFEHEVYFRAIASESIIYDVWSASPNGRGNGNWNVKSANKMMIPGRFYEDKVDWFFQSFPDWNQMFAQLDHMGNGCFTQWHSVEVKKVGQGRFHYHKGESSGLGEDTCKSAHKIETGYMWYFRCWVDTTAYADWFIEVRTGHHIRHKYMFDVAFLCKQEGAMTGVWPITSDDVIMLKQFFHTYSHLPKHMWDDINCTDWWTFMMVCDYELRFQCNYEKTCLARCGMWYHELLSVPPQQHMWDSDSCPSCNVYDQMCPSPLLCQMTMDDHETDTNGRRDILSHPSEHRAWPHKRVMVWMVLIITCHNHYISRPWVWEYKDPQCAYHAGSVCDWIRMKQVNGTVWKDCINKGIGTITTPFHTWYMFRHEDNADIFQMTFQGW

In [None]:
#@markdown ## display hallucinated protein {run: "auto"}
color = "rainbow" #@param ["chain", "pLDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
model.save_pdb(f"{model.protocol}.pdb")
model.plot_pdb(show_sidechains=show_sidechains,
                   show_mainchains=show_mainchains,
                   color=color)

In [None]:
HTML(model.animate())

In [None]:
model.get_seqs()

['MLLLLLLLPLVSSQLLDLDILLLPPDNLYDVTISTNVADSLLNNLSSNNSRSTVSISTSANGTIVGSYVLVNLTLSVLATALEVASTLYLNATNESDSTTITVLNAVRNILTGLITYLDSSTGETSDSIDNNTVDISLTLLLSLVNGLIVLYSFLLSDYENNKSYYSPEELLLLTVVLVSFLLTLLDDLSSSNTTEERQLVSNLLGYFRTLSNFLNTYVRPLFISLSNSFNSTNNPENGPGTPTYLLYYAVLLRSVLRSLGSPNLISLNDTLLVYVMSNGTVVDVVTFQVLSNNLSNFNNPGPDDPNSTLTILFVTNTYVSPRSNDGTDSSSVTNPTVVNTTTVSYCNDGQLIVNSSSTLSGFLSNSPTDNSNSSNSNVAVCLVVTTVGSVGDGGLSYSSGNLLVELLSLSSLTIVTFTFTCCSNADNGTLSLSSDVLTGVSSSSSNLTYTVVGSSDGTSGNNVFLVFTYVVGGATAQTTTVSSTTTDGRSTVSVENSSRAILLALIYILKMLESDGNFLIVVITLFGDEVLFVTITVTSTITNVTVTIVNGGGTGIVVVVSLNKCTLPGGFYIDLVTGECQSFSDSNSSFLQLVLTTNSSTTCTLVVSVSSNGSGTSVVLLTLSSGGGTTTTTNATTITTIYSLYSSCLTNGTAYADCLSSVPTGSTSLTTNLSSVTSLCTSSGNSTGSNTITSNLVNLLTTFLTTTSSLPKNSNDPNPCTVLNTLLNVLDYNLSKQCSSNSTTAQNCKTQYQELCSSYYSTLNSLSAAYLILNVALLNALSDLLGLLLVSDLSTINLSALLATLSSLLKRALPDPEVAGNLVLVNGTLAGTFQQCSSGQSKNDPICAQYYSALDLIPGVITNGQTAQAQAQLIAGAIAAALAWWAAAAIDIPAAIANAQRLNGIAYQLQLLLNNQKVLSSSFSTALSSIQSAFSSVASAINKTQDVVNSNTQALDQVTQQLSSNFGAINGTLNSINARLDKIERKAGKCLG']