Skip to content

Commit

Permalink
Fix _load_from_state_dict for num_batches_tracked in batchnorm (pytor…
Browse files Browse the repository at this point in the history
…ch#115285)

I approved pytorch#110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok

Pull Request resolved: pytorch#115285
Approved by: https://github.com/albanD
  • Loading branch information
mikaylagawarecki authored and dmenig committed Dec 21, 2023
1 parent 1d2d5cd commit 74a2b5c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
8 changes: 8 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5318,6 +5318,14 @@ def test_batchnorm_load_state_dict(self):
bn.load_state_dict(empty_dict, strict=False)
self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10))

# test that when `num_batches_tracked` is not in loaded state_dict,
# meta num_batches_tracked is still replaced with singleton 0 tensor
with torch.device('meta'):
meta_bn = torch.nn.BatchNorm2d(3)
self.assertTrue(meta_bn.num_batches_tracked.device == torch.device('meta'))
meta_bn.load_state_dict(empty_dict, assign=True, strict=False)
self.assertEqual(meta_bn.state_dict()["num_batches_tracked"], torch.tensor(0))

def test_pairwise_distance(self):
input1 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
input2 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _load_from_state_dict(
if num_batches_tracked_key not in state_dict:
state_dict[num_batches_tracked_key] = (
self.num_batches_tracked
if self.num_batches_tracked is not None
if self.num_batches_tracked is not None and self.num_batches_tracked.device != torch.device('meta')
else torch.tensor(0, dtype=torch.long)
)

Expand Down

0 comments on commit 74a2b5c

Please sign in to comment.