/
sgrhmc.py
99 lines (78 loc) · 2.95 KB
/
sgrhmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import math
import torch
from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.samplers.base import Sampler
from ..utils import copy_or_set_
__all__ = ["SGRHMC"]
class SGRHMC(Sampler):
r"""
Stochastic Gradient Riemannian Hamiltonian Monte-Carlo.
Parameters
----------
params : iterable
iterables of tensors for which to perform sampling
epsilon : float
step size
n_steps : int
number of leapfrog steps
alpha : float
:math:`(1 - alpha)` -- momentum term
"""
def __init__(self, params, epsilon=1e-3, n_steps=1, alpha=0.1):
defaults = dict(epsilon=epsilon, alpha=alpha)
super().__init__(params, defaults)
self.n_steps = n_steps
def step(self, closure):
H_old = 0.0
H_new = 0.0
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
if "v" not in state:
state["v"] = torch.zeros_like(p)
epsilon = group["epsilon"]
v = state["v"]
v.normal_().mul_(epsilon)
r = v / epsilon
H_old += 0.5 * (r * r).sum().item()
logp = float("nan")
for i in range(self.n_steps + 1):
logp = closure()
logp.backward()
logp = logp.item()
with torch.no_grad():
for group in self.param_groups:
for p in group["params"]:
if isinstance(p, (ManifoldParameter, ManifoldTensor)):
manifold = p.manifold
else:
manifold = self._default_manifold
egrad2rgrad = manifold.egrad2rgrad
retr_transp = manifold.retr_transp
epsilon, alpha = group["epsilon"], group["alpha"]
v = self.state[p]["v"]
p_, v_ = retr_transp(p, v, v)
copy_or_set_(p, p_)
v.set_(v_)
n = egrad2rgrad(p, torch.randn_like(v))
v.mul_(1 - alpha).add_(epsilon * p.grad).add_(
math.sqrt(2 * alpha * epsilon) * n
)
p.grad.zero_()
r = v / epsilon
H_new += 0.5 * (r * r).sum().item()
if not self.burnin:
self.steps += 1
self.log_probs.append(logp)
@torch.no_grad()
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
manifold = p.manifold
copy_or_set_(p, manifold.projx(p))
# proj here is ok
state = self.state[p]
if not state:
continue
state["v"].set_(manifold.proju(p, state["v"]))