New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SWA script added #945
SWA script added #945
Conversation
Hello @ivashnyov! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found: There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2020-10-13 21:04:26 UTC |
catalyst/dl/scripts/swa.py
Outdated
"--models_mask", "-m", type=str, help="Pattern for models to average" | ||
) | ||
parser.add_argument( | ||
"--save-avaraged-model", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
averaged?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this flag? this script should always produce averaged checkpoints, isn't it?
catalyst/dl/utils/swa.py
Outdated
|
||
|
||
def generate_averaged_weights( | ||
logdir: Path, models_mask: str, save_avaraged_model: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
save_avaraged_model
do we really need this flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and looks like we need some tests :)
catalyst/dl/utils/swa.py
Outdated
torch.save( | ||
averaged_dict, str(logdir / "checkpoints" / "swa_weights.pth") | ||
) | ||
torch.save(averaged_dict, str(logdir / "checkpoints" / "swa_weights.pth")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's return the model and save it outside
catalyst/dl/tests/test_swa.py
Outdated
|
||
from catalyst.dl.utils.swa import generate_averaged_weights | ||
|
||
sys.path.append(".") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still the same question :)
catalyst/dl/scripts/swa.py
Outdated
"""Builds the command line parameters.""" | ||
parser.add_argument("--logdir", type=Path, help="Path to models logdir") | ||
parser.add_argument( | ||
"--models_mask", "-m", type=str, help="Pattern for models to average" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's simplify to --models-mask
catalyst/dl/scripts/swa.py
Outdated
"--models_mask", "-m", type=str, help="Pattern for models to average" | ||
) | ||
parser.add_argument( | ||
"--save_path", type=Path, help="Path to save averaged model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and save-path
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, dont you mind to rename it to output-path
?
catalyst/dl/scripts/swa.py
Outdated
"--models-mask", | ||
"-m", | ||
type=str, | ||
default="train*", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤔 what do you think about *.pth
? I mean, in this case we will take into account all possible checkpoints
default="train*", | |
default="*.pth", |
@ivashnyov Now tests show the error during swa step. Could you please run the first example locally and debug it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test error on minimal requirements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ivashnyov could you please add your description to the Changelog?
catalyst/dl/utils/swa.py
Outdated
def average_weights(state_dicts: List[dict]) -> OrderedDict: | ||
""" | ||
Averaging of input weights. | ||
|
||
Args: | ||
state_dicts (List[dict]): Weights to average | ||
|
||
Returns: | ||
Averaged weights | ||
""" | ||
# source https://gist.github.com/qubvel/70c3d5e4cddcde731408f478e12ef87b | ||
|
||
average_dict = OrderedDict() | ||
for k in state_dicts[0].keys(): | ||
average_dict[k] = torch.div( | ||
sum(state_dict[k] for state_dict in state_dicts), len(state_dicts), | ||
) | ||
return average_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, what do you think about adding state dict key check mechanics (raise an error when key is missing)?
Probably it will be informative to know that something wrong with <checkpoint-file-name>.pth
.
For example you could check this - https://gist.github.com/Ditwoo/5de19670d9946c80916dee75e93ef545
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ditwoo your implementation also support fp16 variant, is it correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ivashnyov could you please check it out?
Pull request has been modified.
This pull request is now in conflicts. @ivashnyov, could you fix it? 🙏 |
catalyst/dl/utils/swa.py
Outdated
Averaging of input weights. | ||
|
||
Args: | ||
state_dicts (List[dict]): Weights to average |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please remove type from docs (the type already specified in arguments)
catalyst/dl/utils/swa.py
Outdated
Load weights of a model. | ||
|
||
Args: | ||
path (str): Path to model weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please remove type from docs (the type already specified in arguments)
catalyst/dl/utils/swa.py
Outdated
logdir (Path): Path to logs directory | ||
models_mask (str): globe-like pattern for models to average |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please remove type from docs (the type already specified in arguments)
also, correct me if I wrong but logdir
can be a string object and everything will work fine.
so, could you please specify type as Union[str, Path]
.
catalyst/dl/utils/swa.py
Outdated
Returns: | ||
Weights | ||
""" | ||
weights = torch.load(path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please reuse https://github.com/catalyst-team/catalyst/blob/master/catalyst/utils/checkpoint.py#L126
It has device handling :)
catalyst/dl/scripts/swa.py
Outdated
|
||
def build_args(parser: ArgumentParser): | ||
"""Builds the command line parameters.""" | ||
parser.add_argument("--logdir", type=Path, help="Path to models logdir") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we make --logdir
optional? None
by default
so you could use this script like
script --models-mask=/some/path/to/checkpoints*.pth
?
Pull request has been modified.
Before submitting
catalyst-make-codestyle && catalyst-check-codestyle
(pip install -U catalyst-codestyle
).make check-docs
?Description
Related Issue
Type of Change
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.