Skip to content

Commit

Permalink
[feat] signal sparsity profiling class (#1060)
Browse files Browse the repository at this point in the history
* added a profiling class

* no more type ignore after merging main

* fixed a int/round bug

* add unit tests

* skip if no cuda for a test

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Aug 11, 2022
1 parent 4c830de commit e982b43
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 0 deletions.
1 change: 1 addition & 0 deletions fairscale/experimental/wgit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .repo import Repo
from .signal_sparsity import Algo, SignalSparsity
from .signal_sparsity_profiling import EnergyConcentrationProfile
from .version import __version_tuple__

__version__ = ".".join([str(x) for x in __version_tuple__])
Expand Down
62 changes: 62 additions & 0 deletions fairscale/experimental/wgit/signal_sparsity_profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import torch
from torch import Tensor


class EnergyConcentrationProfile:
"""Compute "energy" concentration level for a tensor
Args:
dim (int):
The dimension to measure.
top_k_percents (List[float]):
List of percentage values. For each value, the `measure`
function will compute and return the percentage of "energy"
concentrated on that top-K percent of values in the dimension
to measure. Note, this is the opposite of the sparsity percentage.
"""

def __init__(self, dim: int, top_k_percents: List[float]) -> None:
assert isinstance(dim, int)
self.dim = dim
self.percents = []
last_p = 0.0
for p in top_k_percents:
assert isinstance(p, (int, float))
assert p > 0, p
assert p <= 100, p
assert p > last_p, f"p {p} should be larger than last_p {last_p}"
self.percents.append(float(p))
last_p = p

def measure(self, in_tensor: Tensor) -> List[Tensor]:
"""Compute the return the results
Note, we want this function to be nonblocking and async.
Returns:
(List[Tensor])
List of tensors. Each tensor is a singleton float
that contains the energy measure for that top_k_percent.
"""
assert in_tensor.is_floating_point(), in_tensor.dtype
assert self.dim < len(in_tensor.shape), f"tensor shape {in_tensor.shape} not compatible with dim {self.dim}"
dim_size = in_tensor.shape[self.dim]
abs_tensor = in_tensor.abs()
full_energy = abs_tensor.sum()
return_tensors = []
for p in self.percents:
k = max(1, round(p / 100 * dim_size))
abs_top_k_values, _ = abs_tensor.topk(k, dim=self.dim)
return_tensors.append(abs_top_k_values.sum() / full_energy)
return return_tensors

def measure_fft(self, in_tensor: Tensor) -> List[Tensor]:
"""Like measure, but do it in FFT frequency domain."""
return self.measure(torch.fft.fft(in_tensor, dim=self.dim).real)
1 change: 1 addition & 0 deletions tests/ci_test_list_3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ tests/experimental/wgit/test_api.py
tests/experimental/wgit/test_pygit.py
tests/experimental/wgit/test_sha1_store.py
tests/experimental/wgit/test_signal_sparsity.py
tests/experimental/wgit/test_signal_sparsity_profiling.py
74 changes: 74 additions & 0 deletions tests/experimental/wgit/test_signal_sparsity_profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import time

import pytest
import torch

from fair_dev.testing.testing import objects_are_equal, skip_if_no_cuda
from fairscale.experimental.wgit.signal_sparsity_profiling import EnergyConcentrationProfile as ECP

# Our own tolerance
ATOL = 1e-6
RTOL = 1e-5

# enable this for debugging.
# torch.set_printoptions(precision=20)


@skip_if_no_cuda
def test_nonblocking():
"""Tests cpu runs ahead of the GPU in the measuring process."""
big = torch.rand(10, 1000, 1000).cuda()
ecp = ECP(dim=2, top_k_percents=[1, 5, 10, 50, 90])
start = time.time()
out = ecp.measure(big)
out_fft = ecp.measure_fft(big)
cpu_time = time.time() - start
torch.cuda.synchronize()
gpu_time = time.time() - start
assert cpu_time * 5 < gpu_time, f"GPU time should dominate {cpu_time} vs. {gpu_time}"
for o in [out, out_fft]:
# validate the output
p = [x.item() for x in o]
for n, n1 in zip(p, p[1:]):
assert n <= n1 and n >= 0 and n <= 100, f"n={n} n1={n1}"


def get_ones():
"""Return test data with ones tensor"""
return (
0,
[1, 5, 10, 100],
torch.ones(100),
[torch.tensor(0.01), torch.tensor(0.05), torch.tensor(0.1), torch.tensor(1.0)],
)


def get_dim_0():
"""Test case for dim=0 for 2D input."""
return (
0,
[1, 3, 33, 66, 90],
torch.tensor([0.1, 0.2, 0.1, 0.45]).repeat(100, 1),
[torch.tensor(0.01), torch.tensor(0.03), torch.tensor(0.33), torch.tensor(0.66), torch.tensor(0.9)],
)


@pytest.mark.parametrize(
"dim, percents, in_tensor, out_tensors",
[
get_ones(),
get_dim_0(),
],
)
def test_expected_output(dim, percents, in_tensor, out_tensors):
"""Test with a few expected input & outputs."""
ecp = ECP(dim, percents)
out = ecp.measure(in_tensor)
objects_are_equal(out, out_tensors, raise_exception=True, rtol=RTOL, atol=ATOL)
out_fft = ecp.measure_fft(torch.fft.ifft(in_tensor, dim=dim))
objects_are_equal(out_fft, out_tensors, raise_exception=True, rtol=RTOL, atol=ATOL)

0 comments on commit e982b43

Please sign in to comment.