<a href="https://colab.research.google.com/github/clami66/AF_unmasked/blob/notebook/notebooks/AF_unmasked.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AF_unmasked: a simplified notebook

<img src="https://github.com/clami66/AF_unmasked/raw/main/fig/header.png" height="200">

This notebook allows to run AF_unmasked on multimeric sequences and templates of your choice. Not all features implemented on the command line version of AF_unmasked are currently available on the notebook, but more will come later.

This version of AF_unmasked relies on MMseqs2 alignments, run by the [ColabFold](https://github.com/sokrypton/ColabFold) MSA server. Some of the code on this notebook is also based or taken from the ColabFold notebook.

If you use this version of AF_unmasked in your research, consider citing:

- Mirabello et al.: "Unmasking AlphaFold to integrate experiments and predictions in multimeric complexes". [Nature Communications volume 15, 8724 (2024)](https://www.nature.com/articles/s41467-024-52951-w)
- Jumper et al.: "Highly accurate protein structure prediction with AlphaFold". [Nature volume 596, pages 583–589 (2021)](https://www.nature.com/articles/s41586-021-03819-2)
- Evans et al.: "Protein complex prediction with AlphaFold-Multimer". [BiorXiv](https://www.biorxiv.org/content/early/2021/10/04/2021.10.04.463034)
- Mirdita et al. "ColabFold: making protein folding accessible to all" [Nature Methods volume 19, pages 679–682 (2022)](https://www.nature.com/articles/s41592-022-01488-1)

In [None]:
import os
import sys
import pickle
import shutil
import importlib_metadata
from pathlib import Path
cwd = Path.cwd()
from string import ascii_uppercase, ascii_lowercase
ascii_upperlower = ascii_uppercase + ascii_lowercase

python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
PYTHON_VERSION = python_version

print("Setting up the environment...")

os.system("git clone -b notebook https://github.com/clami66/AF_unmasked.git")

if not os.path.isfile("COLABFOLD_READY"):
  print("installing colabfold...")
  os.system("pip install numpy<1.20 biopython==1.79 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'")
  #os.system("ln -s /home/claudio/miniconda3/envs/mercury/lib/python3.10/site-packages/colabfold colabfold")
  #os.system("ln -s /home/claudio/miniconda3/envs/mercury/lib/python3.10/site-packages/alphafold alphafold")
  os.system("touch COLABFOLD_READY")

os.system("mamba install -y -c conda-forge -c bioconda hmmer kalign2=2.04 hhsuite=3.3.0")
sys.path.insert(0, f"{cwd}/AF_unmasked")
sys.path.insert(0, f"{cwd}")

import mercury as mr

from AF_unmasked.alphafold.data.prepare_templates import *
from AF_unmasked.alphafold.data.mmseqs_2_uniprot import *
from Bio import Align, SeqIO, AlignIO
from Bio.PDB.mmcifio import MMCIFIO

from colabfold.batch import get_msa_and_templates, msa_to_str
from colabfold.utils import DEFAULT_API_SERVER, get_commit
from colabfold.download import download_alphafold_params

from AF_unmasked.run_alphafold import predict_structure
from AF_unmasked.alphafold.data import pipeline, pipeline_multimer

from AF_unmasked.alphafold.data.tools import hmmsearch, jackhmmer
from AF_unmasked.alphafold.data import templates

from AF_unmasked.alphafold.model import model, data, config

import py3Dmol
import matplotlib.pyplot as plt
from colabfold.colabfold import plot_plddt_legend
from colabfold.colabfold import pymol_color_list, alphabet_list

print("   ... done.")

In [None]:
app = mr.App(title="AF_unmasked",
        description="",
        show_code=False,
        show_prompt=False,
        continuous_update=True,
        static_notebook=False,
        show_sidebar=True,
        full_screen=True,
        allow_download=True,
        allow_share=True,
        stop_on_error=True
)

#mr.Markdown(text="""### Job settings
#""")

jobname = mr.Text(value="", label="Name your job:", rows=1)
if not jobname.value:
    mr.Stop()


In [None]:
out_dir = Path(f"{jobname.value}")
out_dir.mkdir(parents=True, exist_ok=True)
targets = out_dir.joinpath(f"{jobname.value}.fasta")

In [None]:
# add file upload widget
f = mr.File(label="Upload a single .fasta file for a multimeric target", max_file_size="1MB")
fasta = f.filepath
p = mr.File(label="Upload a template PDB (.pdb, .cif)", max_file_size="20MB")
pdb = p.filepath

if not fasta or not pdb:
    mr.Stop()

In [None]:
if p.filename.endswith(".pdb"):
  template_format = "pdb"
elif p.filename.endswith(".cif"):
  template_format = "cif"
else:
  raise ValueError("Template must be in .pdb or .cif format")

fasta = Path(fasta)
pdb = Path(pdb)

In [None]:
_ = shutil.copyfile(fasta, targets)

if not is_fasta(targets):
  raise ValueError("""The input does not appear to be in fasta format
  Example of fasta format input:
  > H1142_A
  GLEKDFLPLYFGWFLTK...
  > H1142_B
  EVQLEESGGGLVQAGGS...
  """)

with open(targets, "r") as f:
  print("Fasta sequences:")
  print(f.read())
  print()

seq2chain = {}
chain_idx = 0
for record in SeqIO.parse(targets, "fasta"):
  if record.seq not in seq2chain:
    seq2chain[record.seq] = [ascii_upperlower[chain_idx]]
  else:
    seq2chain[record.seq].append(ascii_upperlower[chain_idx])
  chain_idx += 1

In [None]:
template = out_dir.joinpath(p.filename)
_ = shutil.copyfile(pdb, template)

# template data
template_model = load_PDB(template, is_mmcif=(template_format == "cif"))
template_chains = [c.id for c in template_model]
remove_extra_chains(template_model, template_chains)
remove_hetatms(template_model)
template_sequences = [
        get_fastaseq(template_model, chain) for chain in template_chains
    ]

print("Template sequences:")
for seq, chain in zip(template_sequences, template_chains):
  print(f"Chain {chain}: {seq}")

# target data
target_chains, target_sequences, target_models = get_target_data(
            [str(targets)],
            chains=None,
            is_fasta=True,
            is_mmcif=False,
        )
assert len(target_chains) <= len(
      template_chains
), f"Not enough chains in the template structure to cover all target chains. Partial templates are currently not supported on the colab version of AF_unmasked."

In [None]:
model_type = mr.Select(label="Select AF model type (v2 or v3)",
                          choices=["alphafold2_multimer_v2", "alphafold2_multimer_v3"])

predictions_per_model = mr.Numeric(value=1, min=1, max=10, label="Number of predictions per model:", step=1)

num_recycles = mr.Numeric(value=20, min=0, max=100, label="Number of recycles:", step=1)
recycle_early_stop_tolerance = mr.Numeric(value=0.5, min=0.0, max=1.0, label="Recycle early stop tolerance:", step=1)
use_dropout = mr.Checkbox(value=True, label="Use dropout to increase sampling noise:")

msa_mode = mr.Select(label="MSA mode:", choices=["no_MSA", "mmseqs2_uniref", "mmseqs2_uniref_env"])

msa_depth = mr.Select(label="Number of sequences in MSA",
                          choices=["auto", "1", "32", "64", "256", "512"])

msa_depth = None if msa_depth.value == "auto" else int(msa_depth.value)

inpaint_clashes = mr.Checkbox(value=True, label="Automatically inpaint clashes")
align = mr.Checkbox(value=True, label="Align template sequences to targets")

align_tool = "blast"

In [None]:
template_preview = [f"Chain {template_chain} (seq: {template_seq[:10]}...)" for template_chain, template_seq in zip(template_chains, template_sequences)]

template_c = []
for i, ch in enumerate(target_chains):
    template_c.append(mr.Select(label=f"Select template chain for fasta sequence {i+1} (seq: {target_sequences[i][:10]}", value=template_preview[i], choices=template_preview, url_key=i))


In [None]:
def show_pdb(pdb_file, extension, show_sidechains=False, show_mainchains=False, color="lDDT"):
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  view.addModel(open(pdb_file,'r').read(), extension)

  if color == "lDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    chains = len(target_sequences) + 1
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains.value:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains.value:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view

repeat_template = mr.Select(label="Repeat template:", choices=["1 time", "2 times", "3 times", "4 times"], value="1 time")

run_temp = mr.Button(label="Run AF_unmasked")
done = False
if run_temp.clicked:
    temp_reps = int(repeat_template.value.split()[0])
    template_chains = [temp.value.split()[1] for temp in template_c]
    
    if len(template_chains) != len(set(template_chains)):
      raise ValueError("Must select a different template chain for each fasta sequence")
    
    append = False
    
    mmcif_path = Path(out_dir, "template_data", "mmcif_files")
    mmcif_path.mkdir(parents=True, exist_ok=True)
    
    for i in range(temp_reps):
        print(f"Filling template slot n. {i+1}...")
        next_id = get_next_id(mmcif_path) if append else "0000"
    
        io = MMCIFIO()
        template_mmcif_path = os.path.join(
            out_dir, "template_data", "mmcif_files", f"{next_id}.cif"
        )
    
        if inpaint_clashes:
            template_model = detect_and_remove_clashes(template_model)
            template_sequences = [
                get_fastaseq(template_model, chain) for chain in template_chains
            ]
    
        io.set_structure(template_model)
        io.save(template_mmcif_path)
    
        fix_mmcif(
            template_mmcif_path, template_chains, template_sequences, "2100-01-01"
        )
    
        pdb_seqres_path = Path(out_dir, "template_data", "pdb_seqres.txt").resolve()
        write_seqres(
            pdb_seqres_path,
            template_sequences,
            template_chains,
            seq_id=next_id,
            append=append,
        )
    
        # extra flagfile for AF usage
        af_flagfile_path = Path(out_dir, "template_data", "templates.flag")
        if not af_flagfile_path.is_file():  # don't overwrite file if already there
            with open(af_flagfile_path, "w") as flagfile:
                flagfile.write(f"--template_mmcif_dir={mmcif_path.resolve()}\n")
                flagfile.write(f"--pdb_seqres_database_path={pdb_seqres_path}\n")
                if align:  # means we are not going to let AF overwrite pdb_hits.sto
                    flagfile.write("--use_precomputed_msas\n")
    
        if align:
    
            assert len(target_chains) == len(
                template_chains
            ), f"The number of chains to align from target ({target_chains}) doesn't match the number of chains in the template ({template_chains}). Make sure that the files contain the same number of chains or select the chains that should be paired with --target_chains, --template_chains"
            for (
                i,
                (
                    template_chain,
                    template_sequence,
                    target_chain,
                    target_sequence,
                    target_model,
                ),
            ) in enumerate(
                zip(
                    template_chains,
                    template_sequences,
                    target_chains,
                    target_sequences,
                    target_models,
                )
            ):
                msa_chain = ascii_upperlower[i]
                this_template_model = pickle.loads(pickle.dumps(template_model, -1))
                this_target_model = pickle.loads(pickle.dumps(target_model, -1))
                print(f"Aligning fasta sequence {i+1} (seq: {target_sequence[0:10]}...) to template chain {template_chain} (seq: {template_sequence[0:10]}...)")
                alignment = do_align(
                    template_sequence,
                    this_template_model,
                    target_sequence,
                    this_target_model,
                    alignment_type="blast",
                )
                sto_alignment = format_alignment_stockholm(
                    alignment, hit_id=next_id, hit_chain=template_chain
                )
    
    
                msa_path = f"msas/{msa_chain}"
    
                # write alignment to file
                Path(out_dir, msa_path).mkdir(parents=True, exist_ok=True)
                with open(
                    Path(out_dir, msa_path, "pdb_hits.sto"),
                    mode="a" if append else "w",
                ) as pdb_hits:
                    for line in sto_alignment:
                        pdb_hits.write(line)
        if temp_reps > 1:
            append = True
        print("Predicting...")

    if msa_mode.value == "no_MSA": # same as "no_MSA" on the AF_unmasked paper
      unpaired_msa = []
      for i, ts in enumerate(target_sequences):
        unpaired_msa.append(f"> seq_{i}\n{ts}")
      query_seqs_unique = set(target_sequences)
    else: # Alignments rely on ColabFold's MSA server
      print("Querying ColabFold's MSA server")
      msa_lines = None
      use_templates = False
      custom_template_path = None
      pair_mode = "unpaired"
      pairing_strategy = "greedy"
      host_url = DEFAULT_API_SERVER
      version = importlib_metadata.version("colabfold")
      commit = get_commit()
      if commit:
          version += f" ({commit})"
      user_agent = f"colabfold/{version}"

      unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features = get_msa_and_templates(jobname.value, target_sequences, msa_lines, out_dir, msa_mode, use_templates,
                              custom_template_path, pair_mode, pairing_strategy, host_url, user_agent)

    for sequence, msa in zip(query_seqs_unique, unpaired_msa):
      chains = seq2chain[sequence]
      for chain in chains:
        out_dir.joinpath(f"msas/{chain}/bfd_uniref_hits.a3m").write_text(msa)
        pseudo_uniprot = open(out_dir.joinpath(f"msas/{chain}/uniprot_hits.a3m"), "w")
        pseudo_uniprot.write(f"> {chain}\n")
        pseudo_uniprot.write(str(sequence))
        pseudo_uniprot.close()

        input_handle  = open(out_dir.joinpath(f"msas/{chain}/uniprot_hits.a3m"), "r")
        output_handle = open(out_dir.joinpath(f"msas/{chain}/uniprot_hits.sto"), "w")

        alignments = AlignIO.parse(input_handle, "fasta")
        AlignIO.write(alignments, output_handle, "stockholm")

        output_handle.close()
        input_handle.close()

    data_dir = Path("./")
    if not glob.glob(f"{data_dir}/params/*_finished.txt"):
      print("downloading AF parameters...")
      download_alphafold_params(model_type.value, data_dir)

    template_searcher = hmmsearch.Hmmsearch(
        binary_path=shutil.which("hmmsearch"),
        hmmbuild_binary_path=shutil.which("hmmbuild"),
        database_path=out_dir.joinpath(f"template_data/pdb_seqres.txt"))

    template_featurizer = templates.HmmsearchHitFeaturizer(
        mmcif_dir=mmcif_path.resolve(),
        max_template_date="2100-01-01",
        max_hits=4,
        kalign_binary_path=shutil.which("kalign"),
        release_dates_path=None,
        obsolete_pdbs_path=None)

    monomer_data_pipeline = pipeline.DataPipeline(
        jackhmmer_binary_path=shutil.which("jackhmmer"),
        hhblits_binary_path=shutil.which("hhblits"),
        uniref90_database_path=".",
        mgnify_database_path="",
        bfd_database_path="",
        uniref30_database_path="",
        small_bfd_database_path="",
        template_searcher=template_searcher,
        template_featurizer=template_featurizer,
        use_small_bfd=False,
        use_precomputed_msas=True,
        mgnify_max_hits=1,
        uniref_max_hits=1,
        bfd_max_hits=msa_depth,
        no_uniref=True,
        no_mgnify=True)

    data_pipeline = pipeline_multimer.DataPipeline(
        monomer_data_pipeline=monomer_data_pipeline,
        jackhmmer_binary_path=shutil.which("jackhmmer"),
        uniprot_database_path=None,
        use_precomputed_msas=True,
        max_uniprot_hits=1,
        separate_homomer_msas=True)

    model_names = ["model_5_multimer_v2"] if model_type.value == "alphafold2_multimer_v2" else ["model_5_multimer_v3"]

    model_runners = {}

    for model_name in model_names:
        model_config = config.model_config(model_name)
        model_config.model.num_ensemble_eval = 1
        model_config.model.embeddings_and_evoformer.cross_chain_templates = True
        model_config.model.num_recycle = int(num_recycles.value)
        model_config.model.global_config.eval_dropout = use_dropout.value
        model_config.model.recycle_early_stop_tolerance = recycle_early_stop_tolerance.value

        model_params = data.get_model_haiku_params(
            model_name=model_name, data_dir=str(data_dir))
        model_runner = model.RunModel(model_config, model_params)
        for i in range(int(predictions_per_model.value)):
          model_runners[f'{model_name}_pred_{i}'] = model_runner

    predict_structure(
        fasta_path=targets,
        fasta_name=jobname.value,
        output_dir_base=f"{cwd}",
        data_pipeline=data_pipeline,
        model_runners=model_runners,
        benchmark=False,
        random_seed=0,
        models_to_relax=None)

    done = True

my_folder = mr.OutputDir()

In [None]:
if not done:
    mr.Stop()

rank_num = mr.Select(label="Select structure by rank", choices=["0", "1", "2", "3", "4", "5"], value="0")
color = mr.Select(label="Select structure coloring", choices=["chain", "lDDT", "rainbow"], value="chain")
show_sidechains = mr.Checkbox(value=False, label="Show sidechains")
show_mainchains = mr.Checkbox(value=False, label="Show mainchains")

prediction_pdb = f"{out_dir}/ranked_{rank_num.value}.pdb"
template_pdb = f"{mmcif_path}/0000.cif"

print("Template")
show_pdb(template_pdb, "cif", show_sidechains, show_mainchains, color).show()
print("Prediction")
show_pdb(prediction_pdb, "pdb", show_sidechains, show_mainchains, color).show()