# AlphaFold Sandbox


In [1]:
import os

os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'

import json
import pathlib
import pickle
import random
import sys
import time
from tqdm.notebook import tqdm

from typing import Dict

from absl import app
from absl import flags
from absl import logging
import numpy as np

from alphafold.common import protein

from alphafold.data import parsers
from alphafold.data import pipeline
from alphafold.data import templates
from alphafold.data.tools import jackhmmer

from alphafold.model import data
from alphafold.model import config
from alphafold.model import model
from alphafold.relax import relax
from alphafold.relax import utils


In [2]:
PROJECT = 'jk-mlops-dev'
REGION = 'us-central1'
ALPHAFOLD_DATASETS_BUCKET_NAME = 'jk-alphafold-datasets'
STAGING_BUCKET_NAME = 'jk-alphafold-staging'

API_ENDPOINT = '{}-aiplatform.googleapis.com'.format(REGION)
PARENT = "projects/" + PROJECT + "/locations/" + REGION

ALPHA_FOLD_IMAGE_NAME = f'gcr.io/{PROJECT}/alphafold'

In [3]:
MIN_SEQUENCE_LENGTH = 16
MAX_SEQUENCE_LENGTH = 2500


MAX_TEMPLATE_HITS = 20
RELAX_MAX_ITERATIONS = 0
RELAX_ENERGY_TOLERANCE = 2.39
RELAX_STIFFNESS = 10.0
RELAX_EXCLUDE_RESIDUES = []
RELAX_MAX_OUTER_ITERATIONS = 20

PLDDT_BANDS = [(0, 50, '#FF7D45'),
               (50, 70, '#FFDB13'),
               (70, 90, '#65CBF3'),
               (90, 100, '#0053D6')]

ALPHAFOLD_DATASETS_ROOT = '/home/jupyter/alphafold-datasets/jk-alphafold-datasets'
JACKHMMR_BINARY_PATH = '/usr/bin/jackhmmer'

## Prepare test file

In [4]:
sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' 
sequence_file = 'target.fasts'

# Remove all whitespaces, tabs and end lines; upper-case
sequence = sequence.translate(str.maketrans('', '', ' \n\t')).upper()
aatypes = set('ACDEFGHIKLMNPQRSTVWY')  # 20 standard aatypes
if not set(sequence).issubset(aatypes):
  raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. AlphaFold only supports 20 standard amino acids as inputs.')
if len(sequence) < MIN_SEQUENCE_LENGTH:
  raise Exception(f'Input sequence is too short: {len(sequence)} amino acids, while the minimum is {MIN_SEQUENCE_LENGTH}')
if len(sequence) > MAX_SEQUENCE_LENGTH:
  raise Exception(f'Input sequence is too long: {len(sequence)} amino acids, while the maximum is {MAX_SEQUENCE_LENGTH}. Please use the full AlphaFold system for long sequences.')

with open(sequence_file, 'wt') as f:
  f.write(f'>query\n{sequence}')

In [5]:
! cat {sequence_file}

>query
MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH

## Feature Engineering

### Search genetic databases

In [6]:
dbs = []
num_jackhmmer_chunks = {'uniref90': 59, 'smallbfd': 17, 'mgnify': 71}

with tqdm(total=3) as pbar:
    def jackhmmer_chunk_callback(i):
        pbar.update(i)
        
    pbar.set_description('Searching uniref90')
    jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
        binary_path=JACKHMMR_BINARY_PATH,
        database_path=os.path.join(ALPHAFOLD_DATASETS_ROOT, 'uniref90', 'uniref90.fasta'),
    )
    dbs.append(('uniref90', jackhmmer_uniref90_runner.query(sequence_file)))
    pbar.update(1)
    
    pbar.set_description('Searching small_bfd')
    jackhmmer_smallbfd_runner = jackhmmer.Jackhmmer(
        binary_path=JACKHMMR_BINARY_PATH,
        database_path=os.path.join(ALPHAFOLD_DATASETS_ROOT, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta'),
    )
    dbs.append(('smallbfd', jackhmmer_smallbfd_runner.query(sequence_file)))
    pbar.update(1)
    
    pbar.set_description('Searching mgnify')
    jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
        binary_path=JACKHMMR_BINARY_PATH,
        database_path=os.path.join(ALPHAFOLD_DATASETS_ROOT, 'mgnify', 'mgy_clusters.fa'),
    )
    dbs.append(('mgnify', jackhmmer_mgnify_runner.query(sequence_file)))
    pbar.update(1)
    
    

  0%|          | 0/3 [00:00<?, ?it/s]

In [15]:
dbs[0][1][0].keys()

dict_keys(['sto', 'tbl', 'stderr', 'n_iter', 'e_value'])

### Save the results

In [18]:
results_dir = 'search_results'

for db_name, db_results in dbs:
    file_name = f'{db_name}.sto'
    with open(os.path.join(results_dir, file_name), 'wt') as f:
        f.write(db_results[0]['sto'])

### Extract and visualize MSAs

In [21]:
mgnify_max_hits = 501

msas = []
deletion_matrices = []
full_msa = []

for db_name, db_results in dbs:
    unsorted_results = []
    for i, result in enumerate(db_results):
        msa, deletion_matrix, target_names = parsers.parse_stockholm(result['sto'])
        e_values_dict = parsers.parse_e_values_from_tblout(result['tbl'])
        e_values = [e_values_dict[t.split('/')[0]] for t in target_names]
        zipped_results = zip(msa, deletion_matrix, target_names, e_values)
        if i != 0:
            # Only take query from the first chunk
           zipped_results = [x for x in zipped_results if x[2] != 'query']
        unsorted_results.extend(zipped_results)
  

KeyError: 'UniRef90_D7BIZ4'