forked from abahde/DeepOBS
/
testproblem.py
212 lines (175 loc) · 8.13 KB
/
testproblem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# -*- coding: utf-8 -*-
"""Base class for DeepOBS test problems."""
import torch
import abc
from .. import config
class TestProblem(abc.ABC):
"""Base class for DeepOBS test problems.
Args:
batch_size (int): Batch size to use.
weight_decay (float): Weight decay (L2-regularization) factor to use. If
not specified, the test problems revert to their respective defaults.
Note: Some test problems do not use regularization and this value will
be ignored in such a case.
Attributes:
_batch_size: Batch_size for the data of this test problem.
_weight_decay: The regularization factor for this test problem
data: The dataset used by the test problem (datasets.DataSet instance).
loss_function: The loss function for this test problem.
net: The torch module (the neural network) that is trained.
Methods:
train_init_op: Initializes the test problem for the
training phase.
train_eval_init_op: Initializes the test problem for
evaluating on training data.
test_init_op: Initializes the test problem for
evaluating on test data.
_get_next_batch: Returns the next batch of data of the current phase.
get_batch_loss_and_accuracy: Calculates the loss and accuracy of net on the next batch of the current phase.
set_up: Sets all public attributes.
"""
def __init__(self, batch_size, weight_decay=None):
"""Creates a new test problem instance.
Args:
batch_size (int): Batch size to use.
weight_decay (float): Weight decay (L2-regularization) factor to use. If
not specified, the test problems revert to their respective defaults.
Note: Some test problems do not use regularization and this value will
be ignored in such a case.
"""
self._batch_size = batch_size
self._weight_decay = weight_decay
self._device = torch.device(config.get_default_device())
self._batch_count = 0
# Public attributes by which to interact with test problems. These have to
# be created by the set_up function of sub-classes.
self.data = None
self.loss_function = None
self.net = None
self.regularization_groups = None
def train_init_op(self):
"""Initializes the testproblem instance to train mode. I.e.
sets the iterator to the training set and sets the model to train mode.
"""
self._iterator = iter(self.data._train_dataloader)
self.phase = "train"
self.net.train()
def train_eval_init_op(self):
"""Initializes the testproblem instance to train eval mode. I.e.
sets the iterator to the train evaluation set and sets the model to eval mode.
"""
self._iterator = iter(self.data._train_eval_dataloader)
self.phase = "train_eval"
self.net.eval()
def valid_init_op(self):
"""Initializes the testproblem instance to validation mode. I.e.
sets the iterator to the validation set and sets the model to eval mode.
"""
self._iterator = iter(self.data._valid_dataloader)
self.phase = "valid"
self.net.eval()
def test_init_op(self):
"""Initializes the testproblem instance to test mode. I.e.
sets the iterator to the test set and sets the model to eval mode.
"""
self._iterator = iter(self.data._test_dataloader)
self.phase = "test"
self.net.eval()
def _get_next_batch(self):
"""Returns the next batch from the iterator."""
self._batch_count += 1
return next(self._iterator)
def get_batch_loss_and_accuracy(self,
return_forward_func=False,
reduction='mean',
add_regularization_if_available=True,
evaluate_forward_func=True):
"""Gets a new batch and calculates the loss and accuracy (if available)
on that batch. This is a default implementation for image classification.
Testproblems with different calculation routines (e.g. RNNs) overwrite this method accordingly.
Args:
return_forward_func (bool): If ``True``, the call also returns a function that calculates the loss on the \
current batch. Can be used if you need to access the forward path twice.
reduction (str): The reduction that is used for returning the loss. Can be 'mean', 'sum' or 'none' in which \
case each indivual loss in the mini-batch is returned as a tensor.
Returns:
float/torch.tensor, float, (callable): loss and accuracy of the model on the current batch. \
If ``return_forward_func`` is ``True`` it also returns the function that calculates the loss on the current batch.
"""
inputs, labels = self._get_next_batch()
inputs = inputs.to(self._device)
labels = labels.to(self._device)
def _get_batch_loss_and_accuracy():
correct = 0.0
total = 0.0
# in evaluation phase is no gradient needed
if self.phase in ["train_eval", "test", "valid"]:
with torch.no_grad():
outputs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(outputs, labels)
else:
outputs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct/total
if add_regularization_if_available:
regularizer_loss = self.get_regularization_loss()
else:
regularizer_loss = torch.tensor(0.0, device=torch.device(self._device))
return loss + regularizer_loss, accuracy
if return_forward_func:
if evaluate_forward_func is True:
return _get_batch_loss_and_accuracy(), _get_batch_loss_and_accuracy
else:
return _get_batch_loss_and_accuracy
else:
return _get_batch_loss_and_accuracy()
def get_regularization_loss(self):
"""Returns the current regularization loss of the network based on the parameter groups.
Returns:
int or torch.tensor: If no regularzations is applied, it returns the integer 0. Else a torch.tensor \
that holds the regularization loss.
"""
# iterate through all layers
layer_norms = []
for regularization, parameter_group in self.regularization_groups.items():
if regularization > 0.0:
# L2 regularization
for parameters in parameter_group:
layer_norms.append(regularization * parameters.pow(2).sum())
regularization_loss = 0.5 * sum(layer_norms)
return regularization_loss
@abc.abstractmethod
def get_regularization_groups(self):
"""Creates regularization groups for the parameters.
Returns:
dict: A dictionary where the key is the regularization factor and the value is a list of parameters.
"""
return
@abc.abstractmethod
# TODO get rid of setup structure by parsing individual loss func, network and dataset
def set_up(self):
"""Sets up the test problem.
"""
pass
class UnregularizedTestproblem(TestProblem):
def __init__(self, batch_size, weight_decay = None):
super(UnregularizedTestproblem, self).__init__(batch_size, weight_decay)
def get_regularization_groups(self):
"""Creates regularization groups for the parameters.
Returns:
dict: A dictionary where the key is the regularization factor and the value is a list of parameters.
"""
no = 0.0
group_dict = {no: []}
for parameters_name, parameters in self.net.named_parameters():
# penalize no parameters
group_dict[no].append(parameters)
return group_dict
@abc.abstractmethod
def set_up(self):
"""Sets up the test problem.
"""
pass