Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SWA script added #945

Merged
merged 16 commits into from Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- SoftMax, CosFace, ArcFace layers to contrib ([#939](https://github.com/catalyst-team/catalyst/pull/939))
- ArcMargin layer to contrib ([#957](https://github.com/catalyst-team/catalyst/pull/957))
- AdaCos to contrib ([#958](https://github.com/catalyst-team/catalyst/pull/958))
- Manual SWA to utils (https://github.com/catalyst-team/catalyst/pull/945)

### Changed

Expand Down Expand Up @@ -305,4 +306,4 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
-
5 changes: 5 additions & 0 deletions bin/tests/check_dl_cv.sh
Expand Up @@ -43,6 +43,11 @@ PYTHONPATH=./examples:.:${PYTHONPATH} \
python catalyst/dl/scripts/trace.py \
${LOGDIR}

echo 'pipeline 01 - swa'
PYTHONPATH=./examples:.:${PYTHONPATH} \
python catalyst/dl/scripts/swa.py \
--logdir=${LOGDIR} --output-path=./swa.pth

rm -rf ${LOGDIR}


Expand Down
50 changes: 50 additions & 0 deletions catalyst/dl/scripts/swa.py
@@ -0,0 +1,50 @@
import argparse
from argparse import ArgumentParser
from pathlib import Path

import torch

from catalyst.dl.utils.swa import generate_averaged_weights


def build_args(parser: ArgumentParser):
"""Builds the command line parameters."""
parser.add_argument("--logdir", type=Path, help="Path to models logdir")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make --logdir optional? None by default
so you could use this script like

script --models-mask=/some/path/to/checkpoints*.pth

?

parser.add_argument(
"--models-mask",
"-m",
type=str,
default="*.pth",
help="Pattern for models to average",
)
parser.add_argument(
"--output-path",
type=Path,
default="./swa.pth",
help="Path to save averaged model",
)

return parser


def parse_args():
"""Parses the command line arguments for the main method."""
parser = argparse.ArgumentParser()
build_args(parser)
args = parser.parse_args()
return args


def main(args, _):
"""Main method for ``catalyst-dl swa``."""
logdir: Path = args.logdir
models_mask: str = args.models_mask
output_path: Path = args.output_path

averaged_weights = generate_averaged_weights(logdir, models_mask)
Scitator marked this conversation as resolved.
Show resolved Hide resolved

torch.save(averaged_weights, str(output_path))


if __name__ == "__main__":
main(parse_args(), None)
Scitator marked this conversation as resolved.
Show resolved Hide resolved
53 changes: 53 additions & 0 deletions catalyst/dl/tests/test_swa.py
@@ -0,0 +1,53 @@
import os
from pathlib import Path
import shutil
import unittest

import torch
import torch.nn as nn

from catalyst.dl.utils.swa import generate_averaged_weights


class Net(nn.Module):
"""Dummy network class."""

def __init__(self, init_weight=4):
"""Initialization of network and filling it with given numbers."""
super(Net, self).__init__()
self.fc = nn.Linear(2, 1)
self.fc.weight.data.fill_(init_weight)
self.fc.bias.data.fill_(init_weight)


class TestSwa(unittest.TestCase):
"""Test SWA class."""

def setUp(self):
"""Test set up."""
net1 = Net(init_weight=2.0)
net2 = Net(init_weight=5.0)
os.mkdir("./checkpoints")
torch.save(net1.state_dict(), "./checkpoints/net1.pth")
torch.save(net2.state_dict(), "./checkpoints/net2.pth")

def tearDown(self):
"""Test tear down."""
shutil.rmtree("./checkpoints")

def test_averaging(self):
"""Test SWA method."""
weights = generate_averaged_weights(
logdir=Path("./"), models_mask="net*"
)
torch.save(weights, str("./checkpoints/swa_weights.pth"))
model = Net()
model.load_state_dict(torch.load("./checkpoints/swa_weights.pth"))

self.assertEqual(float(model.fc.weight.data[0][0]), 3.5)
self.assertEqual(float(model.fc.weight.data[0][1]), 3.5)
self.assertEqual(float(model.fc.bias.data[0]), 3.5)


if __name__ == "__main__":
unittest.main()
75 changes: 75 additions & 0 deletions catalyst/dl/utils/swa.py
@@ -0,0 +1,75 @@
from typing import List
from collections import OrderedDict
import glob
import os
from pathlib import Path

import torch


def average_weights(state_dicts: List[dict]) -> OrderedDict:
"""
Averaging of input weights.

Args:
state_dicts (List[dict]): Weights to average
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please remove type from docs (the type already specified in arguments)


Raises:
KeyError: If states do not match

Returns:
Averaged weights
"""
# source https://gist.github.com/qubvel/70c3d5e4cddcde731408f478e12ef87b
params_keys = None
for i, state_dict in enumerate(state_dicts):
model_params_keys = list(state_dict.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
"For checkpoint {}, expected list of params: {}, "
"but found: {}".format(i, params_keys, model_params_keys)
)

average_dict = OrderedDict()
for k in state_dicts[0].keys():
average_dict[k] = torch.div(
sum(state_dict[k] for state_dict in state_dicts), len(state_dicts),
)
return average_dict


def load_weight(path: str) -> dict:
"""
Load weights of a model.

Args:
path (str): Path to model weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please remove type from docs (the type already specified in arguments)


Returns:
Weights
"""
weights = torch.load(path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if "model_state_dict" in weights:
weights = weights["model_state_dict"]
return weights


def generate_averaged_weights(logdir: Path, models_mask: str) -> OrderedDict:
"""
Averaging of input weights and saving them.

Args:
logdir (Path): Path to logs directory
models_mask (str): globe-like pattern for models to average
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please remove type from docs (the type already specified in arguments)

also, correct me if I wrong but logdir can be a string object and everything will work fine.
so, could you please specify type as Union[str, Path].


Returns:
Averaged weights
"""
models_pathes = glob.glob(os.path.join(logdir, "checkpoints", models_mask))

all_weights = [load_weight(path) for path in models_pathes]
averaged_dict = average_weights(all_weights)

return averaged_dict