diff --git a/README.md b/README.md index eb44b29b..3b872113 100644 --- a/README.md +++ b/README.md @@ -90,25 +90,35 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \ --jackhmmer_binary_path `which jackhmmer` \ --hhblits_binary_path `which hhblits` \ --hhsearch_binary_path `which hhsearch` \ - --kalign_binary_path `which kalign` + --kalign_binary_path `which kalign` ``` -or run the script `./inference.sh`, you can change +or run the script `./inference.sh`, you can change the parameter in the script ```shell ./inference.sh ``` #### inference with data workflow -alphafold's data pre-processing takes a lot of time, so we speed up the data pre-process by [ray](https://docs.ray.io/en/latest/workflows/concepts.html) workflow, to run the intference with ray workflow, you should install the package by +alphafold's data pre-processing takes a lot of time, so we speed up the data pre-process by [ray](https://docs.ray.io/en/latest/workflows/concepts.html) workflow, to run the intference with ray workflow, you should install the package and add parameter `--enable_workflow` to cmdline or shell script `./inference.sh` ```shell pip install ray pyarrow ``` - -Than you can run by the script `./inference_with_workflow.sh` - ```shell -./inference_with_flow.sh +python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \ + --output_dir ./ \ + --gpus 2 \ + --uniref90_database_path data/uniref90/uniref90.fasta \ + --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \ + --pdb70_database_path data/pdb70/pdb70 \ + --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ + --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ + --jackhmmer_binary_path `which jackhmmer` \ + --hhblits_binary_path `which hhblits` \ + --hhsearch_binary_path `which hhsearch` \ + --kalign_binary_path `which kalign` \ + --enable_workflow ``` + ## Performance Benchmark We have included a performance benchmark script in `./benchmark`. You can benchmark the performance of Evoformer using different settings. diff --git a/fastfold/workflow/factory/__init__.py b/fastfold/workflow/factory/__init__.py index 3c7bc689..6e70de4d 100644 --- a/fastfold/workflow/factory/__init__.py +++ b/fastfold/workflow/factory/__init__.py @@ -2,6 +2,4 @@ from .hhblits import HHBlitsFactory from .hhsearch import HHSearchFactory from .jackhmmer import JackHmmerFactory -from .alphafold import AlphaFoldFactory -from .amber_relax import AmberRelaxFactory from .hhfilter import HHfilterFactory \ No newline at end of file diff --git a/fastfold/workflow/factory/alphafold.py b/fastfold/workflow/factory/alphafold.py deleted file mode 100644 index f6e4a966..00000000 --- a/fastfold/workflow/factory/alphafold.py +++ /dev/null @@ -1,75 +0,0 @@ -from datetime import date -import imp -from ray import workflow -from typing import List -import time - -import torch -import numpy as np -from fastfold.workflow.factory import TaskFactory -from ray.workflow.common import Workflow - -from fastfold.distributed import init_dap -from fastfold.model.hub import AlphaFold -from fastfold.common import protein, residue_constants -from fastfold.config import model_config -from fastfold.data import data_pipeline, feature_pipeline, templates -from fastfold.utils import inject_fastnn -from fastfold.utils.import_weights import import_jax_weights_ -from fastfold.utils.tensor_utils import tensor_tree_map - -class AlphaFoldFactory(TaskFactory): - - keywords = ['kalign_bin_path', 'template_mmcif_dir', 'param_path', 'model_name'] - - def gen_task(self, fasta_path: str, alignment_dir: str, output_path: str, after: List[Workflow]=None): - - self.isReady() - - # setup runners - config = model_config(self.config.get('model_name')) - template_featurizer = templates.TemplateHitFeaturizer( - mmcif_dir=self.config.get('template_mmcif_dir'), - max_template_date=date.today().strftime("%Y-%m-%d"), - max_hits=config.data.predict.max_templates, - kalign_binary_path=self.config.get('kalign_bin_path') - ) - - data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer) - feature_processor = feature_pipeline.FeaturePipeline(config.data) - - # generate step function - @workflow.step(num_gpus=1) - def alphafold_step(fasta_path: str, alignment_dir: str, output_path: str, after: List[Workflow]) -> None: - - # setup model - init_dap() - model = AlphaFold(config) - import_jax_weights_(model, self.config.get('param_path'), self.config.get('model_name')) - model = inject_fastnn(model) - model = model.eval() - model = model.cuda() - - feature_dict = data_processor.process_fasta( - fasta_path=fasta_path, - alignment_dir=alignment_dir - ) - processed_feature_dict = feature_processor.process_features( - feature_dict, - mode='predict' - ) - with torch.no_grad(): - batch = {k: torch.as_tensor(v).cuda() for k, v in processed_feature_dict.items()} - out = model(batch) - batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) - out = tensor_tree_map(lambda x: np.array(x.cpu()), out) - plddt = out["plddt"] - mean_plddt = np.mean(plddt) - plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1) - unrelaxed_protein = protein.from_prediction(features=batch, - result=out, - b_factors=plddt_b_factors) - with open(output_path, 'w') as f: - f.write(protein.to_pdb(unrelaxed_protein)) - - return alphafold_step.step(fasta_path, alignment_dir, output_path, after) diff --git a/fastfold/workflow/factory/amber_relax.py b/fastfold/workflow/factory/amber_relax.py deleted file mode 100644 index 91b7b947..00000000 --- a/fastfold/workflow/factory/amber_relax.py +++ /dev/null @@ -1,36 +0,0 @@ -from ray import workflow -from typing import List -from fastfold.workflow.factory import TaskFactory -from ray.workflow.common import Workflow - -from fastfold.config import config -import fastfold.relax.relax as relax -from fastfold.common import protein - -class AmberRelaxFactory(TaskFactory): - - keywords = [] - - def gen_task(self, unrelaxed_pdb_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: - - self.isReady() - - # setup runner - amber_relaxer = relax.AmberRelaxation( - use_gpu=True, - **config.relax, - ) - - # generate step function - @workflow.step(num_gpus=1) - def amber_relax_step(unrelaxed_pdb_path: str, output_path: str, after: List[Workflow]) -> None: - - with open(unrelaxed_pdb_path, "r") as f: - pdb_str = f.read() - unrelaxed_protein = protein.from_pdb_string(pdb_str) - relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) - - with open(output_path, "w") as f: - f.write(relaxed_pdb_str) - - return amber_relax_step.step(unrelaxed_pdb_path, output_path, after) diff --git a/inference.py b/inference.py index 5f6fdf16..53108507 100644 --- a/inference.py +++ b/inference.py @@ -31,6 +31,7 @@ from fastfold.config import model_config from fastfold.model.fastnn import set_chunk_size from fastfold.data import data_pipeline, feature_pipeline, templates +from fastfold.workflow.template import FastFoldDataWorkFlow from fastfold.utils import inject_fastnn from fastfold.utils.import_weights import import_jax_weights_ from fastfold.utils.tensor_utils import tensor_tree_map @@ -73,7 +74,7 @@ def add_data_args(parser: argparse.ArgumentParser): ) parser.add_argument('--obsolete_pdbs_path', type=str, default=None) parser.add_argument('--release_dates_path', type=str, default=None) - + parser.add_argument('--enable_workflow', default=False, action='store_true', help='run inference with ray workflow or not') def inference_model(rank, world_size, result_q, batch, args): os.environ['RANK'] = str(rank) @@ -158,20 +159,37 @@ def main(args): if (args.use_precomputed_alignments is None): if not os.path.exists(local_alignment_dir): os.makedirs(local_alignment_dir) - - alignment_runner = data_pipeline.AlignmentRunner( - jackhmmer_binary_path=args.jackhmmer_binary_path, - hhblits_binary_path=args.hhblits_binary_path, - hhsearch_binary_path=args.hhsearch_binary_path, - uniref90_database_path=args.uniref90_database_path, - mgnify_database_path=args.mgnify_database_path, - bfd_database_path=args.bfd_database_path, - uniclust30_database_path=args.uniclust30_database_path, - pdb70_database_path=args.pdb70_database_path, - use_small_bfd=use_small_bfd, - no_cpus=args.cpus, - ) - alignment_runner.run(fasta_path, local_alignment_dir) + if args.enable_workflow: + print("Running alignment with ray workflow...") + alignment_data_workflow_runner = FastFoldDataWorkFlow( + jackhmmer_binary_path=args.jackhmmer_binary_path, + hhblits_binary_path=args.hhblits_binary_path, + hhsearch_binary_path=args.hhsearch_binary_path, + uniref90_database_path=args.uniref90_database_path, + mgnify_database_path=args.mgnify_database_path, + bfd_database_path=args.bfd_database_path, + uniclust30_database_path=args.uniclust30_database_path, + pdb70_database_path=args.pdb70_database_path, + use_small_bfd=use_small_bfd, + no_cpus=args.cpus, + ) + t = time.perf_counter() + alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir) + print(f"Alignment data workflow time: {time.perf_counter() - t}") + else: + alignment_runner = data_pipeline.AlignmentRunner( + jackhmmer_binary_path=args.jackhmmer_binary_path, + hhblits_binary_path=args.hhblits_binary_path, + hhsearch_binary_path=args.hhsearch_binary_path, + uniref90_database_path=args.uniref90_database_path, + mgnify_database_path=args.mgnify_database_path, + bfd_database_path=args.bfd_database_path, + uniclust30_database_path=args.uniclust30_database_path, + pdb70_database_path=args.pdb70_database_path, + use_small_bfd=use_small_bfd, + no_cpus=args.cpus, + ) + alignment_runner.run(fasta_path, local_alignment_dir) feature_dict = data_processor.process_fasta(fasta_path=fasta_path, alignment_dir=local_alignment_dir) diff --git a/inference.sh b/inference.sh index 637b58b4..9520e9a0 100755 --- a/inference.sh +++ b/inference.sh @@ -1,15 +1,16 @@ rm -rf alignments/ rm -rf *.pdb + python inference.py target.fasta /data/scratch/alphafold/alphafold/pdb_mmcif/mmcif_files \ --output_dir ./ \ --gpus 2 \ - --uniref90_database_path /data/scratch/alphafold/alphafold/uniref90/uniref90.fasta \ - --mgnify_database_path /data/scratch/alphafold/alphafold/mgnify/mgy_clusters_2018_12.fa \ - --pdb70_database_path /data/scratch/alphafold/alphafold/pdb70/pdb70 \ - --param_path /data/scratch/alphafold/alphafold/params/params_model_1.npz \ - --uniclust30_database_path /data/scratch/alphafold/alphafold/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ - --bfd_database_path /data/scratch/alphafold/alphafold/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ + --uniref90_database_path data/uniref90/uniref90.fasta \ + --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \ + --pdb70_database_path data/pdb70/pdb70 \ + --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ + --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ --jackhmmer_binary_path `which jackhmmer` \ --hhblits_binary_path `which hhblits` \ --hhsearch_binary_path `which hhsearch` \ - --kalign_binary_path `which kalign` \ No newline at end of file + --kalign_binary_path `which kalign` \ + # --enable_workflow \ No newline at end of file diff --git a/inference_with_flow.sh b/inference_with_flow.sh deleted file mode 100755 index b752b3a7..00000000 --- a/inference_with_flow.sh +++ /dev/null @@ -1,15 +0,0 @@ -rm -rf alignments/ -rm -rf *.pdb -python inference_with_workflow.py target.fasta /data/scratch/alphafold/alphafold/pdb_mmcif/mmcif_files \ - --output_dir ./ \ - --gpus 2 \ - --uniref90_database_path /data/scratch/alphafold/alphafold/uniref90/uniref90.fasta \ - --mgnify_database_path /data/scratch/alphafold/alphafold/mgnify/mgy_clusters_2018_12.fa \ - --pdb70_database_path /data/scratch/alphafold/alphafold/pdb70/pdb70 \ - --param_path /data/scratch/alphafold/alphafold/params/params_model_1.npz \ - --uniclust30_database_path /data/scratch/alphafold/alphafold/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ - --bfd_database_path /data/scratch/alphafold/alphafold/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ - --jackhmmer_binary_path `which jackhmmer` \ - --hhblits_binary_path `which hhblits` \ - --hhsearch_binary_path `which hhsearch` \ - --kalign_binary_path `which kalign` \ No newline at end of file diff --git a/inference_with_workflow.py b/inference_with_workflow.py deleted file mode 100644 index 090cdba7..00000000 --- a/inference_with_workflow.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import random -import sys -import time -from datetime import date - -import numpy as np -import torch -import torch.multiprocessing as mp -from fastfold.model.hub import AlphaFold - -import fastfold -import fastfold.relax.relax as relax -from fastfold.common import protein, residue_constants -from fastfold.config import model_config -from fastfold.model.fastnn import set_chunk_size -from fastfold.data import data_pipeline, feature_pipeline, templates -from fastfold.workflow.template import FastFoldDataWorkFlow -from fastfold.utils import inject_fastnn -from fastfold.utils.import_weights import import_jax_weights_ -from fastfold.utils.tensor_utils import tensor_tree_map - - -def add_data_args(parser: argparse.ArgumentParser): - parser.add_argument( - '--uniref90_database_path', - type=str, - default=None, - ) - parser.add_argument( - '--mgnify_database_path', - type=str, - default=None, - ) - parser.add_argument( - '--pdb70_database_path', - type=str, - default=None, - ) - parser.add_argument( - '--uniclust30_database_path', - type=str, - default=None, - ) - parser.add_argument( - '--bfd_database_path', - type=str, - default=None, - ) - parser.add_argument('--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer') - parser.add_argument('--hhblits_binary_path', type=str, default='/usr/bin/hhblits') - parser.add_argument('--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch') - parser.add_argument('--kalign_binary_path', type=str, default='/usr/bin/kalign') - parser.add_argument( - '--max_template_date', - type=str, - default=date.today().strftime("%Y-%m-%d"), - ) - parser.add_argument('--obsolete_pdbs_path', type=str, default=None) - parser.add_argument('--release_dates_path', type=str, default=None) - - -def inference_model(rank, world_size, result_q, batch, args): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - # init distributed for Dynamic Axial Parallelism - fastfold.distributed.init_dap() - torch.cuda.set_device(rank) - config = model_config(args.model_name) - model = AlphaFold(config) - import_jax_weights_(model, args.param_path, version=args.model_name) - - model = inject_fastnn(model) - model = model.eval() - model = model.cuda() - - set_chunk_size(model.globals.chunk_size) - - with torch.no_grad(): - batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} - - t = time.perf_counter() - out = model(batch) - print(f"Inference time: {time.perf_counter() - t}") - - out = tensor_tree_map(lambda x: np.array(x.cpu()), out) - - result_q.put(out) - - torch.distributed.barrier() - torch.cuda.synchronize() - - -def main(args): - config = model_config(args.model_name) - - template_featurizer = templates.TemplateHitFeaturizer( - mmcif_dir=args.template_mmcif_dir, - max_template_date=args.max_template_date, - max_hits=config.data.predict.max_templates, - kalign_binary_path=args.kalign_binary_path, - release_dates_path=args.release_dates_path, - obsolete_pdbs_path=args.obsolete_pdbs_path) - - use_small_bfd = args.preset == 'reduced_dbs' # (args.bfd_database_path is None) - if use_small_bfd: - assert args.bfd_database_path is not None - else: - assert args.bfd_database_path is not None - assert args.uniclust30_database_path is not None - - data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,) - - output_dir_base = args.output_dir - random_seed = args.data_random_seed - if random_seed is None: - random_seed = random.randrange(sys.maxsize) - feature_processor = feature_pipeline.FeaturePipeline(config.data) - if not os.path.exists(output_dir_base): - os.makedirs(output_dir_base) - if (args.use_precomputed_alignments is None): - alignment_dir = os.path.join(output_dir_base, "alignments") - else: - alignment_dir = args.use_precomputed_alignments - - # Gather input sequences - with open(args.fasta_path, "r") as fp: - lines = [l.strip() for l in fp.readlines()] - - tags, seqs = lines[::2], lines[1::2] - tags = [l[1:] for l in tags] - - for tag, seq in zip(tags, seqs): - batch = [None] - - fasta_path = os.path.join(args.output_dir, "tmp.fasta") - with open(fasta_path, "w") as fp: - fp.write(f">{tag}\n{seq}") - - print("Generating features...") - local_alignment_dir = os.path.join(alignment_dir, tag) - if (args.use_precomputed_alignments is None): - if not os.path.exists(local_alignment_dir): - os.makedirs(local_alignment_dir) - - alignment_data_workflow_runner = FastFoldDataWorkFlow( - jackhmmer_binary_path=args.jackhmmer_binary_path, - hhblits_binary_path=args.hhblits_binary_path, - hhsearch_binary_path=args.hhsearch_binary_path, - uniref90_database_path=args.uniref90_database_path, - mgnify_database_path=args.mgnify_database_path, - bfd_database_path=args.bfd_database_path, - uniclust30_database_path=args.uniclust30_database_path, - pdb70_database_path=args.pdb70_database_path, - use_small_bfd=use_small_bfd, - no_cpus=args.cpus, - ) - t = time.perf_counter() - alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir) - print(f"Alignment data workflow time: {time.perf_counter() - t}") - - feature_dict = data_processor.process_fasta(fasta_path=fasta_path, - alignment_dir=local_alignment_dir) - - # Remove temporary FASTA file - os.remove(fasta_path) - - processed_feature_dict = feature_processor.process_features( - feature_dict, - mode='predict', - ) - - batch = processed_feature_dict - - manager = mp.Manager() - result_q = manager.Queue() - torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args)) - - out = result_q.get() - - # Toss out the recycling dimensions --- we don't need them anymore - batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) - - plddt = out["plddt"] - mean_plddt = np.mean(plddt) - - plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1) - - unrelaxed_protein = protein.from_prediction(features=batch, - result=out, - b_factors=plddt_b_factors) - - # Save the unrelaxed PDB. - unrelaxed_output_path = os.path.join(args.output_dir, - f'{tag}_{args.model_name}_unrelaxed.pdb') - with open(unrelaxed_output_path, 'w') as f: - f.write(protein.to_pdb(unrelaxed_protein)) - - amber_relaxer = relax.AmberRelaxation( - use_gpu=True, - **config.relax, - ) - - # Relax the prediction. - t = time.perf_counter() - relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) - print(f"Relaxation time: {time.perf_counter() - t}") - - # Save the relaxed PDB. - relaxed_output_path = os.path.join(args.output_dir, - f'{tag}_{args.model_name}_relaxed.pdb') - with open(relaxed_output_path, 'w') as f: - f.write(relaxed_pdb_str) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "fasta_path", - type=str, - ) - parser.add_argument( - "template_mmcif_dir", - type=str, - ) - parser.add_argument("--use_precomputed_alignments", - type=str, - default=None, - help="""Path to alignment directory. If provided, alignment computation - is skipped and database path arguments are ignored.""") - parser.add_argument( - "--output_dir", - type=str, - default=os.getcwd(), - help="""Name of the directory in which to output the prediction""", - ) - parser.add_argument("--model_name", - type=str, - default="model_1", - help="""Name of a model config. Choose one of model_{1-5} or - model_{1-5}_ptm, as defined on the AlphaFold GitHub.""") - parser.add_argument("--param_path", - type=str, - default=None, - help="""Path to model parameters. If None, parameters are selected - automatically according to the model name from - ./data/params""") - parser.add_argument("--cpus", - type=int, - default=12, - help="""Number of CPUs with which to run alignment tools""") - parser.add_argument("--gpus", - type=int, - default=1, - help="""Number of GPUs with which to run inference""") - parser.add_argument('--preset', - type=str, - default='full_dbs', - choices=('reduced_dbs', 'full_dbs')) - parser.add_argument('--data_random_seed', type=str, default=None) - add_data_args(parser) - args = parser.parse_args() - - if (args.param_path is None): - args.param_path = os.path.join("data", "params", "params_" + args.model_name + ".npz") - - main(args)