Skip to content

Commit

Permalink
Merge 56a7c61 into d15832a
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Jul 11, 2019
2 parents d15832a + 56a7c61 commit 9a42e91
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 23 deletions.
11 changes: 11 additions & 0 deletions geoopt/samplers/rhmc.py
Expand Up @@ -148,3 +148,14 @@ def step(self, closure):
self.log_probs.append(old_logp)
else:
self.log_probs.append(new_logp)

@torch.no_grad()
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
copy_or_set_(p, p.manifold.projx(p))
state = self.state[p]
if not state: # due to None grads
continue
copy_or_set_(state["old_p"], p.manifold.projx(state["old_p"]))
15 changes: 6 additions & 9 deletions geoopt/samplers/rsgld.py
Expand Up @@ -57,12 +57,9 @@ def step(self, closure):
self.steps += 1
self.log_probs.append(logp.item())

def stabilize(self):
"""Stabilize parameters if they are off-manifold due to numerical reasons
"""
with torch.no_grad():
for group in self.param_groups:
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
copy_or_set_(p, p.manifold.projx(p))
@torch.no_grad()
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
copy_or_set_(p, p.manifold.projx(p))
27 changes: 13 additions & 14 deletions geoopt/samplers/sgrhmc.py
Expand Up @@ -91,17 +91,16 @@ def step(self, closure):
self.steps += 1
self.log_probs.append(logp)

def stabilize(self):
"""Stabilize parameters if they are off-manifold due to numerical reasons
"""

for group in self.param_groups:
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue

manifold = p.manifold
v = self.state[p]["v"]
copy_or_set_(p, manifold.projx(p))
# proj here is ok
v.set_(manifold.proju(p, v))
@torch.no_grad()
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue

manifold = p.manifold
copy_or_set_(p, manifold.projx(p))
# proj here is ok
state = self.state[p]
if not state:
continue
state["v"].set_(manifold.proju(p, state["v"]))

0 comments on commit 9a42e91

Please sign in to comment.