Skip to content

Commit

Permalink
make sure that a unit test would catch a regression
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Mar 30, 2021
1 parent 7f5d020 commit 30ea553
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion fairscale/nn/data_parallel/sharded_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def __init__(
self._bucket_list: List[GradBucket] = []

# - setup backward hooks which will be called by Torch's autograd in due time
self._manual_reduce: List[Callable] = []
self._grad_hooks: List[Any] = []
self._manual_reduce: List[Callable] = []

# passing a handle to torch.nn.SyncBatchNorm layer
self._passing_sync_batchnorm_handle(self.module)
Expand Down
5 changes: 4 additions & 1 deletion tests/nn/data_parallel/test_sharded_ddp_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,13 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
np.random.seed(rank)
model = GPT2(
embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
).to(device)
)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)

# Move the model to another device post-construction
model = model.to(device)

# Optim loop
def closure():
optimizer.zero_grad()
Expand Down

0 comments on commit 30ea553

Please sign in to comment.