Skip to content

Commit

Permalink
Fix OptimizerWrapper creation, test gradient clipping (#593)
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab committed Oct 8, 2023
1 parent fdda330 commit fa199ad
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
3 changes: 1 addition & 2 deletions hivemind/moe/server/layers/optim.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import torch


class OptimizerWrapper(torch.optim.Optimizer):
class OptimizerWrapper:
"""A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer"""

def __init__(self, optim: torch.optim.Optimizer):
super().__init__(optim.param_groups, optim.defaults)
self.optim = optim

@property
Expand Down
7 changes: 6 additions & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
SGD = partial(torch.optim.SGD, lr=0.05)

with background_server(
num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
num_experts=2,
device="cpu",
optim_cls=SGD,
hidden_dim=64,
num_handlers=1,
clip_grad_norm=1.0,
) as server_peer_info:
dht = DHT(initial_peers=server_peer_info.addrs, start=True)
expert1, expert2 = create_remote_experts(
Expand Down

0 comments on commit fa199ad

Please sign in to comment.