Skip to content
32 changes: 20 additions & 12 deletions scripts/eval/lddt_evaluation_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,27 +221,33 @@ def main(args: argparse.Namespace):
reference_atom_arrays = {}
for protein_key, protein_config in protein_configs.items():
for occ, sel in itertools.product(args.occupancies, protein_config.selection):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason not to use itertools.product? I used it to facilitate possible future parallelization with joblib.

ref_path, reference_proteins = get_reference_atomarraystack(protein_config, occ)
altloc_occ = {"A": occ, "B": 1.0 - occ}
Comment thread
marcuscollins marked this conversation as resolved.
occ_key = tuple(sorted((k, v) for k, v in altloc_occ.items() if abs(v) > 1e-6))
ref_path, reference_proteins = get_reference_atomarraystack(protein_config, altloc_occ)
if reference_proteins is None:
logger.warning(
f"Could not find ref structure for {protein_key} and occupancy {occ}"
f"Could not find ref structure for {protein_key} and occupancies {altloc_occ}"
)
continue
try:
logger.info(
f"Loaded ref structure for {protein_key} and occupancy {occ}: {ref_path}"
f"Loaded ref structure for {protein_key} "
f"and occupancies {altloc_occ}: {ref_path}"
)
reference_protein_stack, _, _ = map_altlocs_to_stack(
reference_proteins, selection=translate_selection(sel), return_full_array=True
reference_proteins,
selection=translate_selection(sel),
return_full_array=True,
)
# hierarchical dictionary cache makes it lighter weight to parallelize.
if (protein_key, occ) not in reference_atom_arrays:
reference_atom_arrays[(protein_key, occ)] = {}
if (protein_key, occ_key) not in reference_atom_arrays:
reference_atom_arrays[(protein_key, occ_key)] = {}

reference_atom_arrays[(protein_key, occ)][sel] = reference_protein_stack
reference_atom_arrays[(protein_key, occ_key)][sel] = reference_protein_stack
except Exception as e:
logger.error(
f"Error loading ref structure for {protein_key} and occupancy {occ}: {e}"
f"Error loading ref structure for {protein_key} "
f"and occupancies {altloc_occ}: {e}"
)
logger.error(f" Traceback: {traceback.format_exc()}")

Expand All @@ -268,10 +274,11 @@ def main(args: argparse.Namespace):
f"sure you loaded your protein configs with ProteinConfig.from_csv()."
)

if (protein, _exp.occ_a) not in reference_atom_arrays:
atom_array_key = (protein, _exp.occ_key)
if atom_array_key not in reference_atom_arrays:
logger.warning(
f"Skipping {_exp.protein_dir_name}: no reference atom array stack available "
f"for {_exp.protein}, occupancy {_exp.occ_a}."
f"for {_exp.protein} and occupancies {_exp.altloc_occupancies}."
)
# record empty results for all selections, indicating they could not be computed.
for _sel in protein_config.selection:
Expand All @@ -280,12 +287,13 @@ def main(args: argparse.Namespace):
null_results.append(exp_copy)
continue

protein_reference_atom_arrays = reference_atom_arrays[(protein, _exp.occ_a)]
protein_reference_atom_arrays = reference_atom_arrays[atom_array_key]
for _sel in protein_config.selection:
if _sel not in protein_reference_atom_arrays:
logger.warning(
f"Skipping {_exp.protein_dir_name}: no reference atom array stack available "
f"for {_exp.protein}, occupancy {_exp.occ_a} and selection '{_sel}'."
f"for {_exp.protein} and occupancies {_exp.altloc_occupancies}, "
f"selection '{_sel}'."
)
exp_copy = _exp.__dict__.copy()
exp_copy["selection"] = _sel
Expand Down
36 changes: 21 additions & 15 deletions scripts/eval/rscc_grid_search_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from sampleworks.utils.framework_utils import match_batch



# TODO consolidate eval script logic: https://github.com/diff-use/sampleworks/issues/93
def main(args: argparse.Namespace):
workspace_root = Path(args.workspace_root)
Expand Down Expand Up @@ -99,8 +98,8 @@ def main(args: argparse.Namespace):
logger.info(f"Using device: {_device}")

results = []
base_map_cache: dict[tuple[str, float, str], tuple[XMap, XMap]] = {}
ref_full_structure_cache: dict[tuple[str, float], AtomArrayStack] = {}
base_map_cache: dict[tuple[str, tuple[tuple[str, float], ...], str], tuple[XMap, XMap]] = {}
ref_full_structure_cache: dict[tuple[str, tuple[tuple[str, float], ...]], AtomArrayStack] = {}
# TODO parallelize this loop? It uses GPU, so be careful.
for _i, _exp in enumerate(all_experiments):
if _exp.protein in protein_configs:
Expand All @@ -123,11 +122,11 @@ def main(args: argparse.Namespace):
continue

_selection_coords = ref_coords[(protein, selection)]
_base_map_path = protein_config.get_base_map_path_for_occupancy(_exp.occ_a)
_base_map_path = protein_config.get_base_map_path_for_occupancy(_exp.altloc_occupancies)
if _base_map_path is None:
logger.warning(
f"Skipping {_exp.protein_dir_name}: base map for selection {selection} and "
f"occupancy {_exp.occ_a} not found"
f"occupancy {_exp.altloc_occupancies} not found"
)
continue

Expand All @@ -139,7 +138,7 @@ def main(args: argparse.Namespace):
#
# Load base map for canonical unit cell,
# don't overwrite the base map with selection map--we'll use the full map later too.
if (protein, _exp.occ_a, selection) not in base_map_cache:
if (protein, _exp.occ_key, selection) not in base_map_cache:
_base_xmap = protein_config.load_map(_base_map_path)
if _base_xmap is None:
raise ValueError(f"Failed to load base map from {_base_map_path}")
Expand All @@ -151,11 +150,14 @@ def main(args: argparse.Namespace):
)
logger.info(
f"Caching base and subselected maps for {protein} "
f"occ_a={_exp.occ_a} selection={selection}"
f"altloc_occupancies={_exp.altloc_occupancies} selection={selection}"
)
base_map_cache[(protein, _exp.occ_key, selection)] = (
_base_xmap,
_extracted_base,
)
base_map_cache[(protein, _exp.occ_a, selection)] = (_base_xmap, _extracted_base)
else:
_base_xmap, _extracted_base = base_map_cache[(protein, _exp.occ_a, selection)]
_base_xmap, _extracted_base = base_map_cache[(protein, _exp.occ_key, selection)]

# Validate extraction
if _extracted_base is None or _extracted_base.shape[0] == 0:
Expand All @@ -180,20 +182,24 @@ def main(args: argparse.Namespace):
#
# Align the refined structure to the reference structure
# 1. Get the reference structure path and load from cache if available
if (protein, _exp.occ_a) not in ref_full_structure_cache:
ref_path = protein_config.get_reference_structure_path(_exp.occ_a)
if (protein, _exp.occ_key) not in ref_full_structure_cache:
ref_path = protein_config.get_reference_structure_path(_exp.altloc_occupancies)
if ref_path is None:
raise ValueError(
f"Could not find reference structure for occupancy {_exp.occ_a}"
f"Could not find reference structure for "
f"occupancy {_exp.altloc_occupancies}"
)

# 2. Load the reference structure with parse() to get only the first altloc
ref_structure = parse(ref_path, ccd_mirror_path=None)
ref_atom_array = get_asym_unit_from_structure(ref_structure)
logger.info(f"Caching reference structure for {protein} occ_a={_exp.occ_a}")
ref_full_structure_cache[(protein, _exp.occ_a)] = ref_atom_array
logger.info(
f"Caching reference structure for {protein} "
f"altloc_occupancies={_exp.altloc_occupancies}"
)
ref_full_structure_cache[(protein, _exp.occ_key)] = ref_atom_array
else:
ref_atom_array = ref_full_structure_cache[(protein, _exp.occ_a)]
ref_atom_array = ref_full_structure_cache[(protein, _exp.occ_key)]

# 3. Find the common atoms with non-nan coords between the reference
# and the refined structure
Expand Down
64 changes: 41 additions & 23 deletions src/sampleworks/eval/eval_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# https://github.com/diff-use/sampleworks/issues/122
class Experiment:
protein: str
occ_a: float
altloc_occupancies: dict[str, float]
Comment thread
coderabbitai[bot] marked this conversation as resolved.
model: str
method: str | None
scaler: str
Expand All @@ -27,6 +27,12 @@ class Experiment:
base_map_path: Path | None = None
error: Exception | None = None

@property
def occ_key(self) -> tuple[tuple[str, float], ...]:
"""Hashable key for caches. Zero-occupancy altlocs are omitted for
consistency with occupancy_to_str / extract_protein_and_occupancy."""
return tuple(sorted((k, v) for k, v in self.altloc_occupancies.items() if abs(v) > 1e-6))


class ExperimentList(list[Experiment]):
def summarize(self):
Expand Down Expand Up @@ -59,23 +65,27 @@ def __post_init__(self):
def is_selection_valid(self, selection: str) -> bool:
return selection is not None and selection.strip() != ""

def get_base_map_path_for_occupancy(self, altloc_occupancies: dict[str, float]) -> Path | None:
"""Return the base-map path for the given altloc occupancies, or ``None``.

def get_base_map_path_for_occupancy(self, occupancy_a: float) -> Path | None:
occ_str = occupancy_to_str(occupancy_a, use_6b8x_format=self.protein == "6b8x")
Parameters
----------
altloc_occupancies : dict[str, float]
Mapping of altloc labels to occupancy values,
e.g. ``{"A": 0.5, "B": 0.5}`` or ``{"A": 0.5, "B": 0.3, "C": 0.2}``.
Comment thread
marcuscollins marked this conversation as resolved.
"""
try:
occ_str = occupancy_to_str(**altloc_occupancies)
except ValueError as e:
logger.warning(
f"Cannot determine occupancy string for {self.protein} with occupancies"
f" {altloc_occupancies}: {e}"
)
return None
map_path = self.base_map_dir / self.map_pattern.format(occ_str=occ_str)
if map_path.exists():
return map_path

# TODO: this is a kluge we should work to remove @kchrispens
alt_patterns = []
if self.protein == "6b8x":
alt_patterns.append(f"6b8x_{occupancy_to_str(occupancy_a)}_1.74A.ccp4")

for alt in alt_patterns:
alt_path = self.base_map_dir / alt
if alt_path.exists():
return alt_path

logger.warning(f"Base map for protein {self.protein} ({map_path}) NOT FOUND")
return None

Expand All @@ -90,24 +100,32 @@ def load_map(

return xmap

def get_reference_structure_path(self, occupancy_a: float) -> Path | None:
def get_reference_structure_path(self, altloc_occupancies: dict[str, float]) -> Path | None:
"""Return the reference-structure path for the given altloc occupancies, or ``None``.

Parameters
----------
altloc_occupancies : dict[str, float]
Mapping of altloc labels to occupancy values,
e.g. ``{"A": 0.5, "B": 0.5}`` or ``{"A": 0.5, "B": 0.3, "C": 0.2}``.
"""
if not self.structure_pattern:
return None

occ_str = occupancy_to_str(occupancy_a, use_6b8x_format=self.protein == "6b8x")
try:
occ_str = occupancy_to_str(**altloc_occupancies)
except ValueError as e:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when does this happen? Can we prevent it with input validation?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this happens when the altlocs have negative or >1 occupancy

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where input validation can really help. Maybe make an issue to add this feature.

logger.warning(
f"Cannot determine occupancy string for {self.protein} with occupancies"
f" {altloc_occupancies}: {e}"
)
return None
structure_path = self.base_map_dir / self.structure_pattern.format(occ_str=occ_str)
if structure_path.exists():
return structure_path

# Try shifted version for 6b8x
if self.protein == "6b8x":
_pattern = self.structure_pattern.format(occ_str=occ_str)
shifted_path = self.base_map_dir / _pattern.replace(".cif", "_shifted.cif")
if shifted_path.exists():
return shifted_path

logger.warning(
f"Reference structure for {self.protein} with occ {occupancy_a} "
f"Reference structure for {self.protein} with occupancies {altloc_occupancies} "
f"not found: {structure_path}"
)
return None
Expand Down
6 changes: 3 additions & 3 deletions src/sampleworks/eval/grid_search_eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def scan_grid_search_results(
model_dir = scaler_dir.parent
protein_dir = model_dir.parent

protein, occ_a = extract_protein_and_occupancy(protein_dir.name)
protein, altloc_occupancies = extract_protein_and_occupancy(protein_dir.name)
method, model = get_method_and_model_name(model_dir.name)

params = parse_experiment_dir(exp_dir)
Expand All @@ -106,7 +106,7 @@ def scan_grid_search_results(
# Validate parameters to satisfy ty
if (
protein is None
or occ_a is None
or not altloc_occupancies
or (model == StructurePredictor.BOLTZ_2 and method is None)
Comment thread
k-chrispens marked this conversation as resolved.
or params["ensemble_size"] is None
or (guidance_weight is None and gd_steps is None)
Expand All @@ -117,7 +117,7 @@ def scan_grid_search_results(
experiments.append(
Experiment(
protein=protein,
occ_a=occ_a,
altloc_occupancies=altloc_occupancies,
model=model,
method=method,
scaler=scaler_dir.name,
Expand Down
Loading