diff --git a/fastfold/workflow/__init__.py b/fastfold/workflow/__init__.py new file mode 100644 index 00000000..8a0ce50c --- /dev/null +++ b/fastfold/workflow/__init__.py @@ -0,0 +1 @@ +from .workflow_run import batch_run \ No newline at end of file diff --git a/fastfold/workflow/factory/__init__.py b/fastfold/workflow/factory/__init__.py new file mode 100644 index 00000000..3c7bc689 --- /dev/null +++ b/fastfold/workflow/factory/__init__.py @@ -0,0 +1,7 @@ +from .task_factory import TaskFactory +from .hhblits import HHBlitsFactory +from .hhsearch import HHSearchFactory +from .jackhmmer import JackHmmerFactory +from .alphafold import AlphaFoldFactory +from .amber_relax import AmberRelaxFactory +from .hhfilter import HHfilterFactory \ No newline at end of file diff --git a/fastfold/workflow/factory/alphafold.py b/fastfold/workflow/factory/alphafold.py new file mode 100644 index 00000000..f6e4a966 --- /dev/null +++ b/fastfold/workflow/factory/alphafold.py @@ -0,0 +1,75 @@ +from datetime import date +import imp +from ray import workflow +from typing import List +import time + +import torch +import numpy as np +from fastfold.workflow.factory import TaskFactory +from ray.workflow.common import Workflow + +from fastfold.distributed import init_dap +from fastfold.model.hub import AlphaFold +from fastfold.common import protein, residue_constants +from fastfold.config import model_config +from fastfold.data import data_pipeline, feature_pipeline, templates +from fastfold.utils import inject_fastnn +from fastfold.utils.import_weights import import_jax_weights_ +from fastfold.utils.tensor_utils import tensor_tree_map + +class AlphaFoldFactory(TaskFactory): + + keywords = ['kalign_bin_path', 'template_mmcif_dir', 'param_path', 'model_name'] + + def gen_task(self, fasta_path: str, alignment_dir: str, output_path: str, after: List[Workflow]=None): + + self.isReady() + + # setup runners + config = model_config(self.config.get('model_name')) + template_featurizer = templates.TemplateHitFeaturizer( + mmcif_dir=self.config.get('template_mmcif_dir'), + max_template_date=date.today().strftime("%Y-%m-%d"), + max_hits=config.data.predict.max_templates, + kalign_binary_path=self.config.get('kalign_bin_path') + ) + + data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer) + feature_processor = feature_pipeline.FeaturePipeline(config.data) + + # generate step function + @workflow.step(num_gpus=1) + def alphafold_step(fasta_path: str, alignment_dir: str, output_path: str, after: List[Workflow]) -> None: + + # setup model + init_dap() + model = AlphaFold(config) + import_jax_weights_(model, self.config.get('param_path'), self.config.get('model_name')) + model = inject_fastnn(model) + model = model.eval() + model = model.cuda() + + feature_dict = data_processor.process_fasta( + fasta_path=fasta_path, + alignment_dir=alignment_dir + ) + processed_feature_dict = feature_processor.process_features( + feature_dict, + mode='predict' + ) + with torch.no_grad(): + batch = {k: torch.as_tensor(v).cuda() for k, v in processed_feature_dict.items()} + out = model(batch) + batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) + out = tensor_tree_map(lambda x: np.array(x.cpu()), out) + plddt = out["plddt"] + mean_plddt = np.mean(plddt) + plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1) + unrelaxed_protein = protein.from_prediction(features=batch, + result=out, + b_factors=plddt_b_factors) + with open(output_path, 'w') as f: + f.write(protein.to_pdb(unrelaxed_protein)) + + return alphafold_step.step(fasta_path, alignment_dir, output_path, after) diff --git a/fastfold/workflow/factory/amber_relax.py b/fastfold/workflow/factory/amber_relax.py new file mode 100644 index 00000000..91b7b947 --- /dev/null +++ b/fastfold/workflow/factory/amber_relax.py @@ -0,0 +1,36 @@ +from ray import workflow +from typing import List +from fastfold.workflow.factory import TaskFactory +from ray.workflow.common import Workflow + +from fastfold.config import config +import fastfold.relax.relax as relax +from fastfold.common import protein + +class AmberRelaxFactory(TaskFactory): + + keywords = [] + + def gen_task(self, unrelaxed_pdb_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: + + self.isReady() + + # setup runner + amber_relaxer = relax.AmberRelaxation( + use_gpu=True, + **config.relax, + ) + + # generate step function + @workflow.step(num_gpus=1) + def amber_relax_step(unrelaxed_pdb_path: str, output_path: str, after: List[Workflow]) -> None: + + with open(unrelaxed_pdb_path, "r") as f: + pdb_str = f.read() + unrelaxed_protein = protein.from_pdb_string(pdb_str) + relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) + + with open(output_path, "w") as f: + f.write(relaxed_pdb_str) + + return amber_relax_step.step(unrelaxed_pdb_path, output_path, after) diff --git a/fastfold/workflow/factory/hhblits.py b/fastfold/workflow/factory/hhblits.py new file mode 100644 index 00000000..aecfcc55 --- /dev/null +++ b/fastfold/workflow/factory/hhblits.py @@ -0,0 +1,29 @@ +from ray import workflow +from typing import List +from fastfold.workflow.factory import TaskFactory +from ray.workflow.common import Workflow +import fastfold.data.tools.hhblits as ffHHBlits + +class HHBlitsFactory(TaskFactory): + + keywords = ['binary_path', 'databases', 'n_cpu'] + + def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: + + self.isReady() + + # setup runner + runner = ffHHBlits.HHBlits( + binary_path=self.config['binary_path'], + databases=self.config['databases'], + n_cpu=self.config['n_cpu'] + ) + + # generate step function + @workflow.step + def hhblits_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None: + result = runner.query(fasta_path) + with open(output_path, "w") as f: + f.write(result["a3m"]) + + return hhblits_step.step(fasta_path, output_path, after) diff --git a/fastfold/workflow/factory/hhfilter.py b/fastfold/workflow/factory/hhfilter.py new file mode 100644 index 00000000..de680610 --- /dev/null +++ b/fastfold/workflow/factory/hhfilter.py @@ -0,0 +1,33 @@ +import subprocess +import logging +from ray import workflow +from typing import List +from fastfold.workflow.factory import TaskFactory +from ray.workflow.common import Workflow + +class HHfilterFactory(TaskFactory): + + keywords = ['binary_path'] + + def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: + + self.isReady() + + # generate step function + @workflow.step + def hhfilter_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None: + + cmd = [ + self.config.get('binary_path'), + ] + if 'id' in self.config: + cmd += ['-id', str(self.config.get('id'))] + if 'cov' in self.config: + cmd += ['-cov', str(self.config.get('cov'))] + cmd += ['-i', fasta_path, '-o', output_path] + + logging.info(f"HHfilter start: {' '.join(cmd)}") + + subprocess.run(cmd) + + return hhfilter_step.step(fasta_path, output_path, after) diff --git a/fastfold/workflow/factory/hhsearch.py b/fastfold/workflow/factory/hhsearch.py new file mode 100644 index 00000000..d315de07 --- /dev/null +++ b/fastfold/workflow/factory/hhsearch.py @@ -0,0 +1,38 @@ +from fastfold.workflow.factory import TaskFactory +from ray import workflow +from ray.workflow.common import Workflow +import fastfold.data.tools.hhsearch as ffHHSearch +from typing import List + +class HHSearchFactory(TaskFactory): + + keywords = ['binary_path', 'databases', 'n_cpu'] + + def gen_task(self, a3m_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: + + self.isReady() + + # setup runner + runner = ffHHSearch.HHSearch( + binary_path=self.config['binary_path'], + databases=self.config['databases'], + n_cpu=self.config['n_cpu'] + ) + + # generate step function + @workflow.step + def hhsearch_step(a3m_path: str, output_path: str, after: List[Workflow], atab_path: str = None) -> None: + + with open(a3m_path, "r") as f: + a3m = f.read() + if atab_path: + hhsearch_result, atab = runner.query(a3m, gen_atab=True) + else: + hhsearch_result = runner.query(a3m) + with open(output_path, "w") as f: + f.write(hhsearch_result) + if atab_path: + with open(atab_path, "w") as f: + f.write(atab) + + return hhsearch_step.step(a3m_path, output_path, after) diff --git a/fastfold/workflow/factory/jackhmmer.py b/fastfold/workflow/factory/jackhmmer.py new file mode 100644 index 00000000..ebba4ba9 --- /dev/null +++ b/fastfold/workflow/factory/jackhmmer.py @@ -0,0 +1,34 @@ +from fastfold.workflow.factory import TaskFactory +from ray import workflow +from ray.workflow.common import Workflow +import fastfold.data.tools.jackhmmer as ffJackHmmer +from fastfold.data import parsers +from typing import List + +class JackHmmerFactory(TaskFactory): + + keywords = ['binary_path', 'database_path', 'n_cpu', 'uniref_max_hits'] + + def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow: + + self.isReady() + + # setup runner + runner = ffJackHmmer.Jackhmmer( + binary_path=self.config['binary_path'], + database_path=self.config['database_path'], + n_cpu=self.config['n_cpu'] + ) + + # generate step function + @workflow.step + def jackhmmer_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None: + result = runner.query(fasta_path)[0] + uniref90_msa_a3m = parsers.convert_stockholm_to_a3m( + result['sto'], + max_sequences=self.config['uniref_max_hits'] + ) + with open(output_path, "w") as f: + f.write(uniref90_msa_a3m) + + return jackhmmer_step.step(fasta_path, output_path, after) diff --git a/fastfold/workflow/factory/task_factory.py b/fastfold/workflow/factory/task_factory.py new file mode 100644 index 00000000..dd8c739e --- /dev/null +++ b/fastfold/workflow/factory/task_factory.py @@ -0,0 +1,50 @@ +from ast import keyword +import json +from ray.workflow.common import Workflow +from os import path +from typing import List + +class TaskFactory: + + keywords = [] + + def __init__(self, config: dict = None, config_path: str = None) -> None: + + # skip if no keyword required from config file + if not self.__class__.keywords: + return + + # setting config for factory + if config is not None: + self.config = config + elif config_path is not None: + self.loadConfig(config_path) + else: + self.loadConfig() + + def configure(self, config: dict, purge=False) -> None: + if purge: + self.config = config + else: + self.config.update(config) + + def configure(self, keyword: str, value: any) -> None: + self.config[keyword] = value + + def gen_task(self, after: List[Workflow]=None, *args, **kwargs) -> Workflow: + raise NotImplementedError + + def isReady(self): + for key in self.__class__.keywords: + if key not in self.config: + raise KeyError(f"{self.__class__.__name__} not ready: \"{key}\" not specified") + + def loadConfig(self, config_path='./config.json'): + with open(config_path) as configFile: + globalConfig = json.load(configFile) + if 'tools' not in globalConfig: + raise KeyError("\"tools\" not found in global config file") + factoryName = self.__class__.__name__[:-7] + if factoryName not in globalConfig['tools']: + raise KeyError(f"\"{factoryName}\" not found in the \"tools\" section in config") + self.config = globalConfig['tools'][factoryName] \ No newline at end of file diff --git a/fastfold/workflow/template/__init__.py b/fastfold/workflow/template/__init__.py new file mode 100644 index 00000000..f9c45c56 --- /dev/null +++ b/fastfold/workflow/template/__init__.py @@ -0,0 +1 @@ +from .fastfold_data_workflow import FastFoldDataWorkFlow \ No newline at end of file diff --git a/fastfold/workflow/template/fastfold_data_workflow.py b/fastfold/workflow/template/fastfold_data_workflow.py new file mode 100644 index 00000000..3ecfacd0 --- /dev/null +++ b/fastfold/workflow/template/fastfold_data_workflow.py @@ -0,0 +1,140 @@ +import os +import time +from multiprocessing import cpu_count +from ray import workflow +from fastfold.workflow.factory import JackHmmerFactory, HHSearchFactory, HHBlitsFactory +from fastfold.workflow import batch_run +from typing import Optional + +class FastFoldDataWorkFlow: + def __init__( + self, + jackhmmer_binary_path: Optional[str] = None, + hhblits_binary_path: Optional[str] = None, + hhsearch_binary_path: Optional[str] = None, + uniref90_database_path: Optional[str] = None, + mgnify_database_path: Optional[str] = None, + bfd_database_path: Optional[str] = None, + uniclust30_database_path: Optional[str] = None, + pdb70_database_path: Optional[str] = None, + use_small_bfd: Optional[bool] = None, + no_cpus: Optional[int] = None, + uniref_max_hits: int = 10000, + mgnify_max_hits: int = 5000, + ): + self.db_map = { + "jackhmmer": { + "binary": jackhmmer_binary_path, + "dbs": [ + uniref90_database_path, + mgnify_database_path, + bfd_database_path if use_small_bfd else None, + ], + }, + "hhblits": { + "binary": hhblits_binary_path, + "dbs": [ + bfd_database_path if not use_small_bfd else None, + ], + }, + "hhsearch": { + "binary": hhsearch_binary_path, + "dbs": [ + pdb70_database_path, + ], + }, + } + + for name, dic in self.db_map.items(): + binary, dbs = dic["binary"], dic["dbs"] + if(binary is None and not all([x is None for x in dbs])): + raise ValueError( + f"{name} DBs provided but {name} binary is None" + ) + + if(not all([x is None for x in self.db_map["hhsearch"]["dbs"]]) + and uniref90_database_path is None): + raise ValueError( + """uniref90_database_path must be specified in order to perform + template search""" + ) + + self.use_small_bfd = use_small_bfd + self.uniref_max_hits = uniref_max_hits + self.mgnify_max_hits = mgnify_max_hits + + if(no_cpus is None): + self.no_cpus = cpu_count() + else: + self.no_cpus = no_cpus + + def run(self, fasta_path: str, output_dir: str, alignment_dir: str=None) -> None: + + localtime = time.asctime( time.localtime(time.time()) ) + workflow_id = 'fastfold_data_workflow ' + str(localtime) + # clearing remaining ray workflow data + try: + workflow.cancel(workflow_id) + workflow.delete(workflow_id) + except: + print("Workflow not found. Clean. Skipping") + pass + + # prepare alignment directory for alignment outputs + if alignment_dir is None: + alignment_dir = os.path.join(output_dir, "alignment") + if not os.path.exists(alignment_dir): + os.makedirs(alignment_dir) + + # Run JackHmmer on UNIREF90 + # create JackHmmer workflow generator + jh_config = { + "binary_path": self.db_map["jackhmmer"]["binary"], + "database_path": self.db_map["jackhmmer"]["dbs"][0], + "n_cpu": self.no_cpus, + "uniref_max_hits": self.uniref_max_hits, + } + jh_fac = JackHmmerFactory(config = jh_config) + # set jackhmmer output path + uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m") + # generate the workflow with i/o path + wf1 = jh_fac.gen_task(fasta_path, uniref90_out_path) + + #Run HHSearch on STEP1's result with PDB70""" + # create HHSearch workflow generator + hhs_config = { + "binary_path": self.db_map["hhsearch"]["binary"], + "databases": self.db_map["hhsearch"]["dbs"], + "n_cpu": self.no_cpus, + } + hhs_fac = HHSearchFactory(config=hhs_config) + # set HHSearch output path + pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr") + # generate the workflow (STEP2 depend on STEP1) + wf2 = hhs_fac.gen_task(uniref90_out_path, pdb70_out_path, after=[wf1]) + + # Run JackHmmer on MGNIFY + # reconfigure jackhmmer factory to use MGNIFY DB instead + jh_fac.configure('database_path', self.db_map["jackhmmer"]["dbs"][1]) + # set jackhmmer output path + mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m") + # generate workflow for STEP3 + wf3 = jh_fac.gen_task(fasta_path, mgnify_out_path) + + # Run HHBlits on BFD + # create HHBlits workflow generator + hhb_config = { + "binary_path": self.db_map["hhblits"]["binary"], + "databases": self.db_map["hhblits"]["dbs"], + "n_cpu": self.no_cpus, + } + hhb_fac = HHBlitsFactory(config=hhb_config) + # set HHBlits output path + bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m") + # generate workflow for STEP4 + wf4 = hhb_fac.gen_task(fasta_path, bfd_out_path) + + # run workflow + batch_run(wfs=[wf2, wf3, wf4], workflow_id=workflow_id) + + return \ No newline at end of file diff --git a/fastfold/workflow/workflow_run.py b/fastfold/workflow/workflow_run.py new file mode 100644 index 00000000..196dccfa --- /dev/null +++ b/fastfold/workflow/workflow_run.py @@ -0,0 +1,25 @@ +from ast import Call +from typing import Callable, List +from ray.workflow.common import Workflow +from ray import workflow + +def batch_run(wfs: List[Workflow], workflow_id: str) -> None: + + @workflow.step + def batch_step(wfs) -> None: + return + + batch_wf = batch_step.step(wfs) + + batch_wf.run(workflow_id=workflow_id) + +def wf(after: List[Workflow]=None): + def decorator(f: Callable): + + @workflow.step + def step_func(after: List[Workflow]) -> None: + f() + + return step_func.step(after) + + return decorator diff --git a/inference_with_workflow.py b/inference_with_workflow.py new file mode 100644 index 00000000..a27addc7 --- /dev/null +++ b/inference_with_workflow.py @@ -0,0 +1,285 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random +import sys +import time +from datetime import date + +import numpy as np +import torch +import torch.multiprocessing as mp +from fastfold.model.hub import AlphaFold + +import fastfold +import fastfold.relax.relax as relax +from fastfold.common import protein, residue_constants +from fastfold.config import model_config +from fastfold.model.fastnn import set_chunk_size +from fastfold.data import data_pipeline, feature_pipeline, templates +from fastfold.workflow.template import FastFoldDataWorkFlow +from fastfold.utils import inject_fastnn +from fastfold.utils.import_weights import import_jax_weights_ +from fastfold.utils.tensor_utils import tensor_tree_map + + +def add_data_args(parser: argparse.ArgumentParser): + parser.add_argument( + '--uniref90_database_path', + type=str, + default=None, + ) + parser.add_argument( + '--mgnify_database_path', + type=str, + default=None, + ) + parser.add_argument( + '--pdb70_database_path', + type=str, + default=None, + ) + parser.add_argument( + '--uniclust30_database_path', + type=str, + default=None, + ) + parser.add_argument( + '--bfd_database_path', + type=str, + default=None, + ) + parser.add_argument('--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer') + parser.add_argument('--hhblits_binary_path', type=str, default='/usr/bin/hhblits') + parser.add_argument('--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch') + parser.add_argument('--kalign_binary_path', type=str, default='/usr/bin/kalign') + parser.add_argument( + '--max_template_date', + type=str, + default=date.today().strftime("%Y-%m-%d"), + ) + parser.add_argument('--obsolete_pdbs_path', type=str, default=None) + parser.add_argument('--release_dates_path', type=str, default=None) + + +def inference_model(rank, world_size, result_q, batch, args): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + # init distributed for Dynamic Axial Parallelism + fastfold.distributed.init_dap() + torch.cuda.set_device(rank) + config = model_config(args.model_name) + model = AlphaFold(config) + import_jax_weights_(model, args.param_path, version=args.model_name) + + model = inject_fastnn(model) + model = model.eval() + model = model.cuda() + + set_chunk_size(model.globals.chunk_size) + + with torch.no_grad(): + batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} + + t = time.perf_counter() + out = model(batch) + print(f"Inference time: {time.perf_counter() - t}") + + out = tensor_tree_map(lambda x: np.array(x.cpu()), out) + + result_q.put(out) + + torch.distributed.barrier() + torch.cuda.synchronize() + + +def main(args): + print("--------------- inference_with_workflow.py ---------------") + config = model_config(args.model_name) + + template_featurizer = templates.TemplateHitFeaturizer( + mmcif_dir=args.template_mmcif_dir, + max_template_date=args.max_template_date, + max_hits=config.data.predict.max_templates, + kalign_binary_path=args.kalign_binary_path, + release_dates_path=args.release_dates_path, + obsolete_pdbs_path=args.obsolete_pdbs_path) + + use_small_bfd = args.preset == 'reduced_dbs' # (args.bfd_database_path is None) + if use_small_bfd: + assert args.bfd_database_path is not None + else: + assert args.bfd_database_path is not None + assert args.uniclust30_database_path is not None + + data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,) + + output_dir_base = args.output_dir + random_seed = args.data_random_seed + if random_seed is None: + random_seed = random.randrange(sys.maxsize) + feature_processor = feature_pipeline.FeaturePipeline(config.data) + if not os.path.exists(output_dir_base): + os.makedirs(output_dir_base) + if (args.use_precomputed_alignments is None): + alignment_dir = os.path.join(output_dir_base, "alignments") + else: + alignment_dir = args.use_precomputed_alignments + + # Gather input sequences + with open(args.fasta_path, "r") as fp: + lines = [l.strip() for l in fp.readlines()] + + tags, seqs = lines[::2], lines[1::2] + tags = [l[1:] for l in tags] + + for tag, seq in zip(tags, seqs): + batch = [None] + + fasta_path = os.path.join(args.output_dir, "tmp.fasta") + with open(fasta_path, "w") as fp: + fp.write(f">{tag}\n{seq}") + + print("Generating features...") + local_alignment_dir = os.path.join(alignment_dir, tag) + if (args.use_precomputed_alignments is None): + if not os.path.exists(local_alignment_dir): + os.makedirs(local_alignment_dir) + + alignment_data_workflow_runner = FastFoldDataWorkFlow( + jackhmmer_binary_path=args.jackhmmer_binary_path, + hhblits_binary_path=args.hhblits_binary_path, + hhsearch_binary_path=args.hhsearch_binary_path, + uniref90_database_path=args.uniref90_database_path, + mgnify_database_path=args.mgnify_database_path, + bfd_database_path=args.bfd_database_path, + uniclust30_database_path=args.uniclust30_database_path, + pdb70_database_path=args.pdb70_database_path, + use_small_bfd=use_small_bfd, + no_cpus=args.cpus, + ) + t = time.perf_counter() + alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir) + print(f"Alignment data workflow time: {time.perf_counter() - t}") + + feature_dict = data_processor.process_fasta(fasta_path=fasta_path, + alignment_dir=local_alignment_dir) + + # Remove temporary FASTA file + os.remove(fasta_path) + + processed_feature_dict = feature_processor.process_features( + feature_dict, + mode='predict', + ) + + batch = processed_feature_dict + + manager = mp.Manager() + result_q = manager.Queue() + torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args)) + + out = result_q.get() + + # Toss out the recycling dimensions --- we don't need them anymore + batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) + + plddt = out["plddt"] + mean_plddt = np.mean(plddt) + + plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1) + + unrelaxed_protein = protein.from_prediction(features=batch, + result=out, + b_factors=plddt_b_factors) + + # Save the unrelaxed PDB. + unrelaxed_output_path = os.path.join(args.output_dir, + f'{tag}_{args.model_name}_unrelaxed.pdb') + with open(unrelaxed_output_path, 'w') as f: + f.write(protein.to_pdb(unrelaxed_protein)) + + amber_relaxer = relax.AmberRelaxation( + use_gpu=True, + **config.relax, + ) + + # Relax the prediction. + t = time.perf_counter() + relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) + print(f"Relaxation time: {time.perf_counter() - t}") + + # Save the relaxed PDB. + relaxed_output_path = os.path.join(args.output_dir, + f'{tag}_{args.model_name}_relaxed.pdb') + with open(relaxed_output_path, 'w') as f: + f.write(relaxed_pdb_str) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "fasta_path", + type=str, + ) + parser.add_argument( + "template_mmcif_dir", + type=str, + ) + parser.add_argument("--use_precomputed_alignments", + type=str, + default=None, + help="""Path to alignment directory. If provided, alignment computation + is skipped and database path arguments are ignored.""") + parser.add_argument( + "--output_dir", + type=str, + default=os.getcwd(), + help="""Name of the directory in which to output the prediction""", + ) + parser.add_argument("--model_name", + type=str, + default="model_1", + help="""Name of a model config. Choose one of model_{1-5} or + model_{1-5}_ptm, as defined on the AlphaFold GitHub.""") + parser.add_argument("--param_path", + type=str, + default=None, + help="""Path to model parameters. If None, parameters are selected + automatically according to the model name from + ./data/params""") + parser.add_argument("--cpus", + type=int, + default=12, + help="""Number of CPUs with which to run alignment tools""") + parser.add_argument("--gpus", + type=int, + default=1, + help="""Number of GPUs with which to run inference""") + parser.add_argument('--preset', + type=str, + default='full_dbs', + choices=('reduced_dbs', 'full_dbs')) + parser.add_argument('--data_random_seed', type=str, default=None) + add_data_args(parser) + args = parser.parse_args() + + if (args.param_path is None): + args.param_path = os.path.join("data", "params", "params_" + args.model_name + ".npz") + + main(args)