<a href="https://colab.research.google.com/github/ccccclw/ColabDesign/blob/main/af/examples/alphafolding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### This notebook supports
- running iterative predictions with AlphaFold2 (monomer model 1,2) and visualization of structure predictions. For predictions that succesfully find the native state, the structure predictions before native state can possibly resemble protein folding intermediates.


In [1]:
#@title setup
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/ccccclw/ColabDesign.git")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from Bio.PDB import *
import os, re
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.shared.protein import _np_get_cb
import pickle
from colabdesign import af
from google.colab import files
import numpy as np
from IPython.display import HTML
import jax.numpy as jnp
import jax
from scipy.special import softmax
import sys
import tqdm.notebook
import argparse
import matplotlib.pyplot as plt
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

##util functions
def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"


def get_dgram(positions, num_bins=39, min_bin=3.25, max_bin=50.75):
  atom_idx = residue_constants.atom_order
  atoms = {k:positions[...,atom_idx[k],:] for k in ["N","CA","C"]}
  cb = _np_get_cb(**atoms, use_jax=False)
  dist2 = np.square(cb[None,:] - cb[:,None]).sum(-1,keepdims=True)
  lower_breaks = np.linspace(min_bin, max_bin, num_bins)
  lower_breaks = np.square(lower_breaks)
  upper_breaks = np.concatenate([lower_breaks[1:],np.array([1e8], dtype=jnp.float32)], axis=-1)
  return ((dist2 > lower_breaks) * (dist2 < upper_breaks)).astype(float)

def sample_gumbel(shape, eps=1e-10):                  
  """Sample from Gumbel(0, 1)"""
  U = np.random.uniform(size=shape)
  return -np.log(-np.log(U + eps) + eps)
 
def sample_uniform(shape, eps=1e-10): 
  """Sample from Uniform(0, 1)"""
  U = np.random.uniform(size=shape)
  return U + eps
 
from colabdesign.af.alphafold.common import residue_constants
def xyz_atom37(pdb_file):
  """
  Convert atom coordinates [num_atom, 3] from xyz read from file such as pdb to atom37 format.
  """
  atom37_order = residue_constants.atom_order
  parser = PDBParser()
  structure = parser.get_structure("A", pdb_file)
  atoms = list(structure.get_atoms())
  length = len(list(structure.get_residues()))
  atom37_coord = np.zeros((length, 37, 3))
  
  for atom in atoms:
    atom37_index = atom37_order[atom.get_name()]
    residue_index = atom.get_parent().id[1]
    atom37_coord[residue_index-1][atom37_index] = atom.get_coord()
  return atom37_coord


UsageError: Line magic function `%%time` not found.


In [None]:
#@title input preparation
starting_seq = "" #@param {type:"string"}
starting_seq = re.sub("[^A-Z]", "", starting_seq.upper())
##default sequence is PDB:3GB1 if no sequence is provided
starting_seq = "MTYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE" if len(starting_seq) == 0 else starting_seq
length = len(starting_seq)
template = "None" #@param ["custom","None"]
if template == "custom":
  custom_template_path = os.path.join(template,f"template")
  os.makedirs(custom_template_path, exist_ok=True)
  uploaded = files.upload()
  for fn in uploaded.keys():
    os.rename(fn,os.path.join(custom_template_path,fn))
  template_path = os.path.join(custom_template_path,fn)


In [None]:
#@title initialize the model with parameters and run
clear_mem()
model_name = "model_1_ptm" #@param ["model_1_ptm", "model_2_ptm","both"]
use_multimer = False 
model_name = None if model_name == "both" else [model_name]
af_model = mk_afdesign_model(protocol="hallucination",
                             use_templates=True,
                             debug=True, 
                             model_names=model_name,
                             use_multimer=use_multimer)
af_model.prep_inputs(length=length)

mode = "dgram" #@param ["dgram","dgram_retrain"]
if "dgram" in mode:
  if "retrain" in mode and not use_multimer:
    # update distogram head to return all 39 bins
    af_model._cfg.model.heads.distogram.first_break = 3.25
    af_model._cfg.model.heads.distogram.last_break = 50.75
    af_model._cfg.model.heads.distogram.num_bins = 39
    af_model._model = af_model._get_model(af_model._cfg)
    from colabdesign.af.weights import __file__ as af_path
    template_dgram_head = np.load(os.path.join(os.path.dirname(af_path),'template_dgram_head.npy'))
    for k in range(len(af_model._model_params)):
      params = {"weights":jnp.array(template_dgram_head[k]),"bias":jnp.zeros(39)}
      af_model._model_params[k]["alphafold/alphafold_iteration/distogram_head/half_logits"] = params
  else:
    dgram_map = np.eye(39)[np.repeat(np.append(0,np.arange(15)),4)]
    dgram_map[-1,:] = 0 

iterations = 50 #@param [50, 100, 200] {type:"raw"}
use_dgram_noise = None #@param ["g","u","None"]
use_dropout = False #@param {type:"boolean"}
seqsep_mask =  0 #@param {type:"integer"}
num_recycles = 2 #@param {type:"integer"}

sample_models = True if model_name == "both" else False
dgram_noise_type = use_dgram_noise
use_dgram_noise = False if use_dgram_noise is None else True

L = sum(af_model._lengths)
af_model.restart(mode="gumbel")
af_model._inputs["rm_template_seq"] = False
# gather info about inputs
if "offset" in af_model._inputs:           
  offset = af_model._inputs
else:
  idx = af_model._inputs["residue_index"]
  offset = idx[:,None] - idx[None,:]

# initialize sequence
if len(starting_seq) > 1:
  af_model.set_seq(seq=starting_seq)
af_model._inputs["bias"] = np.zeros((L,20))

# initialize coordinates/dgram
af_model._inputs["batch"] = {"aatype":np.zeros(L).astype(int),
                             "all_atom_mask":np.zeros((L,37)),
                             "all_atom_positions":np.zeros((L,37,3)),
                             "dgram":np.zeros((L,L,39))}

if template == "custom":
  xyz = xyz_atom37(pdb_file=template_path)
  af_model._inputs["batch"]["all_atom_positions"] = xyz
  dgram = get_dgram(xyz)
  mask = np.abs(offset) > seqsep_mask
  af_model._inputs["batch"]["dgram"] = dgram * mask[:,:,None]
  if use_dgram_noise:
    if dgram_noise_type == "g":   
      noise = sample_gumbel(dgram.shape) * (1 - k/iterations)
      dgram = softmax(np.log(dgram + 1e-8) + noise, -1)
    elif dgram_noise_type == 'u':  
      noise = sample_uniform(dgram.shape) * (1 - k/iterations)
      dgram = softmax(np.log(dgram + 1e-8) + noise, -1)
plddts = []
print(f"running seq {starting_seq} with model: {'both' if model_name is None else model_name} for {iterations} steps")
for k in range(iterations):
  # noise
  if k > 0:
    dgram_xyz = get_dgram(xyz)
    dgram_prob = softmax(dgram_logits,-1)

    if mode == "xyz":
      dgram = dgram_xyz
    if mode == "dgram":
      dgram = dgram_prob @ dgram_map
      dgram[...,14:] = dgram_xyz[...,14:] * dgram_prob[...,-1:]
    if mode == "dgram_retrain":
      dgram = dgram_prob
    
    if use_dgram_noise:
      if dgram_noise_type == "g":   
        noise = sample_gumbel(dgram.shape) * (1 - k/iterations)
        dgram = softmax(np.log(dgram + 1e-8) + noise, -1)
      elif dgram_noise_type == 'u':  
        noise = sample_uniform(dgram.shape) * (1 - k/iterations)
        dgram = softmax(np.log(dgram + 1e-8) + noise, -1)

    # add mask to avoid local contacts being fixed (otherwise there is a bias toward helix)
    mask = np.abs(offset) > seqsep_mask
    af_model._inputs["batch"]["dgram"] = dgram * mask[:,:,None]

  # prediction
  aux = af_model.predict(return_aux=True, verbose=False,
                        sample_models=sample_models,
                        dropout=use_dropout, num_recycles=num_recycles)
  plddt = aux["plddt"]
  plddts.append(np.average(plddt))
  seq = aux["seq"]["hard"][0].argmax(-1)   
  xyz = aux["atom_positions"].copy()
  dgram_logits = aux["debug"]["outputs"]["distogram"]["logits"] 
  
  # update inputs    
  af_model._inputs["batch"]["aatype"] = seq
  af_model._inputs["batch"]["all_atom_mask"][:,:4] = np.sqrt(plddt)[:,None]
  af_model._inputs["batch"]["all_atom_positions"] = xyz
  
  # save results
  af_model._save_results(aux)
  af_model._k += 1
  af_model.save_pdb(f"iter_{k}.pdb")

In [None]:
#@title visualization
fig,ax=plt.subplots(1,1,figsize=(7.4,2))
ax.scatter(range(len(plddts)),np.array(plddts)*100,s=12, color='grey', zorder=1)
ax.plot(np.array(plddts)*100,'darkorange',zorder=0)
ax.set_xlabel("Prediction iteration")
ax.set_ylabel("pLDDT")
ax.text(ax.get_xlim()[0]+(ax.get_xlim()[1]-ax.get_xlim()[0])*0.85,\
        ax.get_ylim()[0]+(ax.get_ylim()[1]-ax.get_ylim()[0])*0.05,f"recycle# {num_recycles}")
HTML(af_model.animate(dpi=80, interval=300))