-
Notifications
You must be signed in to change notification settings - Fork 9
/
utils_sgm.py
44 lines (33 loc) · 1.47 KB
/
utils_sgm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import torch
import torch.nn as nn
def backward_hook(gamma):
# implement SGM through grad through ReLU
def _backward_hook(module, grad_in, grad_out):
if isinstance(module, nn.ReLU):
return (gamma * grad_in[0],)
return _backward_hook
def backward_hook_norm(module, grad_in, grad_out):
# normalize the gradient to avoid gradient explosion or vanish
std = torch.std(grad_in[0])
return (grad_in[0] / std,)
def register_hook_for_resnet(model, arch, gamma):
# There is only 1 ReLU in Conv module of ResNet-18/34
# and 2 ReLU in Conv module ResNet-50/101/152
if arch in ['resnet50', 'resnet101', 'resnet152']:
gamma = np.power(gamma, 0.5)
backward_hook_sgm = backward_hook(gamma)
for name, module in model.named_modules():
if 'relu' in name and not '0.relu' in name:
module.register_backward_hook(backward_hook_sgm)
# e.g., 1.layer1.1, 1.layer4.2, ...
# if len(name.split('.')) == 3:
if len(name.split('.')) >= 2 and 'layer' in name.split('.')[-2]:
module.register_backward_hook(backward_hook_norm)
def register_hook_for_densenet(model, arch, gamma):
# There are 2 ReLU in Conv module of DenseNet-121/169/201.
gamma = np.power(gamma, 0.5)
backward_hook_sgm = backward_hook(gamma)
for name, module in model.named_modules():
if 'relu' in name and not 'transition' in name:
module.register_backward_hook(backward_hook_sgm)