Skip to content

Commit

Permalink
implementation of lossy_compression method (#1051)
Browse files Browse the repository at this point in the history
* [Feat] implements lossy_compress with tests

1. Implements a method lossy_compress that takes in a dense tensor and returns a reconstruction with sst and dst, and optionally with sparsity.
  • Loading branch information
riohib committed Aug 3, 2022
1 parent c1dada4 commit 5c60f33
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 1 deletion.
2 changes: 2 additions & 0 deletions fairscale/experimental/wgit/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import copy
from dataclasses import dataclass
from enum import Enum
import json
Expand Down Expand Up @@ -276,6 +277,7 @@ def fn(element: Any, names: List[str]) -> Any:
return element

state_dict = torch.load(file_path)
ret_state_dict = copy.deepcopy(state_dict) # This is only a temporary addition for testing.
_recursive_apply_to_elements(state_dict, fn, [])
file_path_or_state_dict = state_dict

Expand Down
42 changes: 41 additions & 1 deletion fairscale/experimental/wgit/signal_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

from enum import Enum
from typing import Optional
from typing import Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -75,6 +75,17 @@ def _top_k_total_size(tensor: Tensor, topk_dim: Optional[int]) -> int:
return top_k_total_size


def _is_sparsity_zero(
dense: Tensor, topk_percent: Optional[float], topk_element: Optional[int], top_k_dim: Optional[int]
) -> bool:
"""Returns True when a given value of topk_percent or topk_element along a particular top_k_dim
for an input tensor results in sparsity=0% (or top-100-percent). Otherwise, returns False.
"""
top_k_total_size = _top_k_total_size(dense, top_k_dim)
k = _get_k_for_topk(topk_percent, topk_element, top_k_total_size)
return k == top_k_total_size


def _dct_transform(dense: Tensor) -> Tensor:
"""Should take a tensor and perform a Discrete Cosine Transform on the tensor.
Expand Down Expand Up @@ -289,6 +300,35 @@ def sst_dst_to_dense(self, sst: Tensor, dst: Optional[Tensor] = None) -> Tensor:
dense_rt += dst
return dense_rt

def lossy_compress(self, dense: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""From dense tensor to lossy reconstruction of dense tensor with the help of SST and DST
tensor calculation. If requested sparsity is zero (or top_100_percent) then simply returns
the input dense tensor as the reconstruction.
Args:
dense (Tensor):
Input dense tensor (no zeros).
Returns:
(Tuple[Tensor, Tensor, Tensor]):
A tuple of the form (lossy_reconstruction, sst, dst) with three tensors of the same
shape as the dense tensor.
"""

if _is_sparsity_zero(
dense, self._sst_top_k_percent, self._sst_top_k_element, self._sst_top_k_dim
) and _is_sparsity_zero(dense, self._dst_top_k_percent, self._dst_top_k_element, self._dst_top_k_dim):
# when sparsity is 0% for both sst and dst, the dense tensor itself is returned as the reconstructed
# tensor, sst is returned as None and dst as the dense tensor. This choice is made because with the
# returned sst=None and dst=dense, we should be able to recombine them if needed to retrieve the
# dense tensor again as: dense = inv_transform(sst) + dst, where inv_transform(sst=None) = zero_tensor
# of the same size as dense.
return dense, None, dense
else:
sst = self.dense_to_sst(dense)
dst = self.dense_sst_to_dst(dense, sst)
return self.sst_dst_to_dense(sst, dst), sst, dst


# We could separate have helper functions that work on state_dict instead of a tensor.
# One option is to extend the above class to handle state_dict as well as tensor
Expand Down
29 changes: 29 additions & 0 deletions tests/experimental/wgit/test_signal_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,32 @@ def test_sst_dst_to_dense(unused1, sst, dst, expd_rt, dim, unused2, k):
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
dense_recons = sparser.sst_dst_to_dense(sst, dst)
objects_are_equal(dense_recons, expd_rt, raise_exception=True)


@pytest.mark.parametrize("tensor, expd_sst, expd_dst, expd_rt, dim, unused, k", get_test_params())
def test_lossy_compress(tensor, expd_sst, expd_dst, expd_rt, dim, unused, k):
"""Tests the lossy_compress method against expected sst, dst and reconstruced tensor."""
sparser = SignalSparsity(sst_top_k_element=k, sst_top_k_dim=dim, dst_top_k_element=k, dst_top_k_dim=dim)
lossy_dense, sst, dst = sparser.lossy_compress(tensor)
objects_are_equal(lossy_dense, expd_rt, raise_exception=True)
objects_are_equal(sst, expd_sst, raise_exception=True)
objects_are_equal(dst, expd_dst, raise_exception=True)


@pytest.mark.parametrize(
"tensor, dim, top_k_percent",
[
(torch.linspace(0.01, 0.06, 40).reshape(5, 8), 0, 100),
(torch.linspace(-0.01, 0.06, 42).reshape(7, 6), 0, 100),
(torch.linspace(-10, 15, 36).reshape(6, 6), 1, 100),
],
)
def test_lossy_compress_sparsity_0(tensor, dim, top_k_percent):
"""Tests whether lossy_compress method simply returns dense tensor when sparsity is 0."""
sparser = SignalSparsity(
sst_top_k_percent=top_k_percent, sst_top_k_dim=dim, dst_top_k_percent=top_k_percent, dst_top_k_dim=dim
)
lossy_dense, sst, dst = sparser.lossy_compress(tensor)
objects_are_equal(lossy_dense, tensor, raise_exception=True)
objects_are_equal(sst, None, raise_exception=True)
objects_are_equal(dst, tensor, raise_exception=True)

0 comments on commit 5c60f33

Please sign in to comment.