Skip to content

Commit

Permalink
[feat] add random_sparse_mask api (#1066)
Browse files Browse the repository at this point in the history
* [feat] add random_sparse_mask api

* correct test skip

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Sep 7, 2022
1 parent 19033c3 commit 1a8d234
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
2 changes: 1 addition & 1 deletion fairscale/experimental/wgit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


from .repo import Repo
from .signal_sparsity import Algo, SignalSparsity
from .signal_sparsity import Algo, SignalSparsity, random_sparse_mask
from .signal_sparsity_profiling import EnergyConcentrationProfile
from .version import __version_tuple__

Expand Down
18 changes: 18 additions & 0 deletions fairscale/experimental/wgit/signal_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,21 @@ def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Optional[Tensor], Optio
sst = self.dense_to_sst(dense)
dst = self.dense_sst_to_dst(dense, sst)
return self.sst_dst_to_dense(sst, dst), sst, dst


def random_sparse_mask(dense: Tensor, percent: float, dim: int) -> Tensor:
"""Get a random sparse mask
Args:
dense (Tensor):
Input dense tensor (no zeros).
percent (float):
Percent of non-zeros.
dim (int):
Dimension on which the random sparse mask is computed.
"""
assert percent > 0, percent
rand = torch.rand_like(dense)
ones = torch.ones_like(dense)
k = _get_k_for_topk(percent, None, dense.shape[dim])
return _scatter_topk_to_sparse_tensor(rand, ones, k, dim)
18 changes: 17 additions & 1 deletion tests/experimental/wgit/test_signal_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from fair_dev.testing.testing import objects_are_equal
from fairscale.experimental.wgit.signal_sparsity import SignalSparsity
from fairscale.experimental.wgit.signal_sparsity import SignalSparsity, random_sparse_mask

# Our own tolerance
ATOL = 1e-6
Expand Down Expand Up @@ -427,3 +427,19 @@ def test_dst_disabled():
objects_are_equal(rt, result_rt, raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(sst, result_sst, raise_exception=True, rtol=RTOL, atol=ATOL)
assert dst is None


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_random_sparse_mask(device):
"""Tests random_sparse_mask API."""
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("no GPU")

dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000, 0.9000]).to(device)
mask = random_sparse_mask(dense, 0.2, 0)
assert mask.sum() == 1
for d in [0, 1]:
dense = torch.rand(100, 100).to(device)
mask = random_sparse_mask(dense, 0.01, d)
assert objects_are_equal(mask.sum(dim=d), torch.ones(100).to(device), raise_exception=True)
assert mask.sum() == 100

0 comments on commit 1a8d234

Please sign in to comment.