Skip to content

Commit

Permalink
Implmentation of dense_sst_to_dst and sst_dst_to_dense (#1048)
Browse files Browse the repository at this point in the history
[Feat] Implements dense_sst_to_dst and sst_dst_to_dense methods and adds tests

1. Implements the dense_sst_to_dst and sst_dst_to_dense method.
2. Adds tests for perfect reconstruction with all top-k across different dims.
3. Adds tests for the two new methods.
  • Loading branch information
riohib committed Jul 31, 2022
1 parent d3bda79 commit c1dada4
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 50 deletions.
52 changes: 32 additions & 20 deletions fairscale/experimental/wgit/signal_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ def _dct_transform(dense: Tensor) -> Tensor:
raise NotImplementedError("Support for DCT has not been implemented yet!")


def _inverse_dct_transform(sst: Tensor) -> Tensor:
"""Should take a tensor and perform an inverse Discrete Cosine Transform and return a new tensor.
Args:
sst (Tensor):
Input sst tensor (may have zeros) in frequency domain.
Returns:
(Tensor):
A new, transformed dense tensor with real domain values.
"""
raise NotImplementedError("Support for iDCT has not been implemented yet!")


class Algo(Enum):
FFT = 0
DCT = 1
Expand Down Expand Up @@ -156,7 +169,7 @@ def __init__(

self._validate_conf()
# TODO (Min): Type checking for the following
self._transform = torch.fft.fft if algo is Algo.FFT else _dct_transform # type: ignore
self._transform, self._inverse_transform = (torch.fft.fft, torch.fft.ifft) if algo is Algo.FFT else (_dct_transform, _inverse_dct_transform) # type: ignore

def _validate_conf(self) -> None:
"""Validating if the config is valid.
Expand Down Expand Up @@ -230,15 +243,13 @@ def dense_to_sst(self, dense: Tensor) -> Tensor:
# or DCT transformed components when using DCT (currently not implemented).
# TODO: In case of the FFT, the imaginary part can perhaps be quantized or pruning can be
# done on the smaller phases.
real_dense_freq = dense_freq.real.abs()
real_dense_freq = torch.real(dense_freq).abs()
return _scatter_topk_to_sparse_tensor(real_dense_freq, dense_freq, k, dim=self._sst_top_k_dim)

def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
"""From dense and SST to a DST
This will use sst_dst_to_dense below but with dst=None.
"""Calculates DST from input dense and SST tensors.
dense - ifft(sst)[using sst_dst_to_dense below) -> top-k -> result
dense - inverse_transform(sst)[using sst_dst_to_dense method] -> top-k -> dst
Args:
dense (Tensor):
Expand All @@ -250,32 +261,33 @@ def dense_sst_to_dst(self, dense: Tensor, sst: Tensor) -> Tensor:
(Tensor):
Same shaped tensor, still dense format but has zeros. Non-zeros are top-k delta values.
"""
pass
if not (dense.shape == sst.shape):
raise ValueError("dense and sst have different shapes!")

def sst_dst_to_dense(self, sst: Tensor, dst: Tensor = None) -> Tensor:
"""From SST and dst back to a dense
top_k_total_size = _top_k_total_size(dense, self._dst_top_k_dim)
k = _get_k_for_topk(self._dst_top_k_percent, self._dst_top_k_element, top_k_total_size)
delta = dense - self.sst_dst_to_dense(sst) # sst_dst_to_dense(sst) returns the inverse transform here
del dense
return _scatter_topk_to_sparse_tensor(delta.abs(), delta, k, dim=self._dst_top_k_dim)

result = ifft(sst)
if dst is not None:
result += dst
return result
def sst_dst_to_dense(self, sst: Tensor, dst: Optional[Tensor] = None) -> Tensor:
"""From SST and DST returns a dense reconstructed tensor (RT). When argument dst=None, simply returns
the inverse transform of the SST tensor.
Args:
sst (Tensor):
Singal sparse tensor. Required argument.
dst (Tensor, optinoal):
dst (Tensor, optional):
Delta sparse tensor, optional.
Returns:
(Tensor):
A dense tensor in real number domain from the SST.
"""
pass

def sst_or_dst_to_mask(self) -> None:
# we shouldn't need this function since going from SST/DST to mask should be a
# trivial call in pytorch. Maybe I am missing something.
pass
dense_rt = torch.real(self._inverse_transform(sst))
if dst is not None:
dense_rt += dst
return dense_rt


# We could separate have helper functions that work on state_dict instead of a tensor.
Expand Down
156 changes: 126 additions & 30 deletions tests/experimental/wgit/test_signal_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,114 @@

def get_test_params():
"""Helper function to create and return a list of tuples of the form:
(in_tensor, expected_tensor, dim, percent, top_k_element) to be used as parameters for tests.
(dense, expected_sst, expected_dst, expected_reconstructed_tensor (RT), dim, percent, top_k_element)
to be used as parameters for tests.
"""
# input in_tensors
tensor_4x3 = torch.arange(12).reshape(4, 3)
tensor_2x2x3 = torch.arange(12).reshape(3, 2, 2)
tensor_4x3_None = torch.arange(12).reshape(4, 3).float()
tensor_4x3_0 = torch.arange(50, 62).reshape(4, 3) / 100
tensor_3x3_1 = torch.linspace(-5, 5, 9).reshape(3, 3)
tensor_2x2x3 = torch.arange(12).reshape(3, 2, 2).float()

# Expected SST output tensors for 4x3 tensor of ascending ints
expected_4x3_None = torch.tensor(
# with dim=None, top-2
expd_sst_4x3_None = torch.tensor(
[
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], # with dim=None, top-2
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[21.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[30.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
],
dtype=torch.complex64,
)

expected_4x3_0 = torch.tensor(
# with dim=None, top-2
expd_dst_4x3_None = torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 4.0, 5.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32
)

# expected_reconstructed_tensor with dim=None and top-2 for both sst and dst
expd_rt_4x3_None = torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 4.0, 5.0], [7.0, 7.0, 7.0], [10.0, 10.0, 10.0]], dtype=torch.float32
)

# with dim=0, top-2
expd_sst_4x3_0 = torch.tensor(
[
[0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j], # with dim=0, top-2
[0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j, 0.0000000 + 0.0000000j],
[21.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, -1.5000000 - 0.8660254j],
[30.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, -1.5000000 - 0.8660254j],
[0.0000000000 + 0.0000000000j, 0.0000000000 + 0.0000000000j, 0.0000000000 + 0.0000000000j],
[0.0000000000 + 0.0000000000j, -0.0150000453 + 0.0086602457j, -0.0150000453 - 0.0086602457j],
[1.7100000381 + 0.0000000000j, 0.0000000000 + 0.0000000000j, 0.0000000000 + 0.0000000000j],
[1.7999999523 + 0.0000000000j, -0.0150000453 + 0.0086602457j, -0.0150000453 - 0.0086602457j],
],
dtype=torch.complex64,
)

expected_4x3_1 = torch.tensor(
# with dim=0, top-2
expd_dst_4x3_0 = torch.tensor(
[[0.5000, 0.5100, 0.5200], [0.5400, 0.5400, 0.5400], [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]]
)

# expected_reconstructed_tensor with dim=0 and top-2 for both sst and dst
expd_rt_4x3_0 = torch.tensor(
[[0.5000, 0.5100, 0.5200], [0.5300, 0.5400, 0.5500], [0.5700, 0.5700, 0.5700], [0.5900, 0.6000, 0.6100]]
)

# with dim=1, top-2
expd_sst_3x3_1 = torch.tensor(
[
[3.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j], # with dim=1, top-2
[12.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j],
[21.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j],
[30.0000000 + 0.0000000j, -1.5000000 + 0.8660254j, 0.0000000 + 0.0000000j],
[-11.2500000000 + 0.0000000000j, -1.8750000000 + 1.0825316906j, 0.0000000000 + 0.0000000000j],
[0.0000000000 + 0.0000000000j, -1.8750000000 + 1.0825316906j, -1.8750000000 - 1.0825316906j],
[11.2500000000 + 0.0000000000j, -1.8750000000 + 1.0825316906j, 0.0000000000 + 0.0000000000j],
],
dtype=torch.complex64,
)

expected_2x2x3_1 = torch.tensor(
# with dim=1, top-2
expd_dst_3x3_1 = torch.tensor(
[
[-6.2500000000e-01, 0.0000000000e00, 6.2500000000e-01],
[0.0000000000e00, -4.8244856998e-08, 0.0000000000e00],
[-6.2500000000e-01, 0.0000000000e00, 6.2500000000e-01],
]
)

# expected_reconstructed_tensor with dim=1 and top-2 for both sst and dst
expd_rt_3x3_1 = torch.tensor([[-5.0000, -3.7500, -2.5000], [-1.2500, 0.0000, 1.2500], [2.5000, 3.7500, 5.0000]])

# with dim=1, top-1
expd_sst_2x2x3_1 = torch.tensor(
[
[[1.0 + 0.0j, -1.0 + 0.0j], [5.0 + 0.0j, -1.0 + 0.0j]], # with dim=1, top-2
[[9.0 + 0.0j, -1.0 + 0.0j], [13.0 + 0.0j, -1.0 + 0.0j]],
[[17.0 + 0.0j, -1.0 + 0.0j], [21.0 + 0.0j, -1.0 + 0.0j]],
[[0.0 + 0.0j, -1.0 + 0.0j], [5.0 + 0.0j, 0.0 + 0.0j]],
[[0.0 + 0.0j, -1.0 + 0.0j], [13.0 + 0.0j, 0.0 + 0.0j]],
[[0.0 + 0.0j, -1.0 + 0.0j], [21.0 + 0.0j, 0.0 + 0.0j]],
],
dtype=torch.complex64,
)

# with dim=1, top-1
expd_dst_2x2x3_1 = torch.tensor(
[
[[0.5000, 0.5000], [0.0000, 0.0000]],
[[4.5000, 4.5000], [0.0000, 0.0000]],
[[8.5000, 8.5000], [0.0000, 0.0000]],
],
dtype=torch.float32,
)

# expected_reconstructed_tensor with dim=1 and top-1 for both sst and dst
expd_rt_2x2x3_1 = torch.tensor(
[
[[0.0000, 1.0000], [2.5000, 2.5000]],
[[4.0000, 5.0000], [6.5000, 6.5000]],
[[8.0000, 9.0000], [10.5000, 10.5000]],
],
dtype=torch.float32,
)

return [
(tensor_4x3, expected_4x3_None, None, 20, 2),
(tensor_4x3, expected_4x3_0, 0, 50, 2),
(tensor_4x3, expected_4x3_1, 1, 70, 2),
(tensor_2x2x3, expected_2x2x3_1, 1, 100, 2),
(tensor_4x3_None, expd_sst_4x3_None, expd_dst_4x3_None, expd_rt_4x3_None, None, 20, 2),
(tensor_4x3_0, expd_sst_4x3_0, expd_dst_4x3_0, expd_rt_4x3_0, 0, 50, 2),
(tensor_3x3_1, expd_sst_3x3_1, expd_dst_3x3_1, expd_rt_3x3_1, 1, 70, 2),
(tensor_2x2x3, expd_sst_2x2x3_1, expd_dst_2x2x3_1, expd_rt_2x2x3_1, 1, 50, 1),
]


Expand Down Expand Up @@ -128,16 +185,16 @@ def test_dense_to_sst_perfect_recons(tensor, dim):
assert all((sparser_2d.dense_to_sst(tensor) == torch.fft.fft(tensor)).flatten())


@pytest.mark.parametrize("tensor, expected, dim, percent, k", get_test_params())
def test_dense_to_sst_fixed(tensor, expected, dim, percent, k):
"""Tests for fixed input dense tensor and fixed expected output SST tensor for top-2 elements."""
@pytest.mark.parametrize("tensor, expd_sst, unused1, unused2, dim, unused3, k", get_test_params())
def test_dense_to_sst_fixed(tensor, expd_sst, unused1, unused2, dim, unused3, k):
"""Tests for fixed input dense tensor and fixed expected output SST tensor."""
sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_percent=100)
sst = sparser_2d.dense_to_sst(tensor)
objects_are_equal(sst, expected, raise_exception=True)
objects_are_equal(sst, expd_sst, raise_exception=True)


@pytest.mark.parametrize("tensor, expected, dim, percent, k", get_test_params())
def test_percent_element(tensor, expected, dim, percent, k):
@pytest.mark.parametrize("tensor, unused1, unused2, unused3, dim, percent, k", get_test_params())
def test_percent_element(tensor, unused1, unused2, unused3, dim, percent, k):
"""Tests whether comparative values for top_k_element and top_k_percent returns same outputs"""
sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_percent=100)
sst_element = sparser_2d.dense_to_sst(tensor)
Expand All @@ -147,3 +204,42 @@ def test_percent_element(tensor, expected, dim, percent, k):
)
sst_percent = sparser_2d.dense_to_sst(tensor)
objects_are_equal(sst_element, sst_percent, raise_exception=True)


@pytest.mark.parametrize("tensor, sst, expd_dst, unused1, dim, unused2, k", get_test_params())
def test_dense_sst_to_dst(tensor, sst, expd_dst, unused1, dim, unused2, k):
"""Tests fixed expected output DST tensor with fixed input dense and SST tensors."""
sparser_2d = SignalSparsity(sst_top_k_percent=None, sst_top_k_element=k, dst_top_k_element=k, dst_top_k_dim=dim)
dst = sparser_2d.dense_sst_to_dst(tensor, sst)
objects_are_equal(dst, expd_dst, raise_exception=True)


@pytest.mark.parametrize(
"dense, k, dim",
[
(torch.linspace(0.01, 0.06, 40).reshape(5, 8), 40, None), # top-40, dim=None
(torch.linspace(0.1, 0.6, 30).reshape(5, 6), 5, 0), # top-5, dim=0
(torch.linspace(-0.1, 0.6, 35).reshape(7, 5), 5, 1), # top-5, dim=1
(torch.arange(60).float().reshape(10, 6), 60, None), # top-60, dim=None
(torch.arange(60).float().reshape(10, 6), 10, 0), # top-10, dim=0
(torch.arange(60).float().reshape(10, 6), 6, 1), # top-6, dim=1
(torch.arange(60).float().reshape(2, 5, 6), 5, 1), # top-5, dim=1
],
)
def test_sst_dst_to_perfect_dense_reconstruction(dense, k, dim):
"""Tests whether perfect reconstruction of input dense tensor is generated when top-k matches the numel
across some dimension dim for both SST and DST.
"""
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
sst = sparser.dense_to_sst(dense)
dst = sparser.dense_sst_to_dst(dense, sst)
dense_recons = sparser.sst_dst_to_dense(sst, dst)
objects_are_equal(dense, dense_recons, raise_exception=True)


@pytest.mark.parametrize("unused1, sst, dst, expd_rt, dim, unused2, k", get_test_params())
def test_sst_dst_to_dense(unused1, sst, dst, expd_rt, dim, unused2, k):
"""Tests the correct expected reconstruction from frozen sst and dst tensors."""
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
dense_recons = sparser.sst_dst_to_dense(sst, dst)
objects_are_equal(dense_recons, expd_rt, raise_exception=True)

0 comments on commit c1dada4

Please sign in to comment.