Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed May 15, 2019
1 parent 90af5ec commit c923167
Show file tree
Hide file tree
Showing 9 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion geoopt/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def perform_step(
new_point, exp_avg_new = manifold.retr_transp(
point, exp_avg, u=direction, t=-step_size
)
point.set_(new_point)
point.set_(new_point.contiguous())
exp_avg.set_(exp_avg_new)

def stabilize_group(self, group):
Expand Down
4 changes: 2 additions & 2 deletions geoopt/optim/rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ def perform_step(
point, momentum_buffer, u=grad, t=-lr
)
momentum_buffer.set_(new_momentum_buffer)
point.set_(new_point)
point.set_(new_point.contiguous())
else:
new_point = manifold.retr(point, grad, -lr)
point.set_(new_point)
point.set_(new_point.contiguous())

def stabilize_group(self, group):
with torch.no_grad():
Expand Down
2 changes: 1 addition & 1 deletion geoopt/samplers/rhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _step(self, p, r, epsilon):

r.add_(epsilon * egrad2rgrad(p, p.grad))
p_, r_ = retr_transp(p, r, u=r, t=epsilon)
p.set_(p_)
p.set_(p_.contiguous())
r.set_(r_)

def step(self, closure):
Expand Down
2 changes: 1 addition & 1 deletion geoopt/samplers/rsgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def step(self, closure):
n = torch.randn_like(p).mul_(math.sqrt(epsilon))
r = egrad2rgrad(p, 0.5 * epsilon * p.grad + n)

p.set_(retr(p, r, 1.0))
p.set_(retr(p, r, 1.0).contiguous())
p.grad.zero_()

if not self.burnin:
Expand Down
2 changes: 1 addition & 1 deletion geoopt/samplers/sgrhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def step(self, closure):
v = self.state[p]["v"]

p_, v_ = retr_transp(p, v, u=v, t=1.0)
p.set_(p_)
p.set_(p_.contiguous())
v.set_(v_)

n = egrad2rgrad(p, torch.randn_like(v))
Expand Down
4 changes: 2 additions & 2 deletions geoopt/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def proj_(self):
Returns
-------
tensor
same instance
same instance (contiguous)
"""
with torch.no_grad():
self.set_(self.manifold.projx(self.data))
self.set_(self.manifold.projx(self.data).contiguous())
return self

@insert_docs(Euclidean.retr.__doc__, r"\s+x : .+\n.+", "")
Expand Down
3 changes: 1 addition & 2 deletions tests/test_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def closure():
if (X - Xstar).norm() < 1e-5:
break
optim.step(closure)

assert X.is_contiguous()
np.testing.assert_allclose(X.data, Xstar, atol=1e-5, rtol=1e-5)
optim.load_state_dict(optim.state_dict())
optim.step(closure)
Expand All @@ -49,5 +49,4 @@ def closure():

for _ in range(2000):
optim.step(closure)

np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5)
2 changes: 1 addition & 1 deletion tests/test_rhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward(self):

points = np.asarray(points)
points = points[::20]

assert nd.x.is_contiguous()
np.testing.assert_allclose(mu.numpy(), points.mean(axis=0), atol=1e-1)
np.testing.assert_allclose(sigma.numpy(), points.std(axis=0), atol=1e-1)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def closure():
if (X - Xstar).norm() < 1e-5:
break
optim.step(closure)

assert X.is_contiguous()
np.testing.assert_allclose(X.data, Xstar, atol=1e-5)
optim.load_state_dict(optim.state_dict())
optim.step(closure)
Expand All @@ -56,5 +56,6 @@ def test_init_manifold():
opt.zero_grad()
opt.step()
assert not np.allclose(p0.data, p0old.data)
assert p0.is_contiguous()
np.testing.assert_allclose(p1.data, p1old.data)
np.testing.assert_allclose(p0.data, stiefel.projx(p0old.data), atol=1e-4)

0 comments on commit c923167

Please sign in to comment.