Skip to content

Commit

Permalink
Fix a bug in nesting check_sparse_tensor_invariants context managers …
Browse files Browse the repository at this point in the history
…(#95372)

As in the title. The bug was reported in pytorch/pytorch#94728 (comment) and has the following reproducer:
```python
>>> import torch
>>> check_ctx = torch.sparse.check_sparse_tensor_invariants(True)
>>> no_check_ctx = torch.sparse.check_sparse_tensor_invariants(False)
>>> with check_ctx:
...   assert torch.sparse.check_sparse_tensor_invariants.is_enabled()
...   with no_check_ctx:
...     assert not torch.sparse.check_sparse_tensor_invariants.is_enabled()
...   assert torch.sparse.check_sparse_tensor_invariants.is_enabled()
...
Traceback (most recent call last):
  File "<stdin>", line 5, in <module>
AssertionError
```

Pull Request resolved: pytorch/pytorch#95372
Approved by: https://github.com/cpuhrsch
  • Loading branch information
pearu authored and cyyever committed Feb 25, 2023
1 parent 5237753 commit b3dfe0a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
27 changes: 27 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4250,6 +4250,33 @@ def create_invalid_tensor(check_invariants=None):
# local context:
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())

# Test nesting of pre-defined context managers
check_ctx = torch.sparse.check_sparse_tensor_invariants(True)
no_check_ctx = torch.sparse.check_sparse_tensor_invariants(False)
with check_ctx:
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
with no_check_ctx:
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())

# Test an attempt to re-use an activate context manager instance
check_ctx2 = torch.sparse.check_sparse_tensor_invariants(True)
with check_ctx:
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
with no_check_ctx:
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
with self.assertRaisesRegex(RuntimeError, "This context manager instance is already activated."
" Use a different context manager instance for context nesting"):
with check_ctx:
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
with check_ctx2:
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())

def test_generate_simple_inputs(self):
layouts = [torch.strided, torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc]

Expand Down
8 changes: 7 additions & 1 deletion torch/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,19 @@ def disable():
# context manager support
def __init__(self, enable=True):
self.state = enable
self.saved_state = self.is_enabled()
self.saved_state : Optional[bool] = None

def __enter__(self):
if self.saved_state is not None:
raise RuntimeError('This context manager instance is already activated.'
' Use a different context manager instance for context nesting.')
self.saved_state = self.is_enabled()
torch._C._set_check_sparse_tensor_invariants(self.state)

def __exit__(self, type, value, traceback):
assert self.saved_state is not None
torch._C._set_check_sparse_tensor_invariants(self.saved_state)
self.saved_state = None

# decorator support
def __call__(self, mth):
Expand Down

0 comments on commit b3dfe0a

Please sign in to comment.