Skip to content

Commit

Permalink
Merge pull request #7493 from Hakuyume/tabular-unary
Browse files Browse the repository at this point in the history
Add unary mode to TabularDataset
  • Loading branch information
beam2d committed Jun 26, 2019
2 parents 3029bba + 479f84b commit eccd20d
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 69 deletions.
5 changes: 4 additions & 1 deletion chainer/dataset/tabular/_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def keys(self):

@property
def mode(self):
return self._datasets[0].mode
for dataset in self._datasets:
if dataset.mode:
return dataset.mode
return tuple

def get_examples(self, indices, key_indices):
if key_indices is None:
Expand Down
15 changes: 14 additions & 1 deletion chainer/dataset/tabular/_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
class _Slice(tabular_dataset.TabularDataset):

def __init__(self, dataset, indices, keys):
if keys is None:
self._unary = None
elif isinstance(keys, tuple):
self._unary = False
else:
self._unary = True
keys = keys,

self._dataset = dataset
self._indices = _as_indices(indices, len(dataset))
self._key_indices = _as_key_indices(keys, dataset.keys)
Expand All @@ -32,7 +40,12 @@ def keys(self):

@property
def mode(self):
return self._dataset.mode
if self._unary is None:
return self._dataset.mode
elif self._unary:
return None
else:
return self._dataset.mode or tuple

def get_examples(self, indices, key_indices):
indices = _merge_indices(
Expand Down
41 changes: 29 additions & 12 deletions chainer/dataset/tabular/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def __init__(self, dataset, keys, transform):

self._dataset = dataset
self._keys = keys
self._mode = None
self._transform = transform

def __len__(self):
Expand All @@ -23,7 +22,7 @@ def keys(self):

@property
def mode(self):
if self._mode is None:
if not hasattr(self, '_mode'):
self.get_examples([0], None)
return self._mode

Expand All @@ -40,24 +39,32 @@ def get_examples(self, indices, key_indices):
elif self._dataset.mode is dict:
out_example = self._transform(
**dict(six.moves.zip(self._dataset.keys, in_example)))
elif self._dataset.mode is None:
out_example = self._transform(*in_example)

if not isinstance(out_example, (tuple, dict)):
out_example = out_example,
if isinstance(out_example, tuple):
if self._mode and self._mode is not tuple:
if hasattr(self, '_mode') and self._mode is not tuple:
raise ValueError(
'transform must not change its return type')
self._mode = tuple
for col_index, key_index in enumerate(key_indices):
out_examples[col_index].append(out_example[key_index])
elif isinstance(out_example, dict):
if self._mode and self._mode is not dict:
if hasattr(self, '_mode') and self._mode is not dict:
raise ValueError(
'transform must not change its return type')
self._mode = dict
for col_index, key_index in enumerate(key_indices):
out_examples[col_index].append(
out_example[self._keys[key_index]])
else:
if hasattr(self, '_mode') and self._mode is not None:
raise ValueError(
'transform must not change its return type')
self._mode = None
out_example = out_example,
for col_index, key_index in enumerate(key_indices):
out_examples[col_index].append(out_example[key_index])

return out_examples

Expand All @@ -70,7 +77,6 @@ def __init__(self, dataset, keys, transform_batch):

self._dataset = dataset
self._keys = keys
self._mode = None
self._transform_batch = transform_batch

def __len__(self):
Expand All @@ -82,7 +88,7 @@ def keys(self):

@property
def mode(self):
if self._mode is None:
if not hasattr(self, '_mode'):
self.get_examples([0], None)
return self._mode

Expand All @@ -105,11 +111,11 @@ def get_examples(self, indices, key_indices):
elif self._dataset.mode is dict:
out_examples = self._transform_batch(
**dict(six.moves.zip(self._dataset.keys, in_examples)))
elif self._dataset.mode is None:
out_examples = self._transform_batch(*in_examples)

if not isinstance(out_examples, (tuple, dict)):
out_examples = out_examples,
if isinstance(out_examples, tuple):
if self._mode and self._mode is not tuple:
if hasattr(self, '_mode') and self._mode is not tuple:
raise ValueError(
'transform_batch must not change its return type')
self._mode = tuple
Expand All @@ -119,7 +125,7 @@ def get_examples(self, indices, key_indices):
return tuple(out_examples[key_index]
for key_index in key_indices)
elif isinstance(out_examples, dict):
if self._mode and self._mode is not dict:
if hasattr(self, '_mode') and self._mode is not dict:
raise ValueError(
'transform_batch must not change its return type')
self._mode = dict
Expand All @@ -128,3 +134,14 @@ def get_examples(self, indices, key_indices):
'transform_batch must not change the length of data')
return tuple(out_examples[self._keys[key_index]]
for key_index in key_indices)
else:
if hasattr(self, '_mode') and self._mode is not None:
raise ValueError(
'transform_batch must not change its return type')
self._mode = None
out_examples = out_examples,
if not all(len(col) == len_ for col in out_examples):
raise ValueError(
'transform_batch must not change the length of data')
return tuple(out_examples[key_index]
for key_index in key_indices)
15 changes: 11 additions & 4 deletions chainer/dataset/tabular/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class TabularDataset(dataset_mixin.DatasetMixin):
Since an example can be represented by both tuple and dict (
:obj:`(a[i], b[i], c[i])` and :obj:`{'a': a[i], 'b': b[i], 'c': c[i]}`),
this class uses :attr:`mode` to indicate which representation will be used.
If there is only one column, an example also can be represented by a value
(:obj:`a[i]`). In this case, :attr:`mode` is :obj:`None`.
An inheritance should implement
:meth:`__len__`, :attr:`keys`, :attr:`mode` and :meth:`get_examples`.
Expand Down Expand Up @@ -91,7 +93,7 @@ def mode(self):
This indicates the type of value returned
by :meth:`fetch` and :meth:`__getitem__`.
:class:`tuple` and :class:`dict` are supported.
:class:`tuple`, :class:`dict`, and :obj:`None` are supported.
"""
raise NotImplementedError

Expand All @@ -115,7 +117,7 @@ def slice(self):
Args:
indices (list/array of ints/bools or slice): Requested rows.
keys (tuple of ints/strs): Requested columns.
keys (tuple of ints/strs or int or str): Requested columns.
Returns:
A view of specified range.
Expand All @@ -127,8 +129,9 @@ def fetch(self):
This method fetches all data of the dataset/view.
Note that this method returns a column-major data
(i.e. :obj:`([a[0], ..., a[3]], ..., [c[0], ... c[3]])` or
:obj:`{'a': [a[0], ..., a[3]], ..., 'c': [c[0], ..., c[3]]}`).
(i.e. :obj:`([a[0], ..., a[3]], ..., [c[0], ... c[3]])`,
:obj:`{'a': [a[0], ..., a[3]], ..., 'c': [c[0], ..., c[3]]}`, or
:obj:`[a[0], ..., a[3]]`).
Returns:
If :attr:`mode` is :class:`tuple`,
Expand All @@ -141,6 +144,8 @@ def fetch(self):
return examples
elif self.mode is dict:
return dict(six.moves.zip(self.keys, examples))
elif self.mode is None:
return examples[0]

def as_tuple(self):
"""Return a view with tuple mode.
Expand Down Expand Up @@ -223,3 +228,5 @@ def get_example(self, i):
return example
elif self.mode is dict:
return dict(six.moves.zip(self.keys, example))
elif self.mode is None:
return example[0]
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class DummyDataset(chainer.dataset.TabularDataset):
def __init__(
self, size=10, keys=('a', 'b', 'c'), mode=tuple,
return_array=False, callback=None):
if mode is None:
keys = keys[0],

self._keys = keys
self._mode = mode
self._return_array = return_array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
@testing.parameterize(
{'mode': tuple},
{'mode': dict},
{'mode': None},
)
class TestAsTuple(unittest.TestCase):

Expand All @@ -23,6 +24,7 @@ def test_as_tuple(self):
@testing.parameterize(
{'mode': tuple},
{'mode': dict},
{'mode': None},
)
class TestAsDict(unittest.TestCase):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

@testing.parameterize(*testing.product_dict(
testing.product({
'mode_a': [tuple, dict],
'mode_b': [tuple, dict],
'mode_a': [tuple, dict, None],
'mode_b': [tuple, dict, None],
'return_array': [True, False],
}),
[
Expand Down Expand Up @@ -42,6 +42,7 @@ def callback_a(indices, key_indices):
self.assertIsNone(key_indices)

dataset_a = dummy_dataset.DummyDataset(
keys=('a', 'b', 'c') if self.mode_b else ('a',),
mode=self.mode_a,
return_array=self.return_array, callback=callback_a)

Expand All @@ -50,7 +51,9 @@ def callback_b(indices, key_indices):
self.assertIsNone(key_indices)

dataset_b = dummy_dataset.DummyDataset(
size=5, mode=self.mode_b,
size=5,
keys=('a', 'b', 'c') if self.mode_a else ('a',),
mode=self.mode_b,
return_array=self.return_array, callback=callback_b)

view = dataset_a.concat(dataset_b)
Expand Down
63 changes: 42 additions & 21 deletions tests/chainer_tests/dataset_tests/tabular_tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,49 @@
from chainer_tests.dataset_tests.tabular_tests import dummy_dataset


@testing.parameterize(*testing.product_dict(
testing.product({
'mode_a': [tuple, dict],
'mode_b': [tuple, dict],
'return_array': [True, False],
}),
[
{'key_indices': None,
'expected_key_indices_a': None,
'expected_key_indices_b': None},
{'key_indices': (0, 4, 1),
'expected_key_indices_a': (0, 1),
'expected_key_indices_b': (1,)},
{'key_indices': (0, 2),
'expected_key_indices_a': (0, 2)},
{'key_indices': (1, 2, 1),
'expected_key_indices_a': (1, 2)},
{'key_indices': ()},
],
))
def _filter_params(params):
for param in params:
key_size = 0
key_size += 3 if param['mode_a'] else 1
key_size += 2 if param['mode_b'] else 1

if param['key_indices'] and \
any(key_size <= key_index for key_index in param['key_indices']):
continue

yield param


@testing.parameterize(*_filter_params(testing.product({
'mode_a': [tuple, dict, None],
'mode_b': [tuple, dict, None],
'return_array': [True, False],
'key_indices': [None, (0, 4, 1), (0, 2), (1, 0), ()],
})))
class TestJoin(unittest.TestCase):

def setUp(self):
if self.key_indices is None:
self.expected_key_indices_a = None
self.expected_key_indices_b = None
return

key_size_a = 3 if self.mode_a else 1

key_indices_a = tuple(
key_index
for key_index in self.key_indices
if key_index < key_size_a)
key_indices_b = tuple(
key_index - key_size_a
for key_index in self.key_indices
if key_size_a <= key_index)

if key_indices_a:
self.expected_key_indices_a = key_indices_a
if key_indices_b:
self.expected_key_indices_b = key_indices_b

def test_join(self):
def callback_a(indices, key_indices):
self.assertIsNone(indices)
Expand All @@ -51,7 +72,7 @@ def callback_b(indices, key_indices):
self.assertIsInstance(view, chainer.dataset.TabularDataset)
self.assertEqual(len(view), len(dataset_a))
self.assertEqual(view.keys, dataset_a.keys + dataset_b.keys)
self.assertEqual(view.mode, dataset_a.mode)
self.assertEqual(view.mode, dataset_a.mode or dataset_b.mode or tuple)

output = view.get_examples(None, self.key_indices)

Expand Down
Loading

0 comments on commit eccd20d

Please sign in to comment.