-
Notifications
You must be signed in to change notification settings - Fork 8
/
bgd_optimizer.py
112 lines (104 loc) · 5.15 KB
/
bgd_optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from torch.optim.optimizer import Optimizer
class BGD(Optimizer):
"""Implements BGD.
A simple usage of BGD would be:
for samples, labels in batches:
for mc_iter in range(mc_iters):
optimizer.randomize_weights()
output = model.forward(samples)
loss = cirterion(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.aggregate_grads()
optimizer.step()
"""
def __init__(self, params, std_init, mean_eta=1, mc_iters=10):
"""
Initialization of BGD optimizer
group["mean_param"] is the learned mean.
group["std_param"] is the learned STD.
:param params: List of model parameters
:param std_init: Initialization value for STD parameter
:param mean_eta: Eta value
:param mc_iters: Number of Monte Carlo iteration. Used for correctness check.
Use None to disable the check.
"""
super(BGD, self).__init__(params, defaults={})
assert mc_iters is None or (type(mc_iters) == int and mc_iters > 0), "mc_iters should be positive int or None."
self.std_init = std_init
self.mean_eta = mean_eta
self.mc_iters = mc_iters
# Initialize mu (mean_param) and sigma (std_param)
for group in self.param_groups:
assert len(group["params"]) == 1, "BGD optimizer does not support multiple params in a group"
# group['params'][0] is the weights
assert isinstance(group["params"][0], torch.Tensor), "BGD expect param to be a tensor"
# We use the initialization of weights to initialize the mean.
group["mean_param"] = group["params"][0].data.clone()
group["std_param"] = torch.zeros_like(group["params"][0].data).add_(self.std_init)
self._init_accumulators()
def get_mc_iters(self):
return self.mc_iters
def _init_accumulators(self):
self.mc_iters_taken = 0
for group in self.param_groups:
group["eps"] = None
group["grad_mul_eps_sum"] = torch.zeros_like(group["params"][0].data)
group["grad_sum"] = torch.zeros_like(group["params"][0].data)
def randomize_weights(self, force_std=-1):
"""
Randomize the weights according to N(mean, std).
:param force_std: If force_std>=0 then force_std is used for STD instead of the learned STD.
:return: None
"""
for group in self.param_groups:
mean = group["mean_param"]
std = group["std_param"]
if force_std >= 0:
std = std.mul(0).add(force_std)
group["eps"] = torch.normal(torch.zeros_like(mean), 1)
# Reparameterization trick (here we set the weights to their randomized value):
group["params"][0].data.copy_(mean.add(std.mul(group["eps"])))
def aggregate_grads(self, batch_size):
"""
Aggregates a single Monte Carlo iteration gradients. Used in step() for the expectations calculations.
optimizer.zero_grad() should be used before calling .backward() once again.
:param batch_size: BGD is using non-normalized gradients, but PyTorch gives normalized gradients.
Therefore, we multiply the gradients by the batch size.
:return: None
"""
self.mc_iters_taken += 1
groups_cnt = 0
for group in self.param_groups:
if group["params"][0].grad is None:
continue
assert group["eps"] is not None, "Must randomize weights before using aggregate_grads"
groups_cnt += 1
grad = group["params"][0].grad.data.mul(batch_size)
group["grad_sum"].add_(grad)
group["grad_mul_eps_sum"].add_(grad.mul(group["eps"]))
group["eps"] = None
assert groups_cnt > 0, "Called aggregate_grads, but all gradients were None. Make sure you called .backward()"
def step(self, closure=None):
"""
Updates the learned mean and STD.
:return:
"""
# Makes sure that self.mc_iters had been taken.
assert self.mc_iters is None or self.mc_iters == self.mc_iters_taken, "MC iters is set to " \
+ str(self.mc_iters) \
+ ", but took " + \
str(self.mc_iters_taken) + " MC iters"
for group in self.param_groups:
mean = group["mean_param"]
std = group["std_param"]
# Divide gradients by MC iters to get expectation
e_grad = group["grad_sum"].div(self.mc_iters_taken)
e_grad_eps = group["grad_mul_eps_sum"].div(self.mc_iters_taken)
# Update mean and STD params
mean.add_(-std.pow(2).mul(e_grad).mul(self.mean_eta))
sqrt_term = torch.sqrt(e_grad_eps.mul(std).div(2).pow(2).add(1)).mul(std)
std.copy_(sqrt_term.add(-e_grad_eps.mul(std.pow(2)).div(2)))
self.randomize_weights(force_std=0)
self._init_accumulators()