Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specifying extra args for default model in ModelConfig #1713

Merged
merged 1 commit into from Feb 20, 2023

Conversation

AdeelH
Copy link
Collaborator

@AdeelH AdeelH commented Feb 20, 2023

Overview

This PR adds an extra_args field to ModelConfig which accepts a dict. In each ModelConfig subclass, these args are passed as keyword args to the torchvision function used to construct the default model.

This is mainly motivated by the FasterCNN used in OD which accepts a large number of keyword args that can be important but are too many to be added as individual fields to ObjectDetectionModelConfig.

Checklist

  • Added needs-backport label if PR is bug fix that applies to previous minor release
  • Ran scripts/format_code and committed any changes
  • Documentation updated if needed
  • PR has a name that won't get you publicly shamed for vagueness

Notes

N/A

Testing Instructions

See new unit test.

@AdeelH AdeelH added the needs-backport This PR needs to be backported to release branches label Feb 20, 2023
@AdeelH AdeelH marked this pull request as ready for review February 20, 2023 15:43
@codecov
Copy link

codecov bot commented Feb 20, 2023

Codecov Report

Merging #1713 (9b5b493) into master (40afafc) will decrease coverage by 0.04%.
The diff coverage is 81.81%.

❗ Current head 9b5b493 differs from pull request most recent head a6981b3. Consider uploading reports for the commit a6981b3 to get more accurate results

@@            Coverage Diff             @@
##           master    #1713      +/-   ##
==========================================
- Coverage   75.74%   75.71%   -0.04%     
==========================================
  Files         192      192              
  Lines        9314     9309       -5     
==========================================
- Hits         7055     7048       -7     
- Misses       2259     2261       +2     
Impacted Files Coverage Δ
...pytorch_learner/object_detection_learner_config.py 84.78% <ø> (ø)
...ch_learner/semantic_segmentation_learner_config.py 90.00% <ø> (ø)
...ision/pytorch_learner/regression_learner_config.py 82.10% <71.42%> (-0.51%) ⬇️
...n/pytorch_learner/classification_learner_config.py 93.84% <100.00%> (+0.09%) ⬆️
...ner/rastervision/pytorch_learner/learner_config.py 82.16% <100.00%> (+0.03%) ⬆️
...data/vector_source/geojson_vector_source_config.py 85.71% <0.00%> (-3.76%) ⬇️
...n/core/data/vector_source/geojson_vector_source.py 86.66% <0.00%> (-3.34%) ⬇️
...ision_core/rastervision/core/data/utils/geojson.py 94.92% <0.00%> (-0.73%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@AdeelH AdeelH merged commit 979f666 into azavea:master Feb 20, 2023
@AdeelH AdeelH deleted the model_cfg_extra_args branch February 20, 2023 16:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-backport This PR needs to be backported to release branches
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant