Load pretrained weights notebook and bug#770
Conversation
…eights. Added notebook, fixed bug, updated test.
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Click here to view all benchmarks. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #770 +/- ##
=======================================
Coverage 64.69% 64.69%
=======================================
Files 61 61
Lines 5894 5894
=======================================
Hits 3813 3813
Misses 2081 2081 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…ict change, added notebook to common workflows, new screenshot too.
There was a problem hiding this comment.
Pull request overview
This PR fixes a bug that prevented continued training with pretrained weights and adds a demonstration notebook for loading pretrained model weights.
Changes:
- Bug fix: Removed the
assign=Trueparameter fromload_state_dictcall inmodel_registry.py - Test update: Added epoch config update to
test_train_resumeto properly cover the resumed training scenario - New notebook: Added
load_pretrained_model.ipynbdemonstrating fine-tuning from saved weights, with a corresponding entry incommon_workflows.rst
Reviewed changes
Copilot reviewed 4 out of 5 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
src/hyrax/models/model_registry.py |
Core bug fix — removes assign=True from load_state_dict |
tests/hyrax/test_train.py |
Updates test to set epoch count before resuming |
docs/pre_executed/load_pretrained_model.ipynb |
New notebook demonstrating pretrained weight loading workflow |
docs/common_workflows.rst |
Registers the new notebook in the documentation table of contents |
You can also share your feedback on Copilot code review. Take the survey.
|
The fact that sphinx re-ran the notebook (despite being in the pre-executed directory) causes the overly verbose look of the notebook in the docs. I'll try to track down what is causing sphinx to run the notebook when it shouldn't - for now let's block the merge. Once the notebook looks bite-sized again, we can move ahead. |
aritraghsh09
left a comment
There was a problem hiding this comment.
One comment. Not blocking. But please take a look before merging.
| state = torch.load(load_path, weights_only=True, map_location=device) | ||
|
|
||
| self.load_state_dict(state, assign=True) | ||
| self.load_state_dict(state) |
There was a problem hiding this comment.
I just want to check that this won't cause problems when someone is trying to resume from a checkpoint (as opposed to running from a weights file)? If we haven't tested it yet, maybe we should test it on this branch?
There was a problem hiding this comment.
I double checked and was able to reproduce the "hyrax checkpointing" notebook. So I don't think we'll be introducing a regression with this change.
There was a problem hiding this comment.
Additionally, the code that resumes from a checkpoint doesn't use this codepath, so we should be double ok :)
Found a bug that prevented continued training with pretrained weights.
Added initial notebook, fixed bug, updated test.
Change Description
Closes #742
Solution Description
Code Quality