In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('/data/morrisq/baalii/alphafold/')

import os
os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'

"""Full AlphaFold protein structure prediction script."""
import json
import os
import pathlib
import pickle
import random
import time
from typing import Dict

from absl import app
from absl import flags
from absl import logging
from alphafold.common import protein
from alphafold.common import residue_constants
from alphafold.data import pipeline
from alphafold.data import templates
from alphafold.model import data
from alphafold.model import config
from alphafold.model import model
from alphafold.relax import relax
import numpy as np

from IPython.utils import io
import subprocess
import tqdm.notebook

import jax
if jax.local_devices()[0].platform == 'tpu':
    raise RuntimeError('Colab TPU runtime not supported. Change it to GPU via Runtime -> Change Runtime Type -> Hardware accelerator -> GPU.')
elif jax.local_devices()[0].platform == 'cpu':
    raise RuntimeError('Colab CPU runtime not supported. Change it to GPU via Runtime -> Change Runtime Type -> Hardware accelerator -> GPU.')

MAX_TEMPLATE_HITS = 20
RELAX_MAX_ITERATIONS = 0
RELAX_ENERGY_TOLERANCE = 2.39
RELAX_STIFFNESS = 10.0
RELAX_EXCLUDE_RESIDUES = []
RELAX_MAX_OUTER_ITERATIONS = 20

In [3]:



def _check_flag(flag_name: str, preset: str, should_be_set: bool):
    if should_be_set != bool(FLAGS[flag_name].value):
        verb = 'be' if should_be_set else 'not be'
        raise ValueError(f'{flag_name} must {verb} set for preset "{preset}"')


def predict_structure(
    fasta_path: str,
    fasta_name: str,
    output_dir_base: str,
    data_pipeline: pipeline.DataPipeline,
    model_runners: Dict[str, model.RunModel],
    amber_relaxer: relax.AmberRelaxation,
    benchmark: bool,
    random_seed: int):
    """Predicts structure using AlphaFold for the given sequence."""
    timings = {}
    output_dir = os.path.join(output_dir_base, fasta_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    msa_output_dir = os.path.join(output_dir, 'msas')
    if not os.path.exists(msa_output_dir):
        os.makedirs(msa_output_dir)

    # Get features.
    t_0 = time.time()
    feature_dict = data_pipeline.process(
        input_fasta_path=fasta_path,
        msa_output_dir=msa_output_dir)
    timings['features'] = time.time() - t_0

    # Write out features as a pickled dictionary.
    features_output_path = os.path.join(output_dir, 'features.pkl')
    with open(features_output_path, 'wb') as f:
        pickle.dump(feature_dict, f, protocol=4)

    relaxed_pdbs = {}
    plddts = {}

    # Run the models.
    for model_name, model_runner in model_runners.items():
        logging.info('Running model %s', model_name)
        t_0 = time.time()
        processed_feature_dict = model_runner.process_features(
            feature_dict, random_seed=random_seed)
        timings[f'process_features_{model_name}'] = time.time() - t_0

        t_0 = time.time()
        prediction_result = model_runner.predict(processed_feature_dict)
        t_diff = time.time() - t_0
        timings[f'predict_and_compile_{model_name}'] = t_diff
        logging.info(
            'Total JAX model %s predict time (includes compilation time, see --benchmark): %.0f?',
            model_name, t_diff)
        if benchmark:
            t_0 = time.time()
            model_runner.predict(processed_feature_dict)
            timings[f'predict_benchmark_{model_name}'] = time.time() - t_0

        # Get mean pLDDT confidence metric.
        plddt = prediction_result['plddt']
        plddts[model_name] = np.mean(plddt)

        # Save the model outputs.
        result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl')
        with open(result_output_path, 'wb') as f:
            pickle.dump(prediction_result, f, protocol=4)

        # Add the predicted LDDT in the b-factor column.
        # Note that higher predicted LDDT value means higher model confidence.
        plddt_b_factors = np.repeat(
            plddt[:, None], residue_constants.atom_type_num, axis=-1)
        unrelaxed_protein = protein.from_prediction(
            features=processed_feature_dict,
            result=prediction_result,
            b_factors=plddt_b_factors)

        unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
        with open(unrelaxed_pdb_path, 'w') as f:
            f.write(protein.to_pdb(unrelaxed_protein))

        # Relax the prediction.
        t_0 = time.time()
        relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
        timings[f'relax_{model_name}'] = time.time() - t_0

        relaxed_pdbs[model_name] = relaxed_pdb_str

        # Save the relaxed PDB.
        relaxed_output_path = os.path.join(output_dir, f'relaxed_{model_name}.pdb')
        with open(relaxed_output_path, 'w') as f:
            f.write(relaxed_pdb_str)

    # Rank by pLDDT and write out relaxed PDBs in rank order.
    ranked_order = []
    for idx, (model_name, _) in enumerate(
        sorted(plddts.items(), key=lambda x: x[1], reverse=True)):
        ranked_order.append(model_name)
        ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb')
        with open(ranked_output_path, 'w') as f:
            f.write(relaxed_pdbs[model_name])

    ranking_output_path = os.path.join(output_dir, 'ranking_debug.json')
    with open(ranking_output_path, 'w') as f:
        f.write(json.dumps({'plddts': plddts, 'order': ranked_order}, indent=4))

    logging.info('Final timings for %s: %s', fasta_name, timings)

    timings_output_path = os.path.join(output_dir, 'timings.json')
    with open(timings_output_path, 'w') as f:
        f.write(json.dumps(timings, indent=4))

In [4]:
fasta_paths=['../data/input.fasta']
output_dir='../alphafold_output'
model_names=['model_1']
data_dir= '../data' 
uniref90_database_path='../data/uniref90/uniref90.fasta'
mgnify_database_path='../data/mgnify/mgy_clusters_2018_12.fa'
small_bfd_database_path='../data/small_fbd/bfd-first_non_consensus_sequences.fasta'
pdb70_database_path='../data/pdb70/pdb70'
template_mmcif_dir='../data/pdb_mmcif/mmcif_files'
max_template_date='2020-05-14'
obsolete_pdbs_path='../data/pdb_mmcif/obsolete.dat'
preset='reduced_dbs'

bfd_database_path=None
uniclust30_database_path = None

# Binary path (change me if required)
hhblits_binary_path=!which hhblits
hhblits_binary_path = hhblits_binary_path[0]
hhsearch_binary_path=!which hhsearch
hhsearch_binary_path = hhsearch_binary_path[0]
jackhmmer_binary_path=!which jackhmmer
jackhmmer_binary_path = jackhmmer_binary_path[0]
kalign_binary_path=!which kalign
kalign_binary_path = kalign_binary_path[0]

random_seed = 1234
benchmark = False




use_small_bfd = preset == 'reduced_dbs'
if preset in ('reduced_dbs', 'full_dbs'):
    num_ensemble = 1
elif preset == 'casp14':
    num_ensemble = 8

# Check for duplicate FASTA file names.
fasta_names = [pathlib.Path(p).stem for p in fasta_paths]
print(fasta_names)
if len(fasta_names) != len(set(fasta_names)):
    raise ValueError('All FASTA paths must have a unique basename.')

template_featurizer = templates.TemplateHitFeaturizer(
  mmcif_dir=template_mmcif_dir,
  max_template_date=max_template_date,
  max_hits=MAX_TEMPLATE_HITS,
  kalign_binary_path=kalign_binary_path,
  release_dates_path=None,
  obsolete_pdbs_path=obsolete_pdbs_path)

data_pipeline = pipeline.DataPipeline(
  jackhmmer_binary_path=jackhmmer_binary_path,
  hhblits_binary_path=hhblits_binary_path,
  hhsearch_binary_path=hhsearch_binary_path,
  uniref90_database_path=uniref90_database_path,
  mgnify_database_path=mgnify_database_path,
  bfd_database_path=bfd_database_path,
  uniclust30_database_path=uniclust30_database_path,
  small_bfd_database_path=small_bfd_database_path,
  pdb70_database_path=pdb70_database_path,
  template_featurizer=template_featurizer,
  use_small_bfd=use_small_bfd)

['input']


In [None]:
model_runners = {}
for model_name in model_names:
    model_config = config.model_config(model_name)
    model_config.data.eval.num_ensemble = num_ensemble
    model_params = data.get_model_haiku_params(
        model_name=model_name, data_dir=data_dir)
    model_runner = model.RunModel(model_config, model_params)
    model_runners[model_name] = model_runner

logging.info('Have %d models: %s', len(model_runners),
           list(model_runners.keys()))

amber_relaxer = relax.AmberRelaxation(
  max_iterations=RELAX_MAX_ITERATIONS,
  tolerance=RELAX_ENERGY_TOLERANCE,
  stiffness=RELAX_STIFFNESS,
  exclude_residues=RELAX_EXCLUDE_RESIDUES,
  max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS)

random_seed = None
if random_seed is None:
    random_seed = random.randrange(sys.maxsize)
logging.info('Using random seed %d for the data pipeline', random_seed)

In [50]:
output_dir_base = '../alphafold_output/'
fasta_path, fasta_name = fasta_paths[0], fasta_names[0]
# Predict structure for each of the sequences.

"""Predicts structure using AlphaFold for the given sequence."""
timings = {}
output_dir = os.path.join(output_dir_base, fasta_name)
# if not os.path.exists(output_dir):
#     os.makedirs(output_dir)
# msa_output_dir = os.path.join(output_dir, 'msas')
# if not os.path.exists(msa_output_dir):
#     os.makedirs(msa_output_dir)

# Get features.
t_0 = time.time()
# Read out features as a pickled dictionary.
features_output_path = os.path.join(output_dir, 'features.pkl')
with open(features_output_path, 'rb') as f:
    feature_dict = pickle.load(f)
timings['features'] = time.time() - t_0


relaxed_pdbs = {}
plddts = {}

# Run the models.
for model_name, model_runner in model_runners.items():
    print(f'Running model {model_name}')
    t_0 = time.time()
    processed_feature_dict = model_runner.process_features(
        feature_dict, random_seed=random_seed)
    timings[f'process_features_{model_name}'] = time.time() - t_0

    t_0 = time.time()
    prediction_result = model_runner.predict(processed_feature_dict)
    
    
#     t_diff = time.time() - t_0
#     timings[f'predict_and_compile_{model_name}'] = t_diff
#     print( f'Total JAX model {model_name} predict time (includes compilation time, see --benchmark): {t_diff:%.0f}?')
#     if benchmark:
#         t_0 = time.time()
#         model_runner.predict(processed_feature_dict)
#         timings[f'predict_benchmark_{model_name}'] = time.time() - t_0

#     # Get mean pLDDT confidence metric.
#     plddt = prediction_result['plddt']
#     plddts[model_name] = np.mean(plddt)

#     # Save the model outputs.
#     result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl')
#     with open(result_output_path, 'wb') as f:
#         pickle.dump(prediction_result, f, protocol=4)

#     # Add the predicted LDDT in the b-factor column.
#     # Note that higher predicted LDDT value means higher model confidence.
#     plddt_b_factors = np.repeat(
#         plddt[:, None], residue_constants.atom_type_num, axis=-1)
#     unrelaxed_protein = protein.from_prediction(
#         features=processed_feature_dict,
#         result=prediction_result,
#         b_factors=plddt_b_factors)

#     unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
#     with open(unrelaxed_pdb_path, 'w') as f:
#         f.write(protein.to_pdb(unrelaxed_protein))

#     # Relax the prediction.
#     t_0 = time.time()
#     relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
#     timings[f'relax_{model_name}'] = time.time() - t_0

#     relaxed_pdbs[model_name] = relaxed_pdb_str

#     # Save the relaxed PDB.
#     relaxed_output_path = os.path.join(output_dir, f'relaxed_{model_name}.pdb')
#     with open(relaxed_output_path, 'w') as f:
#         f.write(relaxed_pdb_str)

# # Rank by pLDDT and write out relaxed PDBs in rank order.
# ranked_order = []
# for idx, (model_name, _) in enumerate(
#     sorted(plddts.items(), key=lambda x: x[1], reverse=True)):
#     ranked_order.append(model_name)
#     ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb')
#     with open(ranked_output_path, 'w') as f:
#         f.write(relaxed_pdbs[model_name])

# ranking_output_path = os.path.join(output_dir, 'ranking_debug.json')
# with open(ranking_output_path, 'w') as f:
#     f.write(json.dumps({'plddts': plddts, 'order': ranked_order}, indent=4))

# print(f'Final timings for {fasta_name}: {timings}')

# timings_output_path = os.path.join(output_dir, 'timings.json')
# with open(timings_output_path, 'w') as f:
#     f.write(json.dumps(timings, indent=4))

Running model model_1
line 432 num_iter: 3
hk.while_loop
line 164 batch0:  dict_keys(['aatype', 'atom14_atom_exists', 'atom37_atom_exists', 'bert_mask', 'extra_deletion_value', 'extra_has_deletion', 'extra_msa', 'extra_msa_mask', 'extra_msa_row_mask', 'is_distillation', 'msa_feat', 'msa_mask', 'msa_row_mask', 'random_crop_to_size_seed', 'residue_index', 'residx_atom14_to_atom37', 'residx_atom37_to_atom14', 'seq_length', 'seq_mask', 'target_feat', 'template_aatype', 'template_all_atom_masks', 'template_all_atom_positions', 'template_mask', 'template_pseudo_beta', 'template_pseudo_beta_mask', 'template_sum_probs', 'true_msa', 'prev_msa_first_row', 'prev_pair', 'prev_pos'])
out_x --> dict_keys(['prev_msa_first_row', 'prev_pair', 'prev_pos'])
stacks: Traced<ShapedArray(float32[3,323,256])>with<DynamicJaxprTrace(level=0/1)>
prev: {'prev_msa_first_row': Traced<ShapedArray(float32[323,256])>with<DynamicJaxprTrace(level=0/1)>, 'prev_pair': Traced<ShapedArray(float32[323,323,128])>with<DynamicJ

In [66]:
output_dir_base = output_dir

# Predict structure for each of the sequences.
for fasta_path, fasta_name in zip(fasta_paths, fasta_names):
    """Predicts structure using AlphaFold for the given sequence."""
    timings = {}
    output_dir = os.path.join(output_dir_base, fasta_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    msa_output_dir = os.path.join(output_dir, 'msas')
    if not os.path.exists(msa_output_dir):
        os.makedirs(msa_output_dir)

    # Get features.
    t_0 = time.time()
    feature_dict = data_pipeline.process(
        input_fasta_path=fasta_path,
        msa_output_dir=msa_output_dir)
    timings['features'] = time.time() - t_0

    # Write out features as a pickled dictionary.
    features_output_path = os.path.join(output_dir, 'features.pkl')
    with open(features_output_path, 'wb') as f:
        pickle.dump(feature_dict, f, protocol=4)

    relaxed_pdbs = {}
    plddts = {}

    # Run the models.
    for model_name, model_runner in model_runners.items():
        print('Running model %s', model_name)
        t_0 = time.time()
        processed_feature_dict = model_runner.process_features(
            feature_dict, random_seed=random_seed)
        timings[f'process_features_{model_name}'] = time.time() - t_0

        t_0 = time.time()
        prediction_result = model_runner.predict(processed_feature_dict)
        t_diff = time.time() - t_0
        timings[f'predict_and_compile_{model_name}'] = t_diff
        print(
            'Total JAX model %s predict time (includes compilation time, see --benchmark): %.0f?',
            model_name, t_diff)
        if benchmark:
            t_0 = time.time()
            model_runner.predict(processed_feature_dict)
            timings[f'predict_benchmark_{model_name}'] = time.time() - t_0

        # Get mean pLDDT confidence metric.
        plddt = prediction_result['plddt']
        plddts[model_name] = np.mean(plddt)

        # Save the model outputs.
        result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl')
        with open(result_output_path, 'wb') as f:
            pickle.dump(prediction_result, f, protocol=4)

        # Add the predicted LDDT in the b-factor column.
        # Note that higher predicted LDDT value means higher model confidence.
        plddt_b_factors = np.repeat(
            plddt[:, None], residue_constants.atom_type_num, axis=-1)
        unrelaxed_protein = protein.from_prediction(
            features=processed_feature_dict,
            result=prediction_result,
            b_factors=plddt_b_factors)

        unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
        with open(unrelaxed_pdb_path, 'w') as f:
            f.write(protein.to_pdb(unrelaxed_protein))

        # Relax the prediction.
        t_0 = time.time()
        relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
        timings[f'relax_{model_name}'] = time.time() - t_0

        relaxed_pdbs[model_name] = relaxed_pdb_str

        # Save the relaxed PDB.
        relaxed_output_path = os.path.join(output_dir, f'relaxed_{model_name}.pdb')
        with open(relaxed_output_path, 'w') as f:
            f.write(relaxed_pdb_str)

    # Rank by pLDDT and write out relaxed PDBs in rank order.
    ranked_order = []
    for idx, (model_name, _) in enumerate(
        sorted(plddts.items(), key=lambda x: x[1], reverse=True)):
        ranked_order.append(model_name)
        ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb')
        with open(ranked_output_path, 'w') as f:
            f.write(relaxed_pdbs[model_name])

    ranking_output_path = os.path.join(output_dir, 'ranking_debug.json')
    with open(ranking_output_path, 'w') as f:
        f.write(json.dumps({'plddts': plddts, 'order': ranked_order}, indent=4))

    print('Final timings for %s: %s', fasta_name, timings)

    timings_output_path = os.path.join(output_dir, 'timings.json')
    with open(timings_output_path, 'w') as f:
        f.write(json.dumps(timings, indent=4))

W1030 17:18:52.907283 47955570354304 templates.py:131] Template structure not in release dates dict: 3oyr
W1030 17:18:53.504036 47955570354304 templates.py:131] Template structure not in release dates dict: 5ze6
W1030 17:18:53.505498 47955570354304 templates.py:131] Template structure not in release dates dict: 4jyx
W1030 17:18:55.962148 47955570354304 templates.py:131] Template structure not in release dates dict: 3mzv
W1030 17:18:56.850518 47955570354304 templates.py:131] Template structure not in release dates dict: 1wy0
W1030 17:18:57.217008 47955570354304 templates.py:131] Template structure not in release dates dict: 3aqb
W1030 17:18:58.051584 47955570354304 templates.py:131] Template structure not in release dates dict: 3aqc
W1030 17:18:59.192491 47955570354304 templates.py:131] Template structure not in release dates dict: 3aq0
W1030 17:19:02.650371 47955570354304 templates.py:131] Template structure not in release dates dict: 3pko
W1030 17:19:03.595981 47955570354304 templates

NameError: name 'k' is not defined