Skip to content

Commit

Permalink
Release code for v2.3.0
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 494507694
  • Loading branch information
Augustin-Zidek committed Dec 11, 2022
1 parent 4494af8 commit 9b18d6a
Show file tree
Hide file tree
Showing 30 changed files with 894 additions and 498 deletions.
242 changes: 130 additions & 112 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion afdb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,9 @@ fractionPlddtVeryHigh | `FLOAT64` | Fraction of the residues in the predi
fractionPlddtVeryLow | `FLOAT64` | Fraction of the residues in the prediction with pLDDT less than 50
gene | `STRING` | The name of the gene if known, e.g. "COII"
geneSynonyms | `ARRAY<STRING>` | Additional synonyms for the gene
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
isReferenceProteome | `BOOL` | Is this protein part of the reference proteome?
isReviewed | `BOOL` | Has this protein been reviewed, i.e. is it part of SwissProt?
globalMetricValue | `FLOAT64` | The mean pLDDT of this prediction
latestVersion | `INT64` | The latest AFDB version for this prediction
modelCreatedDate | `DATE` | The date of creation for this entry, e.g. "2022-06-01"
organismCommonNames | `ARRAY<STRING>` | List of common organism names
Expand Down
14 changes: 7 additions & 7 deletions alphafold/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(self,
uniref90_database_path: str,
mgnify_database_path: str,
bfd_database_path: Optional[str],
uniclust30_database_path: Optional[str],
uniref30_database_path: Optional[str],
small_bfd_database_path: Optional[str],
template_searcher: TemplateSearcher,
template_featurizer: templates.TemplateHitFeaturizer,
Expand All @@ -135,9 +135,9 @@ def __init__(self,
binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path)
else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path])
databases=[bfd_database_path, uniref30_database_path])
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path)
Expand Down Expand Up @@ -211,14 +211,14 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniclust_runner,
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniref_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])

templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
Expand Down
61 changes: 61 additions & 0 deletions alphafold/model/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,64 @@ def __call__(self, inputs):

return output


class LayerNorm(hk.LayerNorm):
"""LayerNorm module.
Equivalent to hk.LayerNorm but with different parameter shapes: they are
always vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""

def __init__(self,
axis,
create_scale: bool,
create_offset: bool,
eps: float = 1e-5,
scale_init=None,
offset_init=None,
use_fast_variance: bool = False,
name=None,
param_axis=None):
super().__init__(
axis=axis,
create_scale=False,
create_offset=False,
eps=eps,
scale_init=None,
offset_init=None,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis)
self._temp_create_scale = create_scale
self._temp_create_offset = create_offset

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
is_bf16 = (x.dtype == jnp.bfloat16)
if is_bf16:
x = x.astype(jnp.float32)

param_axis = self.param_axis[0] if self.param_axis else -1
param_shape = (x.shape[param_axis],)

param_broadcast_shape = [1] * x.ndim
param_broadcast_shape[param_axis] = x.shape[param_axis]
scale = None
offset = None
if self._temp_create_scale:
scale = hk.get_parameter(
'scale', param_shape, x.dtype, init=self.scale_init)
scale = scale.reshape(param_broadcast_shape)

if self._temp_create_offset:
offset = hk.get_parameter(
'offset', param_shape, x.dtype, init=self.offset_init)
offset = offset.reshape(param_broadcast_shape)

out = super().__call__(x, scale=scale, offset=offset)

if is_bf16:
out = out.astype(jnp.bfloat16)

return out

88 changes: 64 additions & 24 deletions alphafold/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model."""

if 'multimer' in name:
return CONFIG_MULTIMER

if name not in CONFIG_DIFFS:
raise ValueError(f'Invalid model name {name}.')
cfg = copy.deepcopy(CONFIG)
if 'multimer' in name:
cfg = copy.deepcopy(CONFIG_MULTIMER)
else:
cfg = copy.deepcopy(CONFIG)
cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
return cfg

Expand All @@ -52,11 +52,11 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'model_5_ptm',
),
'multimer': (
'model_1_multimer_v2',
'model_2_multimer_v2',
'model_3_multimer_v2',
'model_4_multimer_v2',
'model_5_multimer_v2',
'model_1_multimer_v3',
'model_2_multimer_v3',
'model_3_multimer_v3',
'model_4_multimer_v3',
'model_5_multimer_v3',
),
}
MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer']
Expand Down Expand Up @@ -118,8 +118,32 @@ def model_config(name: str) -> ml_collections.ConfigDict:
},
'model_5_ptm': {
'model.heads.predicted_aligned_error.weight': 0.1
}
},
'model_1_multimer_v3': {},
'model_2_multimer_v3': {},
'model_3_multimer_v3': {},
'model_4_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
'model_5_multimer_v3': {
'model.embeddings_and_evoformer.num_extra_msa': 1152
},
}
# Key differences between multimer v1/v2 and v3, mostly due to numerical
# optimisations in the TriangleMultiplication module.
common_updates = {
'model.embeddings_and_evoformer.num_msa': 252,
'model.embeddings_and_evoformer.num_extra_msa': 1152,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights': False,
'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights': False,
}
CONFIG_DIFFS.update(
{f'model_{i}_multimer': common_updates for i in range(1, 6)})
CONFIG_DIFFS.update(
{f'model_{i}_multimer_v2': common_updates for i in range(1, 6)})


CONFIG = ml_collections.ConfigDict({
'data': {
Expand Down Expand Up @@ -260,14 +284,16 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
Expand Down Expand Up @@ -328,14 +354,16 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': False,
},
'pair_transition': {
'dropout_rate': 0.0,
Expand All @@ -354,7 +382,7 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'multimer_mode': False,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
'zero_init': True,
},
'heads': {
'distogram': {
Expand Down Expand Up @@ -483,27 +511,29 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
}
},
'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4,
'num_msa': 252,
'num_extra_msa': 1152,
'num_msa': 508,
'num_extra_msa': 2048,
'masked_msa': {
'profile_prob': 0.1,
'replace_fraction': 0.15,
Expand Down Expand Up @@ -564,24 +594,28 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
'shared_dropout': True,
'fuse_projection_weights': True,
}
}
},
},
'global_config': {
'bfloat16': True,
'bfloat16_output': False,
'deterministic': False,
'multimer_mode': True,
'subbatch_size': 4,
'use_remat': False,
'zero_init': True
'zero_init': True,
},
'heads': {
'distogram': {
Expand Down Expand Up @@ -651,7 +685,13 @@ def model_config(name: str) -> ml_collections.ConfigDict:
}
},
'num_ensemble_eval': 1,
'num_recycle': 3,
'num_recycle': 20,
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `num_recycle` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
'recycle_early_stop_tolerance': 0.5,
'resample_msa_in_recycling': True
}
})
8 changes: 4 additions & 4 deletions alphafold/model/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def safe_dropout_fn(tensor, safe_key):
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
Expand All @@ -353,7 +353,7 @@ def safe_dropout_fn(tensor, safe_key):
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
Expand Down Expand Up @@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config,
c = config
sequence_mask = batch['seq_mask'][:, None]

act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
Expand All @@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config,
'affine': affine.to_tensor(),
}

act_2d = hk.LayerNorm(
act_2d = common_modules.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
Expand Down
8 changes: 4 additions & 4 deletions alphafold/model/folding_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def safe_dropout_fn(tensor, safe_key):
safe_key, *sub_keys = safe_key.split(3)
sub_keys = iter(sub_keys)
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
Expand All @@ -448,7 +448,7 @@ def safe_dropout_fn(tensor, safe_key):
act = jax.nn.relu(act)
act += input_act
act = safe_dropout_fn(act, next(sub_keys))
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
Expand Down Expand Up @@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
"""
c = config
sequence_mask = batch['seq_mask'][:, None]
act = hk.LayerNorm(
act = common_modules.LayerNorm(
axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')(
representations['single'])

Expand All @@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
rigid
}

act_2d = hk.LayerNorm(
act_2d = common_modules.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
Expand Down
Loading

1 comment on commit 9b18d6a

@richaxviv
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#@title 4. Search against genetic databases

#@markdown Once this cell has been executed, you will see
#@markdown statistics about the multiple sequence alignment
#@markdown (MSA) that will be used by AlphaFold. In particular,
#@markdown you’ll see how well each residue is covered by similar
#@markdown sequences in the MSA.

Track cell execution to ensure correct order

notebook_utils.check_cell_execution_order(executed_cells, 4)

--- Python imports ---

import collections
import copy
from concurrent import futures
import json
import random
import shutil

from urllib import request
from google.colab import files
from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
import py3Dmol

from alphafold.model import model
from alphafold.model import config
from alphafold.model import data

from alphafold.data import feature_processing
from alphafold.data import msa_pairing
from alphafold.data import pipeline
from alphafold.data import pipeline_multimer
from alphafold.data.tools import jackhmmer

from alphafold.common import protein

from alphafold.relax import relax
from alphafold.relax import utils

from IPython import display
from ipywidgets import GridspecLayout
from ipywidgets import Output

Color bands for visualizing plddt

PLDDT_BANDS = [(0, 50, '#FF7D45'),
(50, 70, '#FFDB13'),
(70, 90, '#65CBF3'),
(90, 100, '#0053D6')]

--- Find the closest source ---

test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2022_01.fasta.1'
ex = futures.ThreadPoolExecutor(3)
def fetch(source):
request.urlretrieve(test_url_pattern.format(source))
return source
fs = [ex.submit(fetch, source) for source in ['', '-europe', '-asia']]
source = None
for f in futures.as_completed(fs):
source = f.result()
ex.shutdown()
break

JACKHMMER_BINARY_PATH = '/usr/bin/jackhmmer'
DB_ROOT_PATH = f'https://storage.googleapis.com/alphafold-colab{source}/latest/'

The z_value is the number of sequences in a database.

MSA_DATABASES = [
{'db_name': 'uniref90',
'db_path': f'{DB_ROOT_PATH}uniref90_2022_01.fasta',
'num_streamed_chunks': 62,
'z_value': 144_113_457},
{'db_name': 'smallbfd',
'db_path': f'{DB_ROOT_PATH}bfd-first_non_consensus_sequences.fasta',
'num_streamed_chunks': 17,
'z_value': 65_984_053},
{'db_name': 'mgnify',
'db_path': f'{DB_ROOT_PATH}mgy_clusters_2022_05.fasta',
'num_streamed_chunks': 120,
'z_value': 623_796_864},
]

Search UniProt and construct the all_seq features only for heteromers, not homomers.

if model_type_to_use == ModelType.MULTIMER and len(set(sequences)) > 1:
MSA_DATABASES.extend([
# Swiss-Prot and TrEMBL are concatenated together as UniProt.
{'db_name': 'uniprot',
'db_path': f'{DB_ROOT_PATH}uniprot_2021_04.fasta',
'num_streamed_chunks': 101,
'z_value': 225_013_025 + 565_928},
])

TOTAL_JACKHMMER_CHUNKS = sum([cfg['num_streamed_chunks'] for cfg in MSA_DATABASES])

MAX_HITS = {
'uniref90': 10_000,
'smallbfd': 5_000,
'mgnify': 501,
'uniprot': 50_000,
}

def get_msa(sequences):
"""Searches for MSA for given sequences using chunked Jackhmmer search.

Args:
sequences: A list of sequences to search against all databases.

Returns:
A dictionary mapping unique sequences to dicionaries mapping each database
to a list of results, one for each chunk of the database.
"""
sequence_to_fasta_path = {}

Deduplicate to not do redundant work for multiple copies of the same chain in homomers.

for sequence_index, sequence in enumerate(sorted(set(sequences)), 1):
fasta_path = f'target_{sequence_index:02d}.fasta'
with open(fasta_path, 'wt') as f:
f.write(f'>query\n{sequence}')
sequence_to_fasta_path[sequence] = fasta_path

Run the search against chunks of genetic databases (since the genetic

databases don't fit in Colab disk).

raw_msa_results = {sequence: {} for sequence in sequence_to_fasta_path.keys()}
print('\nGetting MSA for all sequences')
with tqdm.notebook.tqdm(total=TOTAL_JACKHMMER_CHUNKS, bar_format=TQDM_BAR_FORMAT) as pbar:
def jackhmmer_chunk_callback(i):
pbar.update(n=1)

for db_config in MSA_DATABASES:
  db_name = db_config['db_name']
  pbar.set_description(f'Searching {db_name}')
  jackhmmer_runner = jackhmmer.Jackhmmer(
      binary_path=JACKHMMER_BINARY_PATH,
      database_path=db_config['db_path'],
      get_tblout=True,
      num_streamed_chunks=db_config['num_streamed_chunks'],
      streaming_callback=jackhmmer_chunk_callback,
      z_value=db_config['z_value'])
  # Query all unique sequences against each chunk of the database to prevent
  # redunantly fetching each chunk for each unique sequence.
  results = jackhmmer_runner.query_multiple(list(sequence_to_fasta_path.values()))
  for sequence, result_for_sequence in zip(sequence_to_fasta_path.keys(), results):
    raw_msa_results[sequence][db_name] = result_for_sequence

return raw_msa_results

features_for_chain = {}
raw_msa_results_for_sequence = get_msa(sequences)
for sequence_index, sequence in enumerate(sequences, start=1):
raw_msa_results = copy.deepcopy(raw_msa_results_for_sequence[sequence])

Extract the MSAs from the Stockholm files.

NB: deduplication happens later in pipeline.make_msa_features.

single_chain_msas = []
uniprot_msa = None
for db_name, db_results in raw_msa_results.items():
merged_msa = notebook_utils.merge_chunked_msa(
results=db_results, max_hits=MAX_HITS.get(db_name))
if merged_msa.sequences and db_name != 'uniprot':
single_chain_msas.append(merged_msa)
msa_size = len(set(merged_msa.sequences))
print(f'{msa_size} unique sequences found in {db_name} for sequence {sequence_index}')
elif merged_msa.sequences and db_name == 'uniprot':
uniprot_msa = merged_msa

notebook_utils.show_msa_info(single_chain_msas=single_chain_msas, sequence_index=sequence_index)

Turn the raw data into model features.

feature_dict = {}
feature_dict.update(pipeline.make_sequence_features(
sequence=sequence, description='query', num_res=len(sequence)))
feature_dict.update(pipeline.make_msa_features(msas=single_chain_msas))

We don't use templates in AlphaFold Colab notebook, add only empty placeholder features.

feature_dict.update(notebook_utils.empty_placeholder_template_features(
num_templates=0, num_res=len(sequence)))

Construct the all_seq features only for heteromers, not homomers.

if model_type_to_use == ModelType.MULTIMER and len(set(sequences)) > 1:
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers',
)
all_seq_features = {
f'{k}_all_seq': v for k, v in pipeline.make_msa_features([uniprot_msa]).items()
if k in valid_feats}
feature_dict.update(all_seq_features)

features_for_chain[protein.PDB_CHAIN_IDS[sequence_index - 1]] = feature_dict

Do further feature post-processing depending on the model type.

if model_type_to_use == ModelType.MONOMER:
np_example = features_for_chain[protein.PDB_CHAIN_IDS[0]]

elif model_type_to_use == ModelType.MULTIMER:
all_chain_features = {}
for chain_id, chain_features in features_for_chain.items():
all_chain_features[chain_id] = pipeline_multimer.convert_monomer_features(
chain_features, chain_id)

all_chain_features = pipeline_multimer.add_assembly_features(all_chain_features)

np_example = feature_processing.pair_and_merge(
all_chain_features=all_chain_features)

Pad MSA to avoid zero-sized extra_msa.

np_example = pipeline_multimer.pad_msa(np_example, min_num_seq=512)

executed_cells.add(4)

Please sign in to comment.