Skip to content

Load pretrained weights notebook and bug#770

Merged
drewoldag merged 3 commits intomainfrom
issue/742/load-pre-trained-weights
Mar 10, 2026
Merged

Load pretrained weights notebook and bug#770
drewoldag merged 3 commits intomainfrom
issue/742/load-pre-trained-weights

Conversation

@drewoldag
Copy link
Copy Markdown
Collaborator

@drewoldag drewoldag commented Mar 9, 2026

Found a bug that prevented continued training with pretrained weights.
Added initial notebook, fixed bug, updated test.

Change Description

Closes #742

Solution Description

Code Quality

  • I have read the Contribution Guide and agree to the Code of Conduct
  • My code follows the code style of this project
  • My code builds (or compiles) cleanly without any errors or warnings
  • My code contains relevant comments and necessary documentation

…eights. Added notebook, fixed bug, updated test.
@drewoldag drewoldag self-assigned this Mar 9, 2026
@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 10, 2026

Before [f0869f2] After [e701b0a] Ratio Benchmark (Parameter)
failed failed n/a data_cache_benchmarks.DataCacheBenchmarks.time_preload_cache_hsc1k
failed failed n/a data_cache_benchmarks.DataCacheBenchmarks.track_cache_hsc1k_hyrax_size_undercount
failed failed n/a data_request_benchmarks.DatasetRequestBenchmarks.time_request_all_data
38.3±0.1ms 39.5±0.6ms 1.03 benchmarks.time_nb_obj_construct
39.2±0.3ms 39.8±0.6ms 1.02 benchmarks.time_nb_obj_dir
3.74G 3.83G 1.02 vector_db_benchmarks.VectorDBInsertBenchmarks.peakmem_load_vector_db(16384, 'qdrant')
10.00±0.04ms 10.2±0.03ms 1.02 vector_db_benchmarks.VectorDBSearchBenchmarks.time_search_by_vector_many_shards(64, 'chromadb')
1.93±0.01s 1.95±0.01s 1.01 benchmarks.time_help
1.96±0.01s 1.98±0.01s 1.01 benchmarks.time_prepare_help
1.56G 1.57G 1.01 vector_db_benchmarks.VectorDBInsertBenchmarks.peakmem_load_vector_db(16384, 'chromadb')

Click here to view all benchmarks.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 10, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 64.69%. Comparing base (f0869f2) to head (cea1023).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #770   +/-   ##
=======================================
  Coverage   64.69%   64.69%           
=======================================
  Files          61       61           
  Lines        5894     5894           
=======================================
  Hits         3813     3813           
  Misses       2081     2081           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

…ict change, added notebook to common workflows, new screenshot too.
@drewoldag drewoldag marked this pull request as ready for review March 10, 2026 18:02
Copilot AI review requested due to automatic review settings March 10, 2026 18:02
@drewoldag drewoldag changed the title WIP - Load pretrained weights notebook and bug Load pretrained weights notebook and bug Mar 10, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a bug that prevented continued training with pretrained weights and adds a demonstration notebook for loading pretrained model weights.

Changes:

  • Bug fix: Removed the assign=True parameter from load_state_dict call in model_registry.py
  • Test update: Added epoch config update to test_train_resume to properly cover the resumed training scenario
  • New notebook: Added load_pretrained_model.ipynb demonstrating fine-tuning from saved weights, with a corresponding entry in common_workflows.rst

Reviewed changes

Copilot reviewed 4 out of 5 changed files in this pull request and generated 1 comment.

File Description
src/hyrax/models/model_registry.py Core bug fix — removes assign=True from load_state_dict
tests/hyrax/test_train.py Updates test to set epoch count before resuming
docs/pre_executed/load_pretrained_model.ipynb New notebook demonstrating pretrained weight loading workflow
docs/common_workflows.rst Registers the new notebook in the documentation table of contents

You can also share your feedback on Copilot code review. Take the survey.

Comment thread docs/pre_executed/load_pretrained_model.ipynb
@drewoldag
Copy link
Copy Markdown
Collaborator Author

The fact that sphinx re-ran the notebook (despite being in the pre-executed directory) causes the overly verbose look of the notebook in the docs.

I'll try to track down what is causing sphinx to run the notebook when it shouldn't - for now let's block the merge.

Once the notebook looks bite-sized again, we can move ahead.

Copy link
Copy Markdown
Collaborator

@aritraghsh09 aritraghsh09 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment. Not blocking. But please take a look before merging.

state = torch.load(load_path, weights_only=True, map_location=device)

self.load_state_dict(state, assign=True)
self.load_state_dict(state)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to check that this won't cause problems when someone is trying to resume from a checkpoint (as opposed to running from a weights file)? If we haven't tested it yet, maybe we should test it on this branch?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked and was able to reproduce the "hyrax checkpointing" notebook. So I don't think we'll be introducing a regression with this change.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, the code that resumes from a checkpoint doesn't use this codepath, so we should be double ok :)

@drewoldag drewoldag merged commit a83f3da into main Mar 10, 2026
9 of 10 checks passed
@drewoldag drewoldag deleted the issue/742/load-pre-trained-weights branch March 10, 2026 22:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

How to load a pretrained model

3 participants