Skip to content

Commit

Permalink
data: perform downsampling in multiplexer provider (tensorflow#3272)
Browse files Browse the repository at this point in the history
Summary:
The `MultiplexerDataProvider` now respects its `downsample` parameter,
even though the backing `PluginEventMultiplexer` already performs its
own sampling. This serves two purposes:

  - It enforces that clients are always specifying the `downsample`
    argument, which is required.
  - It enables us to test plugins’ downsampling parameters to verify
    that they will behave correctly with other data providers.

Test Plan:
Unit tests included. Note that changing the `_DEFAULT_DOWNSAMPLING`
constant in (e.g.) the scalars plugin to a small number (like `5`) now
actually causes charts in the frontend to be downsampled.

wchargin-branch: data-mux-downsample
  • Loading branch information
wchargin committed Feb 21, 2020
1 parent 5b7f9ad commit 4670696
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 15 deletions.
69 changes: 56 additions & 13 deletions tensorboard/backend/event_processing/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import base64
import collections
import json
import random

import six

Expand Down Expand Up @@ -57,6 +58,16 @@ def _validate_experiment_id(self, experiment_id):
% (str, type(experiment_id), experiment_id)
)

def _validate_downsample(self, downsample):
if downsample is None:
raise TypeError("`downsample` required but not given")
if isinstance(downsample, int):
return # OK
raise TypeError(
"`downsample` must be an int, but got %r: %r"
% (type(downsample), downsample)
)

def _test_run_tag(self, run_tag_filter, run, tag):
runs = run_tag_filter.runs
if runs is not None and run not in runs:
Expand Down Expand Up @@ -109,14 +120,11 @@ def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None):
def read_scalars(
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
):
# TODO(@wchargin): Downsampling not implemented, as the multiplexer
# is already downsampled. We could downsample on top of the existing
# sampling, which would be nice for testing.
del downsample # ignored for now
self._validate_downsample(downsample)
index = self.list_scalars(
experiment_id, plugin_name, run_tag_filter=run_tag_filter
)
return self._read(_convert_scalar_event, index)
return self._read(_convert_scalar_event, index, downsample)

def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
self._validate_experiment_id(experiment_id)
Expand All @@ -131,14 +139,11 @@ def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
def read_tensors(
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
):
# TODO(@wchargin): Downsampling not implemented, as the multiplexer
# is already downsampled. We could downsample on top of the existing
# sampling, which would be nice for testing.
del downsample # ignored for now
self._validate_downsample(downsample)
index = self.list_tensors(
experiment_id, plugin_name, run_tag_filter=run_tag_filter
)
return self._read(_convert_tensor_event, index)
return self._read(_convert_tensor_event, index, downsample)

def _list(
self,
Expand Down Expand Up @@ -191,13 +196,15 @@ def _list(
)
return result

def _read(self, convert_event, index):
def _read(self, convert_event, index, downsample):
"""Helper to read scalar or tensor data from the multiplexer.
Args:
convert_event: Takes `plugin_event_accumulator.TensorEvent` to
either `provider.ScalarDatum` or `provider.TensorDatum`.
index: The result of `list_scalars` or `list_tensors`.
downsample: Non-negative `int`; how many samples to return per
time series.
Returns:
A dict of dicts of values returned by `convert_event` calls,
Expand All @@ -209,7 +216,8 @@ def _read(self, convert_event, index):
result[run] = result_for_run
for (tag, metadata) in six.iteritems(tags_for_run):
events = self._multiplexer.Tensors(run, tag)
result_for_run[tag] = [convert_event(e) for e in events]
data = [convert_event(e) for e in events]
result_for_run[tag] = _downsample(data, downsample)
return result

def list_blob_sequences(
Expand Down Expand Up @@ -258,6 +266,7 @@ def read_blob_sequences(
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
):
self._validate_experiment_id(experiment_id)
self._validate_downsample(downsample)
index = self.list_blob_sequences(
experiment_id, plugin_name, run_tag_filter=run_tag_filter
)
Expand All @@ -275,7 +284,7 @@ def read_blob_sequences(
experiment_id, plugin_name, run, tag, event
)
data = [datum for (step, datum) in sorted(data_by_step.items())]
result_for_run[tag] = data
result_for_run[tag] = _downsample(data, downsample)
return result

def read_blob(self, blob_key):
Expand Down Expand Up @@ -411,3 +420,37 @@ def _tensor_size(tensor_proto):
for dim in tensor_proto.tensor_shape.dim:
result *= dim.size
return result


def _downsample(xs, k):
"""Downsample `xs` to at most `k` elements.
If `k` is larger than `xs`, then the contents of `xs` itself will be
returned. If `k` is smaller than `xs`, the last element of `xs` will
always be included (unless `k` is `0`) and the preceding elements
will be selected uniformly at random.
This differs from `random.sample` in that it returns a subsequence
(i.e., order is preserved) and that it permits `k > len(xs)`.
The random number generator will always be `random.Random(0)`, so
this function is deterministic (within a Python process).
Args:
xs: A sequence (`collections.abc.Sequence`).
k: A non-negative integer.
Returns:
A new list whose elements are a subsequence of `xs` of length
`min(k, len(xs))` and that is guaranteed to include the last
element of `xs`, uniformly selected among such subsequences.
"""

if k > len(xs):
return list(xs)
if k == 0:
return []
indices = random.Random(0).sample(six.moves.xrange(len(xs) - 1), k - 1)
indices.sort()
indices += [len(xs) - 1]
return [xs[i] for i in indices]
57 changes: 55 additions & 2 deletions tensorboard/backend/event_processing/data_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_read_scalars(self):
experiment_id="unused",
plugin_name=scalar_metadata.PLUGIN_NAME,
run_tag_filter=run_tag_filter,
downsample=None, # not yet implemented
downsample=100,
)

self.assertItemsEqual(result.keys(), ["polynomials", "waves"])
Expand All @@ -267,6 +267,18 @@ def test_read_scalars(self):
tensor_util.make_ndarray(event.tensor_proto).item(),
)

def test_read_scalars_downsamples(self):
multiplexer = self.create_multiplexer()
provider = data_provider.MultiplexerDataProvider(
multiplexer, self.logdir
)
result = provider.read_scalars(
experiment_id="unused",
plugin_name=scalar_metadata.PLUGIN_NAME,
downsample=3,
)
self.assertLen(result["waves"]["sine"], 3)

def test_read_scalars_but_not_rank_0(self):
provider = self.create_provider()
run_tag_filter = base_provider.RunTagFilter(["waves"], ["bad"])
Expand All @@ -280,6 +292,7 @@ def test_read_scalars_but_not_rank_0(self):
experiment_id="unused",
plugin_name="greetings",
run_tag_filter=run_tag_filter,
downsample=100,
)

def test_list_tensors_all(self):
Expand Down Expand Up @@ -329,7 +342,7 @@ def test_read_tensors(self):
experiment_id="unused",
plugin_name=histogram_metadata.PLUGIN_NAME,
run_tag_filter=run_tag_filter,
downsample=None, # not yet implemented
downsample=100,
)

self.assertItemsEqual(result.keys(), ["lebesgue"])
Expand All @@ -346,6 +359,46 @@ def test_read_tensors(self):
tensor_util.make_ndarray(event.tensor_proto),
)

def test_read_tensors_downsamples(self):
multiplexer = self.create_multiplexer()
provider = data_provider.MultiplexerDataProvider(
multiplexer, self.logdir
)
result = provider.read_tensors(
experiment_id="unused",
plugin_name=histogram_metadata.PLUGIN_NAME,
downsample=3,
)
self.assertLen(result["lebesgue"]["uniform"], 3)


class DownsampleTest(tf.test.TestCase):
"""Tests for the `_downsample` private helper function."""

def test_deterministic(self):
xs = "abcdefg"
expected = data_provider._downsample(xs, k=4)
for _ in range(100):
actual = data_provider._downsample(xs, k=4)
self.assertEqual(actual, expected)

def test_underlong_ok(self):
xs = list("abcdefg")
actual = data_provider._downsample(xs, k=10)
expected = list("abcdefg")
self.assertIsNot(actual, xs)
self.assertEqual(actual, expected)

def test_inorder(self):
xs = list(range(10000))
actual = data_provider._downsample(xs, k=100)
self.assertEqual(actual, sorted(actual))

def test_zero(self):
xs = "abcdefg"
actual = data_provider._downsample(xs, k=0)
self.assertEqual(actual, [])


if __name__ == "__main__":
tf.test.main()
1 change: 1 addition & 0 deletions tensorboard/plugins/graph/graphs_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def graph_impl(
experiment_id=experiment,
plugin_name=metadata.PLUGIN_NAME,
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
downsample=1,
)
blob_datum_list = graph_blob_sequences.get(run, {}).get(tag, ())
try:
Expand Down

0 comments on commit 4670696

Please sign in to comment.