Skip to content

Updating dataset.py #43

Merged
vratins merged 17 commits intomainfrom
dev_dataset_updates
Mar 23, 2026
Merged

Updating dataset.py #43
vratins merged 17 commits intomainfrom
dev_dataset_updates

Conversation

@vratins
Copy link
Copy Markdown
Contributor

@vratins vratins commented Mar 3, 2026

  • Refactor dataset to support multiple encoder types (GVP, SLAE, ESM) with separate cache directories for geometry and embeddings
  • Fix residue counting to include insertion codes
  • Pre-compute and cache PP edge features (unit vectors, RBF) during preprocessing to avoid recomputation at training time
  • Simplify API by removing unused chain_filter parameter and supporting only <pdb_id>_final format in PDB list files
  • Shifting re-used functions to utils.py

Summary by CodeRabbit

  • New Features

    • Unified support for cached protein embeddings (automatic loading, padding, and encoder dispatch).
    • Insertion-code–aware residue/water handling, improved symmetry-mate integration, and separate precomputed geometry caching for faster dataset loads.
    • More flexible dataset filtering and encoder selection options, plus updated data loader defaults for performance.
  • Tests

    • Expanded unit/integration tests covering embeddings, dataset caching/behavior, filtering, and edge-feature invariants.
  • Chores

    • Lowered CI coverage enforcement threshold.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the protein-water dataset pipeline to support multiple encoder types (GVP/SLAE/ESM) with a directory-separated cache layout, fixes residue counting to include insertion codes, and precomputes/caches PP edge geometry features during preprocessing to avoid recomputation at training time.

Changes:

  • Introduces directory-based cache separation for geometry vs. embeddings and adds encoder-specific embedding loading.
  • Fixes residue identity handling by incorporating PDB insertion codes into residue keys (and related water-quality lookup keys).
  • Precomputes and stores PP edge features (unit vectors + RBF) in the geometry cache.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
src/dataset.py Refactors dataset caching/layout, adds encoder-specific embedding loading, caches PP edge features, updates insertion-code handling, removes chain-filter support.
src/utils.py Adds shared helpers for edge geometry/features and insertion-code normalization.
src/constants.py Adds shared feature-dimension constants (e.g., NUM_RBF).
tests/test_dataset.py Updates dataset tests for new cache layout, adds PP-edge caching coverage, and adds insertion-code integration tests.
tests/test_embedding_generation.py Updates embedding-loading tests to reflect modular cache layout and encoder selection.
tests/test_utils.py Adds unit tests for new edge-geometry and insertion-code normalization helpers.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/test_dataset.py Outdated
Comment thread src/dataset.py Outdated
Comment thread src/dataset.py
Comment thread .github/workflows/build.yml
Comment thread .github/workflows/build.yml
Copilot AI review requested due to automatic review settings March 21, 2026 00:31
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Mar 21, 2026

Caution

Review failed

The pull request is closed.

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9221aa0f-c588-4612-894e-c6eedd11eec3

📥 Commits

Reviewing files that changed from the base of the PR and between f60d8f6 and 77aadaa.

📒 Files selected for processing (1)
  • src/dataset.py

📝 Walkthrough

Walkthrough

Refactors dataset caching to separate geometry and embedding caches, adds shared element constants, normalizes insertion-code keys for EDIA/B-factors, introduces embedding-loading/padding helpers and encoder-key config, updates dataloader defaults, and expands tests for embedding, geometry, and PP precomputation.

Changes

Cohort / File(s) Summary
CI Workflow
​.github/workflows/build.yml
Lowered pytest coverage fail-under from 8070.
Constants
src/constants.py
Added ELEMENT_VOCAB (15 element symbols) and ELEM_IDX (symbol→index mapping).
Dataset core
src/dataset.py
Major changes: separate geometry vs embedding caches (configurable geometry_cache_name, encoder_type), normalize insertion codes for EDIA/B-factors, precompute/store PP graph geometry/edge features, add embedding loaders/padding (load_slae_embedding, load_esm_embedding, _pad_atom_embeddings_for_mates), change pdb_list parsing to <pdb>_final, adjust filtering defaults and signatures, and update DataLoader defaults/options.
Encoder base
src/encoder_base.py
CachedEmbeddingEncoder.from_config now accepts optional embedding_key (defaults to "embedding").
Scripts — inference/train
scripts/inference.py, scripts/train.py
Propagate dataset filter config from training config into inference; unify cached-embedding handling to protein.embedding with embedding_key="embedding" and embedding_dim; CLI now uses --embedding_dim and validates it against gvp.
Tests — fixtures & dataset
tests/conftest.py, tests/test_dataset.py
Added create_mock_dataset fixture; updated tests to expect whole-PDB <id>_final format; expanded tests for geometry cache naming, PP edge features, insertion-code key shapes, embedding loaders/padding, encoder-type validation, mates handling, and many edge cases.
Tests — removed / reorganized
tests/test_embedding_generation.py
Deleted (coverage migrated into updated dataset tests).
Tests — utils & encoder
tests/test_utils.py, tests/test_encoder.py
Added tests for compute_edge_geometry, compute_edge_features, normalize_ins_code; updated encoder tests to use generic protein.embedding and embedding_type.
Tests — train config
tests/test_train_config.py
New tests validating embedding resolution, embedding_dim propagation, CLI validation, and extraction of dataset filter config for inference.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant User
    participant Trainer as Trainer/Inference
    participant Dataset as ProteinWaterDataset
    participant GeomCache as GeometryCache(.pt)
    participant EmbCache as EmbeddingCache(.pt)
    participant Encoder

    User->>Trainer: start run / inference
    Trainer->>Dataset: instantiate(encoder_type, geometry_cache_name, filter_config...)
    Dataset->>GeomCache: load geometry cache (positions, PP edges, edge features)
    alt encoder_type uses cached embeddings
        Dataset->>EmbCache: load `embedding_key` (protein.embedding) for cache_key
        EmbCache-->>Dataset: embedding tensor + embedding_type
        Dataset->>Dataset: pad/align embeddings for mates / residue→atom mapping
    end
    Dataset-->>Trainer: return HeteroData sample (protein.embedding + embedding_type)
    Trainer->>Encoder: build/validate encoder (embedding_key, embedding_dim)
    Trainer->>Encoder: forward(sample)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • Updating dataset.py  #43: Mirrors the dataset caching, embedding loaders/padding, insertion-code changes, and related test/script updates in this PR.

Suggested reviewers

  • marcuscollins

Poem

🐰 I hopped through caches split in twain,

geometry snug, embeddings lain,
insertion codes now neat and true,
mates appended, indices grew—
a tiny burrow ready for you.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title 'Updating dataset.py' is vague and overly generic, using a non-descriptive verb ('Updating') without specifying the nature or scope of changes. Use a more specific title that captures the main change, such as 'Refactor dataset.py to support multiple encoder types and cache embeddings separately' or 'Pre-compute PP edge features and fix residue counting with insertion codes in dataset.py'.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 88.28% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch dev_dataset_updates

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread .github/workflows/build.yml
Comment thread src/dataset.py
Comment thread tests/test_dataset.py Outdated
Comment thread src/utils.py Outdated
Comment thread src/dataset.py Outdated
Comment thread src/dataset.py
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/dataset.py (1)

926-932: ⚠️ Potential issue | 🟠 Major

Include symmetry copy identifier in mate residue key.

mate_residue_keys = [(a.chain, a.resi) for a in crystal_data["mate_atoms"]] is insufficient because PyMOL's sym* atoms preserve original chain/resi values across different symmetry copies. When multiple copies contribute the same chain/residue label, they collapse to a single residue_index, causing residue-level pooling operations (e.g., in gvp_encoder._pool_by_residue()) to incorrectly aggregate atom features from spatially distinct structural positions. Add a symmetry identifier (e.g., a.model which encodes the symmetry operator, or a.segi if enabled) to the key to distinguish copies.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/dataset.py` around lines 926 - 932, The current mate residue key
generation (mate_residue_keys = [(a.chain, a.resi) for a in
crystal_data["mate_atoms"]]) collapses different symmetry copies that share the
same chain/resi; update mate_residue_keys to include a symmetry copy identifier
(e.g., a.model or a.segi) so each physical copy is distinct: for each atom in
crystal_data["mate_atoms"] build the key tuple (a.chain, a.resi, a.model) or
(a.chain, a.resi, a.segi) (choose a.model if present else a.segi) then
regenerate unique_mate_res, mate_res_map and mate_res_idx from that augmented
key so downstream functions like gvp_encoder._pool_by_residue() pool by actual
spatial residue copies rather than collapsed labels.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/dataset.py`:
- Around line 634-639: The geometry cache directory (self.geometry_dir)
currently only keys on geometry_cache_name and include_mates, so cached tensors
like pp_edge_index/pp_edge_rbf can be reused across different graph settings
(self.cutoff, NUM_RBF) and produce stale loads; update the caching strategy by
incorporating graph parameters into the cache namespace (e.g., append formatted
self.cutoff and NUM_RBF to geometry_cache_name when building geometry_dir) or
alternatively persist and validate metadata alongside the .pt files (store
cutoff and NUM_RBF when saving and check them when loading
pp_edge_index/pp_edge_rbf), and apply the same change/validation logic wherever
geometry_dir is constructed or geometry files are loaded (references:
geometry_dir, processed_dir, pp_edge_index, pp_edge_rbf, self.cutoff, NUM_RBF).
- Around line 662-669: Move the validation of the encoder_type before any
preprocessing side-effects: check self.encoder_type against {"gvp","slae","esm"}
at the start of the constructor (or before calling self._preprocess_all()) and
raise the ValueError early so _preprocess_all() never runs for an invalid
encoder_type; update the order around the existing self._preprocess_all() call
and the encoder_type check to validate first, then call self._preprocess_all().
- Line 1024: Change the unsafe torch.load calls that load cache files to use
weights_only=True: locate the torch.load call that assigns to slae_cached and
the other torch.load calls that load the other two cache dicts, and update each
to call torch.load(..., weights_only=True) so only tensor data is deserialized
while keeping the existing key validation logic intact.
- Around line 1199-1202: The current lambda used as collate_fn for the
DataLoader cannot be pickled when DataLoader spawns worker processes
(num_workers default 8), causing PicklingError on Windows/macOS; replace the
inline lambda with a module-level function named _collate_heterodata that
accepts batch: list[HeteroData] and returns Batch.from_data_list(batch), then
pass _collate_heterodata to DataLoader's collate_fn parameter (references:
collate_fn lambda, DataLoader, _collate_heterodata, Batch.from_data_list).

In `@tests/test_dataset.py`:
- Around line 49-88: Remove the module-local fixtures pdb_base_dir, pdb_6eey,
pdb_2b5w, pdb_8dzt, and pdb_1deu and rely on the env-aware fixtures defined in
tests/conftest.py instead; delete these fixture definitions from
tests/test_dataset.py so the tests will pick up the shared fixtures (which honor
ENV_PDB_DIR and the bundled tests/test_files fallback) and update any test
functions that referenced the removed fixture names to use the corresponding
fixture names from conftest if they differ.

---

Outside diff comments:
In `@src/dataset.py`:
- Around line 926-932: The current mate residue key generation
(mate_residue_keys = [(a.chain, a.resi) for a in crystal_data["mate_atoms"]])
collapses different symmetry copies that share the same chain/resi; update
mate_residue_keys to include a symmetry copy identifier (e.g., a.model or
a.segi) so each physical copy is distinct: for each atom in
crystal_data["mate_atoms"] build the key tuple (a.chain, a.resi, a.model) or
(a.chain, a.resi, a.segi) (choose a.model if present else a.segi) then
regenerate unique_mate_res, mate_res_map and mate_res_idx from that augmented
key so downstream functions like gvp_encoder._pool_by_residue() pool by actual
spatial residue copies rather than collapsed labels.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7f50184b-fbfb-43c5-b128-35d9c560b14f

📥 Commits

Reviewing files that changed from the base of the PR and between 19cc330 and eef657b.

📒 Files selected for processing (8)
  • .github/workflows/build.yml
  • src/constants.py
  • src/dataset.py
  • src/utils.py
  • tests/conftest.py
  • tests/test_dataset.py
  • tests/test_embedding_generation.py
  • tests/test_utils.py
💤 Files with no reviewable changes (1)
  • tests/test_embedding_generation.py

Comment thread src/dataset.py Outdated
Comment thread src/dataset.py Outdated
Comment thread src/dataset.py Outdated
Comment thread src/dataset.py
Comment thread tests/test_dataset.py Outdated
Copy link
Copy Markdown
Collaborator

@marcuscollins marcuscollins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty close. There are a couple changes I think you need to make in order to avoid later confusion. One is that you appear to have two almost identical methods with almost the same name. I think you should just fix it in this PR. The other is a hard-coded local path in the tests. That one I guess we could create an issue and reference it in the code here, but it might be worth it to just fix it now.

Comment thread .github/workflows/build.yml
Comment thread tests/conftest.py
Comment thread src/dataset.py Outdated
Comment thread src/dataset.py
Comment thread src/dataset.py Outdated
Comment thread src/dataset.py Outdated
Comment thread src/dataset.py Outdated
Comment thread src/dataset.py Outdated
Comment thread tests/test_dataset.py Outdated
Comment thread tests/test_dataset.py Outdated
Comment thread src/dataset.py
Comment on lines 794 to 796
asu_water_indices = match_atoms_to_coords(
water_atoms, crystal_data["asu_coords"]
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually a bit puzzled by what this function is doing here. Aren't the water_atoms you get from parse_asu_with_biotite already in "ASU"?

I put "ASU" in quotation mark because I think maybe the issue you want to address by get_crystal_contacts_pymol is that some waters appear quite far from the protein but they really are just closer to another protein copy related by symmetry or off by +/- unit cell lengths. Is my understanding correct @stephaniewankowicz @vratins ? If so, I think the approach should be to apply sym op (and convert to/from fractional coordinates) to put the water in the "correct" positions, which is what I am currently doing for water data analysis.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right so i think the comment here may be a bit misleading (have added a more clear one imo). The filtering is to ensure consistency between biotite and PyMOL parsing of the same PDB file; both read the deposited ASU coordinates, but may differ in handling altloc selection or edge cases. The match_atoms_to_coords keeps only waters that both parsers agree on.

Comment thread src/dataset.py
Copilot AI review requested due to automatic review settings March 23, 2026 16:47
@vratins vratins review requested due to automatic review settings March 23, 2026 16:47
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
src/dataset.py (2)

1251-1260: ⚠️ Potential issue | 🟠 Major

Replace lambda collate_fn with a module-level function.

Line 1259 uses a lambda as collate_fn. With multiprocessing workers (default num_workers=8), this can fail under spawn-based worker startup due to pickling constraints.

Suggested patch
+def _collate_heterodata(batch: list[HeteroData]) -> Batch:
+    return Batch.from_data_list(batch)
...
-        collate_fn=lambda batch: Batch.from_data_list(batch),
+        collate_fn=_collate_heterodata,
#!/bin/bash
# Verify DataLoader + collate_fn usage
rg -n 'collate_fn\s*=\s*lambda' src/dataset.py -C2
rg -n 'DataLoader\(' src/dataset.py -C6
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/dataset.py` around lines 1251 - 1260, The collate_fn is currently an
inline lambda passed to DataLoader which is not picklable under spawn-based
multiprocessing; replace the lambda with a top-level module function (e.g.,
define a module-level function like collate_batch(batch) that returns
Batch.from_data_list(batch)) and pass that function as collate_fn in the
DataLoader call; ensure the new function lives at module scope (not nested or
bound to a class) so DataLoader (with num_workers>0) can pickle it for worker
processes.

224-224: ⚠️ Potential issue | 🟠 Major

Switch cache deserialization to weights_only=True.

Line 224, Line 265, and Line 1154 still use weights_only=False for .pt cache loads. For cache directories, this keeps unnecessary pickle attack surface; these paths should use safe tensor-only loading.

Suggested patch
-    slae_cached = torch.load(slae_cache_path, weights_only=False)
+    slae_cached = torch.load(slae_cache_path, weights_only=True)
...
-    esm_cached = torch.load(esm_cache_path, weights_only=False)
+    esm_cached = torch.load(esm_cache_path, weights_only=True)
...
-        cached = torch.load(cache_path, weights_only=False)
+        cached = torch.load(cache_path, weights_only=True)
#!/bin/bash
# Verify remaining unsafe loads and inspect save sites
rg -n 'torch\.load\([^)]*weights_only=False' src/dataset.py -C2
fd 'generate_.*embeddings\.py' | xargs -r rg -n 'torch\.save\(' -C2
rg -n 'torch\.save\(' src/dataset.py -C3

Also applies to: 265-265, 1154-1154

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/dataset.py` at line 224, Change torch.load calls that deserialize cache
files to use weights_only=True to avoid unpickling attack surface; specifically
update the call that assigns slae_cached (variable slae_cached) and the other
torch.load invocations in this module to pass weights_only=True instead of
weights_only=False, and verify corresponding torch.save sites still produce
tensor-only content (so loading with weights_only=True succeeds) by inspecting
the save functions that write those .pt caches (search for torch.save usage in
this file).
🧹 Nitpick comments (1)
tests/test_encoder.py (1)

141-157: Add one regression test for non-default embedding_key.

Current tests cover the default "embedding" path well, but the new configurable key path in from_config() is not exercised. A small custom-key test would prevent regressions.

Minimal test addition
+    def test_build_encoder_custom_embedding_key(self, device, sample_hetero_data):
+        config = {"encoder_type": "slae", "embedding_key": "custom_emb"}
+        encoder = build_encoder(config, device)
+        sample_hetero_data["protein"].custom_emb = torch.randn(
+            sample_hetero_data["protein"].num_nodes, 64, device=device
+        )
+        s, V, edge_attr = encoder(sample_hetero_data)
+        assert s.shape[1] == 64
+        assert V.shape[1] == 0
+        assert edge_attr is None

Also applies to: 216-228

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_encoder.py` around lines 141 - 157, Add a regression test that
verifies a non-default embedding key is honored by from_config/build_encoder:
create a config with "encoder_type" (e.g., "slae" or "esm") and "embedding_key":
"my_embedding", call build_encoder(config, device) (which uses the classmethod
from_config internally), then assert the returned object is the expected type
(e.g., CachedEmbeddingEncoder) and that its embedding_key (or equivalent
attribute used by from_config) equals "my_embedding"; name the test something
like test_build_encoder_custom_embedding_key to cover the new configurable path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@src/dataset.py`:
- Around line 1251-1260: The collate_fn is currently an inline lambda passed to
DataLoader which is not picklable under spawn-based multiprocessing; replace the
lambda with a top-level module function (e.g., define a module-level function
like collate_batch(batch) that returns Batch.from_data_list(batch)) and pass
that function as collate_fn in the DataLoader call; ensure the new function
lives at module scope (not nested or bound to a class) so DataLoader (with
num_workers>0) can pickle it for worker processes.
- Line 224: Change torch.load calls that deserialize cache files to use
weights_only=True to avoid unpickling attack surface; specifically update the
call that assigns slae_cached (variable slae_cached) and the other torch.load
invocations in this module to pass weights_only=True instead of
weights_only=False, and verify corresponding torch.save sites still produce
tensor-only content (so loading with weights_only=True succeeds) by inspecting
the save functions that write those .pt caches (search for torch.save usage in
this file).

---

Nitpick comments:
In `@tests/test_encoder.py`:
- Around line 141-157: Add a regression test that verifies a non-default
embedding key is honored by from_config/build_encoder: create a config with
"encoder_type" (e.g., "slae" or "esm") and "embedding_key": "my_embedding", call
build_encoder(config, device) (which uses the classmethod from_config
internally), then assert the returned object is the expected type (e.g.,
CachedEmbeddingEncoder) and that its embedding_key (or equivalent attribute used
by from_config) equals "my_embedding"; name the test something like
test_build_encoder_custom_embedding_key to cover the new configurable path.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8015f904-6137-4e9e-b12e-747297b58e64

📥 Commits

Reviewing files that changed from the base of the PR and between eef657b and a557745.

📒 Files selected for processing (8)
  • src/dataset.py
  • src/encoder_base.py
  • tests/test_dataset.py
  • tests/test_encoder.py
  • tests/test_files/1deu/1deu_final.pdb
  • tests/test_files/2b5w/2b5w_final.pdb
  • tests/test_files/8dzt/8dzt_final.pdb
  • tests/test_utils.py
✅ Files skipped from review due to trivial changes (1)
  • tests/test_dataset.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/test_utils.py

Copy link
Copy Markdown
Collaborator

@marcuscollins marcuscollins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Questions addressed to my satisfaction.

Copilot AI review requested due to automatic review settings March 23, 2026 19:25
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 15 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/dataset.py
Comment thread scripts/inference.py
Comment thread src/dataset.py
Comment thread src/dataset.py
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
src/dataset.py (2)

224-224: ⚠️ Potential issue | 🟠 Major

Use weights_only=True for safer cache loading.

This is a security concern flagged in a previous review. The SLAE cache file contains only tensors and primitive types, which are fully supported by weights_only=True. Using weights_only=False allows arbitrary code execution from malicious cache files.

Proposed fix
-    slae_cached = torch.load(slae_cache_path, weights_only=False)
+    slae_cached = torch.load(slae_cache_path, weights_only=True)

Also applies to lines 265 and 1154.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/dataset.py` at line 224, The torch.load call that assigns slae_cached
uses weights_only=False which permits unsafe deserialization; update the
torch.load invocations (the one assigning slae_cached and the other two
occurrences referenced at lines ~265 and ~1154) to pass weights_only=True so
only tensor/primitive data is loaded safely, leaving all other arguments
unchanged and keeping the same variable names (e.g., slae_cached) and
surrounding logic.

1257-1259: ⚠️ Potential issue | 🟠 Major

Replace the lambda collate_fn with a module-level function.

With num_workers=8 as the default, the DataLoader will use multiprocessing. On Windows/macOS (which use the spawn start method), lambda functions cannot be pickled, causing a PicklingError. This breaks the code on those platforms.

Proposed fix: Add module-level function

Add near the top of the file:

def _collate_heterodata(batch: list[HeteroData]) -> Batch:
    """Collate function for HeteroData batches."""
    return Batch.from_data_list(batch)

Then update the DataLoader call:

-        collate_fn=lambda batch: Batch.from_data_list(batch),
+        collate_fn=_collate_heterodata,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/dataset.py` around lines 1257 - 1259, The collate_fn currently uses an
inline lambda (collate_fn=lambda batch: Batch.from_data_list(batch)) which is
not picklable for multiprocessing on spawn-based platforms; create a
module-level function named _collate_heterodata(batch: list[HeteroData]) ->
Batch that simply returns Batch.from_data_list(batch) and replace the lambda in
the DataLoader call with collate_fn=_collate_heterodata so the function can be
pickled across processes.
🧹 Nitpick comments (1)
tests/test_train_config.py (1)

75-77: Minor: Redundant assertion.

Line 77 is redundant since line 76 already verifies that embedding_dim exists and has the correct value. Consider removing it for clarity.

Proposed simplification
     assert config["embedding_key"] == "embedding"
     assert config["embedding_dim"] == 128
-    assert "embedding_dim" in config
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_train_config.py` around lines 75 - 77, Remove the redundant
existence check for embedding_dim in the test: the assertion "assert
'embedding_dim' in config" is unnecessary because the prior assertion "assert
config['embedding_dim'] == 128" already verifies the key exists and has the
expected value; update the test in tests/test_train_config.py by deleting the
redundant assertion so only "assert config['embedding_key'] == 'embedding'" and
"assert config['embedding_dim'] == 128" remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@src/dataset.py`:
- Line 224: The torch.load call that assigns slae_cached uses weights_only=False
which permits unsafe deserialization; update the torch.load invocations (the one
assigning slae_cached and the other two occurrences referenced at lines ~265 and
~1154) to pass weights_only=True so only tensor/primitive data is loaded safely,
leaving all other arguments unchanged and keeping the same variable names (e.g.,
slae_cached) and surrounding logic.
- Around line 1257-1259: The collate_fn currently uses an inline lambda
(collate_fn=lambda batch: Batch.from_data_list(batch)) which is not picklable
for multiprocessing on spawn-based platforms; create a module-level function
named _collate_heterodata(batch: list[HeteroData]) -> Batch that simply returns
Batch.from_data_list(batch) and replace the lambda in the DataLoader call with
collate_fn=_collate_heterodata so the function can be pickled across processes.

---

Nitpick comments:
In `@tests/test_train_config.py`:
- Around line 75-77: Remove the redundant existence check for embedding_dim in
the test: the assertion "assert 'embedding_dim' in config" is unnecessary
because the prior assertion "assert config['embedding_dim'] == 128" already
verifies the key exists and has the expected value; update the test in
tests/test_train_config.py by deleting the redundant assertion so only "assert
config['embedding_key'] == 'embedding'" and "assert config['embedding_dim'] ==
128" remain.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c899e27f-fb48-43fe-8ea8-34ff2e093578

📥 Commits

Reviewing files that changed from the base of the PR and between a557745 and f60d8f6.

📒 Files selected for processing (4)
  • scripts/inference.py
  • scripts/train.py
  • src/dataset.py
  • tests/test_train_config.py

Copilot AI review requested due to automatic review settings March 23, 2026 19:37
@vratins vratins review requested due to automatic review settings March 23, 2026 19:37
@vratins vratins merged commit 8dcfc69 into main Mar 23, 2026
This was referenced Mar 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

4 participants