Skip to content

Commit

Permalink
[minor] fix doc and assert and test around percent (#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
min-xu-ai committed Sep 7, 2022
1 parent 1a8d234 commit 454537d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions fairscale/experimental/wgit/signal_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,11 +435,11 @@ def random_sparse_mask(dense: Tensor, percent: float, dim: int) -> Tensor:
dense (Tensor):
Input dense tensor (no zeros).
percent (float):
Percent of non-zeros.
Percent of non-zeros (0, 100].
dim (int):
Dimension on which the random sparse mask is computed.
"""
assert percent > 0, percent
assert percent > 0 and percent <= 100, percent
rand = torch.rand_like(dense)
ones = torch.ones_like(dense)
k = _get_k_for_topk(percent, None, dense.shape[dim])
Expand Down
4 changes: 2 additions & 2 deletions tests/experimental/wgit/test_signal_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,10 @@ def test_random_sparse_mask(device):
pytest.skip("no GPU")

dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000, 0.9000]).to(device)
mask = random_sparse_mask(dense, 0.2, 0)
mask = random_sparse_mask(dense, 20, 0)
assert mask.sum() == 1
for d in [0, 1]:
dense = torch.rand(100, 100).to(device)
mask = random_sparse_mask(dense, 0.01, d)
mask = random_sparse_mask(dense, 1, d)
assert objects_are_equal(mask.sum(dim=d), torch.ones(100).to(device), raise_exception=True)
assert mask.sum() == 100

0 comments on commit 454537d

Please sign in to comment.