-
Notifications
You must be signed in to change notification settings - Fork 12
/
conv_type.py
223 lines (175 loc) · 7.11 KB
/
conv_type.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import math
from args import args as parser_args
import pdb
DenseConv = nn.Conv2d
class GetSubnet(autograd.Function):
@staticmethod
def forward(ctx, scores, k):
# Get the subnetwork by sorting the scores and using the top k%
out = scores.clone()
_, idx = scores.flatten().sort()
j = int((1 - k) * scores.numel())
# flat_out and out access the same memory
flat_out = out.flatten()
flat_out[idx[:j]] = 0
flat_out[idx[j:]] = 1
return out
@staticmethod
def backward(ctx, g):
# send the gradient g straight-through on the backward pass.
return g, None
class GetQuantnet_binary_old(autograd.Function):
@staticmethod
def forward(ctx, scores, weights, k):
# Get the subnetwork by sorting the scores and using the top k%
out = scores.clone()
_, idx = scores.flatten().sort()
j = int((1 - k) * scores.numel())
# flat_out and out access the same memory. switched 0 and 1
flat_out = out.flatten()
flat_out[idx[:j]] = 0
flat_out[idx[j:]] = 1
## Perform binary quantization of weights
abs_wgt = torch.abs(weights.clone()) # Absolute value of original weights
q_weight = abs_wgt * out # Remove pruned weights
num_unpruned = int(k * scores.numel()) # Number of unpruned weights
alpha = torch.sum(q_weight) / num_unpruned # Compute alpha = || q_weight ||_1 / (number of unpruned weights)
# Reset q_weight because we don't want to divide by zero
q_weight = 1 / abs_wgt # Take reciprocal of absolute value of weights
q_weight = q_weight * out # Remove pruned weights
q_weight = alpha * q_weight # Multiply each element of q_weight by alpha
return q_weight
@staticmethod
def backward(ctx, g):
# send the gradient g straight-through on the backward pass.
return g, None, None
class GetQuantnet_binary(autograd.Function):
@staticmethod
def forward(ctx, scores, weights, k):
# Get the subnetwork by sorting the scores and using the top k%
out = scores.clone()
_, idx = scores.flatten().sort()
j = int((1 - k) * scores.numel())
# flat_out and out access the same memory. switched 0 and 1
flat_out = out.flatten()
flat_out[idx[:j]] = 0
flat_out[idx[j:]] = 1
# Perform binary quantization of weights
abs_wgt = torch.abs(weights.clone()) # Absolute value of original weights
q_weight = abs_wgt * out # Remove pruned weights
num_unpruned = int(k * scores.numel()) # Number of unpruned weights
alpha = torch.sum(q_weight) / num_unpruned # Compute alpha = || q_weight ||_1 / (number of unpruned weights)
# Save absolute value of weights for backward
ctx.save_for_backward(abs_wgt)
# Return pruning mask with gain term alpha for binary weights
return alpha * out
@staticmethod
def backward(ctx, g):
# Get absolute value of weights from saved ctx
abs_wgt, = ctx.saved_tensors
# send the gradient g times abs_wgt on the backward pass
return g * abs_wgt, None, None
class SubnetConv(nn.Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))
# print ("subnet conv init: ", torch.isnan(self.scores).any())
def set_prune_rate(self, prune_rate):
self.prune_rate = prune_rate
@property
def clamped_scores(self):
# For unquantized activations
return self.scores.abs()
def forward(self, x):
# For debugging gradients, prints out maximum value in gradients
if parser_args.debug:
if quantnet.grad: print ("subnetconv fwd quantnet grad ", torch.max(quantnet.grad))
# Get binary mask and gain term for subnetwork
quantnet = GetQuantnet_binary.apply(self.clamped_scores, self.weight, self.prune_rate)
# Binarize weights by taking sign, multiply by pruning mask and gain term (alpha)
w = torch.sign(self.weight) * quantnet
# Pass binary subnetwork weights to convolution layer
x = F.conv2d(
x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
)
# Return output from convolution layer
return x
"""
Sample Based Sparsification
"""
class StraightThroughBinomialSample(autograd.Function):
@staticmethod
def forward(ctx, scores):
output = (torch.rand_like(scores) < scores).float()
return output
@staticmethod
def backward(ctx, grad_outputs):
return grad_outputs, None
class BinomialSample(autograd.Function):
@staticmethod
def forward(ctx, scores):
output = (torch.rand_like(scores) < scores).float()
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_outputs):
subnet, = ctx.saved_variables
grad_inputs = grad_outputs.clone()
grad_inputs[subnet == 0.0] = 0.0
return grad_inputs, None
# Not learning weights, finding subnet
class SampleSubnetConv(nn.Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
if parser_args.score_init_constant is not None:
self.scores.data = (
torch.ones_like(self.scores) * parser_args.score_init_constant
)
else:
nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))
@property
def clamped_scores(self):
return torch.sigmoid(self.scores)
def forward(self, x):
subnet = StraightThroughBinomialSample.apply(self.clamped_scores)
w = self.weight * subnet
x = F.conv2d(
x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
)
return x
"""
Fixed subnets
"""
class FixedSubnetConv(nn.Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))
def set_prune_rate(self, prune_rate):
self.prune_rate = prune_rate
print("prune_rate_{}".format(self.prune_rate))
def set_subnet(self):
output = self.clamped_scores().clone()
_, idx = self.clamped_scores().flatten().abs().sort()
p = int(self.prune_rate * self.clamped_scores().numel())
flat_oup = output.flatten()
flat_oup[idx[:p]] = 0
flat_oup[idx[p:]] = 1
self.scores = torch.nn.Parameter(output)
self.scores.requires_grad = False
def clamped_scores(self):
return self.scores.abs()
def get_subnet(self):
return self.weight * self.scores
def forward(self, x):
w = self.get_subnet()
x = F.conv2d(
x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
)
return x