-
Notifications
You must be signed in to change notification settings - Fork 264
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] signal sparsity profiling class (#1060)
* added a profiling class * no more type ignore after merging main * fixed a int/round bug * add unit tests * skip if no cuda for a test Co-authored-by: Min Xu <min.xu.public@gmail.com>
- Loading branch information
Showing
4 changed files
with
138 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import List | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
class EnergyConcentrationProfile: | ||
"""Compute "energy" concentration level for a tensor | ||
Args: | ||
dim (int): | ||
The dimension to measure. | ||
top_k_percents (List[float]): | ||
List of percentage values. For each value, the `measure` | ||
function will compute and return the percentage of "energy" | ||
concentrated on that top-K percent of values in the dimension | ||
to measure. Note, this is the opposite of the sparsity percentage. | ||
""" | ||
|
||
def __init__(self, dim: int, top_k_percents: List[float]) -> None: | ||
assert isinstance(dim, int) | ||
self.dim = dim | ||
self.percents = [] | ||
last_p = 0.0 | ||
for p in top_k_percents: | ||
assert isinstance(p, (int, float)) | ||
assert p > 0, p | ||
assert p <= 100, p | ||
assert p > last_p, f"p {p} should be larger than last_p {last_p}" | ||
self.percents.append(float(p)) | ||
last_p = p | ||
|
||
def measure(self, in_tensor: Tensor) -> List[Tensor]: | ||
"""Compute the return the results | ||
Note, we want this function to be nonblocking and async. | ||
Returns: | ||
(List[Tensor]) | ||
List of tensors. Each tensor is a singleton float | ||
that contains the energy measure for that top_k_percent. | ||
""" | ||
assert in_tensor.is_floating_point(), in_tensor.dtype | ||
assert self.dim < len(in_tensor.shape), f"tensor shape {in_tensor.shape} not compatible with dim {self.dim}" | ||
dim_size = in_tensor.shape[self.dim] | ||
abs_tensor = in_tensor.abs() | ||
full_energy = abs_tensor.sum() | ||
return_tensors = [] | ||
for p in self.percents: | ||
k = max(1, round(p / 100 * dim_size)) | ||
abs_top_k_values, _ = abs_tensor.topk(k, dim=self.dim) | ||
return_tensors.append(abs_top_k_values.sum() / full_energy) | ||
return return_tensors | ||
|
||
def measure_fft(self, in_tensor: Tensor) -> List[Tensor]: | ||
"""Like measure, but do it in FFT frequency domain.""" | ||
return self.measure(torch.fft.fft(in_tensor, dim=self.dim).real) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import time | ||
|
||
import pytest | ||
import torch | ||
|
||
from fair_dev.testing.testing import objects_are_equal, skip_if_no_cuda | ||
from fairscale.experimental.wgit.signal_sparsity_profiling import EnergyConcentrationProfile as ECP | ||
|
||
# Our own tolerance | ||
ATOL = 1e-6 | ||
RTOL = 1e-5 | ||
|
||
# enable this for debugging. | ||
# torch.set_printoptions(precision=20) | ||
|
||
|
||
@skip_if_no_cuda | ||
def test_nonblocking(): | ||
"""Tests cpu runs ahead of the GPU in the measuring process.""" | ||
big = torch.rand(10, 1000, 1000).cuda() | ||
ecp = ECP(dim=2, top_k_percents=[1, 5, 10, 50, 90]) | ||
start = time.time() | ||
out = ecp.measure(big) | ||
out_fft = ecp.measure_fft(big) | ||
cpu_time = time.time() - start | ||
torch.cuda.synchronize() | ||
gpu_time = time.time() - start | ||
assert cpu_time * 5 < gpu_time, f"GPU time should dominate {cpu_time} vs. {gpu_time}" | ||
for o in [out, out_fft]: | ||
# validate the output | ||
p = [x.item() for x in o] | ||
for n, n1 in zip(p, p[1:]): | ||
assert n <= n1 and n >= 0 and n <= 100, f"n={n} n1={n1}" | ||
|
||
|
||
def get_ones(): | ||
"""Return test data with ones tensor""" | ||
return ( | ||
0, | ||
[1, 5, 10, 100], | ||
torch.ones(100), | ||
[torch.tensor(0.01), torch.tensor(0.05), torch.tensor(0.1), torch.tensor(1.0)], | ||
) | ||
|
||
|
||
def get_dim_0(): | ||
"""Test case for dim=0 for 2D input.""" | ||
return ( | ||
0, | ||
[1, 3, 33, 66, 90], | ||
torch.tensor([0.1, 0.2, 0.1, 0.45]).repeat(100, 1), | ||
[torch.tensor(0.01), torch.tensor(0.03), torch.tensor(0.33), torch.tensor(0.66), torch.tensor(0.9)], | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"dim, percents, in_tensor, out_tensors", | ||
[ | ||
get_ones(), | ||
get_dim_0(), | ||
], | ||
) | ||
def test_expected_output(dim, percents, in_tensor, out_tensors): | ||
"""Test with a few expected input & outputs.""" | ||
ecp = ECP(dim, percents) | ||
out = ecp.measure(in_tensor) | ||
objects_are_equal(out, out_tensors, raise_exception=True, rtol=RTOL, atol=ATOL) | ||
out_fft = ecp.measure_fft(torch.fft.ifft(in_tensor, dim=dim)) | ||
objects_are_equal(out_fft, out_tensors, raise_exception=True, rtol=RTOL, atol=ATOL) |