forked from catalyst-team/catalyst
-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_swa.py
53 lines (41 loc) · 1.52 KB
/
test_swa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
from pathlib import Path
import shutil
import unittest
import torch
import torch.nn as nn
from catalyst.dl.utils.swa import generate_averaged_weights
from catalyst.utils.checkpoint import load_checkpoint
class Net(nn.Module):
"""Dummy network class."""
def __init__(self, init_weight=4):
"""Initialization of network and filling it with given numbers."""
super(Net, self).__init__()
self.fc = nn.Linear(2, 1)
self.fc.weight.data.fill_(init_weight)
self.fc.bias.data.fill_(init_weight)
class TestSwa(unittest.TestCase):
"""Test SWA class."""
def setUp(self):
"""Test set up."""
net1 = Net(init_weight=2.0)
net2 = Net(init_weight=5.0)
os.mkdir("./checkpoints")
torch.save(net1.state_dict(), "./checkpoints/net1.pth")
torch.save(net2.state_dict(), "./checkpoints/net2.pth")
def tearDown(self):
"""Test tear down."""
shutil.rmtree("./checkpoints")
def test_averaging(self):
"""Test SWA method."""
weights = generate_averaged_weights(
logdir=Path("./"), models_mask="net*"
)
torch.save(weights, str("./checkpoints/swa_weights.pth"))
model = Net()
model.load_state_dict(load_checkpoint("./checkpoints/swa_weights.pth"))
self.assertEqual(float(model.fc.weight.data[0][0]), 3.5)
self.assertEqual(float(model.fc.weight.data[0][1]), 3.5)
self.assertEqual(float(model.fc.bias.data[0]), 3.5)
if __name__ == "__main__":
unittest.main()