diff --git a/scripts/eval/lddt_evaluation_script.py b/scripts/eval/lddt_evaluation_script.py index 6f0288cc..d0deae9e 100644 --- a/scripts/eval/lddt_evaluation_script.py +++ b/scripts/eval/lddt_evaluation_script.py @@ -1,4 +1,5 @@ import argparse +import itertools import re import sys import traceback @@ -9,9 +10,9 @@ from atomworks.io.transforms.atom_array import ensure_atom_array_stack from atomworks.io.utils.io_utils import load_any from biotite.structure import AtomArray, AtomArrayStack +from joblib import delayed, Parallel from loguru import logger -from sampleworks.eval.constants import OCCUPANCY_LEVELS -from sampleworks.eval.eval_dataclasses import ProteinConfig +from sampleworks.eval.eval_dataclasses import Experiment, ProteinConfig from sampleworks.eval.grid_search_eval_utils import parse_args, scan_grid_search_results from sampleworks.eval.structure_utils import get_reference_atomarraystack from sampleworks.metrics.lddt import AllAtomLDDT @@ -175,7 +176,11 @@ def translate_selection(selection: str) -> str: # current selection strings are pymol like, and we want to convert to atomworks/pandas like # this should be a temporary measure only until we switch to atomworks style in the RSCC script - logger.warning( + if any(x in selection for x in ("==", ">", "<", "<=", ">=", " in ")): + # assume this is already atomworks/pandas style and ignore. + return selection + + DeprecationWarning( "DEPRECATED: translate_selection converts from some pymol-like selection strings to " "AtomWorks selection strings, but is not guaranteed to be correct for all cases." ) @@ -215,8 +220,7 @@ def main(args: argparse.Namespace): logger.info("Pre-loading reference structures for each protein for coordinate extraction") reference_atom_arrays = {} for protein_key, protein_config in protein_configs.items(): - # TODO: should the reference occupancy should be specified in the config? - for occ in OCCUPANCY_LEVELS: + for occ, sel in itertools.product(args.occupancies, protein_config.selection): ref_path, reference_proteins = get_reference_atomarraystack(protein_config, occ) if reference_proteins is None: logger.warning( @@ -227,63 +231,144 @@ def main(args: argparse.Namespace): logger.info( f"Loaded ref structure for {protein_key} and occupancy {occ}: {ref_path}" ) - # ignoring the altloc_id and occupancy arrays - reference_protein_stack, _, _ = map_altlocs_to_stack(reference_proteins) - if reference_proteins is not None: - reference_atom_arrays[(protein_key, occ)] = reference_protein_stack + reference_protein_stack, _, _ = map_altlocs_to_stack( + reference_proteins, selection=translate_selection(sel), return_full_array=True + ) + # hierarchical dictionary cache makes it lighter weight to parallelize. + if (protein_key, occ) not in reference_atom_arrays: + reference_atom_arrays[(protein_key, occ)] = {} + + reference_atom_arrays[(protein_key, occ)][sel] = reference_protein_stack except Exception as e: logger.error( f"Error loading ref structure for {protein_key} and occupancy {occ}: {e}" ) logger.error(f" Traceback: {traceback.format_exc()}") - all_results = [] - for _i, _exp in enumerate(all_experiments): - result = _exp.__dict__.copy() - if _exp.protein not in protein_configs: + # Do the quick pass through all the "rows" of our output table to filter in those we can run. + filtered_experiments = [] + null_results = [] + for _exp in all_experiments: + if _exp.protein in protein_configs: + protein = _exp.protein + elif _exp.protein.upper() in protein_configs: + protein = _exp.protein.upper() + elif _exp.protein.lower() in protein_configs: + protein = _exp.protein.lower() + else: + # These we just skip over--we assume that the user has told us via the config file + # what results they are interested in. logger.warning(f"Skipping protein with no configuration: {_exp.protein}") continue - protein_config = protein_configs[_exp.protein] + protein_config = protein_configs[protein] + if protein_config.protein != protein: + raise ValueError( + f"Protein name mismatch: expected {protein_config.protein}, got {protein}, make" + f"sure you loaded your protein configs with ProteinConfig.from_csv()." + ) - atom_array_key = (_exp.protein, _exp.occ_a) - if atom_array_key not in reference_atom_arrays: + if (protein, _exp.occ_a) not in reference_atom_arrays: logger.warning( f"Skipping {_exp.protein_dir_name}: no reference atom array stack available " - f"for {_exp.protein} and occupancy {_exp.occ_a}." + f"for {_exp.protein}, occupancy {_exp.occ_a}." ) + # record empty results for all selections, indicating they could not be computed. + for _sel in protein_config.selection: + exp_copy = _exp.__dict__.copy() + exp_copy["selection"] = _sel + null_results.append(exp_copy) continue - try: - reference_atom_array_stack = reference_atom_arrays[atom_array_key] - # these shouldn't have altlocs, don't need altloc="all". - predicted_atom_array_stack = load_any(_exp.refined_cif_path) - clustering_results = nn_lddt_clustering( - reference_atom_array_stack, - ensure_atom_array_stack(predicted_atom_array_stack), - translate_selection(protein_config.selection), - ) + protein_reference_atom_arrays = reference_atom_arrays[(protein, _exp.occ_a)] + for _sel in protein_config.selection: + if _sel not in protein_reference_atom_arrays: + logger.warning( + f"Skipping {_exp.protein_dir_name}: no reference atom array stack available " + f"for {_exp.protein}, occupancy {_exp.occ_a} and selection '{_sel}'." + ) + exp_copy = _exp.__dict__.copy() + exp_copy["selection"] = _sel + null_results.append(exp_copy) + continue - result.update( - { - k: clustering_results[k] - for k in ("occupancies", "avg_silhouette", "avg_silhouette_to_ref") - } - ) - except Exception as e: - logger.error(f"Error processing experiment {_exp.exp_dir}: {e}") - logger.error(f" Traceback: {traceback.format_exc()}") - result["error"] = str(e) - result["avg_silhouette"] = np.nan - result["avg_silhouette_to_ref"] = np.nan - result["occupancies"] = [] + px_seln_refernce_atom_array = protein_reference_atom_arrays[_sel] + filtered_experiments.append((_exp, protein_config, px_seln_refernce_atom_array, _sel)) - all_results.append(result) + # now we can more easily parallelize this loop. + logger.debug("Starting LDDT evaluation loop. This may take a while...") + all_results = Parallel(n_jobs=args.n_jobs)( + delayed(process_exp_with_selection)( + _exp, protein_config, px_seln_refernce_atom_array, selection + ) + for _exp, protein_config, px_seln_refernce_atom_array, selection in filtered_experiments + ) - df = pd.DataFrame(all_results) + df = pd.DataFrame(null_results + all_results) df.to_csv(grid_search_dir / "lddt_results.csv", index=False) +def process_exp_with_selection( + exp: Experiment, + protein_config: ProteinConfig, + px_seln_refernce_atom_array: AtomArrayStack, + selection_string: str, +) -> dict[str, str | float | list[float]]: + """ + + Parameters + ---------- + exp: Experiment, a description of the structure generation experiment + protein_config: ProteinConfig, specifying the locations of reference structures and maps + px_seln_refernce_atom_array: AtomArrayStack, + the atom array stack for the reference structure, which in principle could be fetched + using the protein_config, but for efficiency we load once previously and pass in here, + since this method will run many times in parallel using the same structure + selection_string: str, the selection string for the evaluation + + Returns + ------- + A dictionary of results of LDDT-based clustering that can be collated in a dataframe. + In addition to the data in the `exp` object, this dictionary contains: + - occupancies: list[float], the occupancies of the selected atoms, computed as the + fraction of structures in the experiment that are closest to one or the other + altloc of the reference structure. . + - avg_silhouette: float, the average silhouette score for the LDDT-based clustering + - avg_silhouette_to_ref: float, + a sort of silhouette score, measuring each structure's relative "closeness" to the + assigned reference altloc. + """ + logger.debug(f"Evaluating selection {selection_string} for protein {protein_config}") + result = exp.__dict__.copy() + result["selection"] = selection_string + + try: + # generated structures shouldn't have altlocs, don't need altloc="all". + predicted_atom_array_stack = load_any(exp.refined_cif_path) + clustering_results = nn_lddt_clustering( + px_seln_refernce_atom_array, + ensure_atom_array_stack(predicted_atom_array_stack), + translate_selection(selection_string), + ) + + lddt_result_keys = ("occupancies", "avg_silhouette", "avg_silhouette_to_ref") + result.update({k: clustering_results[k] for k in lddt_result_keys}) + + logger.info( + f"Successfully processed {exp.protein_dir_name} w/ selection {selection_string}" + ) + + except Exception as e: + logger.error(f"Error processing experiment {exp.exp_dir}: {e}") + logger.error(f" Traceback: {traceback.format_exc()}") + result["error"] = str(e) + result["avg_silhouette"] = np.nan + result["avg_silhouette_to_ref"] = np.nan + result["occupancies"] = [] + + return result + + if __name__ == "__main__": args = parse_args("Evaluate LDDT on grid search results.") main(args) diff --git a/scripts/eval/rscc_grid_search_script.py b/scripts/eval/rscc_grid_search_script.py index 3a712c24..d9c77bfd 100644 --- a/scripts/eval/rscc_grid_search_script.py +++ b/scripts/eval/rscc_grid_search_script.py @@ -27,7 +27,7 @@ from atomworks.io.parser import parse from loguru import logger from sampleworks.core.forward_models.xray.real_space_density_deps.qfit.volume import XMap -from sampleworks.eval.constants import DEFAULT_SELECTION_PADDING, OCCUPANCY_LEVELS +from sampleworks.eval.constants import DEFAULT_SELECTION_PADDING from sampleworks.eval.eval_dataclasses import ProteinConfig from sampleworks.eval.grid_search_eval_utils import parse_args, scan_grid_search_results from sampleworks.eval.metrics import rscc @@ -38,9 +38,10 @@ from sampleworks.utils.density_utils import compute_density_from_atomarray +# TODO consolidate eval script logic: https://github.com/diff-use/sampleworks/issues/93 def main(args: argparse.Namespace): workspace_root = Path(args.workspace_root) - grid_search_dir = workspace_root / "grid_search_results" # TODO make more general + grid_search_dir = workspace_root / "grid_search_results" # Protein configurations: base map paths, structure selections, and resolutions protein_inputs_dir = args.grid_search_inputs_path or workspace_root @@ -52,13 +53,15 @@ def main(args: argparse.Namespace): # Test base map path resolution logger.debug("Testing base map path resolution:") for _, config in protein_configs.items(): - for _occ in OCCUPANCY_LEVELS: # TODO make configurable + for _occ in args.occupancies: _path = config.get_base_map_path_for_occupancy(_occ) # will warn if not found if _path: logger.debug(f" {config.protein} occ={_occ}: {_path}") # Scan for experiments (look for refined.cif files) - all_experiments = scan_grid_search_results(grid_search_dir) + all_experiments = scan_grid_search_results( + grid_search_dir, target_filename=args.target_filename + ) logger.info(f"Found {len(all_experiments)} experiments with refined.cif files") if all_experiments: @@ -67,14 +70,12 @@ def main(args: argparse.Namespace): logger.info("Pre-loading reference structures for each protein for coordinate extraction") ref_coords = {} for protein_key, protein_config in protein_configs.items(): - # NOTE THAT THIS will be by default _two_ structures, one computed for occupancy 0 and - # one for occupancy 1 of altloc A. Historically this is because these coordinates are - # used here only as a mask for map comparisons. - # TODO: change that method to return the coordinates for occupancy 0 and 1 separately, - # and then we can merge them here. + # NOTE THAT THIS will be by default include all altlocs, as we use them to create a mask + # for where to judge the maps' correlation. protein_ref_coords = get_reference_structure_coords(protein_config, protein_key) if protein_ref_coords is not None: - ref_coords[protein_key] = protein_ref_coords + for selection in protein_ref_coords.keys(): + ref_coords[(protein_key, selection)] = protein_ref_coords[selection] # Calculate RSCC for all experiments # (BIG) TODO: implement a sliding-window version (global can be achieved with diff't selections. @@ -87,91 +88,111 @@ def main(args: argparse.Namespace): logger.info(f"Using device: {_device}") results = [] - base_map_cache: dict[tuple[str, float], tuple[XMap, XMap]] = {} + base_map_cache: dict[tuple[str, float, str], tuple[XMap, XMap]] = {} # TODO parallelize this loop? It uses GPU, so be careful. for _i, _exp in enumerate(all_experiments): - if _exp.protein not in protein_configs: + if _exp.protein in protein_configs: + protein = _exp.protein + elif _exp.protein.upper() in protein_configs: + protein = _exp.protein.upper() + else: logger.warning(f"Skipping protein with no configuration: {_exp.protein}") continue - protein_config = protein_configs[_exp.protein] + protein_config = protein_configs[protein] + for selection in protein_config.selection: + # Check if we have reference coordinates for region extraction + if (protein, selection) not in ref_coords: + logger.warning( + f"Skipping {_exp.protein_dir_name}/{selection}: no reference structure " + f"available for {_exp.protein}, this may be due to a selection with zero atoms " + f"or NaN/Inf coordinates. Check logs above." + ) + continue + + _selection_coords = ref_coords[(protein, selection)] + _base_map_path = protein_config.get_base_map_path_for_occupancy(_exp.occ_a) + if _base_map_path is None: + logger.warning( + f"Skipping {_exp.protein_dir_name}: base map for selection {selection} and " + f"occupancy {_exp.occ_a} not found" + ) + continue + + exp_result = _exp.__dict__.copy() + exp_result["selection"] = selection + exp_result["error"] = None + try: + # TODO: this needs to be better unified with what's in generate_synthetic_density + # + # Load base map for canonical unit cell, + # don't overwrite the base map with selection map--we'll use the full map later too. + if (protein, _exp.occ_a, selection) not in base_map_cache: + _base_xmap = protein_config.load_map(_base_map_path) + if _base_xmap is None: + raise ValueError(f"Failed to load base map from {_base_map_path}") + + # Extract the region around altloc residues from the base map, using the + # union of boxes around each atom. _extracted_base is no longer an XMap + _, _extracted_base = _base_xmap.extract_tight( + _selection_coords, padding=DEFAULT_SELECTION_PADDING + ) + logger.info( + f"Caching base and subselected maps for {protein} " + f"occ_a={_exp.occ_a} selection={selection}" + ) + base_map_cache[(protein, _exp.occ_a, selection)] = (_base_xmap, _extracted_base) + else: + _base_xmap, _extracted_base = base_map_cache[(protein, _exp.occ_a, selection)] - # Check if we have reference coordinates for region extraction - if _exp.protein not in ref_coords: - logger.warning( - f"Skipping {_exp.protein_dir_name}: no reference structure available " - f"for {_exp.protein}, this may be due to a selection with zero atoms " - f"or NaN/Inf coordinates. Check logs above." - ) - continue + # Validate extraction + if _extracted_base is None or _extracted_base.shape[0] == 0: + raise ValueError(f"Extracted base map from {_base_map_path} is empty") - _selection_coords = ref_coords[_exp.protein] - _base_map_path = protein_config.get_base_map_path_for_occupancy(_exp.occ_a) - if _base_map_path is None: - logger.warning( - f"Skipping {_exp.protein_dir_name}: base map for occupancy {_exp.occ_a} not found" - ) - continue + # Load refined structure + _structure = parse(_exp.refined_cif_path, ccd_mirror_path=None) - try: - # Load base map for canonical unit cell, - # don't overwrite the base map with selection map as we'll use the full map later too. - if (_exp.protein, _exp.occ_a) not in base_map_cache: - _base_xmap = protein_config.load_map(_base_map_path) - if _base_xmap is None: - raise ValueError(f"Failed to load base map from {_base_map_path}") - - # Extract the region around altloc residues from the base map, using the - # union of boxes around each atom. _extracted_base is no longer an XMap - _, _extracted_base = _base_xmap.extract_tight( - _selection_coords, padding=DEFAULT_SELECTION_PADDING - ) - logger.info( - f"Caching base and subselected maps for {_exp.protein} occ_a={_exp.occ_a}" - ) - base_map_cache[(_exp.protein, _exp.occ_a)] = (_base_xmap, _extracted_base) - else: - _base_xmap, _extracted_base = base_map_cache[(_exp.protein, _exp.occ_a)] + # Compute density from refined structure + atom_array = get_asym_unit_from_structure(_structure) + if not hasattr(atom_array, "coord") or atom_array.coord is None: + raise AttributeError("AtomArray | AtomArrayStack is missing coordinates") - # Validate extraction - if _extracted_base is None or _extracted_base.shape[0] == 0: - raise ValueError(f"Extracted base map from {_base_map_path} is empty") + if not hasattr(atom_array, "b_factor"): + logger.warning( + f"No b-factor array found in {_exp.refined_cif_path}, setting to 20." + ) + atom_array.set_annotation("b_factor", np.full(atom_array.coord.shape[-2], 20.0)) - # Load refined structure - _structure = parse(_exp.refined_cif_path, ccd_mirror_path=None) + _computed_density, _ = compute_density_from_atomarray( + atom_array, xmap=_base_xmap, em_mode=False, device=_device + ) - # Compute density from refined structure - atom_array = get_asym_unit_from_structure(_structure) - _computed_density, _ = compute_density_from_atomarray( - atom_array, xmap=_base_xmap, em_mode=False, device=_device - ) + # Create an XMap from the computed density by copying the base xmap + # and replacing its array with the computed density + _computed_xmap = copy.deepcopy(_base_xmap) + _computed_xmap.array = _computed_density.cpu().numpy().squeeze() + _, _extracted_computed = _computed_xmap.extract_tight( + _selection_coords, padding=DEFAULT_SELECTION_PADDING + ) - # Create an XMap from the computed density by copying the base xmap - # and replacing its array with the computed density - _computed_xmap = copy.deepcopy(_base_xmap) - _computed_xmap.array = _computed_density.cpu().numpy().squeeze() - _, _extracted_computed = _computed_xmap.extract_tight( - _selection_coords, padding=DEFAULT_SELECTION_PADDING - ) + # Validate extraction + if _extracted_computed is None or _extracted_computed.shape[0] == 0: + raise ValueError("Extracted computed map is empty") - # Validate extraction - if _extracted_computed is None or _extracted_computed.shape[0] == 0: - raise ValueError("Extracted computed map is empty") + # Calculate RSCC on extracted regions + # TODO: don't alter the input object _exp, just get a copy of it as a dict. + exp_result.rscc = rscc(_extracted_base, _extracted_computed) + exp_result.base_map_path = _base_map_path - # Calculate RSCC on extracted regions - _exp.rscc = rscc(_extracted_base, _extracted_computed) - _exp.base_map_path = _base_map_path + except Exception as _e: + logger.error(f"ERROR processing {_exp.exp_dir}: {_e}") + logger.error(f" Traceback: {traceback.format_exc()}") + exp_result["error"] = _e + exp_result["rscc"] = np.nan # this is the default, but better to be explicit. + exp_result["base_map_path"] = _base_map_path - except Exception as _e: - logger.error(f"ERROR processing {_exp.exp_dir}: {_e}") - logger.error(f" Traceback: {traceback.format_exc()}") - _exp.error = _e - _exp.rscc = np.nan # this is the default, but better to be explicit. - _exp.base_map_path = _base_map_path + results.append(exp_result) - exp_dict_copy = _exp.__dict__.copy() - exp_dict_copy["selection"] = protein_config.selection - results.append(exp_dict_copy) if (_i + 1) % 10 == 0 or _i == 0: logger.debug( f" [{_i + 1}/{len(all_experiments)}] {_exp.protein_dir_name} / " @@ -206,94 +227,6 @@ def main(args: argparse.Namespace): ) logger.info(summary) - # Calculate correlation between base maps and pure conformer maps - logger.info("Calculating correlations between base maps and pure conformer maps...") - logger.info("This shows how well single conformers explain occupancy-mixed data") - - base_pure_correlations = [] - - for protein_key, protein_config in protein_configs.items(): - if protein_key not in ref_coords: - print(f"Skipping {protein_key}: no reference coordinates available") - continue - - # We re-use the selection coordinates from the reference structure computed - # at 0.5 occupancy above. - _selection_coords = ref_coords[protein_key] - - logger.info(f"\nProcessing {protein_key} single conformer explanatory power:") - map_path_1occA = protein_config.get_base_map_path_for_occupancy(1.0) - map_path_1occB = protein_config.get_base_map_path_for_occupancy(0.0) - if map_path_1occA is None or map_path_1occB is None: - logger.warning(f"Skipping {protein_key}: pure conformer maps not found") - continue - try: - # Load pure conformer maps--returns canonical unit cell by default, - # extract selection with padding 0.0 - _extracted_pure_A = protein_config.load_map( - map_path_1occA, selection_coords=_selection_coords - ) - _extracted_pure_B = protein_config.load_map( - map_path_1occB, selection_coords=_selection_coords - ) - - logger.info( - f" Pure A reference: {map_path_1occA}\n Pure B reference: {map_path_1occB}" - ) - - # Calculate correlations for each occupancy - for _occ_a in OCCUPANCY_LEVELS: # TODO make configurable - try: - _base_map_path = protein_config.get_base_map_path_for_occupancy(_occ_a) - if _base_map_path is None: # map file not found, will warn. - continue - - logger.info(f" Processing occ_A={_occ_a}: {_base_map_path.name}") - - # Load the base map for this occupancy, and do the selection - # default padding is zero. - _extracted_base = protein_config.load_map( - _base_map_path, selection_coords=_selection_coords - ) - - if ( - _extracted_base is None - or _extracted_pure_A is None - or _extracted_pure_B is None - ): - raise ValueError("One of the extracted maps is empty") - - # Calculate correlations - _corr_base_vs_pureA = rscc(_extracted_base.array, _extracted_pure_A.array) - _corr_base_vs_pureB = rscc(_extracted_base.array, _extracted_pure_B.array) - - base_pure_correlations.append( - { - "protein": protein_key, - "occ_a": _occ_a, - "base_vs_1occA": _corr_base_vs_pureA, - "base_vs_1occB": _corr_base_vs_pureB, - } - ) - - logger.info(f" Base map vs pure A: {_corr_base_vs_pureA:.4f}") - logger.info(f" Base map vs pure B: {_corr_base_vs_pureB:.4f}") - - except Exception as _e: - logger.error(f" Error processing occ_A={_occ_a} for {protein_key}: {_e}") - logger.error(f" Traceback: {traceback.format_exc()}") - - except Exception as _e: - logger.error(f"Error calculating correlations for {protein_key}: {_e}") - logger.error(f" Traceback: {traceback.format_exc()}") - - df_base_vs_pure = pd.DataFrame(base_pure_correlations) - df.to_csv(grid_search_dir / "rscc_results_for_pure_conformer_maps.csv", index=False) - logger.info( - f"\nCalculated single conformer explanatory power for " - f"{len(df_base_vs_pure)} occupancy points" - ) - if __name__ == "__main__": args = parse_args("Evaluate RSCC on grid search results.") diff --git a/src/sampleworks/core/rewards/real_space_density.py b/src/sampleworks/core/rewards/real_space_density.py index 2da73c1f..c5351964 100644 --- a/src/sampleworks/core/rewards/real_space_density.py +++ b/src/sampleworks/core/rewards/real_space_density.py @@ -43,7 +43,7 @@ def setup_scattering_params( containing scattering coefficients for each element type """ elements = atom_array.element - unique_elements = sorted(set(normalize_element(e) for e in elements)) # ty: ignore[not-iterable] + unique_elements = sorted(set(normalize_element(e) for e in elements)) atomic_num_dict = {elem: ELEMENT_TO_ATOMIC_NUM[elem] for elem in unique_elements} structure_factors = ELECTRON_SCATTERING_FACTORS if em_mode else ATOM_STRUCTURE_FACTORS @@ -98,6 +98,7 @@ def extract_density_inputs_from_atomarray( ``(1, n_atoms)``. For an ``AtomArrayStack`` with *M* models the batch dimension is *M* instead of 1. """ + is_stack = isinstance(atom_array_or_stack, AtomArrayStack) coords = cast(np.ndarray[Any, np.dtype[np.float64]], atom_array_or_stack.coord) @@ -130,6 +131,7 @@ def extract_density_inputs_from_atomarray( f"({n_valid}/{n_total} atoms remaining)" ) + # TODO can we use [..., valid_mask] here? if is_stack: valid_coords = coords[:, valid_mask] else: diff --git a/src/sampleworks/eval/eval_dataclasses.py b/src/sampleworks/eval/eval_dataclasses.py index db1f2654..e2db5b08 100644 --- a/src/sampleworks/eval/eval_dataclasses.py +++ b/src/sampleworks/eval/eval_dataclasses.py @@ -9,6 +9,8 @@ @dataclass +# TODO rename to make consistend w/ hub.diffuse.science +# https://github.com/diff-use/sampleworks/issues/122 class Experiment: protein: str occ_a: float @@ -42,7 +44,7 @@ class ProteinConfig: protein: str base_map_dir: Path - selection: str + selection: list[str] resolution: float map_pattern: str structure_pattern: str = "" @@ -50,6 +52,13 @@ class ProteinConfig: def __post_init__(self): # TODO validate structure patterns? Anything else we should check to avoid later errors? self.base_map_dir = Path(self.base_map_dir) # just in case someone passes a string + self.selection = [s.strip() for s in self.selection if self.is_selection_valid(s)] + if not self.selection: + raise ValueError(f"No valid selection provided for protein {self.protein}") + + def is_selection_valid(self, selection: str) -> bool: + return selection is not None and selection.strip() != "" + def get_base_map_path_for_occupancy(self, occupancy_a: float) -> Path | None: occ_str = occupancy_to_str(occupancy_a, use_6b8x_format=self.protein == "6b8x") @@ -190,7 +199,7 @@ def from_csv(cls, workspace_root: Path, csv_path: Path) -> dict[str, "ProteinCon config = cls( protein=protein, base_map_dir=base_map_dir, - selection=row["selection"].strip(), + selection=row["selection"].split(";"), resolution=resolution, map_pattern=row["map_pattern"].strip(), structure_pattern=structure_pattern, diff --git a/src/sampleworks/eval/grid_search_eval_utils.py b/src/sampleworks/eval/grid_search_eval_utils.py index cec2fd5a..89dae772 100644 --- a/src/sampleworks/eval/grid_search_eval_utils.py +++ b/src/sampleworks/eval/grid_search_eval_utils.py @@ -17,6 +17,7 @@ # TODO: this either (both) needs tests or (and) there needs to be a clearer "API" # for how the folder names are generated. +# https://github.com/diff-use/sampleworks/issues/121 def parse_experiment_dir(exp_dir: Path) -> dict[str, int | float | None]: """Parse experiment directory name to extract parameters. @@ -75,6 +76,7 @@ def scan_grid_search_results( ) return experiments + # FIXME https://github.com/diff-use/sampleworks/issues/121 # Check if we found a refined.cif file in the current directory refined_cif = current_directory / target_filename if current_depth == target_depth and refined_cif.exists(): @@ -183,7 +185,18 @@ def parse_args(description: str | None = None): "--occupancies", nargs="+", type=float, - help="Occupancies to evaluate", + help=f"Occupancies to evaluate, defaults to {OCCUPANCY_LEVELS}", default=OCCUPANCY_LEVELS, ) + parser.add_argument( + "--target-filename", + default="refined.cif", + help="Target filename for the CIF files to process, defaults to 'refined.cif'", + ) + parser.add_argument( + "--n-jobs", + type=int, + help="Number of parallel jobs to run. -1 uses all CPUs.", + default=16, + ) return parser.parse_args() diff --git a/src/sampleworks/eval/structure_utils.py b/src/sampleworks/eval/structure_utils.py index fab3c7fa..efbbfd22 100644 --- a/src/sampleworks/eval/structure_utils.py +++ b/src/sampleworks/eval/structure_utils.py @@ -2,7 +2,7 @@ import traceback from dataclasses import dataclass, replace from pathlib import Path -from typing import cast +from typing import Any, cast import numpy as np import torch @@ -16,6 +16,7 @@ from sampleworks.utils.atom_reconciler import AtomReconciler from sampleworks.utils.framework_utils import match_batch +ATOMWORKS_COMPARISON_OPS = ("==", ">", "<", "<=", ">=", " in ") try: from sampleworks.models.protenix.structure_processing import ( @@ -243,6 +244,8 @@ def process_structure_to_trajectory_input( ) +# TODO: migrate this to atomworks's selection algebra that is added to AtomArray/Stack +# https://github.com/diff-use/sampleworks/issues/56 def parse_selection_string(selection: str) -> tuple[str | None, int | None, int | None]: """Parse a selection string like 'chain A and resi 326-339'. @@ -303,8 +306,24 @@ def apply_selection(atom_array: AtomArray, selection: str | None) -> AtomArray: if selection is None: return atom_array + if not any(x in selection for x in ATOMWORKS_COMPARISON_OPS): + mask = get_mask_from_old_selection_string(atom_array, selection) + else: + mask = atom_array.mask(selection) + + return cast(AtomArray, atom_array[mask]) + + +def get_mask_from_old_selection_string( + atom_array: AtomArray | AtomArrayStack, selection: str +) -> np.ndarray[tuple[int], np.dtype[Any]]: + DeprecationWarning( + f"Using old-style selection strings like {selection} is deprecated." + f" Use atomworks/pandas style selection strings instead." + ) chain_id, resi_start, resi_end = parse_selection_string(selection) - mask = np.ones(len(atom_array), dtype=bool) + # use the length of any of the required non-coord attributes to get the mask shape + mask = np.ones(len(atom_array.res_id), dtype=bool) if chain_id is not None: mask &= atom_array.chain_id == chain_id @@ -318,8 +337,7 @@ def apply_selection(atom_array: AtomArray, selection: str | None) -> AtomArray: if mask.sum() == 0: raise ValueError(f"Selection '{selection}' matched no atoms") - - return cast(AtomArray, atom_array[mask]) + return mask def extract_selection_coordinates( @@ -349,24 +367,10 @@ def extract_selection_coordinates( else: working_array = atom_array - chain_id, resi_start, resi_end = parse_selection_string(selection) - - # Create the selection mask, don't rely on len(atom_array) in case it is the ensemble size - mask = np.ones(len(working_array), dtype=bool) - - if chain_id is not None: - mask &= working_array.chain_id == chain_id - - if resi_start is not None: - res_ids = cast(np.ndarray, working_array.res_id) - if resi_end is not None: - # Explicitly check for None to satisfy pyright - start: int = resi_start - end: int = resi_end - mask &= (res_ids >= start) & (res_ids <= end) - else: - start = resi_start - mask &= res_ids == start + if not any(x in selection for x in ATOMWORKS_COMPARISON_OPS): + mask = get_mask_from_old_selection_string(atom_array, selection) + else: + mask = working_array.mask(selection) selected_coords = cast(np.ndarray, working_array.coord)[mask] @@ -374,7 +378,6 @@ def extract_selection_coordinates( if len(selected_coords) == 0: raise RuntimeError( f"No atoms matched selection: '{selection}'. " - f"Chain ID: {chain_id}, Residue range: {resi_start}-{resi_end}. " f"Total atoms in structure: {len(atom_array)}" ) @@ -431,45 +434,43 @@ def get_reference_atomarraystack( def get_reference_structure_coords( protein_config: ProteinConfig, protein_key: str, occ_list: tuple[float, ...] = (0.0, 1.0) -) -> np.ndarray | None: +) -> dict[str, np.ndarray] | None: """ This has a slightly odd function, which is to output an array of all possible coordinates of a structure, with altlocs mixed in. It returns NO information about which atom is which or whether there are duplicates. It's used for masking density maps. """ - protein_ref_coords_list = [] + protein_ref_coords_list = {selection: [] for selection in protein_config.selection} for occ in occ_list: ref_path, ref_struct = get_reference_atomarraystack(protein_config, occ) if ref_path and ref_struct: # if not None, it is already a validated Path object - try: - # TODO: enumerate actual exceptions this can raise. - coords = extract_selection_coordinates(ref_struct, protein_config.selection) - if not len(coords): - logger.warning( - f" No atoms in selection '{protein_config.selection}' for {protein_key}" - ) - elif not np.isfinite(coords).all(): - logger.warning( - f" NaN/Inf coordinates in selection " - f"'{protein_config.selection}' for {protein_key}" + for selection in protein_config.selection: + try: + # TODO: enumerate actual exceptions this can raise. + coords = extract_selection_coordinates(ref_struct, selection) + if not len(coords): + logger.warning(f" No atoms in selection '{selection}' for {protein_key}") + elif not np.isfinite(coords).all(): + logger.warning( + f" NaN/Inf coordinates in selection '{selection}' for {protein_key}" + ) + else: + protein_ref_coords_list[selection].append(coords) + logger.info( + f" Loaded reference structure for {protein_key}: " + f"{len(coords)} atoms in selection '{selection}'" + ) + except Exception as _e: + _selection = selection if selection else "(none)" + logger.error( + f" ERROR: Failed to load reference structure for {protein_key}: {_e}\n" + f" Path: {ref_path}\n" + f" Selection: {_selection}\n" + f" Traceback: {traceback.format_exc()}" ) - else: - protein_ref_coords_list.append(coords) - logger.info( - f" Loaded reference structure for {protein_key}: " - f"{len(coords)} atoms in selection '{protein_config.selection}'" - ) - except Exception as _e: - _selection = protein_config.selection if protein_config.selection else "(none)" - logger.error( - f" ERROR: Failed to load reference structure for {protein_key}: {_e}\n" - f" Path: {ref_path}\n" - f" Selection: {_selection}\n" - f" Traceback: {traceback.format_exc()}" - ) - - if not protein_ref_coords_list: - logger.error(f"No reference structures found for {protein_key}") - return None - - return np.vstack(protein_ref_coords_list) + + return { + k: np.vstack(protein_ref_coords_list[k]) + for k in protein_ref_coords_list + if protein_ref_coords_list[k] + } diff --git a/src/sampleworks/utils/atom_array_utils.py b/src/sampleworks/utils/atom_array_utils.py index e8ca4c01..a0ebab89 100644 --- a/src/sampleworks/utils/atom_array_utils.py +++ b/src/sampleworks/utils/atom_array_utils.py @@ -186,14 +186,24 @@ def save_structure_to_cif( return output_path -def find_all_altloc_ids(atom_array: AtomArray | AtomArrayStack) -> set[str]: +def find_all_altloc_ids( + atom_array: AtomArray | AtomArrayStack, altloc_label: str = "altloc_id" +) -> set[str]: """ Find all unique alternate location indicator (altloc) IDs in an AtomArray or AtomArrayStack. """ - if hasattr(atom_array, "altloc_id"): - altloc_ids = np.unique(atom_array.altloc_id) + if hasattr(atom_array, altloc_label): + altloc_ids = np.unique(getattr(atom_array, altloc_label)) + altloc_values = getattr(atom_array, altloc_label) + if altloc_values is None: + raise AttributeError( + f"atom_array.{altloc_label} exists but is None; cannot find altloc IDs" + ) else: - raise AttributeError("atom_array must have `altloc_id` annotation") + raise AttributeError( + "atom_array must have altloc annotation, defaults to 'altloc_id'; " + f"you provided {altloc_label}" + ) return set(altloc_ids.tolist()) - BLANK_ALTLOC_IDS @@ -234,12 +244,22 @@ def detect_altlocs(atom_array: AtomArray) -> AltlocInfo: def map_altlocs_to_stack( atom_array: AtomArray | AtomArrayStack, + selection: str | None = None, + return_full_array: bool = True, ) -> tuple[AtomArrayStack, np.ndarray, np.ndarray]: """ Map alternate location indicators (altloc) to separate structures in a new AtomArrayStack. - Note: this will take _only_ the first structure if you pass an AtomArrayStack. It will raise - an error if there is more than one structure in the input AtomArrayStack. + Note: This will raise an error if you pass an AtomArrayStack containing multiple structures + + Parameters: + atom_array: AtomArray or AtomArrayStack to map altlocs from. + selection: str + Optional selection string to apply to the atom array before mapping altlocs. + If the return_full_array is also True, this selection applies _only_ to the altlocs. + return_full_array: bool + If True, return the full AtomArrayStack with all atoms, even those with no altlocs. + If False, only return the specific atoms with altlocs. Returns: Tuple containing: @@ -253,11 +273,42 @@ def map_altlocs_to_stack( if len(atom_array) > 1: raise ValueError("Cannot map altlocs with multiple structures each containing altlocs") atom_array = atom_array[0] + + if not hasattr(atom_array, "altloc_id") or atom_array.altloc_id is None: + raise ValueError("The passed AtomArray | AtomArrayStack must have attribute 'altloc_id'") + + if not hasattr(atom_array, "mask") or not hasattr(atom_array, "query"): + raise ValueError( + "map_altlocs_to_stack requires atom_arrays loaded with " + "atomworks.io.utils.io_utils.load_any or a similar method that provides .mask, .query" + ) + altloc_ids = sorted(list(find_all_altloc_ids(atom_array))) + if selection is not None and return_full_array: + # in this case, we select atoms with no altlocs, and atoms in the selection + # that do have altlocs construct the query (note np.isin doesn't seem to list + # sets, hence conversion to list. + no_altloc_mask = np.isin(atom_array.altloc_id, list(BLANK_ALTLOC_IDS)) + altloc_mask = np.isin(atom_array.altloc_id, altloc_ids) + selection_mask = atom_array.mask(selection) + mask = np.logical_or(no_altloc_mask, np.logical_and(altloc_mask, selection_mask)) + atom_array = atom_array[mask] + elif selection is not None: + atom_array = atom_array.query(selection) + + # The available altloc ids might have changed if one or more are missing from the selection. + altloc_ids = sorted(list(find_all_altloc_ids(atom_array))) + if len(altloc_ids) == 0: + raise ValueError( + f"No altlocs found in selection '{selection}' for AtomArray with altloc_id" + ) altloc_list = [ - select_altloc(atom_array, altloc_id, return_full_array=True) for altloc_id in altloc_ids + select_altloc(atom_array, altloc_id, return_full_array=return_full_array) + for altloc_id in altloc_ids ] # ensure that each structure has the same number of atoms + # it's critical that we've updated the altloc id list above, or this will return only + # residues with no altlocs in case one or more altloc ids is/are missing from the selection. atom_arrays = filter_to_common_atoms(*altloc_list) altloc_ids = np.vstack([r.altloc_id for r in atom_arrays]) occupancies = np.vstack([r.occupancy for r in atom_arrays]) diff --git a/src/sampleworks/utils/density_utils.py b/src/sampleworks/utils/density_utils.py index cdbe4b87..93063ace 100644 --- a/src/sampleworks/utils/density_utils.py +++ b/src/sampleworks/utils/density_utils.py @@ -152,7 +152,7 @@ def compute_density_from_atomarray( atom_array, device ) - # need to make sure these all have the same batch dimension or the transformer will fail. + # need to make sure these all have the same batch dimension, or the transformer will fail. elements = elements.expand(coords.shape[0], -1) b_factors = b_factors.expand(coords.shape[0], -1) occupancies = occupancies.expand(coords.shape[0], -1)