Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors the protein-water dataset pipeline to support multiple encoder types (GVP/SLAE/ESM) with a directory-separated cache layout, fixes residue counting to include insertion codes, and precomputes/caches PP edge geometry features during preprocessing to avoid recomputation at training time.
Changes:
- Introduces directory-based cache separation for geometry vs. embeddings and adds encoder-specific embedding loading.
- Fixes residue identity handling by incorporating PDB insertion codes into residue keys (and related water-quality lookup keys).
- Precomputes and stores PP edge features (unit vectors + RBF) in the geometry cache.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
src/dataset.py |
Refactors dataset caching/layout, adds encoder-specific embedding loading, caches PP edge features, updates insertion-code handling, removes chain-filter support. |
src/utils.py |
Adds shared helpers for edge geometry/features and insertion-code normalization. |
src/constants.py |
Adds shared feature-dimension constants (e.g., NUM_RBF). |
tests/test_dataset.py |
Updates dataset tests for new cache layout, adds PP-edge caching coverage, and adds insertion-code integration tests. |
tests/test_embedding_generation.py |
Updates embedding-loading tests to reflect modular cache layout and encoder selection. |
tests/test_utils.py |
Adds unit tests for new edge-geometry and insertion-code normalization helpers. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
81299fe to
4d78b16
Compare
|
Caution Review failedThe pull request is closed. ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughRefactors dataset caching to separate geometry and embedding caches, adds shared element constants, normalizes insertion-code keys for EDIA/B-factors, introduces embedding-loading/padding helpers and encoder-key config, updates dataloader defaults, and expands tests for embedding, geometry, and PP precomputation. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant Trainer as Trainer/Inference
participant Dataset as ProteinWaterDataset
participant GeomCache as GeometryCache(.pt)
participant EmbCache as EmbeddingCache(.pt)
participant Encoder
User->>Trainer: start run / inference
Trainer->>Dataset: instantiate(encoder_type, geometry_cache_name, filter_config...)
Dataset->>GeomCache: load geometry cache (positions, PP edges, edge features)
alt encoder_type uses cached embeddings
Dataset->>EmbCache: load `embedding_key` (protein.embedding) for cache_key
EmbCache-->>Dataset: embedding tensor + embedding_type
Dataset->>Dataset: pad/align embeddings for mates / residue→atom mapping
end
Dataset-->>Trainer: return HeteroData sample (protein.embedding + embedding_type)
Trainer->>Encoder: build/validate encoder (embedding_key, embedding_dim)
Trainer->>Encoder: forward(sample)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/dataset.py (1)
926-932:⚠️ Potential issue | 🟠 MajorInclude symmetry copy identifier in mate residue key.
mate_residue_keys = [(a.chain, a.resi) for a in crystal_data["mate_atoms"]]is insufficient because PyMOL'ssym*atoms preserve original chain/resi values across different symmetry copies. When multiple copies contribute the same chain/residue label, they collapse to a singleresidue_index, causing residue-level pooling operations (e.g., ingvp_encoder._pool_by_residue()) to incorrectly aggregate atom features from spatially distinct structural positions. Add a symmetry identifier (e.g.,a.modelwhich encodes the symmetry operator, ora.segiif enabled) to the key to distinguish copies.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/dataset.py` around lines 926 - 932, The current mate residue key generation (mate_residue_keys = [(a.chain, a.resi) for a in crystal_data["mate_atoms"]]) collapses different symmetry copies that share the same chain/resi; update mate_residue_keys to include a symmetry copy identifier (e.g., a.model or a.segi) so each physical copy is distinct: for each atom in crystal_data["mate_atoms"] build the key tuple (a.chain, a.resi, a.model) or (a.chain, a.resi, a.segi) (choose a.model if present else a.segi) then regenerate unique_mate_res, mate_res_map and mate_res_idx from that augmented key so downstream functions like gvp_encoder._pool_by_residue() pool by actual spatial residue copies rather than collapsed labels.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/dataset.py`:
- Around line 634-639: The geometry cache directory (self.geometry_dir)
currently only keys on geometry_cache_name and include_mates, so cached tensors
like pp_edge_index/pp_edge_rbf can be reused across different graph settings
(self.cutoff, NUM_RBF) and produce stale loads; update the caching strategy by
incorporating graph parameters into the cache namespace (e.g., append formatted
self.cutoff and NUM_RBF to geometry_cache_name when building geometry_dir) or
alternatively persist and validate metadata alongside the .pt files (store
cutoff and NUM_RBF when saving and check them when loading
pp_edge_index/pp_edge_rbf), and apply the same change/validation logic wherever
geometry_dir is constructed or geometry files are loaded (references:
geometry_dir, processed_dir, pp_edge_index, pp_edge_rbf, self.cutoff, NUM_RBF).
- Around line 662-669: Move the validation of the encoder_type before any
preprocessing side-effects: check self.encoder_type against {"gvp","slae","esm"}
at the start of the constructor (or before calling self._preprocess_all()) and
raise the ValueError early so _preprocess_all() never runs for an invalid
encoder_type; update the order around the existing self._preprocess_all() call
and the encoder_type check to validate first, then call self._preprocess_all().
- Line 1024: Change the unsafe torch.load calls that load cache files to use
weights_only=True: locate the torch.load call that assigns to slae_cached and
the other torch.load calls that load the other two cache dicts, and update each
to call torch.load(..., weights_only=True) so only tensor data is deserialized
while keeping the existing key validation logic intact.
- Around line 1199-1202: The current lambda used as collate_fn for the
DataLoader cannot be pickled when DataLoader spawns worker processes
(num_workers default 8), causing PicklingError on Windows/macOS; replace the
inline lambda with a module-level function named _collate_heterodata that
accepts batch: list[HeteroData] and returns Batch.from_data_list(batch), then
pass _collate_heterodata to DataLoader's collate_fn parameter (references:
collate_fn lambda, DataLoader, _collate_heterodata, Batch.from_data_list).
In `@tests/test_dataset.py`:
- Around line 49-88: Remove the module-local fixtures pdb_base_dir, pdb_6eey,
pdb_2b5w, pdb_8dzt, and pdb_1deu and rely on the env-aware fixtures defined in
tests/conftest.py instead; delete these fixture definitions from
tests/test_dataset.py so the tests will pick up the shared fixtures (which honor
ENV_PDB_DIR and the bundled tests/test_files fallback) and update any test
functions that referenced the removed fixture names to use the corresponding
fixture names from conftest if they differ.
---
Outside diff comments:
In `@src/dataset.py`:
- Around line 926-932: The current mate residue key generation
(mate_residue_keys = [(a.chain, a.resi) for a in crystal_data["mate_atoms"]])
collapses different symmetry copies that share the same chain/resi; update
mate_residue_keys to include a symmetry copy identifier (e.g., a.model or
a.segi) so each physical copy is distinct: for each atom in
crystal_data["mate_atoms"] build the key tuple (a.chain, a.resi, a.model) or
(a.chain, a.resi, a.segi) (choose a.model if present else a.segi) then
regenerate unique_mate_res, mate_res_map and mate_res_idx from that augmented
key so downstream functions like gvp_encoder._pool_by_residue() pool by actual
spatial residue copies rather than collapsed labels.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7f50184b-fbfb-43c5-b128-35d9c560b14f
📒 Files selected for processing (8)
.github/workflows/build.ymlsrc/constants.pysrc/dataset.pysrc/utils.pytests/conftest.pytests/test_dataset.pytests/test_embedding_generation.pytests/test_utils.py
💤 Files with no reviewable changes (1)
- tests/test_embedding_generation.py
marcuscollins
left a comment
There was a problem hiding this comment.
This is pretty close. There are a couple changes I think you need to make in order to avoid later confusion. One is that you appear to have two almost identical methods with almost the same name. I think you should just fix it in this PR. The other is a hard-coded local path in the tests. That one I guess we could create an issue and reference it in the code here, but it might be worth it to just fix it now.
| asu_water_indices = match_atoms_to_coords( | ||
| water_atoms, crystal_data["asu_coords"] | ||
| ) |
There was a problem hiding this comment.
I am actually a bit puzzled by what this function is doing here. Aren't the water_atoms you get from parse_asu_with_biotite already in "ASU"?
I put "ASU" in quotation mark because I think maybe the issue you want to address by get_crystal_contacts_pymol is that some waters appear quite far from the protein but they really are just closer to another protein copy related by symmetry or off by +/- unit cell lengths. Is my understanding correct @stephaniewankowicz @vratins ? If so, I think the approach should be to apply sym op (and convert to/from fractional coordinates) to put the water in the "correct" positions, which is what I am currently doing for water data analysis.
There was a problem hiding this comment.
Right so i think the comment here may be a bit misleading (have added a more clear one imo). The filtering is to ensure consistency between biotite and PyMOL parsing of the same PDB file; both read the deposited ASU coordinates, but may differ in handling altloc selection or edge cases. The match_atoms_to_coords keeps only waters that both parsers agree on.
There was a problem hiding this comment.
♻️ Duplicate comments (2)
src/dataset.py (2)
1251-1260:⚠️ Potential issue | 🟠 MajorReplace lambda
collate_fnwith a module-level function.Line 1259 uses a lambda as
collate_fn. With multiprocessing workers (defaultnum_workers=8), this can fail under spawn-based worker startup due to pickling constraints.Suggested patch
+def _collate_heterodata(batch: list[HeteroData]) -> Batch: + return Batch.from_data_list(batch) ... - collate_fn=lambda batch: Batch.from_data_list(batch), + collate_fn=_collate_heterodata,#!/bin/bash # Verify DataLoader + collate_fn usage rg -n 'collate_fn\s*=\s*lambda' src/dataset.py -C2 rg -n 'DataLoader\(' src/dataset.py -C6🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/dataset.py` around lines 1251 - 1260, The collate_fn is currently an inline lambda passed to DataLoader which is not picklable under spawn-based multiprocessing; replace the lambda with a top-level module function (e.g., define a module-level function like collate_batch(batch) that returns Batch.from_data_list(batch)) and pass that function as collate_fn in the DataLoader call; ensure the new function lives at module scope (not nested or bound to a class) so DataLoader (with num_workers>0) can pickle it for worker processes.
224-224:⚠️ Potential issue | 🟠 MajorSwitch cache deserialization to
weights_only=True.Line 224, Line 265, and Line 1154 still use
weights_only=Falsefor.ptcache loads. For cache directories, this keeps unnecessary pickle attack surface; these paths should use safe tensor-only loading.Suggested patch
- slae_cached = torch.load(slae_cache_path, weights_only=False) + slae_cached = torch.load(slae_cache_path, weights_only=True) ... - esm_cached = torch.load(esm_cache_path, weights_only=False) + esm_cached = torch.load(esm_cache_path, weights_only=True) ... - cached = torch.load(cache_path, weights_only=False) + cached = torch.load(cache_path, weights_only=True)#!/bin/bash # Verify remaining unsafe loads and inspect save sites rg -n 'torch\.load\([^)]*weights_only=False' src/dataset.py -C2 fd 'generate_.*embeddings\.py' | xargs -r rg -n 'torch\.save\(' -C2 rg -n 'torch\.save\(' src/dataset.py -C3Also applies to: 265-265, 1154-1154
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/dataset.py` at line 224, Change torch.load calls that deserialize cache files to use weights_only=True to avoid unpickling attack surface; specifically update the call that assigns slae_cached (variable slae_cached) and the other torch.load invocations in this module to pass weights_only=True instead of weights_only=False, and verify corresponding torch.save sites still produce tensor-only content (so loading with weights_only=True succeeds) by inspecting the save functions that write those .pt caches (search for torch.save usage in this file).
🧹 Nitpick comments (1)
tests/test_encoder.py (1)
141-157: Add one regression test for non-defaultembedding_key.Current tests cover the default
"embedding"path well, but the new configurable key path infrom_config()is not exercised. A small custom-key test would prevent regressions.Minimal test addition
+ def test_build_encoder_custom_embedding_key(self, device, sample_hetero_data): + config = {"encoder_type": "slae", "embedding_key": "custom_emb"} + encoder = build_encoder(config, device) + sample_hetero_data["protein"].custom_emb = torch.randn( + sample_hetero_data["protein"].num_nodes, 64, device=device + ) + s, V, edge_attr = encoder(sample_hetero_data) + assert s.shape[1] == 64 + assert V.shape[1] == 0 + assert edge_attr is NoneAlso applies to: 216-228
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_encoder.py` around lines 141 - 157, Add a regression test that verifies a non-default embedding key is honored by from_config/build_encoder: create a config with "encoder_type" (e.g., "slae" or "esm") and "embedding_key": "my_embedding", call build_encoder(config, device) (which uses the classmethod from_config internally), then assert the returned object is the expected type (e.g., CachedEmbeddingEncoder) and that its embedding_key (or equivalent attribute used by from_config) equals "my_embedding"; name the test something like test_build_encoder_custom_embedding_key to cover the new configurable path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@src/dataset.py`:
- Around line 1251-1260: The collate_fn is currently an inline lambda passed to
DataLoader which is not picklable under spawn-based multiprocessing; replace the
lambda with a top-level module function (e.g., define a module-level function
like collate_batch(batch) that returns Batch.from_data_list(batch)) and pass
that function as collate_fn in the DataLoader call; ensure the new function
lives at module scope (not nested or bound to a class) so DataLoader (with
num_workers>0) can pickle it for worker processes.
- Line 224: Change torch.load calls that deserialize cache files to use
weights_only=True to avoid unpickling attack surface; specifically update the
call that assigns slae_cached (variable slae_cached) and the other torch.load
invocations in this module to pass weights_only=True instead of
weights_only=False, and verify corresponding torch.save sites still produce
tensor-only content (so loading with weights_only=True succeeds) by inspecting
the save functions that write those .pt caches (search for torch.save usage in
this file).
---
Nitpick comments:
In `@tests/test_encoder.py`:
- Around line 141-157: Add a regression test that verifies a non-default
embedding key is honored by from_config/build_encoder: create a config with
"encoder_type" (e.g., "slae" or "esm") and "embedding_key": "my_embedding", call
build_encoder(config, device) (which uses the classmethod from_config
internally), then assert the returned object is the expected type (e.g.,
CachedEmbeddingEncoder) and that its embedding_key (or equivalent attribute used
by from_config) equals "my_embedding"; name the test something like
test_build_encoder_custom_embedding_key to cover the new configurable path.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8015f904-6137-4e9e-b12e-747297b58e64
📒 Files selected for processing (8)
src/dataset.pysrc/encoder_base.pytests/test_dataset.pytests/test_encoder.pytests/test_files/1deu/1deu_final.pdbtests/test_files/2b5w/2b5w_final.pdbtests/test_files/8dzt/8dzt_final.pdbtests/test_utils.py
✅ Files skipped from review due to trivial changes (1)
- tests/test_dataset.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/test_utils.py
marcuscollins
left a comment
There was a problem hiding this comment.
Questions addressed to my satisfaction.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 12 out of 15 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
♻️ Duplicate comments (2)
src/dataset.py (2)
224-224:⚠️ Potential issue | 🟠 MajorUse
weights_only=Truefor safer cache loading.This is a security concern flagged in a previous review. The SLAE cache file contains only tensors and primitive types, which are fully supported by
weights_only=True. Usingweights_only=Falseallows arbitrary code execution from malicious cache files.Proposed fix
- slae_cached = torch.load(slae_cache_path, weights_only=False) + slae_cached = torch.load(slae_cache_path, weights_only=True)Also applies to lines 265 and 1154.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/dataset.py` at line 224, The torch.load call that assigns slae_cached uses weights_only=False which permits unsafe deserialization; update the torch.load invocations (the one assigning slae_cached and the other two occurrences referenced at lines ~265 and ~1154) to pass weights_only=True so only tensor/primitive data is loaded safely, leaving all other arguments unchanged and keeping the same variable names (e.g., slae_cached) and surrounding logic.
1257-1259:⚠️ Potential issue | 🟠 MajorReplace the lambda
collate_fnwith a module-level function.With
num_workers=8as the default, the DataLoader will use multiprocessing. On Windows/macOS (which use thespawnstart method), lambda functions cannot be pickled, causing aPicklingError. This breaks the code on those platforms.Proposed fix: Add module-level function
Add near the top of the file:
def _collate_heterodata(batch: list[HeteroData]) -> Batch: """Collate function for HeteroData batches.""" return Batch.from_data_list(batch)Then update the DataLoader call:
- collate_fn=lambda batch: Batch.from_data_list(batch), + collate_fn=_collate_heterodata,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/dataset.py` around lines 1257 - 1259, The collate_fn currently uses an inline lambda (collate_fn=lambda batch: Batch.from_data_list(batch)) which is not picklable for multiprocessing on spawn-based platforms; create a module-level function named _collate_heterodata(batch: list[HeteroData]) -> Batch that simply returns Batch.from_data_list(batch) and replace the lambda in the DataLoader call with collate_fn=_collate_heterodata so the function can be pickled across processes.
🧹 Nitpick comments (1)
tests/test_train_config.py (1)
75-77: Minor: Redundant assertion.Line 77 is redundant since line 76 already verifies that
embedding_dimexists and has the correct value. Consider removing it for clarity.Proposed simplification
assert config["embedding_key"] == "embedding" assert config["embedding_dim"] == 128 - assert "embedding_dim" in config🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_train_config.py` around lines 75 - 77, Remove the redundant existence check for embedding_dim in the test: the assertion "assert 'embedding_dim' in config" is unnecessary because the prior assertion "assert config['embedding_dim'] == 128" already verifies the key exists and has the expected value; update the test in tests/test_train_config.py by deleting the redundant assertion so only "assert config['embedding_key'] == 'embedding'" and "assert config['embedding_dim'] == 128" remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@src/dataset.py`:
- Line 224: The torch.load call that assigns slae_cached uses weights_only=False
which permits unsafe deserialization; update the torch.load invocations (the one
assigning slae_cached and the other two occurrences referenced at lines ~265 and
~1154) to pass weights_only=True so only tensor/primitive data is loaded safely,
leaving all other arguments unchanged and keeping the same variable names (e.g.,
slae_cached) and surrounding logic.
- Around line 1257-1259: The collate_fn currently uses an inline lambda
(collate_fn=lambda batch: Batch.from_data_list(batch)) which is not picklable
for multiprocessing on spawn-based platforms; create a module-level function
named _collate_heterodata(batch: list[HeteroData]) -> Batch that simply returns
Batch.from_data_list(batch) and replace the lambda in the DataLoader call with
collate_fn=_collate_heterodata so the function can be pickled across processes.
---
Nitpick comments:
In `@tests/test_train_config.py`:
- Around line 75-77: Remove the redundant existence check for embedding_dim in
the test: the assertion "assert 'embedding_dim' in config" is unnecessary
because the prior assertion "assert config['embedding_dim'] == 128" already
verifies the key exists and has the expected value; update the test in
tests/test_train_config.py by deleting the redundant assertion so only "assert
config['embedding_key'] == 'embedding'" and "assert config['embedding_dim'] ==
128" remain.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c899e27f-fb48-43fe-8ea8-34ff2e093578
📒 Files selected for processing (4)
scripts/inference.pyscripts/train.pysrc/dataset.pytests/test_train_config.py
utils.pySummary by CodeRabbit
New Features
Tests
Chores