From 34eb10e0b8f1061e87ad076294df78579ac25019 Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Fri, 27 Feb 2026 08:17:50 -0800 Subject: [PATCH 1/6] core changes to enable multiple selections per protein # Conflicts: # src/sampleworks/core/rewards/real_space_density.py # Conflicts: # src/sampleworks/utils/atom_array_utils.py --- .../core/rewards/real_space_density.py | 3 + src/sampleworks/eval/eval_dataclasses.py | 4 +- src/sampleworks/eval/structure_utils.py | 110 +++++++++--------- src/sampleworks/utils/atom_array_utils.py | 71 ++++++++--- src/sampleworks/utils/density_utils.py | 2 +- 5 files changed, 115 insertions(+), 75 deletions(-) diff --git a/src/sampleworks/core/rewards/real_space_density.py b/src/sampleworks/core/rewards/real_space_density.py index 2da73c1f..762fa996 100644 --- a/src/sampleworks/core/rewards/real_space_density.py +++ b/src/sampleworks/core/rewards/real_space_density.py @@ -2,6 +2,7 @@ import numpy as np import torch +from atomworks.io.transforms.atom_array import ensure_atom_array_stack from biotite.structure import AtomArray, AtomArrayStack from jaxtyping import ArrayLike, Float, Int from loguru import logger @@ -98,6 +99,7 @@ def extract_density_inputs_from_atomarray( ``(1, n_atoms)``. For an ``AtomArrayStack`` with *M* models the batch dimension is *M* instead of 1. """ + is_stack = isinstance(atom_array_or_stack, AtomArrayStack) coords = cast(np.ndarray[Any, np.dtype[np.float64]], atom_array_or_stack.coord) @@ -130,6 +132,7 @@ def extract_density_inputs_from_atomarray( f"({n_valid}/{n_total} atoms remaining)" ) + # TODO can we use [..., valid_mask] here? if is_stack: valid_coords = coords[:, valid_mask] else: diff --git a/src/sampleworks/eval/eval_dataclasses.py b/src/sampleworks/eval/eval_dataclasses.py index db1f2654..7dc35c4e 100644 --- a/src/sampleworks/eval/eval_dataclasses.py +++ b/src/sampleworks/eval/eval_dataclasses.py @@ -42,7 +42,7 @@ class ProteinConfig: protein: str base_map_dir: Path - selection: str + selection: list[str] resolution: float map_pattern: str structure_pattern: str = "" @@ -190,7 +190,7 @@ def from_csv(cls, workspace_root: Path, csv_path: Path) -> dict[str, "ProteinCon config = cls( protein=protein, base_map_dir=base_map_dir, - selection=row["selection"].strip(), + selection=row["selection"].strip().split(";"), resolution=resolution, map_pattern=row["map_pattern"].strip(), structure_pattern=structure_pattern, diff --git a/src/sampleworks/eval/structure_utils.py b/src/sampleworks/eval/structure_utils.py index fab3c7fa..3e74c807 100644 --- a/src/sampleworks/eval/structure_utils.py +++ b/src/sampleworks/eval/structure_utils.py @@ -2,7 +2,7 @@ import traceback from dataclasses import dataclass, replace from pathlib import Path -from typing import cast +from typing import cast, Any import numpy as np import torch @@ -243,6 +243,8 @@ def process_structure_to_trajectory_input( ) +# TODO: migrate this to atomworks's selection algebra that is added to AtomArray/Stack +# https://github.com/diff-use/sampleworks/issues/56 def parse_selection_string(selection: str) -> tuple[str | None, int | None, int | None]: """Parse a selection string like 'chain A and resi 326-339'. @@ -303,8 +305,22 @@ def apply_selection(atom_array: AtomArray, selection: str | None) -> AtomArray: if selection is None: return atom_array + if not any(x in selection for x in ("==", ">", "<", "<=", ">=", " in ")): + mask = get_mask_from_old_selection_string(atom_array, selection) + else: + mask = atom_array.mask(selection) + + return cast(AtomArray, atom_array[mask]) + + +def get_mask_from_old_selection_string( + atom_array: AtomArray, selection: str +) -> np.ndarray[tuple[int], np.dtype[Any]]: + DeprecationWarning(f"Using old-style selection strings like {selection} is deprecated." + f" Use atomworks/pandas style selection strings instead.") chain_id, resi_start, resi_end = parse_selection_string(selection) - mask = np.ones(len(atom_array), dtype=bool) + # use the length of any of the required non-coord attributes to get the mask shape + mask = np.ones(len(atom_array.res_id), dtype=bool) if chain_id is not None: mask &= atom_array.chain_id == chain_id @@ -318,8 +334,7 @@ def apply_selection(atom_array: AtomArray, selection: str | None) -> AtomArray: if mask.sum() == 0: raise ValueError(f"Selection '{selection}' matched no atoms") - - return cast(AtomArray, atom_array[mask]) + return mask def extract_selection_coordinates( @@ -349,24 +364,10 @@ def extract_selection_coordinates( else: working_array = atom_array - chain_id, resi_start, resi_end = parse_selection_string(selection) - - # Create the selection mask, don't rely on len(atom_array) in case it is the ensemble size - mask = np.ones(len(working_array), dtype=bool) - - if chain_id is not None: - mask &= working_array.chain_id == chain_id - - if resi_start is not None: - res_ids = cast(np.ndarray, working_array.res_id) - if resi_end is not None: - # Explicitly check for None to satisfy pyright - start: int = resi_start - end: int = resi_end - mask &= (res_ids >= start) & (res_ids <= end) - else: - start = resi_start - mask &= res_ids == start + if not any(x in selection for x in ("==", ">", "<", "<=", ">=")): + mask = get_mask_from_old_selection_string(atom_array, selection) + else: + mask = atom_array.mask(selection) selected_coords = cast(np.ndarray, working_array.coord)[mask] @@ -374,7 +375,6 @@ def extract_selection_coordinates( if len(selected_coords) == 0: raise RuntimeError( f"No atoms matched selection: '{selection}'. " - f"Chain ID: {chain_id}, Residue range: {resi_start}-{resi_end}. " f"Total atoms in structure: {len(atom_array)}" ) @@ -431,45 +431,43 @@ def get_reference_atomarraystack( def get_reference_structure_coords( protein_config: ProteinConfig, protein_key: str, occ_list: tuple[float, ...] = (0.0, 1.0) -) -> np.ndarray | None: +) -> dict[str, np.ndarray] | None: """ This has a slightly odd function, which is to output an array of all possible coordinates of a structure, with altlocs mixed in. It returns NO information about which atom is which or whether there are duplicates. It's used for masking density maps. """ - protein_ref_coords_list = [] + protein_ref_coords_list = {selection: [] for selection in protein_config.selection} for occ in occ_list: ref_path, ref_struct = get_reference_atomarraystack(protein_config, occ) if ref_path and ref_struct: # if not None, it is already a validated Path object - try: - # TODO: enumerate actual exceptions this can raise. - coords = extract_selection_coordinates(ref_struct, protein_config.selection) - if not len(coords): - logger.warning( - f" No atoms in selection '{protein_config.selection}' for {protein_key}" - ) - elif not np.isfinite(coords).all(): - logger.warning( - f" NaN/Inf coordinates in selection " - f"'{protein_config.selection}' for {protein_key}" + for selection in protein_config.selection: + try: + # TODO: enumerate actual exceptions this can raise. + coords = extract_selection_coordinates(ref_struct, selection) + if not len(coords): + logger.warning(f" No atoms in selection '{selection}' for {protein_key}") + elif not np.isfinite(coords).all(): + logger.warning( + f" NaN/Inf coordinates in selection '{selection}' for {protein_key}" + ) + else: + protein_ref_coords_list[selection].append(coords) + logger.info( + f" Loaded reference structure for {protein_key}: " + f"{len(coords)} atoms in selection '{selection}'" + ) + except Exception as _e: + _selection = selection if selection else "(none)" + logger.error( + f" ERROR: Failed to load reference structure for {protein_key}: {_e}\n" + f" Path: {ref_path}\n" + f" Selection: {_selection}\n" + f" Traceback: {traceback.format_exc()}" ) - else: - protein_ref_coords_list.append(coords) - logger.info( - f" Loaded reference structure for {protein_key}: " - f"{len(coords)} atoms in selection '{protein_config.selection}'" - ) - except Exception as _e: - _selection = protein_config.selection if protein_config.selection else "(none)" - logger.error( - f" ERROR: Failed to load reference structure for {protein_key}: {_e}\n" - f" Path: {ref_path}\n" - f" Selection: {_selection}\n" - f" Traceback: {traceback.format_exc()}" - ) - - if not protein_ref_coords_list: - logger.error(f"No reference structures found for {protein_key}") - return None - - return np.vstack(protein_ref_coords_list) + + return { + k: np.vstack(protein_ref_coords_list[k]) + for k in protein_ref_coords_list + if protein_ref_coords_list[k] + } diff --git a/src/sampleworks/utils/atom_array_utils.py b/src/sampleworks/utils/atom_array_utils.py index e8ca4c01..9f18d881 100644 --- a/src/sampleworks/utils/atom_array_utils.py +++ b/src/sampleworks/utils/atom_array_utils.py @@ -186,14 +186,19 @@ def save_structure_to_cif( return output_path -def find_all_altloc_ids(atom_array: AtomArray | AtomArrayStack) -> set[str]: +def find_all_altloc_ids( + atom_array: AtomArray | AtomArrayStack, altloc_label: str = "altloc_id" +) -> set[str]: """ Find all unique alternate location indicator (altloc) IDs in an AtomArray or AtomArrayStack. """ - if hasattr(atom_array, "altloc_id"): - altloc_ids = np.unique(atom_array.altloc_id) + if hasattr(atom_array, altloc_label): + altloc_ids = np.unique(getattr(atom_array, altloc_label)) # ty: ignore[invalid-argument-type] else: - raise AttributeError("atom_array must have `altloc_id` annotation") + raise AttributeError( + "atom_array must have altloc annotation, defaults to 'altloc_id'; " + f"you provided {altloc_label}" + ) return set(altloc_ids.tolist()) - BLANK_ALTLOC_IDS @@ -234,12 +239,22 @@ def detect_altlocs(atom_array: AtomArray) -> AltlocInfo: def map_altlocs_to_stack( atom_array: AtomArray | AtomArrayStack, + selection: str | None = None, + return_full_array: bool = True, ) -> tuple[AtomArrayStack, np.ndarray, np.ndarray]: """ Map alternate location indicators (altloc) to separate structures in a new AtomArrayStack. - Note: this will take _only_ the first structure if you pass an AtomArrayStack. It will raise - an error if there is more than one structure in the input AtomArrayStack. + Note: This will raise an error if you pass an AtomArrayStack containing multiple structures + + Parameters: + atom_array: AtomArray or AtomArrayStack to map altlocs from. + selection: str + Optional selection string to apply to the atom array before mapping altlocs. + If the return_full_array is also True, this selection applies _only_ to the altlocs. + return_full_array: bool + If True, return the full AtomArrayStack with all atoms, even those with no altlocs. + If False, only return structures with altlocs. Returns: Tuple containing: @@ -252,15 +267,43 @@ def map_altlocs_to_stack( if isinstance(atom_array, AtomArrayStack): if len(atom_array) > 1: raise ValueError("Cannot map altlocs with multiple structures each containing altlocs") - atom_array = atom_array[0] + atom_array = atom_array[0] # ty: ignore[invalid-assignment] + + if not hasattr(atom_array, "altloc_id") or atom_array.altloc_id is None: + raise ValueError("The passed AtomArray | AtomArrayStack must have attribute 'altloc_id'") + + if not hasattr(atom_array, "mask") or not hasattr(atom_array, "query"): + raise ValueError( + "map_altlocs_to_stack requires atom_arrays loaded with " + "atomworks.io.utils.io_utils.load_any or a similar method that provides .mask, .query" + ) + + + altloc_ids = sorted(list(find_all_altloc_ids(atom_array))) + if selection is not None and return_full_array: + # in this case, we select atoms with no altlocs, and atoms in the selection + # that do have altlocs construct the query (note np.isin doesn't seem to list + # sets, hence conversion to list. + no_altloc_mask = np.isin(atom_array.altloc_id, list(BLANK_ALTLOC_IDS)) + altloc_mask = np.isin(atom_array.altloc_id, altloc_ids) + selection_mask = atom_array.mask(selection) + mask = np.logical_or(no_altloc_mask, np.logical_and(altloc_mask, selection_mask)) + atom_array = atom_array[mask] + elif selection is not None: + atom_array = atom_array.query(selection) + + # The available altloc ids might have changed if one or more are missing from the selection. altloc_ids = sorted(list(find_all_altloc_ids(atom_array))) altloc_list = [ - select_altloc(atom_array, altloc_id, return_full_array=True) for altloc_id in altloc_ids + select_altloc(atom_array, altloc_id, return_full_array=return_full_array) + for altloc_id in altloc_ids ] # ensure that each structure has the same number of atoms + # it's critical that we've updated the altloc id list above, or this will return only + # residues with no altlocs in case one or more altloc ids is/are missing from the selection. atom_arrays = filter_to_common_atoms(*altloc_list) - altloc_ids = np.vstack([r.altloc_id for r in atom_arrays]) - occupancies = np.vstack([r.occupancy for r in atom_arrays]) + altloc_ids = np.vstack([r.altloc_id for r in atom_arrays]) # ty: ignore + occupancies = np.vstack([r.occupancy for r in atom_arrays]) # ty: ignore # remove those annotations or we cannot stack arrays. for array in atom_arrays: @@ -305,12 +348,8 @@ def select_altloc( if return_full_array: mask = np.isin( - atom_array.altloc_id, - list( - { - altloc_id, - }.union(BLANK_ALTLOC_IDS) - ), + atom_array.altloc_id, # ty: ignore[invalid-argument-type] + list({altloc_id,}.union(BLANK_ALTLOC_IDS)), ) else: mask = atom_array.altloc_id == altloc_id diff --git a/src/sampleworks/utils/density_utils.py b/src/sampleworks/utils/density_utils.py index cdbe4b87..4241b4e7 100644 --- a/src/sampleworks/utils/density_utils.py +++ b/src/sampleworks/utils/density_utils.py @@ -152,7 +152,7 @@ def compute_density_from_atomarray( atom_array, device ) - # need to make sure these all have the same batch dimension or the transformer will fail. + # need to make sure these all have the same batch dimension. elements = elements.expand(coords.shape[0], -1) b_factors = b_factors.expand(coords.shape[0], -1) occupancies = occupancies.expand(coords.shape[0], -1) From cfa619a6b7dc3a2bbb4ea158419ab3bdbbe1448e Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Fri, 27 Feb 2026 09:56:26 -0800 Subject: [PATCH 2/6] [WIP] Partially functional multi-selection code. --- scripts/eval/lddt_evaluation_script.py | 97 ++++--- scripts/eval/rscc_grid_search_script.py | 266 +++++++----------- .../eval/grid_search_eval_utils.py | 7 +- 3 files changed, 159 insertions(+), 211 deletions(-) diff --git a/scripts/eval/lddt_evaluation_script.py b/scripts/eval/lddt_evaluation_script.py index 6f0288cc..35ca4509 100644 --- a/scripts/eval/lddt_evaluation_script.py +++ b/scripts/eval/lddt_evaluation_script.py @@ -1,4 +1,5 @@ import argparse +import itertools import re import sys import traceback @@ -10,7 +11,6 @@ from atomworks.io.utils.io_utils import load_any from biotite.structure import AtomArray, AtomArrayStack from loguru import logger -from sampleworks.eval.constants import OCCUPANCY_LEVELS from sampleworks.eval.eval_dataclasses import ProteinConfig from sampleworks.eval.grid_search_eval_utils import parse_args, scan_grid_search_results from sampleworks.eval.structure_utils import get_reference_atomarraystack @@ -175,7 +175,11 @@ def translate_selection(selection: str) -> str: # current selection strings are pymol like, and we want to convert to atomworks/pandas like # this should be a temporary measure only until we switch to atomworks style in the RSCC script - logger.warning( + if any(x in selection for x in ("==", ">", "<", "<=", ">=", " in ")): + # assume this is already atomworks/pandas style and ignore. + return selection + + DeprecationWarning( "DEPRECATED: translate_selection converts from some pymol-like selection strings to " "AtomWorks selection strings, but is not guaranteed to be correct for all cases." ) @@ -215,8 +219,7 @@ def main(args: argparse.Namespace): logger.info("Pre-loading reference structures for each protein for coordinate extraction") reference_atom_arrays = {} for protein_key, protein_config in protein_configs.items(): - # TODO: should the reference occupancy should be specified in the config? - for occ in OCCUPANCY_LEVELS: + for occ, sel in itertools.product(args.occupancies, protein_config.selection): ref_path, reference_proteins = get_reference_atomarraystack(protein_config, occ) if reference_proteins is None: logger.warning( @@ -227,10 +230,10 @@ def main(args: argparse.Namespace): logger.info( f"Loaded ref structure for {protein_key} and occupancy {occ}: {ref_path}" ) - # ignoring the altloc_id and occupancy arrays - reference_protein_stack, _, _ = map_altlocs_to_stack(reference_proteins) - if reference_proteins is not None: - reference_atom_arrays[(protein_key, occ)] = reference_protein_stack + reference_protein_stack, _, _ = map_altlocs_to_stack( + reference_proteins, selection=translate_selection(sel), return_full_array=True + ) + reference_atom_arrays[(protein_key, occ, sel)] = reference_protein_stack except Exception as e: logger.error( f"Error loading ref structure for {protein_key} and occupancy {occ}: {e}" @@ -238,47 +241,55 @@ def main(args: argparse.Namespace): logger.error(f" Traceback: {traceback.format_exc()}") all_results = [] + # TODO parallelize this loop? It will require replicating `reference_atom_arrays` + # https://github.com/diff-use/sampleworks/issues/98 for _i, _exp in enumerate(all_experiments): - result = _exp.__dict__.copy() - if _exp.protein not in protein_configs: + if _exp.protein in protein_configs: + protein = _exp.protein + elif _exp.protein.upper() in protein_configs: + protein = _exp.protein.upper() + else: logger.warning(f"Skipping protein with no configuration: {_exp.protein}") continue - protein_config = protein_configs[_exp.protein] + protein_config = protein_configs[protein] - atom_array_key = (_exp.protein, _exp.occ_a) - if atom_array_key not in reference_atom_arrays: - logger.warning( - f"Skipping {_exp.protein_dir_name}: no reference atom array stack available " - f"for {_exp.protein} and occupancy {_exp.occ_a}." - ) - continue + for selection in protein_config.selection: + result = _exp.__dict__.copy() + result["selection"] = selection + atom_array_key = (protein, _exp.occ_a, selection) + if atom_array_key not in reference_atom_arrays: + logger.warning( + f"Skipping {_exp.protein_dir_name}: no reference atom array stack available " + f"for {_exp.protein}, occupancy {_exp.occ_a} and selection '{selection}'." + ) + continue + + try: + reference_atom_array_stack = reference_atom_arrays[atom_array_key] + # generated structures shouldn't have altlocs, don't need altloc="all". + predicted_atom_array_stack = load_any(_exp.refined_cif_path) + clustering_results = nn_lddt_clustering( + reference_atom_array_stack, + ensure_atom_array_stack(predicted_atom_array_stack), + translate_selection(selection), + ) + + result.update( + { + k: clustering_results[k] + for k in ("occupancies", "avg_silhouette", "avg_silhouette_to_ref") + } + ) + except Exception as e: + logger.error(f"Error processing experiment {_exp.exp_dir}: {e}") + logger.error(f" Traceback: {traceback.format_exc()}") + result["error"] = str(e) + result["avg_silhouette"] = np.nan + result["avg_silhouette_to_ref"] = np.nan + result["occupancies"] = [] - try: - reference_atom_array_stack = reference_atom_arrays[atom_array_key] - # these shouldn't have altlocs, don't need altloc="all". - predicted_atom_array_stack = load_any(_exp.refined_cif_path) - clustering_results = nn_lddt_clustering( - reference_atom_array_stack, - ensure_atom_array_stack(predicted_atom_array_stack), - translate_selection(protein_config.selection), - ) - - result.update( - { - k: clustering_results[k] - for k in ("occupancies", "avg_silhouette", "avg_silhouette_to_ref") - } - ) - except Exception as e: - logger.error(f"Error processing experiment {_exp.exp_dir}: {e}") - logger.error(f" Traceback: {traceback.format_exc()}") - result["error"] = str(e) - result["avg_silhouette"] = np.nan - result["avg_silhouette_to_ref"] = np.nan - result["occupancies"] = [] - - all_results.append(result) + all_results.append(result) df = pd.DataFrame(all_results) df.to_csv(grid_search_dir / "lddt_results.csv", index=False) diff --git a/scripts/eval/rscc_grid_search_script.py b/scripts/eval/rscc_grid_search_script.py index 3a712c24..8d9152a5 100644 --- a/scripts/eval/rscc_grid_search_script.py +++ b/scripts/eval/rscc_grid_search_script.py @@ -27,7 +27,7 @@ from atomworks.io.parser import parse from loguru import logger from sampleworks.core.forward_models.xray.real_space_density_deps.qfit.volume import XMap -from sampleworks.eval.constants import DEFAULT_SELECTION_PADDING, OCCUPANCY_LEVELS +from sampleworks.eval.constants import DEFAULT_SELECTION_PADDING from sampleworks.eval.eval_dataclasses import ProteinConfig from sampleworks.eval.grid_search_eval_utils import parse_args, scan_grid_search_results from sampleworks.eval.metrics import rscc @@ -38,9 +38,10 @@ from sampleworks.utils.density_utils import compute_density_from_atomarray +# TODO consolidate eval script logic: https://github.com/diff-use/sampleworks/issues/93 def main(args: argparse.Namespace): workspace_root = Path(args.workspace_root) - grid_search_dir = workspace_root / "grid_search_results" # TODO make more general + grid_search_dir = workspace_root / "grid_search_results" # Protein configurations: base map paths, structure selections, and resolutions protein_inputs_dir = args.grid_search_inputs_path or workspace_root @@ -52,13 +53,15 @@ def main(args: argparse.Namespace): # Test base map path resolution logger.debug("Testing base map path resolution:") for _, config in protein_configs.items(): - for _occ in OCCUPANCY_LEVELS: # TODO make configurable + for _occ in args.occupancies: _path = config.get_base_map_path_for_occupancy(_occ) # will warn if not found if _path: logger.debug(f" {config.protein} occ={_occ}: {_path}") # Scan for experiments (look for refined.cif files) - all_experiments = scan_grid_search_results(grid_search_dir) + all_experiments = scan_grid_search_results( + grid_search_dir, target_filename=args.target_filename + ) logger.info(f"Found {len(all_experiments)} experiments with refined.cif files") if all_experiments: @@ -67,14 +70,12 @@ def main(args: argparse.Namespace): logger.info("Pre-loading reference structures for each protein for coordinate extraction") ref_coords = {} for protein_key, protein_config in protein_configs.items(): - # NOTE THAT THIS will be by default _two_ structures, one computed for occupancy 0 and - # one for occupancy 1 of altloc A. Historically this is because these coordinates are - # used here only as a mask for map comparisons. - # TODO: change that method to return the coordinates for occupancy 0 and 1 separately, - # and then we can merge them here. + # NOTE THAT THIS will be by default include all altlocs, as we use them to create a mask + # for where to judge the maps' correlation. protein_ref_coords = get_reference_structure_coords(protein_config, protein_key) if protein_ref_coords is not None: - ref_coords[protein_key] = protein_ref_coords + for selection in protein_ref_coords.keys(): + ref_coords[(protein_key, selection)] = protein_ref_coords[selection] # Calculate RSCC for all experiments # (BIG) TODO: implement a sliding-window version (global can be achieved with diff't selections. @@ -87,91 +88,110 @@ def main(args: argparse.Namespace): logger.info(f"Using device: {_device}") results = [] - base_map_cache: dict[tuple[str, float], tuple[XMap, XMap]] = {} + base_map_cache: dict[tuple[str, float, str], tuple[XMap, XMap]] = {} # TODO parallelize this loop? It uses GPU, so be careful. for _i, _exp in enumerate(all_experiments): - if _exp.protein not in protein_configs: + if _exp.protein in protein_configs: + protein = _exp.protein + elif _exp.protein.upper() in protein_configs: + protein = _exp.protein.upper() + else: logger.warning(f"Skipping protein with no configuration: {_exp.protein}") continue - protein_config = protein_configs[_exp.protein] + protein_config = protein_configs[protein] + for selection in protein_config.selection: + # Check if we have reference coordinates for region extraction + if (protein, selection) not in ref_coords: + logger.warning( + f"Skipping {_exp.protein_dir_name}/{selection}: no reference structure " + f"available for {_exp.protein}, this may be due to a selection with zero atoms " + f"or NaN/Inf coordinates. Check logs above." + ) + continue + + _selection_coords = ref_coords[(protein, selection)] + _base_map_path = protein_config.get_base_map_path_for_occupancy(_exp.occ_a) + if _base_map_path is None: + logger.warning( + f"Skipping {_exp.protein_dir_name}: base map for selection {selection} and " + f"occupancy {_exp.occ_a} not found" + ) + continue + + try: + # TODO: this needs to be better unified with what's in generate_synthetic_density + # + # Load base map for canonical unit cell, + # don't overwrite the base map with selection map--we'll use the full map later too. + if (protein, _exp.occ_a, selection) not in base_map_cache: + _base_xmap = protein_config.load_map(_base_map_path) + if _base_xmap is None: + raise ValueError(f"Failed to load base map from {_base_map_path}") + + # Extract the region around altloc residues from the base map, using the + # union of boxes around each atom. _extracted_base is no longer an XMap + _, _extracted_base = _base_xmap.extract_tight( + _selection_coords, padding=DEFAULT_SELECTION_PADDING + ) + logger.info( + f"Caching base and subselected maps for {protein} " + f"occ_a={_exp.occ_a} selection={selection}" + ) + base_map_cache[(protein, _exp.occ_a, selection)] = (_base_xmap, _extracted_base) + else: + _base_xmap, _extracted_base = base_map_cache[(protein, _exp.occ_a, selection)] - # Check if we have reference coordinates for region extraction - if _exp.protein not in ref_coords: - logger.warning( - f"Skipping {_exp.protein_dir_name}: no reference structure available " - f"for {_exp.protein}, this may be due to a selection with zero atoms " - f"or NaN/Inf coordinates. Check logs above." - ) - continue + # Validate extraction + if _extracted_base is None or _extracted_base.shape[0] == 0: + raise ValueError(f"Extracted base map from {_base_map_path} is empty") - _selection_coords = ref_coords[_exp.protein] - _base_map_path = protein_config.get_base_map_path_for_occupancy(_exp.occ_a) - if _base_map_path is None: - logger.warning( - f"Skipping {_exp.protein_dir_name}: base map for occupancy {_exp.occ_a} not found" - ) - continue + # Load refined structure + _structure = parse(_exp.refined_cif_path, ccd_mirror_path=None) - try: - # Load base map for canonical unit cell, - # don't overwrite the base map with selection map as we'll use the full map later too. - if (_exp.protein, _exp.occ_a) not in base_map_cache: - _base_xmap = protein_config.load_map(_base_map_path) - if _base_xmap is None: - raise ValueError(f"Failed to load base map from {_base_map_path}") - - # Extract the region around altloc residues from the base map, using the - # union of boxes around each atom. _extracted_base is no longer an XMap - _, _extracted_base = _base_xmap.extract_tight( - _selection_coords, padding=DEFAULT_SELECTION_PADDING - ) - logger.info( - f"Caching base and subselected maps for {_exp.protein} occ_a={_exp.occ_a}" - ) - base_map_cache[(_exp.protein, _exp.occ_a)] = (_base_xmap, _extracted_base) - else: - _base_xmap, _extracted_base = base_map_cache[(_exp.protein, _exp.occ_a)] + # Compute density from refined structure + atom_array = get_asym_unit_from_structure(_structure) + if not hasattr(atom_array, "coord") or atom_array.coord is None: + raise AttributeError("AtomArray | AtomArrayStack is missing coordinates") - # Validate extraction - if _extracted_base is None or _extracted_base.shape[0] == 0: - raise ValueError(f"Extracted base map from {_base_map_path} is empty") + if not hasattr(atom_array, "b_factor"): + logger.warning( + f"No b-factor array found in {_exp.refined_cif_path}, setting to 20." + ) + atom_array.set_annotation("b_factor", np.full(atom_array.coord.shape[-2], 20.0)) - # Load refined structure - _structure = parse(_exp.refined_cif_path, ccd_mirror_path=None) + _computed_density, _ = compute_density_from_atomarray( + atom_array, xmap=_base_xmap, em_mode=False, device=_device + ) - # Compute density from refined structure - atom_array = get_asym_unit_from_structure(_structure) - _computed_density, _ = compute_density_from_atomarray( - atom_array, xmap=_base_xmap, em_mode=False, device=_device - ) + # Create an XMap from the computed density by copying the base xmap + # and replacing its array with the computed density + _computed_xmap = copy.deepcopy(_base_xmap) + _computed_xmap.array = _computed_density.cpu().numpy().squeeze() + _, _extracted_computed = _computed_xmap.extract_tight( + _selection_coords, padding=DEFAULT_SELECTION_PADDING + ) - # Create an XMap from the computed density by copying the base xmap - # and replacing its array with the computed density - _computed_xmap = copy.deepcopy(_base_xmap) - _computed_xmap.array = _computed_density.cpu().numpy().squeeze() - _, _extracted_computed = _computed_xmap.extract_tight( - _selection_coords, padding=DEFAULT_SELECTION_PADDING - ) + # Validate extraction + if _extracted_computed is None or _extracted_computed.shape[0] == 0: + raise ValueError("Extracted computed map is empty") - # Validate extraction - if _extracted_computed is None or _extracted_computed.shape[0] == 0: - raise ValueError("Extracted computed map is empty") + # Calculate RSCC on extracted regions + # TODO: don't alter the input object _exp, just get a copy of it as a dict. + _exp.rscc = rscc(_extracted_base, _extracted_computed) + _exp.base_map_path = _base_map_path - # Calculate RSCC on extracted regions - _exp.rscc = rscc(_extracted_base, _extracted_computed) - _exp.base_map_path = _base_map_path + except Exception as _e: + logger.error(f"ERROR processing {_exp.exp_dir}: {_e}") + logger.error(f" Traceback: {traceback.format_exc()}") + _exp.error = _e + _exp.rscc = np.nan # this is the default, but better to be explicit. + _exp.base_map_path = _base_map_path - except Exception as _e: - logger.error(f"ERROR processing {_exp.exp_dir}: {_e}") - logger.error(f" Traceback: {traceback.format_exc()}") - _exp.error = _e - _exp.rscc = np.nan # this is the default, but better to be explicit. - _exp.base_map_path = _base_map_path + exp_dict_copy = _exp.__dict__.copy() + exp_dict_copy["selection"] = selection + results.append(exp_dict_copy) - exp_dict_copy = _exp.__dict__.copy() - exp_dict_copy["selection"] = protein_config.selection - results.append(exp_dict_copy) if (_i + 1) % 10 == 0 or _i == 0: logger.debug( f" [{_i + 1}/{len(all_experiments)}] {_exp.protein_dir_name} / " @@ -206,94 +226,6 @@ def main(args: argparse.Namespace): ) logger.info(summary) - # Calculate correlation between base maps and pure conformer maps - logger.info("Calculating correlations between base maps and pure conformer maps...") - logger.info("This shows how well single conformers explain occupancy-mixed data") - - base_pure_correlations = [] - - for protein_key, protein_config in protein_configs.items(): - if protein_key not in ref_coords: - print(f"Skipping {protein_key}: no reference coordinates available") - continue - - # We re-use the selection coordinates from the reference structure computed - # at 0.5 occupancy above. - _selection_coords = ref_coords[protein_key] - - logger.info(f"\nProcessing {protein_key} single conformer explanatory power:") - map_path_1occA = protein_config.get_base_map_path_for_occupancy(1.0) - map_path_1occB = protein_config.get_base_map_path_for_occupancy(0.0) - if map_path_1occA is None or map_path_1occB is None: - logger.warning(f"Skipping {protein_key}: pure conformer maps not found") - continue - try: - # Load pure conformer maps--returns canonical unit cell by default, - # extract selection with padding 0.0 - _extracted_pure_A = protein_config.load_map( - map_path_1occA, selection_coords=_selection_coords - ) - _extracted_pure_B = protein_config.load_map( - map_path_1occB, selection_coords=_selection_coords - ) - - logger.info( - f" Pure A reference: {map_path_1occA}\n Pure B reference: {map_path_1occB}" - ) - - # Calculate correlations for each occupancy - for _occ_a in OCCUPANCY_LEVELS: # TODO make configurable - try: - _base_map_path = protein_config.get_base_map_path_for_occupancy(_occ_a) - if _base_map_path is None: # map file not found, will warn. - continue - - logger.info(f" Processing occ_A={_occ_a}: {_base_map_path.name}") - - # Load the base map for this occupancy, and do the selection - # default padding is zero. - _extracted_base = protein_config.load_map( - _base_map_path, selection_coords=_selection_coords - ) - - if ( - _extracted_base is None - or _extracted_pure_A is None - or _extracted_pure_B is None - ): - raise ValueError("One of the extracted maps is empty") - - # Calculate correlations - _corr_base_vs_pureA = rscc(_extracted_base.array, _extracted_pure_A.array) - _corr_base_vs_pureB = rscc(_extracted_base.array, _extracted_pure_B.array) - - base_pure_correlations.append( - { - "protein": protein_key, - "occ_a": _occ_a, - "base_vs_1occA": _corr_base_vs_pureA, - "base_vs_1occB": _corr_base_vs_pureB, - } - ) - - logger.info(f" Base map vs pure A: {_corr_base_vs_pureA:.4f}") - logger.info(f" Base map vs pure B: {_corr_base_vs_pureB:.4f}") - - except Exception as _e: - logger.error(f" Error processing occ_A={_occ_a} for {protein_key}: {_e}") - logger.error(f" Traceback: {traceback.format_exc()}") - - except Exception as _e: - logger.error(f"Error calculating correlations for {protein_key}: {_e}") - logger.error(f" Traceback: {traceback.format_exc()}") - - df_base_vs_pure = pd.DataFrame(base_pure_correlations) - df.to_csv(grid_search_dir / "rscc_results_for_pure_conformer_maps.csv", index=False) - logger.info( - f"\nCalculated single conformer explanatory power for " - f"{len(df_base_vs_pure)} occupancy points" - ) - if __name__ == "__main__": args = parse_args("Evaluate RSCC on grid search results.") diff --git a/src/sampleworks/eval/grid_search_eval_utils.py b/src/sampleworks/eval/grid_search_eval_utils.py index cec2fd5a..3cb26e7a 100644 --- a/src/sampleworks/eval/grid_search_eval_utils.py +++ b/src/sampleworks/eval/grid_search_eval_utils.py @@ -183,7 +183,12 @@ def parse_args(description: str | None = None): "--occupancies", nargs="+", type=float, - help="Occupancies to evaluate", + help=f"Occupancies to evaluate, defaults to {OCCUPANCY_LEVELS}", default=OCCUPANCY_LEVELS, ) + parser.add_argument( + "--target-filename", + default="refined.cif", + help="Target filename for the CIF files to process, defaults to 'refined.cif'", + ) return parser.parse_args() From 50ef437e44273ab9fb231304d8be7b1be59773ce Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Fri, 27 Feb 2026 13:43:09 -0800 Subject: [PATCH 3/6] Prep work for parallelization of LDDT eval script --- scripts/eval/lddt_evaluation_script.py | 145 +++++++++++++++++++------ 1 file changed, 111 insertions(+), 34 deletions(-) diff --git a/scripts/eval/lddt_evaluation_script.py b/scripts/eval/lddt_evaluation_script.py index 35ca4509..1940af42 100644 --- a/scripts/eval/lddt_evaluation_script.py +++ b/scripts/eval/lddt_evaluation_script.py @@ -11,7 +11,7 @@ from atomworks.io.utils.io_utils import load_any from biotite.structure import AtomArray, AtomArrayStack from loguru import logger -from sampleworks.eval.eval_dataclasses import ProteinConfig +from sampleworks.eval.eval_dataclasses import ProteinConfig, Experiment from sampleworks.eval.grid_search_eval_utils import parse_args, scan_grid_search_results from sampleworks.eval.structure_utils import get_reference_atomarraystack from sampleworks.metrics.lddt import AllAtomLDDT @@ -134,7 +134,9 @@ def nn_lddt_clustering( # Second, compute the cross-LDDT matrix between all structures in the predicted stack and # all structures in the reference stack - cross_lddt_matrix = compute_cross_lddts(ref_atom_array_stack, pred_atom_array_stack, selection) + cross_lddt_matrix = compute_cross_lddts( + ref_atom_array_stack, pred_atom_array_stack, selection + ) # Assign each predicted structure to the closest reference structure based on LDDT score (i.e., # the reference structure with the highest LDDT score is assigned to the predicted structure) @@ -233,7 +235,11 @@ def main(args: argparse.Namespace): reference_protein_stack, _, _ = map_altlocs_to_stack( reference_proteins, selection=translate_selection(sel), return_full_array=True ) - reference_atom_arrays[(protein_key, occ, sel)] = reference_protein_stack + # hierarchical dictionary cache makes it lighter weight to parallelize. + if (protein_key, occ) not in reference_atom_arrays: + reference_atom_arrays[(protein_key, occ)] = {} + + reference_atom_arrays[(protein_key, occ)][sel] = reference_protein_stack except Exception as e: logger.error( f"Error loading ref structure for {protein_key} and occupancy {occ}: {e}" @@ -243,58 +249,129 @@ def main(args: argparse.Namespace): all_results = [] # TODO parallelize this loop? It will require replicating `reference_atom_arrays` # https://github.com/diff-use/sampleworks/issues/98 - for _i, _exp in enumerate(all_experiments): + + # Do the quick pass through all the "rows" of our output table to filter in those we can run. + filtered_experiments = [] + for _exp in all_experiments: + if _exp.protein in protein_configs: protein = _exp.protein elif _exp.protein.upper() in protein_configs: protein = _exp.protein.upper() + elif _exp.protein.lower() in protein_configs: + protein = _exp.protein.lower() else: + # These we just skip over--we assume that the user has told us via the config file + # what results they are interested in. logger.warning(f"Skipping protein with no configuration: {_exp.protein}") continue protein_config = protein_configs[protein] + if protein_config.protein != protein: + raise ValueError( + f"Protein name mismatch: expected {protein_config.protein}, got {protein}, make" + f"sure you loaded your protein configs with ProteinConfig.from_csv()." + ) + + null_results = [] + if (protein, _exp.occ_a) not in reference_atom_arrays: + logger.warning( + f"Skipping {_exp.protein_dir_name}: no reference atom array stack available " + f"for {_exp.protein}, occupancy {_exp.occ_a}." + ) + # record empty results for all selections, indicating they could not be computed. + for _sel in protein_config.selection: + exp_copy = _exp.__dict__.copy() + exp_copy["selection"] = _sel + null_results.append(exp_copy) + continue - for selection in protein_config.selection: - result = _exp.__dict__.copy() - result["selection"] = selection - atom_array_key = (protein, _exp.occ_a, selection) - if atom_array_key not in reference_atom_arrays: + protein_reference_atom_arrays = reference_atom_arrays[(protein, _exp.occ_a)] + for _sel in protein_config.selection: + if _sel not in protein_reference_atom_arrays: logger.warning( f"Skipping {_exp.protein_dir_name}: no reference atom array stack available " - f"for {_exp.protein}, occupancy {_exp.occ_a} and selection '{selection}'." + f"for {_exp.protein}, occupancy {_exp.occ_a} and selection '{_sel}'." ) + exp_copy = _exp.__dict__.copy() + exp_copy["selection"] = _sel + null_results.append(exp_copy) continue - try: - reference_atom_array_stack = reference_atom_arrays[atom_array_key] - # generated structures shouldn't have altlocs, don't need altloc="all". - predicted_atom_array_stack = load_any(_exp.refined_cif_path) - clustering_results = nn_lddt_clustering( - reference_atom_array_stack, - ensure_atom_array_stack(predicted_atom_array_stack), - translate_selection(selection), - ) - - result.update( - { - k: clustering_results[k] - for k in ("occupancies", "avg_silhouette", "avg_silhouette_to_ref") - } - ) - except Exception as e: - logger.error(f"Error processing experiment {_exp.exp_dir}: {e}") - logger.error(f" Traceback: {traceback.format_exc()}") - result["error"] = str(e) - result["avg_silhouette"] = np.nan - result["avg_silhouette_to_ref"] = np.nan - result["occupancies"] = [] + px_seln_refernce_atom_array = protein_reference_atom_arrays[_sel] + filtered_experiments.append( + (_exp, protein_config, px_seln_refernce_atom_array, _sel) + ) - all_results.append(result) + # now we can more easily parallelize this loop. + logger.debug("Starting LDDT evaluation loop. This may take a while...") + for _i, (_exp, protein_config, px_seln_refernce_atom_array, selection) in enumerate(filtered_experiments): + result = process_exp_with_selection(_exp, protein_config, px_seln_refernce_atom_array, selection) + all_results.append(result) df = pd.DataFrame(all_results) df.to_csv(grid_search_dir / "lddt_results.csv", index=False) +def process_exp_with_selection( + exp: Experiment, + protein_config: ProteinConfig, + px_seln_refernce_atom_array: AtomArrayStack, + selection_string: str +) -> dict[str, str | float | list[float]]: + """ + + Parameters + ---------- + exp: Experiment, a description of the structure generation experiment + protein_config: ProteinConfig, specifying the locations of reference structures and maps + px_seln_refernce_atom_array: AtomArrayStack, + the atom array stack for the reference structure, which in principle could be fetched + using the protein_config, but for efficiency we load once previously and pass in here, + since this method will run many times in parallel using the same structure + selection_string: str, the selection string for the evaluation + + Returns + ------- + A dictionary of results of LDDT-based clustering that can be collated in a dataframe. + In addition to the data in the `exp` object, this dictionary contains: + - occupancies: list[float], the occupancies of the selected atoms, computed as the + fraction of structures in the experiment that are closest to one or the other + altloc of the reference structure. . + - avg_silhouette: float, the average silhouette score for the LDDT-based clustering + - avg_silhouette_to_ref: float, + a sort of silhouette score, measuring each structure's relative "closeness" to the + assigned reference altloc. + """ + logger.debug(f"Evaluating selection {selection_string} for protein {protein_config}") + result = exp.__dict__.copy() + result["selection"] = selection_string + + try: + # generated structures shouldn't have altlocs, don't need altloc="all". + predicted_atom_array_stack = load_any(exp.refined_cif_path) + clustering_results = nn_lddt_clustering( + px_seln_refernce_atom_array, + ensure_atom_array_stack(predicted_atom_array_stack), + translate_selection(selection_string), + ) + + lddt_result_keys = ("occupancies", "avg_silhouette", "avg_silhouette_to_ref") + result.update({k: clustering_results[k] for k in lddt_result_keys}) + + logger.info(f"Successfully processed {exp.protein_dir_name} w/ selection {selection_string}") + + except Exception as e: + logger.error(f"Error processing experiment {exp.exp_dir}: {e}") + logger.error(f" Traceback: {traceback.format_exc()}") + result["error"] = str(e) + result["avg_silhouette"] = np.nan + result["avg_silhouette_to_ref"] = np.nan + result["occupancies"] = [] + + return result + + if __name__ == "__main__": args = parse_args("Evaluate LDDT on grid search results.") main(args) From ba249f1397103d7e18c2e55664b7016bbbf37b33 Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Fri, 27 Feb 2026 14:04:25 -0800 Subject: [PATCH 4/6] Parallelizing LDDT calculations with joblib. Parallelizing LDDT script with joblib. Closes https://github.com/diff-use/sampleworks/issues/98 --- scripts/eval/lddt_evaluation_script.py | 17 ++++++++--------- src/sampleworks/eval/grid_search_eval_utils.py | 6 ++++++ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/scripts/eval/lddt_evaluation_script.py b/scripts/eval/lddt_evaluation_script.py index 1940af42..b54d255c 100644 --- a/scripts/eval/lddt_evaluation_script.py +++ b/scripts/eval/lddt_evaluation_script.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd +from joblib import Parallel, delayed from atomworks.io.transforms.atom_array import ensure_atom_array_stack from atomworks.io.utils.io_utils import load_any from biotite.structure import AtomArray, AtomArrayStack @@ -246,12 +247,9 @@ def main(args: argparse.Namespace): ) logger.error(f" Traceback: {traceback.format_exc()}") - all_results = [] - # TODO parallelize this loop? It will require replicating `reference_atom_arrays` - # https://github.com/diff-use/sampleworks/issues/98 - # Do the quick pass through all the "rows" of our output table to filter in those we can run. filtered_experiments = [] + null_results = [] for _exp in all_experiments: if _exp.protein in protein_configs: @@ -273,7 +271,7 @@ def main(args: argparse.Namespace): f"sure you loaded your protein configs with ProteinConfig.from_csv()." ) - null_results = [] + if (protein, _exp.occ_a) not in reference_atom_arrays: logger.warning( f"Skipping {_exp.protein_dir_name}: no reference atom array stack available " @@ -305,11 +303,12 @@ def main(args: argparse.Namespace): # now we can more easily parallelize this loop. logger.debug("Starting LDDT evaluation loop. This may take a while...") - for _i, (_exp, protein_config, px_seln_refernce_atom_array, selection) in enumerate(filtered_experiments): - result = process_exp_with_selection(_exp, protein_config, px_seln_refernce_atom_array, selection) - all_results.append(result) + all_results = Parallel(n_jobs=args.n_jobs)( + delayed(process_exp_with_selection)(_exp, protein_config, px_seln_refernce_atom_array, selection) + for _exp, protein_config, px_seln_refernce_atom_array, selection in filtered_experiments + ) - df = pd.DataFrame(all_results) + df = pd.DataFrame(null_results + all_results) df.to_csv(grid_search_dir / "lddt_results.csv", index=False) diff --git a/src/sampleworks/eval/grid_search_eval_utils.py b/src/sampleworks/eval/grid_search_eval_utils.py index 3cb26e7a..f9434316 100644 --- a/src/sampleworks/eval/grid_search_eval_utils.py +++ b/src/sampleworks/eval/grid_search_eval_utils.py @@ -191,4 +191,10 @@ def parse_args(description: str | None = None): default="refined.cif", help="Target filename for the CIF files to process, defaults to 'refined.cif'", ) + parser.add_argument( + "--n-jobs", + type=int, + help="Number of parallel jobs to run. -1 uses all CPUs.", + default=16, + ) return parser.parse_args() From 51ee774368e3b9236627d2f8e87049e3bd5cc06e Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Sun, 1 Mar 2026 03:04:01 +0000 Subject: [PATCH 5/6] ruff/ty fixes --- scripts/eval/lddt_evaluation_script.py | 30 +++++++++---------- .../core/rewards/real_space_density.py | 3 +- src/sampleworks/eval/structure_utils.py | 10 ++++--- src/sampleworks/utils/atom_array_utils.py | 17 ++++++----- 4 files changed, 31 insertions(+), 29 deletions(-) diff --git a/scripts/eval/lddt_evaluation_script.py b/scripts/eval/lddt_evaluation_script.py index b54d255c..d0deae9e 100644 --- a/scripts/eval/lddt_evaluation_script.py +++ b/scripts/eval/lddt_evaluation_script.py @@ -7,12 +7,12 @@ import numpy as np import pandas as pd -from joblib import Parallel, delayed from atomworks.io.transforms.atom_array import ensure_atom_array_stack from atomworks.io.utils.io_utils import load_any from biotite.structure import AtomArray, AtomArrayStack +from joblib import delayed, Parallel from loguru import logger -from sampleworks.eval.eval_dataclasses import ProteinConfig, Experiment +from sampleworks.eval.eval_dataclasses import Experiment, ProteinConfig from sampleworks.eval.grid_search_eval_utils import parse_args, scan_grid_search_results from sampleworks.eval.structure_utils import get_reference_atomarraystack from sampleworks.metrics.lddt import AllAtomLDDT @@ -135,9 +135,7 @@ def nn_lddt_clustering( # Second, compute the cross-LDDT matrix between all structures in the predicted stack and # all structures in the reference stack - cross_lddt_matrix = compute_cross_lddts( - ref_atom_array_stack, pred_atom_array_stack, selection - ) + cross_lddt_matrix = compute_cross_lddts(ref_atom_array_stack, pred_atom_array_stack, selection) # Assign each predicted structure to the closest reference structure based on LDDT score (i.e., # the reference structure with the highest LDDT score is assigned to the predicted structure) @@ -251,7 +249,6 @@ def main(args: argparse.Namespace): filtered_experiments = [] null_results = [] for _exp in all_experiments: - if _exp.protein in protein_configs: protein = _exp.protein elif _exp.protein.upper() in protein_configs: @@ -271,7 +268,6 @@ def main(args: argparse.Namespace): f"sure you loaded your protein configs with ProteinConfig.from_csv()." ) - if (protein, _exp.occ_a) not in reference_atom_arrays: logger.warning( f"Skipping {_exp.protein_dir_name}: no reference atom array stack available " @@ -297,14 +293,14 @@ def main(args: argparse.Namespace): continue px_seln_refernce_atom_array = protein_reference_atom_arrays[_sel] - filtered_experiments.append( - (_exp, protein_config, px_seln_refernce_atom_array, _sel) - ) + filtered_experiments.append((_exp, protein_config, px_seln_refernce_atom_array, _sel)) # now we can more easily parallelize this loop. logger.debug("Starting LDDT evaluation loop. This may take a while...") all_results = Parallel(n_jobs=args.n_jobs)( - delayed(process_exp_with_selection)(_exp, protein_config, px_seln_refernce_atom_array, selection) + delayed(process_exp_with_selection)( + _exp, protein_config, px_seln_refernce_atom_array, selection + ) for _exp, protein_config, px_seln_refernce_atom_array, selection in filtered_experiments ) @@ -313,10 +309,10 @@ def main(args: argparse.Namespace): def process_exp_with_selection( - exp: Experiment, - protein_config: ProteinConfig, - px_seln_refernce_atom_array: AtomArrayStack, - selection_string: str + exp: Experiment, + protein_config: ProteinConfig, + px_seln_refernce_atom_array: AtomArrayStack, + selection_string: str, ) -> dict[str, str | float | list[float]]: """ @@ -358,7 +354,9 @@ def process_exp_with_selection( lddt_result_keys = ("occupancies", "avg_silhouette", "avg_silhouette_to_ref") result.update({k: clustering_results[k] for k in lddt_result_keys}) - logger.info(f"Successfully processed {exp.protein_dir_name} w/ selection {selection_string}") + logger.info( + f"Successfully processed {exp.protein_dir_name} w/ selection {selection_string}" + ) except Exception as e: logger.error(f"Error processing experiment {exp.exp_dir}: {e}") diff --git a/src/sampleworks/core/rewards/real_space_density.py b/src/sampleworks/core/rewards/real_space_density.py index 762fa996..c5351964 100644 --- a/src/sampleworks/core/rewards/real_space_density.py +++ b/src/sampleworks/core/rewards/real_space_density.py @@ -2,7 +2,6 @@ import numpy as np import torch -from atomworks.io.transforms.atom_array import ensure_atom_array_stack from biotite.structure import AtomArray, AtomArrayStack from jaxtyping import ArrayLike, Float, Int from loguru import logger @@ -44,7 +43,7 @@ def setup_scattering_params( containing scattering coefficients for each element type """ elements = atom_array.element - unique_elements = sorted(set(normalize_element(e) for e in elements)) # ty: ignore[not-iterable] + unique_elements = sorted(set(normalize_element(e) for e in elements)) atomic_num_dict = {elem: ELEMENT_TO_ATOMIC_NUM[elem] for elem in unique_elements} structure_factors = ELECTRON_SCATTERING_FACTORS if em_mode else ATOM_STRUCTURE_FACTORS diff --git a/src/sampleworks/eval/structure_utils.py b/src/sampleworks/eval/structure_utils.py index 3e74c807..bd34f5eb 100644 --- a/src/sampleworks/eval/structure_utils.py +++ b/src/sampleworks/eval/structure_utils.py @@ -2,7 +2,7 @@ import traceback from dataclasses import dataclass, replace from pathlib import Path -from typing import cast, Any +from typing import Any, cast import numpy as np import torch @@ -314,10 +314,12 @@ def apply_selection(atom_array: AtomArray, selection: str | None) -> AtomArray: def get_mask_from_old_selection_string( - atom_array: AtomArray, selection: str + atom_array: AtomArray | AtomArrayStack, selection: str ) -> np.ndarray[tuple[int], np.dtype[Any]]: - DeprecationWarning(f"Using old-style selection strings like {selection} is deprecated." - f" Use atomworks/pandas style selection strings instead.") + DeprecationWarning( + f"Using old-style selection strings like {selection} is deprecated." + f" Use atomworks/pandas style selection strings instead." + ) chain_id, resi_start, resi_end = parse_selection_string(selection) # use the length of any of the required non-coord attributes to get the mask shape mask = np.ones(len(atom_array.res_id), dtype=bool) diff --git a/src/sampleworks/utils/atom_array_utils.py b/src/sampleworks/utils/atom_array_utils.py index 9f18d881..3e38bee1 100644 --- a/src/sampleworks/utils/atom_array_utils.py +++ b/src/sampleworks/utils/atom_array_utils.py @@ -193,7 +193,7 @@ def find_all_altloc_ids( Find all unique alternate location indicator (altloc) IDs in an AtomArray or AtomArrayStack. """ if hasattr(atom_array, altloc_label): - altloc_ids = np.unique(getattr(atom_array, altloc_label)) # ty: ignore[invalid-argument-type] + altloc_ids = np.unique(getattr(atom_array, altloc_label)) else: raise AttributeError( "atom_array must have altloc annotation, defaults to 'altloc_id'; " @@ -267,7 +267,7 @@ def map_altlocs_to_stack( if isinstance(atom_array, AtomArrayStack): if len(atom_array) > 1: raise ValueError("Cannot map altlocs with multiple structures each containing altlocs") - atom_array = atom_array[0] # ty: ignore[invalid-assignment] + atom_array = atom_array[0] if not hasattr(atom_array, "altloc_id") or atom_array.altloc_id is None: raise ValueError("The passed AtomArray | AtomArrayStack must have attribute 'altloc_id'") @@ -278,7 +278,6 @@ def map_altlocs_to_stack( "atomworks.io.utils.io_utils.load_any or a similar method that provides .mask, .query" ) - altloc_ids = sorted(list(find_all_altloc_ids(atom_array))) if selection is not None and return_full_array: # in this case, we select atoms with no altlocs, and atoms in the selection @@ -302,8 +301,8 @@ def map_altlocs_to_stack( # it's critical that we've updated the altloc id list above, or this will return only # residues with no altlocs in case one or more altloc ids is/are missing from the selection. atom_arrays = filter_to_common_atoms(*altloc_list) - altloc_ids = np.vstack([r.altloc_id for r in atom_arrays]) # ty: ignore - occupancies = np.vstack([r.occupancy for r in atom_arrays]) # ty: ignore + altloc_ids = np.vstack([r.altloc_id for r in atom_arrays]) + occupancies = np.vstack([r.occupancy for r in atom_arrays]) # remove those annotations or we cannot stack arrays. for array in atom_arrays: @@ -348,8 +347,12 @@ def select_altloc( if return_full_array: mask = np.isin( - atom_array.altloc_id, # ty: ignore[invalid-argument-type] - list({altloc_id,}.union(BLANK_ALTLOC_IDS)), + atom_array.altloc_id, + list( + { + altloc_id, + }.union(BLANK_ALTLOC_IDS) + ), ) else: mask = atom_array.altloc_id == altloc_id From 2945059bdb79fd13d0c72df9283e19bb854ef02e Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Mon, 2 Mar 2026 13:56:05 -0800 Subject: [PATCH 6/6] Addressing coderabbit and k-chrispen's f/b on PR 116 (new structure/atom selection capabilities for eval scripts) --- scripts/eval/rscc_grid_search_script.py | 17 +++++++++-------- src/sampleworks/eval/eval_dataclasses.py | 11 ++++++++++- src/sampleworks/eval/grid_search_eval_utils.py | 2 ++ src/sampleworks/eval/structure_utils.py | 7 ++++--- src/sampleworks/utils/atom_array_utils.py | 11 ++++++++++- src/sampleworks/utils/density_utils.py | 2 +- 6 files changed, 36 insertions(+), 14 deletions(-) diff --git a/scripts/eval/rscc_grid_search_script.py b/scripts/eval/rscc_grid_search_script.py index 8d9152a5..d9c77bfd 100644 --- a/scripts/eval/rscc_grid_search_script.py +++ b/scripts/eval/rscc_grid_search_script.py @@ -119,6 +119,9 @@ def main(args: argparse.Namespace): ) continue + exp_result = _exp.__dict__.copy() + exp_result["selection"] = selection + exp_result["error"] = None try: # TODO: this needs to be better unified with what's in generate_synthetic_density # @@ -178,19 +181,17 @@ def main(args: argparse.Namespace): # Calculate RSCC on extracted regions # TODO: don't alter the input object _exp, just get a copy of it as a dict. - _exp.rscc = rscc(_extracted_base, _extracted_computed) - _exp.base_map_path = _base_map_path + exp_result.rscc = rscc(_extracted_base, _extracted_computed) + exp_result.base_map_path = _base_map_path except Exception as _e: logger.error(f"ERROR processing {_exp.exp_dir}: {_e}") logger.error(f" Traceback: {traceback.format_exc()}") - _exp.error = _e - _exp.rscc = np.nan # this is the default, but better to be explicit. - _exp.base_map_path = _base_map_path + exp_result["error"] = _e + exp_result["rscc"] = np.nan # this is the default, but better to be explicit. + exp_result["base_map_path"] = _base_map_path - exp_dict_copy = _exp.__dict__.copy() - exp_dict_copy["selection"] = selection - results.append(exp_dict_copy) + results.append(exp_result) if (_i + 1) % 10 == 0 or _i == 0: logger.debug( diff --git a/src/sampleworks/eval/eval_dataclasses.py b/src/sampleworks/eval/eval_dataclasses.py index 7dc35c4e..e2db5b08 100644 --- a/src/sampleworks/eval/eval_dataclasses.py +++ b/src/sampleworks/eval/eval_dataclasses.py @@ -9,6 +9,8 @@ @dataclass +# TODO rename to make consistend w/ hub.diffuse.science +# https://github.com/diff-use/sampleworks/issues/122 class Experiment: protein: str occ_a: float @@ -50,6 +52,13 @@ class ProteinConfig: def __post_init__(self): # TODO validate structure patterns? Anything else we should check to avoid later errors? self.base_map_dir = Path(self.base_map_dir) # just in case someone passes a string + self.selection = [s.strip() for s in self.selection if self.is_selection_valid(s)] + if not self.selection: + raise ValueError(f"No valid selection provided for protein {self.protein}") + + def is_selection_valid(self, selection: str) -> bool: + return selection is not None and selection.strip() != "" + def get_base_map_path_for_occupancy(self, occupancy_a: float) -> Path | None: occ_str = occupancy_to_str(occupancy_a, use_6b8x_format=self.protein == "6b8x") @@ -190,7 +199,7 @@ def from_csv(cls, workspace_root: Path, csv_path: Path) -> dict[str, "ProteinCon config = cls( protein=protein, base_map_dir=base_map_dir, - selection=row["selection"].strip().split(";"), + selection=row["selection"].split(";"), resolution=resolution, map_pattern=row["map_pattern"].strip(), structure_pattern=structure_pattern, diff --git a/src/sampleworks/eval/grid_search_eval_utils.py b/src/sampleworks/eval/grid_search_eval_utils.py index f9434316..89dae772 100644 --- a/src/sampleworks/eval/grid_search_eval_utils.py +++ b/src/sampleworks/eval/grid_search_eval_utils.py @@ -17,6 +17,7 @@ # TODO: this either (both) needs tests or (and) there needs to be a clearer "API" # for how the folder names are generated. +# https://github.com/diff-use/sampleworks/issues/121 def parse_experiment_dir(exp_dir: Path) -> dict[str, int | float | None]: """Parse experiment directory name to extract parameters. @@ -75,6 +76,7 @@ def scan_grid_search_results( ) return experiments + # FIXME https://github.com/diff-use/sampleworks/issues/121 # Check if we found a refined.cif file in the current directory refined_cif = current_directory / target_filename if current_depth == target_depth and refined_cif.exists(): diff --git a/src/sampleworks/eval/structure_utils.py b/src/sampleworks/eval/structure_utils.py index bd34f5eb..efbbfd22 100644 --- a/src/sampleworks/eval/structure_utils.py +++ b/src/sampleworks/eval/structure_utils.py @@ -16,6 +16,7 @@ from sampleworks.utils.atom_reconciler import AtomReconciler from sampleworks.utils.framework_utils import match_batch +ATOMWORKS_COMPARISON_OPS = ("==", ">", "<", "<=", ">=", " in ") try: from sampleworks.models.protenix.structure_processing import ( @@ -305,7 +306,7 @@ def apply_selection(atom_array: AtomArray, selection: str | None) -> AtomArray: if selection is None: return atom_array - if not any(x in selection for x in ("==", ">", "<", "<=", ">=", " in ")): + if not any(x in selection for x in ATOMWORKS_COMPARISON_OPS): mask = get_mask_from_old_selection_string(atom_array, selection) else: mask = atom_array.mask(selection) @@ -366,10 +367,10 @@ def extract_selection_coordinates( else: working_array = atom_array - if not any(x in selection for x in ("==", ">", "<", "<=", ">=")): + if not any(x in selection for x in ATOMWORKS_COMPARISON_OPS): mask = get_mask_from_old_selection_string(atom_array, selection) else: - mask = atom_array.mask(selection) + mask = working_array.mask(selection) selected_coords = cast(np.ndarray, working_array.coord)[mask] diff --git a/src/sampleworks/utils/atom_array_utils.py b/src/sampleworks/utils/atom_array_utils.py index 3e38bee1..a0ebab89 100644 --- a/src/sampleworks/utils/atom_array_utils.py +++ b/src/sampleworks/utils/atom_array_utils.py @@ -194,6 +194,11 @@ def find_all_altloc_ids( """ if hasattr(atom_array, altloc_label): altloc_ids = np.unique(getattr(atom_array, altloc_label)) + altloc_values = getattr(atom_array, altloc_label) + if altloc_values is None: + raise AttributeError( + f"atom_array.{altloc_label} exists but is None; cannot find altloc IDs" + ) else: raise AttributeError( "atom_array must have altloc annotation, defaults to 'altloc_id'; " @@ -254,7 +259,7 @@ def map_altlocs_to_stack( If the return_full_array is also True, this selection applies _only_ to the altlocs. return_full_array: bool If True, return the full AtomArrayStack with all atoms, even those with no altlocs. - If False, only return structures with altlocs. + If False, only return the specific atoms with altlocs. Returns: Tuple containing: @@ -293,6 +298,10 @@ def map_altlocs_to_stack( # The available altloc ids might have changed if one or more are missing from the selection. altloc_ids = sorted(list(find_all_altloc_ids(atom_array))) + if len(altloc_ids) == 0: + raise ValueError( + f"No altlocs found in selection '{selection}' for AtomArray with altloc_id" + ) altloc_list = [ select_altloc(atom_array, altloc_id, return_full_array=return_full_array) for altloc_id in altloc_ids diff --git a/src/sampleworks/utils/density_utils.py b/src/sampleworks/utils/density_utils.py index 4241b4e7..93063ace 100644 --- a/src/sampleworks/utils/density_utils.py +++ b/src/sampleworks/utils/density_utils.py @@ -152,7 +152,7 @@ def compute_density_from_atomarray( atom_array, device ) - # need to make sure these all have the same batch dimension. + # need to make sure these all have the same batch dimension, or the transformer will fail. elements = elements.expand(coords.shape[0], -1) b_factors = b_factors.expand(coords.shape[0], -1) occupancies = occupancies.expand(coords.shape[0], -1)