<a href="https://colab.research.google.com/gist/jkobject/4871462939758394793fde702666ba1c/colabfold_with_precomputed_humanmsa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


## Colabfold with pre-computed MSA

Protein structure and complex prediction using
[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2) and
[Alphafold2-multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1)
with pre-constructed alignments for human proteomes. All pre-constructed
alignments are available from
[human PPI database](http://prodata.swmed.edu/humanPPI). This Colab notebook is
only for view. To run the code, please use playground mode here:
[run code](https://colab.research.google.com/drive/1suhoIB5q6xn0APFHJE8c1eMiCuv9gCk_#offline=true&sandboxMode=true)


In [None]:
#@title Input protein(s), then hit `Submit` -> `Submit selections`
import os
import re
import hashlib
import random
import ipywidgets as widgets
import IPython
from IPython.display import display, clear_output, HTML
from sys import version_info
python_version = f"{version_info.major}.{version_info.minor}"

os.system("wget --no-check-certificate -qnc https://conglab.swmed.edu/humanPPI/uniprot_function .")

global collected_protein_data
collected_protein_data = []

# Predefined options for dropdowns
with open('uniprot_function') as f:
  prots=f.readlines()[1:]
prots=[i.strip().split('\t') for i in prots]
options=[]
map2uniprot={}
for i in prots:
  options.append(i[0])
  map2uniprot[i[0]]=i[0]
  for j in i[1].split():
    options.append(j)
    map2uniprot[j]=i[0]
  options.append(i[2])
  map2uniprot[i[2]]=i[0]

# Create container for entries
entry_container = widgets.VBox([])
display(entry_container)

def create_entry():
    dropdown = widgets.Combobox(
        options=options,
        description='Protein:',
        placeholder='Type or select a protein',
        ensure_option=True,
        disabled=False
    )
    textbox = widgets.Text(
        description='Copies:',
        disabled=False
    )
    hbox = widgets.HBox([dropdown, textbox],layout=widgets.Layout(display='flex',justify_content='center', width='100%'))
    return hbox

first_entry = create_entry()
entry_container.children = [first_entry]


def on_add_button_clicked(b):
    new_entry = create_entry()
    entry_container.children = tuple(list(entry_container.children) + [new_entry])


add_button = widgets.Button(description="Add entity",layout=widgets.Layout(width='200px'))
hbox_addbutton = widgets.HBox(children=[add_button],layout=widgets.Layout(display='flex',justify_content='center', width='100%'))
#add_button = widgets.Button(description="Add More Protein", layout=button_layout)
display(hbox_addbutton)

add_button.on_click(on_add_button_clicked)

def gather_and_process_inputs():
    protein_data = []
    for entry in entry_container.children:
        protein_name = entry.children[0].value
        copy_number = entry.children[1].value
        if protein_name and copy_number:
            protein_data.append((protein_name, copy_number))
    return protein_data

submit_button = widgets.Button(description="Submit",layout=widgets.Layout(width='200px'))
hbox_submitbutton = widgets.HBox(children=[submit_button],layout=widgets.Layout(display='flex',justify_content='center', width='100%'))
display(hbox_submitbutton)

def display_protein_data():
    print("Current Protein Data:", collected_protein_data)

def on_submit_clicked(b):
    global collected_protein_data
    collected_protein_data = gather_and_process_inputs()
    collected_protein_data=[[map2uniprot[i[0]],i[1]] for i in collected_protein_data]
    if len(set([i[0] for i in collected_protein_data]))!=len(collected_protein_data):
      print("Duplicated entries submitted, please check.")
      display_protein_data()
      return
    display_protein_data()

submit_button.on_click(on_submit_clicked)

collected_protein_data=[[map2uniprot[i[0]],i[1]] for i in collected_protein_data]P41567_P62328



VBox()

HBox(children=(Button(description='Add entity', layout=Layout(width='200px'), style=ButtonStyle()),), layout=L…

HBox(children=(Button(description='Submit', layout=Layout(width='200px'), style=ButtonStyle()),), layout=Layou…

In [None]:
import os
import re
import hashlib
import random


def add_hash(x, y):
    return x + "_" + hashlib.sha1(y.encode()).hexdigest()[:5]


jobname = "test2"  # @param {type:"string"}
# number of models to use
num_relax = 0  # @param [0, 1, 5] {type:"raw"}
# @markdown - specify how many of the top ranked structures to relax using amber
template_mode = "pdb100"  # @param ["none", "pdb100","custom"]
# @markdown - `none` = no template information is used. `pdb100` = detect templates in pdb100 (see [notes](#pdb100)). `custom` - upload and search own templates (PDB or mmCIF format, see [notes](#custom_templates))

basejobname = "".join(jobname.split())
basejobname = re.sub(r"\W+", "", basejobname)
jobname = add_hash(basejobname, "_".join([i[0] for i in collected_protein_data]))


# check if directory with jobname exists
def check(folder):
    if os.path.exists(folder):
        return False
    else:
        return True


if not check(jobname):
    n = 0
    while not check(f"{jobname}_{n}"):
        n += 1
    jobname = f"{jobname}_{n}"

# make directory to save results
os.makedirs(jobname, exist_ok=True)

use_templates = True
custom_template_path = None

print("jobname", jobname)

jobname test2_cd4fc


In [None]:
#@title Install dependencies
%%time
import os
USE_TEMPLATES = use_templates
PYTHON_VERSION = python_version

if not os.path.isfile("COLABFOLD_READY"):
  print("installing colabfold...")
  os.system("uv pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'")
  if os.environ.get('TPU_NAME', False) != False:
    os.system("uv pip uninstall -y jax jaxlib")
    os.system("uv pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold")
  # hack to fix TF crash
  os.system("rm -f /usr/local/lib/python3.*/dist-packages/tensorflow/core/kernels/libtfkernel_sobol_op.so")
  os.system("touch COLABFOLD_READY")

if not os.path.isfile("HH_READY"):
  if not os.path.isfile("CONDA_READY"):
    print("installing conda...")
    os.system("wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh")
    os.system("bash Miniforge3-Linux-x86_64.sh -bfp ../")
    os.system("mamba config --set auto_update_conda false")
    os.system("touch CONDA_READY")
  os.system(f"mamba install -y -c conda-forge -c bioconda hhsuite=3.3.0")
  os.system("touch HH_READY")

if USE_TEMPLATES:
  os.system(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 python='{PYTHON_VERSION}'")

if not os.path.exists('bad_pairs'):
  print("downloading the supplemetary file")
  os.system("wget --no-check-certificate -qnc https://conglab.swmed.edu/humanPPI/bad_pairs.gz .")
  os.system('gzip -d bad_pairs.gz')


In [None]:
# @markdown ## Other settings
msa_mode = "custom"
# @markdown #### MSA generation settings
MSA_pair_mode = (
    "paired_unpaired"  # @param ["paired_unpaired","paired","unpaired"] {type:"string"}
)
# @markdown - "unpaired_paired" = pair sequences from same species + unpaired MSA, "unpaired" = seperate MSA for each chain, "paired" - only use paired sequences.

pairing_strategy = "greedy"  # @param ["greedy", "complete"] {type:"string"}
# @markdown - `greedy` = pair any taxonomically matching subsets, `complete` = all sequences have to match in one line.

filtering_val = True  # @param ["True", "False"] {type:"string"}
# @markdown Choose to filter paralogues or poor quality paired alignment
filtering_flag = "1" if filtering_val == "True" else "0"

identity_filt = 95.0  # @param {type:"number"}
# @markdown - Identity to filter out highly similar sequences for modeling

lineage = "All"  # @param ['All','Mammalia (Mammal)','Chordata'] {type: "raw"}
lineage = lineage.split()[0]
# @markdown - Select lineage to include for modeling

# @markdown #### Advanced model settings
model_type = "alphafold2_multimer_v3"  # @param ["auto", "alphafold2_ptm", "alphafold2_multimer_v1", "alphafold2_multimer_v2", "alphafold2_multimer_v3", "deepfold_v1"]
# @markdown Any of the mode_types can be used (regardless if input is monomer or complex).
# @markdown - if `auto` selected, will use `alphafold2_ptm` for monomer prediction and `alphafold2_multimer_v3` for complex prediction.
num_models = 1  # @param ["1", "2", "3", "4", "5"] {type:"raw"}
# @markdown - number of models to model

num_recycles = "3"  # @param ["auto", "0", "1", "3", "6", "12", "24", "48"]
# @markdown - if `auto` selected, will use `num_recycles=20` if `model_type=alphafold2_multimer_v3`, else `num_recycles=3` .
recycle_early_stop_tolerance = "auto"  # @param ["auto", "0.0", "0.5", "1.0"]
# @markdown - if `auto` selected, will use `tol=0.5` if `model_type=alphafold2_multimer_v3` else `tol=0.0`.
relax_max_iterations = 200  # @param [0, 200, 2000] {type:"raw"}
# @markdown - max amber relax iterations, `0` = unlimited (AlphaFold2 default, can take very long)


# @markdown #### Sample settings
# @markdown -  enable dropouts and increase number of seeds to sample predictions from uncertainty of the model.
# @markdown -  decrease `max_msa` to increase uncertainity
max_msa = "auto"  # @param ["auto", "512:1024", "256:512", "64:128", "32:64", "16:32"]
num_seeds = 1  # @param [1,2,4,8,16] {type:"raw"}
use_dropout = False  # @param {type:"boolean"}

num_recycles = None if num_recycles == "auto" else int(num_recycles)
recycle_early_stop_tolerance = (
    None
    if recycle_early_stop_tolerance == "auto"
    else float(recycle_early_stop_tolerance)
)
if max_msa == "auto":
    max_msa = None

# @markdown #### Save settings
save_all = True  # @param {type:"boolean"}
save_recycles = False  # @param {type:"boolean"}
# @markdown -  if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive
dpi = 200  # @param {type:"integer"}
# @markdown - set dpi for image resolution
os.system(
    "wget --no-check-certificate -qnc http://prodata.swmed.edu/download/pub/test_fetch/bad_pairs"
)

# @markdown Don't forget to hit `Runtime` -> `Run all` after updating the form.

In [None]:
#@title Download and generate MSA
import os, copy,shutil, time
from itertools import combinations
def pair_msa(queries_def, tmpdir, lineage, MSA_pair_mode, pairing_strategy, a3m_file, filtering_flag):
    prots = set([i[0] for i in queries_def])
    bad_pairs = {}
    f=open('bad_pairs')
    for line in f:
        line = line.split()
        if line[0] in prots and line[1] in prots:
            bad_pairs[frozenset([line[0],line[1]])] = line[-1].split(',')
    f.close()
    def read_fasta(filename):
        with open(filename) as f:
            header = ''
            sequence = ''
            for line in f:
                line = line.strip()
                if line.startswith('>'):
                    if header and sequence:
                        yield (header, sequence)
                    header = line[1:]
                    sequence = ''
                else:
                    sequence += line
            if header and sequence:
                yield (header, sequence)

    def write_header(f, queries_def):
        lengths = []
        copies = []
        queries_order = []
        for protein, copy, path in queries_def:
            for header, sequence in read_fasta(path):
                lengths.append(str(len(sequence)))
                break
            copies.append(copy)
            queries_order.append(protein)
        f.write(f"#{','.join(lengths)}\t{','.join(copies)}\n")
        f.write('>' + '\t'.join(queries_order) + '\n')
        for ent in queries_def:
            for _, sequence in read_fasta(ent[-1]):
                f.write(sequence)
                break
        f.write('\n')
        return queries_order, {p: l for p, l in zip(queries_order, lengths)}

    if lineage != 'All':
        for ent in queries_def:
            if not os.path.exists(f'{tmpdir}/{ent[0]}_{lineage}.fas'):
                filename = get_msa_lineage(ent[0], lineage, tmpdir)
                ent.append(filename)
            else:
                ent.append(f'{tmpdir}/{ent[0]}_{lineage}.fas')
    else:
        for ent in queries_def:
            shutil.copy(f"{tmpdir}/{ent[0]}.fas", f'{tmpdir}/{ent[0]}_all.fas')
            ent.append(f'{tmpdir}/{ent[0]}_all.fas')

    taxon_count = {}
    for prot_index, protein in enumerate(queries_def):
        for header, sequence in read_fasta(protein[-1]):
            if header != 'query':
                taxon = header.split()[0]
                if taxon not in taxon_count:
                    taxon_count[taxon] = []
                taxon_count[taxon].append(prot_index)
    protindex2taxons = {}
    for taxon in taxon_count:
        for k in range(1,len(taxon_count[taxon])+1):
            combs = combinations(taxon_count[taxon], k)
            for i in combs:
                i=sorted(i)
                try:
                    protindex2taxons[tuple(i)].add(taxon)
                except:
                    protindex2taxons[tuple(i)] = set([taxon])


    done_taxons = set()
    fw = open(a3m_file, 'w')
    queries_order, length_dic = write_header(fw, queries_def)
    if MSA_pair_mode in ['paired', 'paired_unpaired']:
        queries_index = [i for i in range(len(queries_def))]
        print(queries_index)
        for i in range(len(queries_def), 1, -1):
            comb = combinations(queries_index, i)
            all_good_taxons = set()
            for j in comb:
                j = tuple(sorted(j))
                pairs4exclusion = set()
                if filtering_flag == '1':
                    pairs = combinations(j, 2)
                    for pair in pairs:
                        prot_pair = frozenset([queries_def[p][0] for p in pair])
                        if prot_pair in bad_pairs:
                            pairs4exclusion = pairs4exclusion.union(bad_pairs[prot_pair])
                if j in protindex2taxons:
                    good_taxons = protindex2taxons[j] - pairs4exclusion - done_taxons
                    all_good_taxons = all_good_taxons.union(good_taxons)
                    seqs = {}
                    for taxon in good_taxons:
                        seqs[taxon] = [0]*len(queries_order)
                    for index, p in enumerate(queries_def):
                        if index not in j:
                            for taxon in good_taxons:
                                seqs[taxon][index] = '-'*int(length_dic[p[0]])
                        else:
                            for header, sequence in read_fasta(p[-1]):
                                taxon = header.split()[0]
                                if taxon in good_taxons:
                                    seqs[taxon][index] = sequence
                    defline = []
                    for taxon in good_taxons:
                        defline = [taxon+'_'+queries_def[k][0] if k in j else 'DUMMY' for k in range(len(queries_def))]
                        defline = '\t'.join(defline)
                        fw.write('>' + defline + '\n' + ''.join(seqs[taxon]).upper()+'\n')
            done_taxons = done_taxons.union(all_good_taxons)
            if pairing_strategy == 'complete':
                break
    print(done_taxons)
    if MSA_pair_mode in ['unpaired', 'paired_unpaired']:
        unpaired_taxons = set(taxon_count.keys()) - done_taxons

        for index, query in enumerate(queries_def):
            padding_left = ''
            padding_right = ''
            for i in range(len(queries_def)):
                if i<index:
                    padding_left = padding_left + '-'*int(length_dic[queries_def[i][0]])
                elif i>index:
                    padding_right = padding_right + '-'*int(length_dic[queries_def[i][0]])
            file_name = query[-1]
            for head, sequence in read_fasta(file_name):
                if head == 'query':
                    fw.write('>' + query[0] + '\n')
                    fw.write(f'{padding_left}{sequence}{padding_right}\n')
                elif head.split()[0] in unpaired_taxons:
                    fw.write(f'>{head}\n{padding_left}{sequence.upper()}{padding_right}\n')
    fw.close()
    return a3m_file, length_dic

def get_msa_lineage(pid, lineage, tmpdir):
    f = open(f"{tmpdir}/{pid}.fas", 'r')
    fw = open(f"{tmpdir}/{pid}_{lineage}.fas", 'w')
    flag = 0
    fw.write(next(f))
    fw.write(next(f))
    for line in f:
        if line[0] == '>' and lineage in line.split()[1].split(':'):
            flag =1
            header = line
        elif line[0] == '>' and lineage not in line.split()[1].split(':'):
            flag = 0
            header = line
        elif line[0] != '>':
            if flag == 1 :
                fw.write(header)
                fw.write(line)
    fw.close()
    f.close()
    return f"{tmpdir}/{pid}_{lineage}.fas"


a3m_file = os.path.join(jobname,f"{jobname}.custom.a3m")
queries_path = a3m_file
tmpdir = f'./{jobname}/tmp_msa'
if not os.path.exists(tmpdir):
    os.makedirs(tmpdir)

for i in collected_protein_data:
    os.system(f'wget --no-check-certificate -qnc  https://conglab.swmed.edu/humanPPI/MSAs/{i[0]}.fas -O ./{jobname}/tmp_msa/{i[0]}.fas')

a3m_file, length_dic=pair_msa([list(i) for i in collected_protein_data], tmpdir, lineage, MSA_pair_mode, pairing_strategy, a3m_file, filtering_flag)
time.sleep(2)
cmd = f'../bin/hhfilter -i {a3m_file} -o {a3m_file[:-4]}.i{identity_filt}.a3m -id {identity_filt} -M first >& tmplog'
!{cmd}
os.system(f'mv {a3m_file} {a3m_file[:-4]}.noidentfilter.fas')
os.system(f'mv {a3m_file[:-4]}.i{identity_filt}.a3m {a3m_file}')

In [None]:
# @title Run Prediction
display_images = True  # @param {type:"boolean"}

import sys
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
from pathlib import Path
from colabfold.download import download_alphafold_params, default_data_dir
from colabfold.utils import setup_logging
from colabfold.batch import get_queries, run, set_model_type
from colabfold.plot import plot_msa_v2

import os
import numpy as np

try:
    K80_chk = os.popen('nvidia-smi | grep "Tesla K80" | wc -l').read()
except:
    K80_chk = "0"
    pass
if "1" in K80_chk:
    print("WARNING: found GPU Tesla K80: limited to total length < 1000")
    if "TF_FORCE_UNIFIED_MEMORY" in os.environ:
        del os.environ["TF_FORCE_UNIFIED_MEMORY"]
    if "XLA_PYTHON_CLIENT_MEM_FRACTION" in os.environ:
        del os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]

from colabfold.colabfold import plot_protein
from pathlib import Path
import matplotlib.pyplot as plt


def input_features_callback(input_features):
    if display_images:
        plot_msa_v2(input_features)
        plt.show()
        plt.close()


def prediction_callback(protein_obj, length, prediction_result, input_features, mode):
    model_name, relaxed = mode
    if not relaxed:
        if display_images:
            fig = plot_protein(protein_obj, Ls=length, dpi=150)
            plt.show()
            plt.close()


def calculate_contactprob_and_interactprob(jobname, collected_protein_data, length_dic):
    import pickle, scipy, os
    import numpy as np

    chain2prot = {}
    prot2chain = {}
    resid2index = {}
    allchains = []

    lens = []
    CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
    index_count = 0
    chain_count = 0
    for i in range(len(collected_protein_data)):
        prot = collected_protein_data[i]
        for j in range(int(prot[1])):
            lens.append(int(length_dic[prot[0]]))
            allchains.append(CHARS[chain_count])
            chain2prot[CHARS[chain_count]] = prot[0]
            resid2index[CHARS[chain_count]] = {}
            try:
                prot2chain[prot[0]].append(CHARS[i * j + j])
            except:
                prot2chain[prot[0]] = [CHARS[i * j + j]]
            for k in range(int(length_dic[prot[0]])):
                resid2index[CHARS[chain_count]][k + 1] = index_count
                index_count = index_count + 1
            chain_count = +1

    total_lens = sum(lens)

    pickle_files = [i[:-1] for i in os.popen(f"ls -1 {jobname}/*.pickle")]
    for name in pickle_files:
        with open(name, "rb") as f:
            prediction_result = pickle.load(f)
        pdist = prediction_result["distogram"]["logits"][:total_lens][:, :total_lens]
        pdist = scipy.special.softmax(pdist, axis=-1)
        prob12 = np.sum(pdist[:, :, :32], axis=-1)
        prediction_result["contact_prob"] = prob12.astype(np.float16)
        prediction_result["interaction_prob"] = {}
        #    pickle.dump(prediction_result, open(name, 'wb'))

        print(allchains)
        resid2CAcoor = {chain: {} for chain in allchains}
        resid2coors = {chain: {} for chain in allchains}
        fp = open(
            name.replace(".pickle", ".pdb").replace("all_rank", "unrelaxed_rank"), "r"
        )
        for line in fp:
            if len(line) >= 50:
                if line[:4] == "ATOM":
                    atom = line[12:16].strip()
                    resid = int(line[22:26])
                    chainid = line[21]
                    coorx = float(line[30:38])
                    coory = float(line[38:46])
                    coorz = float(line[46:54])
                    resid2CAcoor
                    if atom == "CA":
                        resid2CAcoor[chainid][resid] = [coorx, coory, coorz]
                    try:
                        resid2coors[chainid][resid].append([coorx, coory, coorz])
                    except KeyError:
                        resid2coors[chainid][resid] = [[coorx, coory, coorz]]
        fp.close()

        for idx1, chainid1 in enumerate(allchains):
            for idx2, chainid2 in enumerate(allchains):
                if idx1 < idx2:
                    mask = np.zeros((total_lens, total_lens), dtype=bool)
                    residsA = resid2CAcoor[chainid1].keys()
                    residsB = resid2CAcoor[chainid2].keys()
                    for residA in residsA:
                        for residB in residsB:
                            caA = resid2CAcoor[chainid1][residA]
                            caB = resid2CAcoor[chainid2][residB]
                            cadist = (
                                (caA[0] - caB[0]) ** 2
                                + (caA[1] - caB[1]) ** 2
                                + (caA[2] - caB[2]) ** 2
                            ) ** 0.5

                            if cadist < 20:
                                dists = []
                                coorsA = resid2coors[chainid1][residA]
                                coorsB = resid2coors[chainid2][residB]
                                for coorA in coorsA:
                                    for coorB in coorsB:
                                        dist = (
                                            (coorA[0] - coorB[0]) ** 2
                                            + (coorA[1] - coorB[1]) ** 2
                                            + (coorA[2] - coorB[2]) ** 2
                                        ) ** 0.5
                                        dists.append(dist)
                                if min(dists) < 8:
                                    mask[
                                        resid2index[chainid1][residA],
                                        resid2index[chainid2][residB],
                                    ] = True

                    try:
                        probs = np.copy(prediction_result["contact_prob"])
                        IFprobs = probs[mask]
                        if len(IFprobs) > 0:
                            maxprob = np.max(IFprobs)
                        else:
                            maxprob = 0.0
                        prediction_result["interaction_prob"][
                            chainid1 + chainid2
                        ] = maxprob
                    except:
                        prediction_result["interaction_prob"][
                            chainid1 + chainid2
                        ] = "error"
                        print("error in getting interaction probability")
        pickle.dump(prediction_result, open(name, "wb"))
    return


result_dir = jobname
log_filename = os.path.join(jobname, "log.txt")
setup_logging(Path(log_filename))

queries, is_complex = get_queries(queries_path)
model_type = set_model_type(is_complex, model_type)

if "multimer" in model_type and max_msa is not None:
    use_cluster_profile = False
else:
    use_cluster_profile = True

download_alphafold_params(model_type, Path("."))
results = run(
    queries=queries,
    result_dir=result_dir,
    use_templates=use_templates,
    custom_template_path=custom_template_path,
    num_relax=num_relax,
    msa_mode=msa_mode,
    model_type=model_type,
    num_models=5,
    num_recycles=num_recycles,
    relax_max_iterations=relax_max_iterations,
    recycle_early_stop_tolerance=recycle_early_stop_tolerance,
    num_seeds=num_seeds,
    use_dropout=use_dropout,
    model_order=[3, 1, 2, 4, 5],
    is_complex=is_complex,
    data_dir=Path("."),
    keep_existing_results=False,
    rank_by="auto",
    pair_mode=MSA_pair_mode,
    pairing_strategy=pairing_strategy,
    stop_at_score=float(100),
    prediction_callback=prediction_callback,
    dpi=dpi,
    zip_results=False,
    save_all=save_all,
    max_msa=max_msa,
    use_cluster_profile=use_cluster_profile,
    input_features_callback=input_features_callback,
    save_recycles=save_recycles,
    user_agent="colabfold/google-colab-main",
)

calculate_contactprob_and_interactprob(jobname, collected_protein_data, length_dic)
results_zip = f"{jobname}.result.zip"
os.system(f"zip -r {results_zip} {jobname}")

In [None]:
# @title Display 3D structure {run: "auto"}
import py3Dmol
import glob
import matplotlib.pyplot as plt
from colabfold.colabfold import plot_plddt_legend
from colabfold.colabfold import pymol_color_list, alphabet_list

rank_num = 1  # @param ["1", "2", "3", "4", "5"] {type:"raw"}
color = "lDDT"  # @param ["chain", "lDDT", "rainbow"]
show_sidechains = False  # @param {type:"boolean"}
show_mainchains = False  # @param {type:"boolean"}

tag = results["rank"][0][rank_num - 1]
jobname_prefix = ".custom" if msa_mode == "custom" else ""
pdb_filename = f"{jobname}/{jobname}{jobname_prefix}_unrelaxed_{tag}.pdb"
pdb_file = glob.glob(pdb_filename)


def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color="lDDT"):
    model_name = f"rank_{rank_num}"
    view = py3Dmol.view(
        js="https://3dmol.org/build/3Dmol.js",
    )
    view.addModel(open(pdb_file[0], "r").read(), "pdb")

    if color == "lDDT":
        view.setStyle(
            {
                "cartoon": {
                    "colorscheme": {
                        "prop": "b",
                        "gradient": "roygb",
                        "min": 50,
                        "max": 90,
                    }
                }
            }
        )
    elif color == "rainbow":
        view.setStyle({"cartoon": {"color": "spectrum"}})
    elif color == "chain":
        chains = len(queries[0][1]) + 1 if is_complex else 1
        for n, chain, color in zip(range(chains), alphabet_list, pymol_color_list):
            view.setStyle({"chain": chain}, {"cartoon": {"color": color}})

    if show_sidechains:
        BB = ["C", "O", "N"]
        view.addStyle(
            {
                "and": [
                    {"resn": ["GLY", "PRO"], "invert": True},
                    {"atom": BB, "invert": True},
                ]
            },
            {"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}},
        )
        view.addStyle(
            {"and": [{"resn": "GLY"}, {"atom": "CA"}]},
            {"sphere": {"colorscheme": f"WhiteCarbon", "radius": 0.3}},
        )
        view.addStyle(
            {"and": [{"resn": "PRO"}, {"atom": ["C", "O"], "invert": True}]},
            {"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}},
        )
    if show_mainchains:
        BB = ["C", "O", "N", "CA"]
        view.addStyle(
            {"atom": BB}, {"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}}
        )

    view.zoomTo()
    return view


show_pdb(rank_num, show_sidechains, show_mainchains, color).show()
if color == "lDDT":
    plot_plddt_legend().show()

In [None]:
# @title Plots {run: "auto"}
from IPython.display import display, HTML
import base64
from html import escape


# see: https://stackoverflow.com/a/53688522
def image_to_data_url(filename):
    ext = filename.split(".")[-1]
    prefix = f"data:image/{ext};base64,"
    with open(filename, "rb") as f:
        img = f.read()
    return prefix + base64.b64encode(img).decode("utf-8")


pae = ""
pae_file = os.path.join(jobname, f"{jobname}{jobname_prefix}_pae.png")
if os.path.isfile(pae_file):
    pae = image_to_data_url(pae_file)
cov = image_to_data_url(
    os.path.join(jobname, f"{jobname}{jobname_prefix}_coverage.png")
)
plddt = image_to_data_url(os.path.join(jobname, f"{jobname}{jobname_prefix}_plddt.png"))
display(
    HTML(
        f"""
<style>
  img {{
    float:left;
  }}
  .full {{
    max-width:100%;
  }}
  .half {{
    max-width:50%;
  }}
  @media (max-width:640px) {{
    .half {{
      max-width:100%;
    }}
  }}
</style>
<div style="max-width:90%; padding:2em;">
  <h1>Plots for {escape(jobname)}</h1>
  { '<!--' if pae == '' else '' }<img src="{pae}" class="full" />{ '-->' if pae == '' else '' }
  <img src="{cov}" class="half" />
  <img src="{plddt}" class="half" />
</div>
"""
    )
)

In [None]:
ls

In [None]:
# @title Package and download results
# @markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \"Download\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).

if msa_mode == "custom":
    print("Don't forget to cite your custom MSA generation method.")

files.download(f"{jobname}.result.zip")

if save_to_google_drive == True and drive:
    uploaded = drive.CreateFile({"title": f"{jobname}.result.zip"})
    uploaded.SetContentFile(f"{jobname}.result.zip")
    uploaded.Upload()
    print(f"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}")

# Instructions <a name="Instructions"></a>

**Quick start**

1. Paste your protein sequence(s) in the input field.
2. Press "Runtime" -> "Run all".
3. The pipeline consists of 5 steps. The currently running step is indicated by
   a circle with a stop sign next to it.

**Result zip file contents**

1. PDB formatted structures sorted by avg. pLDDT and complexes are sorted by
   pTMscore. (unrelaxed and relaxed if `use_amber` is enabled).
2. Plots of the model quality.
3. Plots of the MSA coverage.
4. Parameter log file.
5. fasta formatted input MSA.
6. A `predicted_aligned_error_v1.json` using
   [AlphaFold-DB's format](https://alphafold.ebi.ac.uk/faq#faq-7) and a
   `scores.json` for each model which contains an array (list of lists) for PAE,
   a list with the average pLDDT and the pTMscore.
7. BibTeX file with citations for all used tools and databases.

At the end of the job a download modal box will pop up with a
`jobname.result.zip` file. Additionally, if the `save_to_google_drive` option
was selected, the `jobname.result.zip` will be uploaded to your Google Drive.

**Using custom templates** <a name="custom_templates"></a>

To predict the structure with a custom template (PDB or mmCIF formatted): (1)
change the `template_mode` to "custom" in the execute cell and (2) wait for an
upload box to appear at the end of the "Input Protein" box. Select and upload
your templates (multiple choices are possible).

- Templates must follow the four letter PDB naming with lower case letters.

- Templates in mmCIF format must contain `_entity_poly_seq`. An error is thrown
  if this field is not present. The field
  `_pdbx_audit_revision_history.revision_date` is automatically generated if it
  is not present.

- Templates in PDB format are automatically converted to the mmCIF format.
  `_entity_poly_seq` and `_pdbx_audit_revision_history.revision_date` are
  automatically generated.

**Troubleshooting**

- Check that the runtime type is set to GPU at "Runtime" -> "Change runtime
  type".
- Try to restart the session "Runtime" -> "Factory reset runtime".
- Check your input sequence.

**Known issues**

- Google Colab assigns different types of GPUs with varying amount of memory.
  Some might not have enough memory to predict the structure for a long
  sequence.
- Your browser can block the pop-up for downloading the result file. You can
  choose the `save_to_google_drive` option to upload to Google Drive instead or
  manually download the result file: Click on the little folder icon to the
  left, navigate to file: `jobname.result.zip`, right-click and select
  \"Download\" (see
  [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).

**Description of the plots**

- **Number of sequences per position** - We want to see at least 30 sequences
  per position, for best performance, ideally 100 sequences.
- **Predicted lDDT per position** - model confidence (out of 100) at each
  position. The higher the better.
- **Predicted Alignment Error** - For homooligomers, this could be a useful
  metric to assess how confident the model is about the interface. The lower the
  better.

**License**

The source code of ColabFold is licensed under
[MIT](https://raw.githubusercontent.com/sokrypton/ColabFold/main/LICENSE).
Additionally, this notebook uses the AlphaFold2 source code and its parameters
licensed under
[Apache 2.0](https://raw.githubusercontent.com/deepmind/alphafold/main/LICENSE)
and [CC BY 4.0](https://creativecommons.org/licenses/by-sa/4.0/) respectively.
Read more about the AlphaFold license
[here](https://github.com/deepmind/alphafold).

**Acknowledgments**

- We thank the AlphaFold and ColabFold team for developing an excellent model
  and open sourcing the software.
