Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ requires = ["hatchling"]

[dependency-groups]
analysis = [
"polars",
"mdtraj",
"ipython",
"marimo",
"matplotlib",
"seaborn",
"mdtraj",
"pandas",
"polars",
"pyzmq",
"ipython",
"scikit-learn"
"scikit-learn",
"seaborn"
]
boltz = ["boltz", "cuequivariance-torch", "cuequivariance-ops-torch-cu12", "rdkit>=2025.3.6"]
dev = ["pytest", "pytest-cov", "mypy", "pre-commit", "pyright", "ruff", "pytest-loguru"]
Expand Down
55 changes: 32 additions & 23 deletions scripts/patch_input_cif_files.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
# utility script to put all header information from original PDB entry into our CIF files
import fnmatch
import joblib
import re
import shutil
from argparse import ArgumentParser
from pathlib import Path
from shutil import copy

import einx
import joblib
import numpy as np
from atomworks.io.parser import parse
from biotite.structure import AtomArrayStack
from biotite.structure.io.pdbx import CIFFile, set_structure, CIFColumn
from atomworks.io.transforms.atom_array import ensure_atom_array_stack
from atomworks.io.utils.io_utils import load_any
from biotite.database.rcsb import fetch
from biotite.structure.io.pdbx import CIFColumn, CIFFile, set_structure
from loguru import logger


SAMPLEWORKS_CACHE = Path("~/.sampleworks/rcsb").expanduser()


def crawl_dir_by_depth(
root_dir: str | Path,
target_pattern: str,
n_levels: int,
root_dir: str | Path,
target_pattern: str,
n_levels: int,
) -> list[Path]:
"""
Recursively crawl `root_dir` up to `n_levels` directory levels deep and return
Expand Down Expand Up @@ -55,22 +57,25 @@ def _crawl(current: Path, levels_left: int) -> None:
def parse_args():
parser = ArgumentParser()
parser.add_argument("--input-dir", required=True)
parser.add_argument("--cif-pattern", default='refined.cif')
parser.add_argument("--cif-pattern", default="refined.cif")
parser.add_argument(
"--rcsb-pattern", default='grid_search_results/(.{4})',
"--rcsb-pattern",
default="grid_search_results/(.{4})",
help="Regex pattern for rcsb ids in file paths. "
"Must have only one group, surrounding the id"
"Must have only one group, surrounding the id",
)
parser.add_argument(
"--depth", type=int, default=4, help="Depth to search the directory tree below input-dir"
)
parser.add_argument("--depth", type=int, default=4,
help="Depth to search the directory tree below input-dir")
args = parser.parse_args()
return args


def main(
input_dir: str | Path,
target_pattern: str,
rcsb_regex: str = r"grid_search_results/(.{4})",
depth: int = 4
input_dir: str | Path,
target_pattern: str,
rcsb_regex: str = r"grid_search_results/(.{4})",
depth: int = 4,
) -> None:
# make sure the cache exists
SAMPLEWORKS_CACHE.mkdir(parents=True, exist_ok=True)
Expand All @@ -97,15 +102,20 @@ def patch_individual_cif_file(cif_file: Path, rcsb_regex: str):
return cif_file

# write a backup version of the input cif file
copy(cif_path, cif_path.parent / (cif_path.name + ".bak"))
shutil.copy(cif_path, cif_path.parent / (cif_path.name + ".bak"))

# fetch only downloads the file if it isn't already present.
rcsb_path = fetch(rcsb_id, format="cif", target_path=str(SAMPLEWORKS_CACHE))

# load the copy, and the new coordinates for it.
template = CIFFile.read(rcsb_path)
cif_to_patch = parse(cif_file)
asym_unit: AtomArrayStack = cif_to_patch["asym_unit"]
asym_unit = load_any(cif_file)
asym_unit = ensure_atom_array_stack(asym_unit)

# remove any atoms with nan coordinates--these seem to come in because we sometimes use parse
# (from AtomWorks) which creates them. Still we'll do this here just in case.
flat_coords = einx.rearrange("a b c -> b (a c)", asym_unit.coord)
asym_unit = asym_unit[:, ~np.isnan(flat_coords).any(axis=1)] # pyright: ignore
Comment thread
marcuscollins marked this conversation as resolved.

# make sure entity ids match in atom_site and entity_poly
if "entity_poly" in template.block:
Expand All @@ -116,7 +126,7 @@ def patch_individual_cif_file(cif_file: Path, rcsb_regex: str):
entity_id = ep["entity_id"].as_item()
if "label_entity_id" not in asym_unit.get_annotation_categories():
asym_unit.add_annotation("label_entity_id", int)
asym_unit.label_entity_id = np.ones_like(asym_unit.label_entity_id) * int(entity_id)
asym_unit.label_entity_id = np.ones_like(asym_unit.label_entity_id) * int(entity_id) # pyright: ignore
else:
logger.warning("No entity_poly block found in template CIF file. Cannot patch entity ids")

Expand All @@ -139,6 +149,5 @@ def patch_individual_cif_file(cif_file: Path, rcsb_regex: str):


if __name__ == "__main__":

args = parse_args()
main(args.input_dir, args.cif_pattern, args.rcsb_pattern, args.depth)
main(args.input_dir, args.cif_pattern, args.rcsb_pattern, args.depth)
127 changes: 127 additions & 0 deletions scripts/run_and_process_phenix_clashscore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import argparse
import json
import subprocess
from pathlib import Path

import joblib
import pandas as pd
from loguru import logger
from sampleworks.eval.eval_dataclasses import Experiment
from sampleworks.eval.grid_search_eval_utils import scan_grid_search_results


def parse_args(description: str | None = None) -> argparse.Namespace:
"""
Return a common set of arguments for grid search evaluation scripts,
with a custom description, which is passed to argparse.ArgumentParser.

All eval scripts should use this same framework
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--workspace-root",
type=Path,
required=True,
help="Path containing the grid search results directory, e.g. if results are "
"at $HOME/grid_search_results, $HOME should be what you pass",
)
parser.add_argument(
"--n-jobs",
type=int,
help="Number of parallel jobs to run. -1 uses all CPUs.",
default=16,
)
return parser.parse_args()


def main(args) -> None:
# check that phenix is installed and available, bail early if not.
try:
subprocess.call("phenix.clashscore", stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except FileNotFoundError:
raise RuntimeError(
"phenix.clashscore is not available, make sure phenix is installed "
" and that you have activated it, e.g. `source phenix-dir/phenix_env.sh`"
)

workspace_root = Path(args.workspace_root)
grid_search_dir = workspace_root / "grid_search_results" # TODO make more general
Comment thread
k-chrispens marked this conversation as resolved.
all_experiments = scan_grid_search_results(grid_search_dir)
logger.info(f"Found {len(all_experiments)} experiments with refined.cif files")

# Now loop over experiments with joblib and get back tuples of experiment level metrics
clashscore_metrics = joblib.Parallel(n_jobs=args.n_jobs)(
joblib.delayed(process_one_experiment)(experiment) for experiment in all_experiments
)
if not clashscore_metrics:
logger.error(
"No experiments successfully processed, check that result files are available."
)
return

clashscore_df = pd.concat(clashscore_metrics) # pyright: ignore
Comment thread
marcuscollins marked this conversation as resolved.
clashscore_df.to_csv(
workspace_root / "grid_search_results" / "clashscore_metrics.csv", index=False
)


def process_one_experiment(experiment: Experiment) -> pd.DataFrame:
# make sure there are no nan lines in the CIF file; this is an extra
# precaution, even though our CIF writers should now avoid writing nans
file_with_no_nans = experiment.refined_cif_path.parent / "nonan.cif"
json_output = experiment.refined_cif_path.parent / "clashscore.json"
logfile = experiment.refined_cif_path.parent / "clashscore.log"
logger.info(f"Removing nans from {experiment.refined_cif_path}")

with file_with_no_nans.open("w") as fn:
retcode = subprocess.call(
["grep", "-viP", r"\bnan\b", str(experiment.refined_cif_path)], stdout=fn
)
Comment thread
marcuscollins marked this conversation as resolved.
if retcode != 0:
raise RuntimeError(f"grep failed with code {retcode}, see {logfile} for details")

# phenix needs to be installed and on path for this to work. Also sh won't work with
# phenix.clashscore because of that pesky period in the name.
with logfile.open("w") as fn:
# phenix.clashscore generates a JSON file with both per-model scores as well as per-model
# lists of clashes.
retcode = subprocess.call(
["phenix.clashscore", str(file_with_no_nans), "--json-filename", str(json_output)],
stderr=fn,
Comment thread
marcuscollins marked this conversation as resolved.
)
if retcode != 0:
logger.error(f"phenix.clashscore failed, see {logfile} for details")
return pd.DataFrame()
return process_clashscore_json_output(json_output)


def process_clashscore_json_output(json_output: Path) -> pd.DataFrame:
"""
Opens the json output file `json_output` and parses out the
"summary_results", flattening it into rows which include the "model_name" field

"""
with open(json_output) as f:
json_data = json.load(f)

model_name = json_data.get("model_name")
# For now we're only collecting model-level summary statistics, but
# there are lists of specific clashes in each model too.
summary_results = json_data.get("summary_results", {})

rows = []
for model_id, results in summary_results.items():
row = {
"model_name": model_name,
"model_id": model_id,
"clashscore": results.get("clashscore"),
"num_clashes": results.get("num_clashes"),
}
rows.append(row)

return pd.DataFrame(rows)


if __name__ == "__main__":
args = parse_args()
main(args)