Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastfold/workflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .workflow_run import batch_run
7 changes: 7 additions & 0 deletions fastfold/workflow/factory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .task_factory import TaskFactory
from .hhblits import HHBlitsFactory
from .hhsearch import HHSearchFactory
from .jackhmmer import JackHmmerFactory
from .alphafold import AlphaFoldFactory
from .amber_relax import AmberRelaxFactory
from .hhfilter import HHfilterFactory
75 changes: 75 additions & 0 deletions fastfold/workflow/factory/alphafold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from datetime import date
import imp
from ray import workflow
from typing import List
import time

import torch
import numpy as np
from fastfold.workflow.factory import TaskFactory
from ray.workflow.common import Workflow

from fastfold.distributed import init_dap
from fastfold.model.hub import AlphaFold
from fastfold.common import protein, residue_constants
from fastfold.config import model_config
from fastfold.data import data_pipeline, feature_pipeline, templates
from fastfold.utils import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map

class AlphaFoldFactory(TaskFactory):

keywords = ['kalign_bin_path', 'template_mmcif_dir', 'param_path', 'model_name']

def gen_task(self, fasta_path: str, alignment_dir: str, output_path: str, after: List[Workflow]=None):

self.isReady()

# setup runners
config = model_config(self.config.get('model_name'))
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=self.config.get('template_mmcif_dir'),
max_template_date=date.today().strftime("%Y-%m-%d"),
max_hits=config.data.predict.max_templates,
kalign_binary_path=self.config.get('kalign_bin_path')
)

data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer)
feature_processor = feature_pipeline.FeaturePipeline(config.data)

# generate step function
@workflow.step(num_gpus=1)
def alphafold_step(fasta_path: str, alignment_dir: str, output_path: str, after: List[Workflow]) -> None:

# setup model
init_dap()
model = AlphaFold(config)
import_jax_weights_(model, self.config.get('param_path'), self.config.get('model_name'))
model = inject_fastnn(model)
model = model.eval()
model = model.cuda()

feature_dict = data_processor.process_fasta(
fasta_path=fasta_path,
alignment_dir=alignment_dir
)
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict'
)
with torch.no_grad():
batch = {k: torch.as_tensor(v).cuda() for k, v in processed_feature_dict.items()}
out = model(batch)
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
with open(output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))

return alphafold_step.step(fasta_path, alignment_dir, output_path, after)
36 changes: 36 additions & 0 deletions fastfold/workflow/factory/amber_relax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from ray import workflow
from typing import List
from fastfold.workflow.factory import TaskFactory
from ray.workflow.common import Workflow

from fastfold.config import config
import fastfold.relax.relax as relax
from fastfold.common import protein

class AmberRelaxFactory(TaskFactory):

keywords = []

def gen_task(self, unrelaxed_pdb_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:

self.isReady()

# setup runner
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
)

# generate step function
@workflow.step(num_gpus=1)
def amber_relax_step(unrelaxed_pdb_path: str, output_path: str, after: List[Workflow]) -> None:

with open(unrelaxed_pdb_path, "r") as f:
pdb_str = f.read()
unrelaxed_protein = protein.from_pdb_string(pdb_str)
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)

with open(output_path, "w") as f:
f.write(relaxed_pdb_str)

return amber_relax_step.step(unrelaxed_pdb_path, output_path, after)
29 changes: 29 additions & 0 deletions fastfold/workflow/factory/hhblits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from ray import workflow
from typing import List
from fastfold.workflow.factory import TaskFactory
from ray.workflow.common import Workflow
import fastfold.data.tools.hhblits as ffHHBlits

class HHBlitsFactory(TaskFactory):

keywords = ['binary_path', 'databases', 'n_cpu']

def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:

self.isReady()

# setup runner
runner = ffHHBlits.HHBlits(
binary_path=self.config['binary_path'],
databases=self.config['databases'],
n_cpu=self.config['n_cpu']
)

# generate step function
@workflow.step
def hhblits_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None:
result = runner.query(fasta_path)
with open(output_path, "w") as f:
f.write(result["a3m"])

return hhblits_step.step(fasta_path, output_path, after)
33 changes: 33 additions & 0 deletions fastfold/workflow/factory/hhfilter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import subprocess
import logging
from ray import workflow
from typing import List
from fastfold.workflow.factory import TaskFactory
from ray.workflow.common import Workflow

class HHfilterFactory(TaskFactory):

keywords = ['binary_path']

def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:

self.isReady()

# generate step function
@workflow.step
def hhfilter_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None:

cmd = [
self.config.get('binary_path'),
]
if 'id' in self.config:
cmd += ['-id', str(self.config.get('id'))]
if 'cov' in self.config:
cmd += ['-cov', str(self.config.get('cov'))]
cmd += ['-i', fasta_path, '-o', output_path]

logging.info(f"HHfilter start: {' '.join(cmd)}")

subprocess.run(cmd)

return hhfilter_step.step(fasta_path, output_path, after)
38 changes: 38 additions & 0 deletions fastfold/workflow/factory/hhsearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from fastfold.workflow.factory import TaskFactory
from ray import workflow
from ray.workflow.common import Workflow
import fastfold.data.tools.hhsearch as ffHHSearch
from typing import List

class HHSearchFactory(TaskFactory):

keywords = ['binary_path', 'databases', 'n_cpu']

def gen_task(self, a3m_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:

self.isReady()

# setup runner
runner = ffHHSearch.HHSearch(
binary_path=self.config['binary_path'],
databases=self.config['databases'],
n_cpu=self.config['n_cpu']
)

# generate step function
@workflow.step
def hhsearch_step(a3m_path: str, output_path: str, after: List[Workflow], atab_path: str = None) -> None:

with open(a3m_path, "r") as f:
a3m = f.read()
if atab_path:
hhsearch_result, atab = runner.query(a3m, gen_atab=True)
else:
hhsearch_result = runner.query(a3m)
with open(output_path, "w") as f:
f.write(hhsearch_result)
if atab_path:
with open(atab_path, "w") as f:
f.write(atab)

return hhsearch_step.step(a3m_path, output_path, after)
34 changes: 34 additions & 0 deletions fastfold/workflow/factory/jackhmmer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from fastfold.workflow.factory import TaskFactory
from ray import workflow
from ray.workflow.common import Workflow
import fastfold.data.tools.jackhmmer as ffJackHmmer
from fastfold.data import parsers
from typing import List

class JackHmmerFactory(TaskFactory):

keywords = ['binary_path', 'database_path', 'n_cpu', 'uniref_max_hits']

def gen_task(self, fasta_path: str, output_path: str, after: List[Workflow]=None) -> Workflow:

self.isReady()

# setup runner
runner = ffJackHmmer.Jackhmmer(
binary_path=self.config['binary_path'],
database_path=self.config['database_path'],
n_cpu=self.config['n_cpu']
)

# generate step function
@workflow.step
def jackhmmer_step(fasta_path: str, output_path: str, after: List[Workflow]) -> None:
result = runner.query(fasta_path)[0]
uniref90_msa_a3m = parsers.convert_stockholm_to_a3m(
result['sto'],
max_sequences=self.config['uniref_max_hits']
)
with open(output_path, "w") as f:
f.write(uniref90_msa_a3m)

return jackhmmer_step.step(fasta_path, output_path, after)
50 changes: 50 additions & 0 deletions fastfold/workflow/factory/task_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from ast import keyword
import json
from ray.workflow.common import Workflow
from os import path
from typing import List

class TaskFactory:

keywords = []

def __init__(self, config: dict = None, config_path: str = None) -> None:

# skip if no keyword required from config file
if not self.__class__.keywords:
return

# setting config for factory
if config is not None:
self.config = config
elif config_path is not None:
self.loadConfig(config_path)
else:
self.loadConfig()

def configure(self, config: dict, purge=False) -> None:
if purge:
self.config = config
else:
self.config.update(config)

def configure(self, keyword: str, value: any) -> None:
self.config[keyword] = value

def gen_task(self, after: List[Workflow]=None, *args, **kwargs) -> Workflow:
raise NotImplementedError

def isReady(self):
for key in self.__class__.keywords:
if key not in self.config:
raise KeyError(f"{self.__class__.__name__} not ready: \"{key}\" not specified")

def loadConfig(self, config_path='./config.json'):
with open(config_path) as configFile:
globalConfig = json.load(configFile)
if 'tools' not in globalConfig:
raise KeyError("\"tools\" not found in global config file")
factoryName = self.__class__.__name__[:-7]
if factoryName not in globalConfig['tools']:
raise KeyError(f"\"{factoryName}\" not found in the \"tools\" section in config")
self.config = globalConfig['tools'][factoryName]
1 change: 1 addition & 0 deletions fastfold/workflow/template/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fastfold_data_workflow import FastFoldDataWorkFlow
Loading