Skip to content

Commit

Permalink
change group increments (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Apr 11, 2021
1 parent 244bcd0 commit e7e4dd7
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 8 deletions.
4 changes: 1 addition & 3 deletions geoopt/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def step(self, closure=None):
eps = group["eps"]
learning_rate = group["lr"]
amsgrad = group["amsgrad"]
group["step"] += 1
for point in group["params"]:
grad = point.grad
if grad is None:
Expand All @@ -75,7 +76,6 @@ def step(self, closure=None):

# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(point)
# Exponential moving average of squared gradient values
Expand All @@ -101,7 +101,6 @@ def step(self, closure=None):
denom = max_exp_avg_sq.sqrt().add_(eps)
else:
denom = exp_avg_sq.sqrt().add_(eps)
group["step"] += 1
bias_correction1 = 1 - betas[0] ** group["step"]
bias_correction2 = 1 - betas[1] ** group["step"]
step_size = (
Expand All @@ -119,7 +118,6 @@ def step(self, closure=None):
copy_or_set_(point, new_point)
exp_avg.set_(exp_avg_new)

group["step"] += 1
if (
group["stabilize"] is not None
and group["step"] % group["stabilize"] == 0
Expand Down
2 changes: 1 addition & 1 deletion geoopt/optim/rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def step(self, closure=None):
dampening = group["dampening"]
nesterov = group["nesterov"]
learning_rate = group["lr"]
group["step"] += 1
for point in group["params"]:
grad = point.grad
if grad is None:
Expand Down Expand Up @@ -113,7 +114,6 @@ def step(self, closure=None):
new_point = manifold.retr(point, -learning_rate * grad)
copy_or_set_(point, new_point)

group["step"] += 1
if (
group["stabilize"] is not None
and group["step"] % group["stabilize"] == 0
Expand Down
4 changes: 1 addition & 3 deletions geoopt/optim/sparse_radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def step(self, closure=None):
eps = group["eps"]
learning_rate = group["lr"]
amsgrad = group["amsgrad"]
group["step"] += 1
for point in group["params"]:
grad = point.grad
if grad is None:
Expand All @@ -92,7 +93,6 @@ def step(self, closure=None):

# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(point)
# Exponential moving average of squared gradient values
Expand Down Expand Up @@ -124,7 +124,6 @@ def step(self, closure=None):
state["max_exp_avg_sq"][rows] = max_exp_avg_sq
else:
denom = exp_avg_sq.sqrt().add_(eps)
group["step"] += 1
bias_correction1 = 1 - betas[0] ** group["step"]
bias_correction2 = 1 - betas[1] ** group["step"]
step_size = (
Expand All @@ -143,7 +142,6 @@ def step(self, closure=None):
state["exp_avg"][rows] = exp_avg_new
state["exp_avg_sq"][rows] = exp_avg_sq

group["step"] += 1
if (
group["stabilize"] is not None
and group["step"] % group["stabilize"] == 0
Expand Down
2 changes: 1 addition & 1 deletion geoopt/optim/sparse_rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def step(self, closure=None):
dampening = group["dampening"]
nesterov = group["nesterov"]
learning_rate = group["lr"]
group["step"] += 1
for point in group["params"]:
grad = point.grad
if grad is None:
Expand Down Expand Up @@ -115,7 +116,6 @@ def step(self, closure=None):
new_point = manifold.retr(point, -learning_rate * grad)
full_point[rows] = new_point

group["step"] += 1
if (
group["stabilize"] is not None
and group["step"] % group["stabilize"] == 0
Expand Down

0 comments on commit e7e4dd7

Please sign in to comment.