Skip to content

Commit

Permalink
[feat] support optional SST and DST (#1063)
Browse files Browse the repository at this point in the history
* [feat] support sst disabled and dst disabled cases

* added tests

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Aug 26, 2022
1 parent 15d4cf1 commit 3cc7fa8
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 19 deletions.
56 changes: 43 additions & 13 deletions fairscale/experimental/wgit/signal_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def _is_sparsity_zero(
"""Returns True when a given value of topk_percent or topk_element along a particular top_k_dim
for an input tensor results in sparsity=0% (or top-100-percent). Otherwise, returns False.
"""
if topk_percent is None and topk_element is None:
return False # 100% sparse

top_k_total_size = _top_k_total_size(dense, top_k_dim)
k = _get_k_for_topk(topk_percent, topk_element, top_k_total_size)
return k == top_k_total_size
Expand Down Expand Up @@ -245,11 +248,20 @@ def __init__(
self._dst_top_k_percent = dst_top_k_percent

self._validate_conf()
# TODO (Min): Type checking for the following
self._transform, self._inverse_transform = (
(_fft_transform, _ifft_transform) if algo is Algo.FFT else (_dct_transform, _idct_transform)
)

@property
def _sst_enabled(self) -> bool:
"""True if SST is enabled."""
return self._sst_top_k_element is not None or self._sst_top_k_percent is not None

@property
def _dst_enabled(self) -> bool:
"""True if DST is enabled."""
return self._dst_top_k_element is not None or self._dst_top_k_percent is not None

def _validate_conf(self) -> None:
"""Validating if the config is valid.
Expand All @@ -262,16 +274,14 @@ def _validate_conf(self) -> None:
If validation fails.
"""
# assert that both top_k_elements and top_k_percent aren't set for sst and dst
def one_and_only(a: Optional[int], b: Optional[float]) -> bool:
return (a is None) ^ (b is None)
def both_set(a: Optional[int], b: Optional[float]) -> bool:
return (a is not None) and (b is not None)

if not (
one_and_only(self._sst_top_k_element, self._sst_top_k_percent)
and one_and_only(self._dst_top_k_element, self._dst_top_k_percent)
if both_set(self._sst_top_k_element, self._sst_top_k_percent) or both_set(
self._dst_top_k_element, self._dst_top_k_percent
):
raise ValueError(
"One and only one of top_k_element and top_k_percent for "
"each of sst and dst must be provided as an argument.\n"
"top_k_element and top_k_percent can't be both set\n"
f"Input values are: sst element={self._sst_top_k_element}, sst percent={self._sst_top_k_percent}, "
f"dst element={self._dst_top_k_element}, dst percent={self._dst_top_k_percent}"
)
Expand All @@ -296,7 +306,7 @@ def none_or_greater_0(a: Optional[int]) -> bool:
f"and dst element={self._dst_top_k_element}"
)

def dense_to_sst(self, dense: Tensor) -> Tensor:
def dense_to_sst(self, dense: Tensor) -> Optional[Tensor]:
"""Get Signal Sparse Tensor (SST) from a dense tensor
Dense -> fft -> top-k -> results.
Expand All @@ -310,10 +320,14 @@ def dense_to_sst(self, dense: Tensor) -> Tensor:
Input dense tensor (no zeros).
Returns:
(Tensor):
(Tensor, optional):
Same shaped tensor as the input dense tensor, still in dense format but in frequency
domain (complex valued) and has zeros.
"""
if not self._sst_enabled:
# Special case, SST is simply None, which represents an all-zero tensor.
return None

top_k_total_size = _top_k_total_size(dense, self._sst_top_k_dim)
k = _get_k_for_topk(self._sst_top_k_percent, self._sst_top_k_element, top_k_total_size)
dense_freq = self._transform(dense, dim=self._sst_top_k_dim)
Expand All @@ -325,7 +339,7 @@ def dense_to_sst(self, dense: Tensor) -> Tensor:
real_dense_freq = dense_freq.real.abs()
return _scatter_topk_to_sparse_tensor(real_dense_freq, dense_freq, k, dim=self._sst_top_k_dim)

def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
def dense_sst_to_dst(self, dense: Tensor, sst: Optional[Tensor]) -> Optional[Tensor]:
"""Calculates DST from input dense and SST tensors.
dense - inverse_transform(sst)[using sst_dst_to_dense method] -> top-k -> dst
Expand All @@ -340,6 +354,13 @@ def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
(Tensor):
Same shaped tensor, still dense format but has zeros. Non-zeros are top-k delta values.
"""
if not self._dst_enabled:
# Special case, DST is simply None, which represents an all-zero tensor.
return None

if sst is None:
sst = torch.zeros_like(dense, dtype=torch.complex64)

if not (dense.shape == sst.shape):
raise ValueError("dense and sst have different shapes!")

Expand All @@ -349,7 +370,7 @@ def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
del dense
return _scatter_topk_to_sparse_tensor(delta.abs(), delta, k, dim=self._dst_top_k_dim)

def sst_dst_to_dense(self, sst: Tensor, dst: Optional[Tensor] = None) -> Tensor:
def sst_dst_to_dense(self, sst: Optional[Tensor], dst: Optional[Tensor] = None) -> Tensor:
"""From SST and DST returns a dense reconstructed tensor (RT). When argument dst=None, simply returns
the inverse transform of the SST tensor.
Expand All @@ -363,12 +384,19 @@ def sst_dst_to_dense(self, sst: Tensor, dst: Optional[Tensor] = None) -> Tensor:
(Tensor):
A dense tensor in real number domain from the SST.
"""
assert not (sst is None and dst is None), "both-None-case is not useful"

if sst is None:
# Simply the delta is the reconstruction.
return dst

# Now, ifft and then add the delta.
dense_rt = torch.real(self._inverse_transform(sst, dim=self._sst_top_k_dim))
if dst is not None:
dense_rt += dst
return dense_rt

def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
"""From dense tensor to lossy reconstruction of dense tensor with the help of SST and DST
tensor calculation. If requested sparsity is zero (or top_100_percent) then simply returns
the input dense tensor as the reconstruction.
Expand All @@ -393,6 +421,8 @@ def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
# of the same size as dense.
return dense, None, dense
else:
# depending on whether self._sst_enabled and self._dst_enabled, None SST/DST tensors can be returned
# below as well.
sst = self.dense_to_sst(dense)
dst = self.dense_sst_to_dst(dense, sst)
return self.sst_dst_to_dense(sst, dst), sst, dst
40 changes: 34 additions & 6 deletions tests/experimental/wgit/test_signal_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,15 @@ def kwargs(vals_list):
return dict(zip(arg_key_list, vals_list))

# Validate value error is raised when, either:
# 1. One and only one of sst (or dst) percent and element is not provided a value (not None).
# 2. Both of sst (or dst) percent and element is set to None.
# 3. top_k_percent and top_k_element are not in valid range (elem > 0) and for 0 < percent <= 100.
# 1. both sst (or dst) percent and element is not provided a value (not None).
# 2. top_k_percent and top_k_element are not in valid range (elem > 0) and for 0 < percent <= 100.
element = 10
percent = 50
dim = 0
args_list = [
[element, percent, dim, element, None, dim], # case 1.
[element, None, dim, element, percent, dim],
[None, None, dim, element, None, dim], # case 2.
[element, None, dim, None, None, dim],
[0, None, dim, None, None, dim], # case 3.
[0, None, dim, None, None, dim], # case 2.
[None, 0, dim, None, None, dim],
[element, None, dim, 0, None, dim],
[element, None, dim, None, 0, dim],
Expand Down Expand Up @@ -399,3 +396,34 @@ def test_lossy_compress_sparsity_0(tensor, dim, top_k_percent, device):
objects_are_equal(lossy_dense.to(device), tensor.to(device), raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(sst, None, raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(dst.to(device), tensor.to(device), raise_exception=True, rtol=RTOL, atol=ATOL)


def test_sst_disabled():
"""Tests the case where SST is disabled."""
dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000])
result = torch.tensor([0.0, 0.0, 0.7000, 0.8000])
sparser = SignalSparsity(dst_top_k_element=2, dst_top_k_dim=0)
rt, sst, dst = sparser.lossy_compress(dense)
objects_are_equal(rt, result, raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(dst, result, raise_exception=True, rtol=RTOL, atol=ATOL)
assert sst is None


def test_dst_disabled():
"""Tests the case where DST is disabled."""
dense = torch.tensor([0.5000, 0.6000, 0.7000, 0.8000, 0.9000])
result_rt = torch.tensor([0.6000, 0.7618, 0.7000, 0.6382, 0.8000])
result_sst = torch.tensor(
[
3.50000000000000000000 + 0.00000000000000000000j,
0.00000000000000000000 + 0.00000000000000000000j,
-0.25000002980232238770 + 0.08122986555099487305j,
-0.25000002980232238770 - 0.08122986555099487305j,
0.00000000000000000000 + 0.00000000000000000000j,
]
)
sparser = SignalSparsity(sst_top_k_element=3, sst_top_k_dim=0)
rt, sst, dst = sparser.lossy_compress(dense)
objects_are_equal(rt, result_rt, raise_exception=True, rtol=RTOL, atol=ATOL)
objects_are_equal(sst, result_sst, raise_exception=True, rtol=RTOL, atol=ATOL)
assert dst is None

0 comments on commit 3cc7fa8

Please sign in to comment.