New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SWA script added #945
SWA script added #945
Changes from 14 commits
a9af7cd
27828ff
1f770dc
78ff817
8260b0a
7ab43ab
504f6be
6482869
89d079c
4b0ce99
a4e07c5
5b00ca3
95034ca
44f750a
17a0ff1
669bbb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import argparse | ||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from catalyst.dl.utils.swa import generate_averaged_weights | ||
|
||
|
||
def build_args(parser: ArgumentParser): | ||
"""Builds the command line parameters.""" | ||
parser.add_argument("--logdir", type=Path, help="Path to models logdir") | ||
parser.add_argument( | ||
"--models-mask", | ||
"-m", | ||
type=str, | ||
default="*.pth", | ||
help="Pattern for models to average", | ||
) | ||
parser.add_argument( | ||
"--output-path", | ||
type=Path, | ||
default="./swa.pth", | ||
help="Path to save averaged model", | ||
) | ||
|
||
return parser | ||
|
||
|
||
def parse_args(): | ||
"""Parses the command line arguments for the main method.""" | ||
parser = argparse.ArgumentParser() | ||
build_args(parser) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(args, _): | ||
"""Main method for ``catalyst-dl swa``.""" | ||
logdir: Path = args.logdir | ||
models_mask: str = args.models_mask | ||
output_path: Path = args.output_path | ||
|
||
averaged_weights = generate_averaged_weights(logdir, models_mask) | ||
Scitator marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
torch.save(averaged_weights, str(output_path)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main(parse_args(), None) | ||
Scitator marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import os | ||
from pathlib import Path | ||
import shutil | ||
import unittest | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from catalyst.dl.utils.swa import generate_averaged_weights | ||
|
||
|
||
class Net(nn.Module): | ||
"""Dummy network class.""" | ||
|
||
def __init__(self, init_weight=4): | ||
"""Initialization of network and filling it with given numbers.""" | ||
super(Net, self).__init__() | ||
self.fc = nn.Linear(2, 1) | ||
self.fc.weight.data.fill_(init_weight) | ||
self.fc.bias.data.fill_(init_weight) | ||
|
||
|
||
class TestSwa(unittest.TestCase): | ||
"""Test SWA class.""" | ||
|
||
def setUp(self): | ||
"""Test set up.""" | ||
net1 = Net(init_weight=2.0) | ||
net2 = Net(init_weight=5.0) | ||
os.mkdir("./checkpoints") | ||
torch.save(net1.state_dict(), "./checkpoints/net1.pth") | ||
torch.save(net2.state_dict(), "./checkpoints/net2.pth") | ||
|
||
def tearDown(self): | ||
"""Test tear down.""" | ||
shutil.rmtree("./checkpoints") | ||
|
||
def test_averaging(self): | ||
"""Test SWA method.""" | ||
weights = generate_averaged_weights( | ||
logdir=Path("./"), models_mask="net*" | ||
) | ||
torch.save(weights, str("./checkpoints/swa_weights.pth")) | ||
model = Net() | ||
model.load_state_dict(torch.load("./checkpoints/swa_weights.pth")) | ||
|
||
self.assertEqual(float(model.fc.weight.data[0][0]), 3.5) | ||
self.assertEqual(float(model.fc.weight.data[0][1]), 3.5) | ||
self.assertEqual(float(model.fc.bias.data[0]), 3.5) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import List | ||
from collections import OrderedDict | ||
import glob | ||
import os | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
|
||
def average_weights(state_dicts: List[dict]) -> OrderedDict: | ||
""" | ||
Averaging of input weights. | ||
|
||
Args: | ||
state_dicts (List[dict]): Weights to average | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please remove type from docs (the type already specified in arguments) |
||
|
||
Raises: | ||
KeyError: If states do not match | ||
|
||
Returns: | ||
Averaged weights | ||
""" | ||
# source https://gist.github.com/qubvel/70c3d5e4cddcde731408f478e12ef87b | ||
params_keys = None | ||
for i, state_dict in enumerate(state_dicts): | ||
model_params_keys = list(state_dict.keys()) | ||
if params_keys is None: | ||
params_keys = model_params_keys | ||
elif params_keys != model_params_keys: | ||
raise KeyError( | ||
"For checkpoint {}, expected list of params: {}, " | ||
"but found: {}".format(i, params_keys, model_params_keys) | ||
) | ||
|
||
average_dict = OrderedDict() | ||
for k in state_dicts[0].keys(): | ||
average_dict[k] = torch.div( | ||
sum(state_dict[k] for state_dict in state_dicts), len(state_dicts), | ||
) | ||
return average_dict | ||
|
||
|
||
def load_weight(path: str) -> dict: | ||
""" | ||
Load weights of a model. | ||
|
||
Args: | ||
path (str): Path to model weights | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please remove type from docs (the type already specified in arguments) |
||
|
||
Returns: | ||
Weights | ||
""" | ||
weights = torch.load(path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please reuse https://github.com/catalyst-team/catalyst/blob/master/catalyst/utils/checkpoint.py#L126 |
||
if "model_state_dict" in weights: | ||
weights = weights["model_state_dict"] | ||
return weights | ||
|
||
|
||
def generate_averaged_weights(logdir: Path, models_mask: str) -> OrderedDict: | ||
""" | ||
Averaging of input weights and saving them. | ||
|
||
Args: | ||
logdir (Path): Path to logs directory | ||
models_mask (str): globe-like pattern for models to average | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please remove type from docs (the type already specified in arguments) also, correct me if I wrong but |
||
|
||
Returns: | ||
Averaged weights | ||
""" | ||
models_pathes = glob.glob(os.path.join(logdir, "checkpoints", models_mask)) | ||
|
||
all_weights = [load_weight(path) for path in models_pathes] | ||
averaged_dict = average_weights(all_weights) | ||
|
||
return averaged_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we make
--logdir
optional?None
by defaultso you could use this script like
?