/
norm_constraint.py
161 lines (132 loc) · 5.13 KB
/
norm_constraint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
Classes for constraining the norms of weight matrices.
"""
import numpy as np
import warnings
from theano import tensor as T
from pylearn2.model_extensions.model_extension import ModelExtension
from pylearn2.utils import wraps
class ConstrainFilterL2Norm(ModelExtension):
"""
Constrains the maximum L2 norm (not squared L2) of a
weight matrix.
Expects the weight matrix to either be the
sole parameter of the model's `transformer` field or to
be the model's `W` field.
Parameters
----------
limit : float or symbolic float
The maximum norm of the weight matrix is constrained
to be <= limit along the axes
min_limit : float or symbolic float
The minimum norm of the weight matrix is constrained
to be => limit along the axes
axis : int or tuple of int
The axis or axes over which the norm is computed. Default is 0.
"""
def __init__(self, limit, min_limit=0., axis=0):
self.max_limit = limit
self.min_limit = min_limit
if (limit is not None) and (min_limit is not None):
if limit < min_limit:
raise ValueError('The maximum limit must be higher than '
'the minimum limit.')
self.axis = axis
@wraps(ModelExtension.post_modify_updates)
def post_modify_updates(self, updates, model):
if hasattr(model, 'W'):
W = model.W
else:
if not hasattr(model, 'transformer'):
raise TypeError("model has neither 'W' nor 'transformer'.")
transformer = model.transformer
params = transformer.get_params()
if len(params) != 1:
raise TypeError("self.transformer does not have exactly one "
"parameter tensor.")
W, = params
if W in updates:
updated_W = updates[W]
l2_norms = T.sqrt(
T.square(updated_W).sum(
axis=self.axis, keepdims=True
)
)
if self.min_limit is None:
min_limit = 0.
else:
min_limit = self.min_limit
if self.max_limit is None:
max_limit = l2_norms.max()
else:
max_limit = self.max_limit
desired_norms = T.clip(l2_norms, min_limit, max_limit)
scale = desired_norms / T.maximum(1e-7, l2_norms)
updates[W] = updated_W * scale
class MaxL2FilterNorm(ConstrainFilterL2Norm):
"""
A copy of ConstrainFilterL2Norm, made to preserve the old class name.
This name is deprecated.
Parameters
----------
args : list
Passed on to the superclass.
kwargs : dict
Passed on to the superclass.
"""
def __init__(self, *args, **kwargs):
warnings.warn("MaxL2FilterNorm is deprecated and may be removed on or"
" after 2016-01-31. Use ConstrainFilterL2Norm.")
super(MaxL2FilterNorm, self).__init__(*args, **kwargs)
class ConstrainFilterMaxNorm(ModelExtension):
"""
Constrains the maximum max norm of a weight matrix.
Expects the weight matrix to either be the
sole parameter of the model's `transformer` field or to
be the model's `W` field.
Parameters
----------
limit : float or symbolic float
The maximum norm of the weight matrix is constrained
to be <= limit
min_limit : float or symbolic float
The minimum norm of the weight matrix is constrained
to be => limit
"""
def __init__(self, limit, min_limit=None):
self.max_limit = limit
self.min_limit = min_limit
if (limit is not None) and (min_limit is not None):
if limit < min_limit:
raise ValueError('The maximum limit must be higher than '
'the minimum limit.')
@wraps(ModelExtension.post_modify_updates)
def post_modify_updates(self, updates, model):
if hasattr(model, 'W'):
W = model.W
else:
if not hasattr(model, 'transformer'):
raise TypeError("model has neither 'W' nor 'transformer'.")
transformer = model.transformer
params = transformer.get_params()
if len(params) != 1:
raise TypeError("self.transformer does not have exactly one "
"parameter tensor.")
W, = params
if W in updates:
updated_W = updates[W]
if self.min_limit is None:
min_limit = 0.
else:
min_limit = self.min_limit
if self.max_limit is None:
max_limit = np.inf
else:
max_limit = self.max_limit
if self.min_limit is not None:
# This would be a pretty weird feature to want but I put
# the interface here for compatibility with the L2 norm
# constraint class.
raise NotImplementedError()
if self.max_limit is not None:
updates[W] = T.clip(updated_W, -self.max_limit, self.max_limit)