Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SWA script added #945

Merged
merged 16 commits into from Oct 14, 2020
Merged

SWA script added #945

merged 16 commits into from Oct 14, 2020

Conversation

ivashnyov
Copy link
Contributor

@ivashnyov ivashnyov commented Sep 27, 2020

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contribution guide?
  • Did you check the code style? catalyst-make-codestyle && catalyst-check-codestyle (pip install -U catalyst-codestyle).
  • Did you make sure to update the docs? We use Google format for all the methods and classes.
  • Did you check the docs with make check-docs?
  • Did you write any new necessary tests?
  • Did you add your new functionality to the docs?
  • Did you update the CHANGELOG?
  • You can use 'Login as guest' to see Teamcity build logs.

Description

Related Issue

Type of Change

  • Examples / docs / tutorials / contributors update
  • Bug fix (non-breaking change which fixes an issue)
  • Improvement (non-breaking change which improves an existing feature)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@pep8speaks
Copy link

pep8speaks commented Sep 27, 2020

Hello @ivashnyov! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2020-10-13 21:04:26 UTC

"--models_mask", "-m", type=str, help="Pattern for models to average"
)
parser.add_argument(
"--save-avaraged-model",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

averaged?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this flag? this script should always produce averaged checkpoints, isn't it?



def generate_averaged_weights(
logdir: Path, models_mask: str, save_avaraged_model: bool = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_avaraged_model do we really need this flag?

Copy link
Member

@Scitator Scitator left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and looks like we need some tests :)

torch.save(
averaged_dict, str(logdir / "checkpoints" / "swa_weights.pth")
)
torch.save(averaged_dict, str(logdir / "checkpoints" / "swa_weights.pth"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's return the model and save it outside

catalyst/dl/scripts/swa.py Show resolved Hide resolved

from catalyst.dl.utils.swa import generate_averaged_weights

sys.path.append(".")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still the same question :)

"""Builds the command line parameters."""
parser.add_argument("--logdir", type=Path, help="Path to models logdir")
parser.add_argument(
"--models_mask", "-m", type=str, help="Pattern for models to average"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's simplify to --models-mask

"--models_mask", "-m", type=str, help="Pattern for models to average"
)
parser.add_argument(
"--save_path", type=Path, help="Path to save averaged model"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and save-path

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, dont you mind to rename it to output-path?

catalyst/dl/scripts/swa.py Show resolved Hide resolved
catalyst/dl/scripts/swa.py Outdated Show resolved Hide resolved
"--models-mask",
"-m",
type=str,
default="train*",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 what do you think about *.pth? I mean, in this case we will take into account all possible checkpoints

Suggested change
default="train*",
default="*.pth",

@Scitator
Copy link
Member

Scitator commented Oct 7, 2020

@ivashnyov Now tests show the error during swa step. Could you please run the first example locally and debug it?

Copy link
Member

@Scitator Scitator left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test error on minimal requirements

@mergify mergify bot dismissed Scitator’s stale review October 7, 2020 20:15

Pull request has been modified.

Scitator
Scitator previously approved these changes Oct 8, 2020
Copy link
Member

@Scitator Scitator left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ivashnyov could you please add your description to the Changelog?

Comment on lines 10 to 27
def average_weights(state_dicts: List[dict]) -> OrderedDict:
"""
Averaging of input weights.

Args:
state_dicts (List[dict]): Weights to average

Returns:
Averaged weights
"""
# source https://gist.github.com/qubvel/70c3d5e4cddcde731408f478e12ef87b

average_dict = OrderedDict()
for k in state_dicts[0].keys():
average_dict[k] = torch.div(
sum(state_dict[k] for state_dict in state_dicts), len(state_dicts),
)
return average_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, what do you think about adding state dict key check mechanics (raise an error when key is missing)?
Probably it will be informative to know that something wrong with <checkpoint-file-name>.pth.

For example you could check this - https://gist.github.com/Ditwoo/5de19670d9946c80916dee75e93ef545

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ditwoo your implementation also support fp16 variant, is it correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ivashnyov could you please check it out?

@mergify mergify bot dismissed Scitator’s stale review October 11, 2020 21:04

Pull request has been modified.

@mergify
Copy link

mergify bot commented Oct 11, 2020

This pull request is now in conflicts. @ivashnyov, could you fix it? 🙏

Averaging of input weights.

Args:
state_dicts (List[dict]): Weights to average
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please remove type from docs (the type already specified in arguments)

Load weights of a model.

Args:
path (str): Path to model weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please remove type from docs (the type already specified in arguments)

Comment on lines 64 to 65
logdir (Path): Path to logs directory
models_mask (str): globe-like pattern for models to average
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please remove type from docs (the type already specified in arguments)

also, correct me if I wrong but logdir can be a string object and everything will work fine.
so, could you please specify type as Union[str, Path].

ditwoo
ditwoo previously approved these changes Oct 13, 2020
Returns:
Weights
"""
weights = torch.load(path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def build_args(parser: ArgumentParser):
"""Builds the command line parameters."""
parser.add_argument("--logdir", type=Path, help="Path to models logdir")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make --logdir optional? None by default
so you could use this script like

script --models-mask=/some/path/to/checkpoints*.pth

?

@mergify mergify bot dismissed stale reviews from Scitator and ditwoo October 13, 2020 21:05

Pull request has been modified.

@Scitator Scitator merged commit f2acebb into catalyst-team:master Oct 14, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants