Skip to content

Commit

Permalink
Merge c46d278 into 90af5ec
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine authored May 15, 2019
2 parents 90af5ec + c46d278 commit cc29f40
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 18 deletions.
5 changes: 3 additions & 2 deletions geoopt/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def perform_step(
new_point, exp_avg_new = manifold.retr_transp(
point, exp_avg, u=direction, t=-step_size
)
point.set_(new_point)
# use copy only for user facing point
point.copy_(new_point)
exp_avg.set_(exp_avg_new)

def stabilize_group(self, group):
Expand All @@ -202,7 +203,7 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
exp_avg = state["exp_avg"]
p.set_(manifold.projx(p))
p.copy_(manifold.projx(p))
exp_avg.set_(manifold.proju(p, exp_avg))

def _sanitize_group(self, group):
Expand Down
7 changes: 4 additions & 3 deletions geoopt/optim/rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,11 @@ def perform_step(
point, momentum_buffer, u=grad, t=-lr
)
momentum_buffer.set_(new_momentum_buffer)
point.set_(new_point)
# use copy only for user facing point
point.copy_(new_point)
else:
new_point = manifold.retr(point, grad, -lr)
point.set_(new_point)
point.copy_(new_point)

def stabilize_group(self, group):
with torch.no_grad():
Expand All @@ -180,7 +181,7 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
momentum = group["momentum"]
p.set_(manifold.projx(p))
p.copy_(manifold.projx(p))
if momentum > 0:
param_state = self.state[p]
if not param_state: # due to None grads
Expand Down
2 changes: 1 addition & 1 deletion geoopt/samplers/rhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _step(self, p, r, epsilon):

r.add_(epsilon * egrad2rgrad(p, p.grad))
p_, r_ = retr_transp(p, r, u=r, t=epsilon)
p.set_(p_)
p.copy_(p_)
r.set_(r_)

def step(self, closure):
Expand Down
6 changes: 3 additions & 3 deletions geoopt/samplers/rsgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def step(self, closure):

n = torch.randn_like(p).mul_(math.sqrt(epsilon))
r = egrad2rgrad(p, 0.5 * epsilon * p.grad + n)

p.set_(retr(p, r, 1.0))
# use copy only for user facing point
p.copy_(retr(p, r, 1.0))
p.grad.zero_()

if not self.burnin:
Expand All @@ -67,4 +67,4 @@ def stabilize(self):
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue

p.set_(p.manifold.projx(p))
p.copy_(p.manifold.projx(p))
4 changes: 2 additions & 2 deletions geoopt/samplers/sgrhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def step(self, closure):
v = self.state[p]["v"]

p_, v_ = retr_transp(p, v, u=v, t=1.0)
p.set_(p_)
p.copy_(p_)
v.set_(v_)

n = egrad2rgrad(p, torch.randn_like(v))
Expand Down Expand Up @@ -104,6 +104,6 @@ def stabilize(self):
manifold = p.manifold
v = self.state[p]["v"]

p.set_(manifold.projx(p))
p.copy_(manifold.projx(p))
# proj here is ok
v.set_(manifold.proju(p, v))
5 changes: 2 additions & 3 deletions geoopt/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __new__(cls, *args, manifold=Euclidean(), requires_grad=False, **kwargs):
instance.manifold = manifold
return instance

@torch.no_grad()
def proj_(self):
"""
Inplace projection to the manifold
Expand All @@ -36,9 +37,7 @@ def proj_(self):
tensor
same instance
"""
with torch.no_grad():
self.set_(self.manifold.projx(self.data))
return self
return self.copy_(self.manifold.projx(self))

@insert_docs(Euclidean.retr.__doc__, r"\s+x : .+\n.+", "")
def retr(self, u, t=1.0, order=None):
Expand Down
3 changes: 1 addition & 2 deletions tests/test_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def closure():
if (X - Xstar).norm() < 1e-5:
break
optim.step(closure)

assert X.is_contiguous()
np.testing.assert_allclose(X.data, Xstar, atol=1e-5, rtol=1e-5)
optim.load_state_dict(optim.state_dict())
optim.step(closure)
Expand All @@ -49,5 +49,4 @@ def closure():

for _ in range(2000):
optim.step(closure)

np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5)
2 changes: 1 addition & 1 deletion tests/test_rhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward(self):

points = np.asarray(points)
points = points[::20]

assert nd.x.is_contiguous()
np.testing.assert_allclose(mu.numpy(), points.mean(axis=0), atol=1e-1)
np.testing.assert_allclose(sigma.numpy(), points.std(axis=0), atol=1e-1)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def closure():
if (X - Xstar).norm() < 1e-5:
break
optim.step(closure)

assert X.is_contiguous()
np.testing.assert_allclose(X.data, Xstar, atol=1e-5)
optim.load_state_dict(optim.state_dict())
optim.step(closure)
Expand All @@ -56,5 +56,6 @@ def test_init_manifold():
opt.zero_grad()
opt.step()
assert not np.allclose(p0.data, p0old.data)
assert p0.is_contiguous()
np.testing.assert_allclose(p1.data, p1old.data)
np.testing.assert_allclose(p0.data, stiefel.projx(p0old.data), atol=1e-4)

0 comments on commit cc29f40

Please sign in to comment.