In [None]:
#@title gapTrick - AlphaFold2 with multimeric templates

#@markdown This notebok uses code from https://github.com/gchojnowski/gapTrick and is based on third-party software, including
#@markdown -  AlphaFold2 [code](https://github.com/google-deepmind/alphafold)
#@markdown -  MMseqs2 API [code](https://github.com/soedinglab/MMseqs2)
#@markdown -  ... and many others. See [acknowledgements](https://github.com/deepmind/alphafold/#acknowledgements) in AlphaFold's readme for more details.

#@markdown If you found it uselful remember to cite
#@markdown  - Chojnowski bioRxiv (2025) [01.31.635911](https://doi.org/10.1101/2025.01.31.635911)
#@markdown  - Jumper et al Nature (2021) [596, 583–589](https://www.nature.com/articles/s41586-021-03819-2)
#@markdown  - Mirdita et al Bioinformatics (2021) [btab184](https://doi.org/10.1093/bioinformatics/btab184)


#@markdown **Select options below and hit `Runtime` -> `Run all` to start uploading protein sequences and template**

#@markdown **A few tips on how to use this notebook**
#@markdown - Use TPU v2-8 runtime (most RAM)
#@markdown - Try disabling relaxation if you run out of memory anyway


from google.colab import files
import os, re
import hashlib
import random
from pathlib import Path

from datetime import datetime
d=datetime.now()

def add_hash(x,y):
  return x+"_"+hashlib.sha1(y.encode()).hexdigest()[:7]
chain_ids = ''

n=d.strftime('%H%M%d%m%Y')

#query_sequence = 'PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK' #@param {type:"string"}
jobname = "ja pobrusze a ty poczywaj" #@param {type:"string"}

#@markdown - Upload a template protein model (PDB or mmCIF)
use_templates = True #@param {type:"boolean"}

#@markdown - Run AMBER relaxation on top-scored prediction
relax_top_model = True #@param {type:"boolean"}

##@markdown - Force matching following template chains to consecutive input sequences (default: greedy automation)
#chain_ids = '' # @param {type:"string", placeholder:"A,B,C"}

#@markdown - Remove everything except models and figures from output (MSAs, AF2 pkl etc can be pretty large)
clean_outputs = True #@param {type:"boolean"}

#@markdown - Close session after downloading results
disconnet_when_complete = False #@param {type:"boolean"}

#@markdown - You can reduce MSA depth here
max_msa_depth = 5000 #@param [10, 100, 1000, 2000, 3000, 4000, 5000, 10000] {type:"raw"}

jobname_prefix = re.sub(r'\s', '_', jobname.strip())
jobname_prefix = re.sub(r'\W+', '', jobname_prefix)

print("Uploading FASTA")
uploaded_fasta = files.upload()
fasta_fn, query_sequence = list(uploaded_fasta.items())[0]
query_sequence = query_sequence.decode('ascii')
jobfolder=add_hash(jobname_prefix, query_sequence)


if os.path.exists(jobfolder):
  n=1
  while os.path.exists(f"{jobfolder}_{n}"): n += 1
  jobfolder = f"{jobfolder}_{n}"

# create a job directory
os.makedirs(jobfolder, exist_ok=True)
os.rename(fasta_fn,os.path.join(jobfolder,fasta_fn))
fasta_fn = os.path.join('/content',jobfolder,fasta_fn)


template_files=[]
if use_templates:
  print("Uploading template in PDB/mmCIF")
  uploaded_template = files.upload()
  use_templates = True
  fn = list(uploaded_template.keys())[0]
  template_fn=os.path.join('/','content',jobfolder,fn)
  os.rename(fn,template_fn)
  template_files.append(template_fn)


In [None]:
# Set environment variables before running any other code.
import os
from sys import version_info
python_version = f"{version_info.major}.{version_info.minor}"

os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '4.0'


#@title 1. Prepare colab environment

from IPython.utils import io
import os
import subprocess
import tqdm.notebook

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

if os.path.exists('/opt/conda'):
  print('Using existing conda environment')
else:
  try:
    with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:
      with io.capture_output() as captured:
        # Uninstall default Colab version of TF.
        #%shell pip uninstall -y tensorflow keras
        #%shell sudo apt install --quiet --yes hmmer
        #pbar.update(6)

        # Install py3dmol.
        %shell pip install py3dmol
        pbar.update(10)

        #%shell wget -q -P /tmp \
        #  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
        #    && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -bfp /opt/conda \
        #    && rm /tmp/Miniconda3-latest-Linux-x86_64.sh

        %shell wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh
        %shell bash Miniforge3-Linux-x86_64.sh -bfp /opt/conda
        PATH=%env PATH
        %env PATH=/opt/conda/bin:{PATH}

        pbar.update(10)

        #%shell mamba config --set auto_update_conda false
        %shell mamba install -y -c conda-forge -c bioconda \
                          python='{python_version}' \
                          matplotlib \
                          kalign2=2.04 \
                          hhsuite=3.3.0 \
                          openmm=8.0.0 \
                          pdbfixer

        #%shell conda install -qy conda==24.11.1 \
        #    && conda install -qy -c conda-forge -c bioconda\
        #      python={python_version} \
        #      openmm=8.0.0 \
        #      matplotlib \
        #      kalign2 \
        #      hhsuite \
        #      pyopenssl==22.0.0 \
        #      pdbfixer
        pbar.update(70)

        %shell git clone --recursive https://github.com/gchojnowski/gapTrick.git
        pbar.update(10)
  except subprocess.CalledProcessError:
    print(captured)
    raise

In [None]:
#@title 2. Download AlphaFold

from sys import version_info
python_version = f"{version_info.major}.{version_info.minor}"
GIT_REPO = 'https://github.com/deepmind/alphafold'
SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar'
PARAMS_DIR = '/content/alphafold/data/params'
PARAMS_PATH = os.path.join(PARAMS_DIR, os.path.basename(SOURCE_URL))


if os.path.exists('/content/alphafold'):
  print("Using an existing AF2 instalation")
else:
  try:
    with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:
      with io.capture_output() as captured:

        %shell git clone --branch main {GIT_REPO} /content/alphafold
        pbar.update(10)
        # Install AlphaFodl2 and its dependencies.
        %shell pip3 install -r /content/alphafold/requirements.txt
        %shell pip3 install --no-dependencies /content/alphafold
        pbar.update(30)

        # Download stereo_chemical_props.txt
        %shell wget -q -P /content https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
        %shell mkdir -p /opt/conda/lib/python{python_version}/site-packages/alphafold/common/
        %shell cp -f /content/stereo_chemical_props.txt /opt/conda/lib/python{python_version}/site-packages/alphafold/common/

        # Download weights
        %shell mkdir --parents "{PARAMS_DIR}"
        %shell wget -O "{PARAMS_PATH}" "{SOURCE_URL}"
        pbar.update(50)

        %shell tar --extract --verbose --file="{PARAMS_PATH}" --directory="{PARAMS_DIR}" --preserve-permissions
        %shell rm "{PARAMS_PATH}"
        pbar.update(10)
  except subprocess.CalledProcessError:
    print(captured)
    raise

import sys
sys.path.append(f'/opt/conda/lib/python{python_version}/site-packages')
sys.path.append('/content/alphafold')

In [None]:
#@title 3. Run gapTrick

jobpath=Path('/', 'content', jobfolder)
msapath=Path('/', 'content', jobfolder, 'msa')
os.makedirs(msapath, exist_ok=True)

from sys import version_info
python_version = f"{version_info.major}.{version_info.minor}"
import sys,os
import glob
import datetime
import matplotlib.pyplot as plt
plt.close()

import logging
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(message)s",\
      handlers=[logging.FileHandler(Path(jobpath, 'logfile.txt')),
                logging.StreamHandler(sys.stdout)])

sys.path.append(f'/opt/conda/lib/python{python_version}/site-packages')
sys.path.append('/content/alphafold')
sys.path.insert(1, '/content/gapTrick')

from gapTrick.__main__ import *
import gapTrick.version

from Bio import SeqIO

logger.info(f"## gapTrick version {version.__version__} (colab)\n\n")
logger.info(f"Started at {datetime.now().strftime('%H:%M:%S %d/%m/%Y')}\n")

msas=[]
local_msa_dict = {}
with open(fasta_fn) as ifile:
  for record in SeqIO.parse(ifile, "fasta"):
    a3m_fname=local_msa_dict.get(record.seq, None)

    if a3m_fname:
      logger.info("Found existing MSA")
    else:
      a3m_fname = os.path.join(msapath, f"{len(local_msa_dict):04d}.a3m")
      query_mmseqs2(record.seq, a3m_fname, user_agent='colab')
      local_msa_dict[record.seq]=a3m_fname

    logger.info(f"{record.id}: {a3m_fname}")
    logger.info('')
    msas.append(a3m_fname)


runme(msa_filenames     = msas,
      query_cardinality = [1]*len(msas),
      query_trim        = [[0,9999]]*len(msas),
      jobname           = jobpath,
      do_relax          = relax_top_model,
      data_dir          = Path('/', 'content', 'alphafold', 'data'),
      max_seq           = max_msa_depth,
      pbty_cutoff       = 0.8,
      chain_ids         = chain_ids if chain_ids else None,
      template_fn_list  = template_files if use_templates else [])


logger.info(f"\n\nNormal termination at {datetime.now().strftime('%H:%M:%S %d/%m/%Y')}")

In [None]:
#@title Display 3D model
import locale
locale.getpreferredencoding = lambda: "UTF-8"

PLDDT_BANDS = [(50, '#FF7D45'),
               (70, '#FFDB13'),
               (90, '#65CBF3'),
               (100, '#0053D6')]

def set_b_to_plddtbands_bio(structure):

    plddt_lims = np.array([_[0] for _ in PLDDT_BANDS])

    for chain in structure:
        for resi in chain:
            for atm in resi:
                atm.set_bfactor(float(np.argmax(plddt_lims>atm.get_bfactor())))

# -----------------------------------------------------------------------------

def set_3dmol_styles(
    view,
    viewer,
    chain_ids=1,
    color=["lDDT", "rainbow", "chain"][0],
    show_sidechains=False,
    show_mainchains=False,):

    if color == "lDDT":
        color_map = {i: band[1] for i, band in enumerate(PLDDT_BANDS)}
        view.setStyle(
            {
                "cartoon": {
                    "colorscheme": {"prop": "b",'map': color_map}
                }
            }, viewer=viewer,)
    elif color == "chain":
        for cid, color in zip(
            chain_ids, ["lime","cyan","magenta", \
             "yellow","salmon","white",\
             "blue","orange","black",\
             "green","gray",]* 2,):
            view.setStyle({"chain": cid},
                          {"cartoon": {"color": color}}, viewer=viewer)

# -----------------------------------------------------------------------------

def bio2pdbstring(structure):
    pio=PDBIO()
    pio.set_structure(structure)
    fo = io.StringIO()
    pio.save(fo)
    fo.seek(0)
    return fo.read()

# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
import re
import py3Dmol
parser = PDBParser()

top_model_fn = jobpath=Path('/', 'content', jobfolder, 'input', 'ranked_0.pdb')
structure = parser.get_structure("AF", top_model_fn)[0]
parser = PDBParser(QUIET=True)
template_structure = None
if use_templates:
  try:
    template_structure = parser.get_structure("template", template_files[0])[0]
  except:
    parser = MMCIFParser(QUIET=True)
    template_structure = parser.get_structure("template", template_files[0])[0]



chain_ids = [_.id for _ in structure]

view = py3Dmol.view(
            width=1000,
            height=500,
            viewergrid=(1, 2),)

viewer = (0, 0)
view.addModel(bio2pdbstring(structure), "pdb", viewer=viewer)
view.zoomTo(viewer=viewer)
set_3dmol_styles( view, viewer, chain_ids=chain_ids, color="chain")

if use_templates:
  view.addModel(bio2pdbstring(template_structure), "pdb", viewer=viewer)

  view.setStyle({"model": -1},
                {"cartoon":
                 {"color": 'black',
                  'style':'trace',
                  'ribbon':True,
                  'opacity':0.7}},
              viewer=viewer)

viewer = (0, 1)
set_b_to_plddtbands_bio(structure)
view.addModel(bio2pdbstring(structure), "pdb", viewer=viewer)
view.zoomTo(viewer=viewer)
set_3dmol_styles(view, viewer, chain_ids=chain_ids, color="lDDT")

if use_templates:
  view.addModel(bio2pdbstring(template_structure), "pdb", viewer=viewer)

  view.setStyle({"model": -1},
                {"cartoon":
                 {"color": 'black',
                  'style':'trace',
                  'ribbon':True,
                  'opacity':0.7}},
              viewer=viewer)

# generate CA position dict
CA_atom_coords = {}
for chain in structure:
  for res in chain:
    CA_atom_coords[f'{res._id[1]}{chain.id}'] = list(res['CA'].coord)

contact_template = r"\* (?P<rname1>\w+?)/(?P<ch1>\w+?)/(?P<res1>\w+?)\s+(?P<rname2>\w+?)/(?P<ch2>\w+?)/(?P<res2>\w+?)\s+(?P<pbty>[\d\.]*?)"
with open(Path('/', 'content', jobfolder, 'contacts.txt'), 'r') as ifile:
  for contact_str in ifile:
        m = re.match(contact_template, contact_str)
        if not m: continue
        xyz_from = dict([ (l,float(c)) for l,c in zip(('x','y','z'),CA_atom_coords[f"{m.group('res1')}{m.group('ch1')}"])])
        xyz_to = dict([ (l,float(c)) for l,c in zip(('x','y','z'),CA_atom_coords[f"{m.group('res2')}{m.group('ch2')}"])])
        view.addLine({"color":'red',"start":xyz_from,"end":xyz_to});

view.show()






In [None]:
#@title Pack and download results, and terminate the session (if requested)

from google.colab import files
import time

if clean_outputs:
  jobpath=Path('/', 'content', jobfolder)
  %shell rm {jobpath}/input/features.pkl
  %shell rm {jobpath}/input/result_model*.pkl
  %shell rm -r {jobpath}/input/msas
  %shell rm -r {jobpath}/msa

# pack and download
results_zip = f"{jobfolder}.zip"
os.system(f"zip -r {results_zip} {jobfolder}")
files.download(results_zip)


if disconnet_when_complete:
  time.sleep(300)
  from google.colab import runtime
  runtime.unassign()
