Skip to content

Commit

Permalink
Merge pull request #7276 from Hakuyume/tabular-helpers
Browse files Browse the repository at this point in the history
Add `chainer.dataset.tabular.DelegateDataset`
  • Loading branch information
beam2d committed Aug 20, 2019
2 parents 4bebd21 + 6bd5faa commit 4381df8
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainer/dataset/tabular/__init__.py
Expand Up @@ -5,4 +5,5 @@
from chainer.dataset.tabular import _transform # NOQA
from chainer.dataset.tabular import _with_converter # NOQA

from chainer.dataset.tabular.delegate_dataset import DelegateDataset # NOQA
from chainer.dataset.tabular.from_data import from_data # NOQA
59 changes: 59 additions & 0 deletions chainer/dataset/tabular/delegate_dataset.py
@@ -0,0 +1,59 @@
from chainer.dataset.tabular import tabular_dataset


class DelegateDataset(tabular_dataset.TabularDataset):

"""A helper class to implement a TabularDataset.
This class wraps an instance of :class:`~chainer.dataset.TabularDataset`
and provides methods of :class:`~chainer.dataset.TabularDataset`.
This class is useful to create a custom dataset class by inheriting it.
>>> import numpy as np
>>>
>>> from chainer.dataset import tabular
>>>
>>> class MyDataset(tabular.DelegateDataset):
...
... def __init__(self):
... super().__init__(tabular.from_data((
... ('a', np.arange(10)),
... ('b', self.get_b),
... ('c', [3, 1, 4, 5, 9, 2, 6, 8, 7, 0]),
... (('d', 'e'), self.get_de))))
...
... def get_b(self, i):
... return 'b[{}]'.format(i)
...
... def get_de(self, i):
... return {'d': 'd[{}]'.format(i), 'e': 'e[{}]'.format(i)}
...
>>> dataset = MyDataset()
>>> len(dataset)
10
>>> dataset.keys
('a', 'b', 'c', 'd', 'e')
>>> dataset[0]
(0, 'b[0]', 3, 'd[0]', 'e[0]')
Args:
dataset (chainer.dataset.TabularDataset): An underlying dataset.
"""

def __init__(self, dataset):
self.dataset = dataset

def __len__(self):
return len(self.dataset)

@property
def keys(self):
return self.dataset.keys

@property
def mode(self):
return self.dataset.mode

def get_examples(self, indices, key_indices):
return self.dataset.get_examples(indices, key_indices)
1 change: 1 addition & 0 deletions docs/source/reference/datasets.rst
Expand Up @@ -49,6 +49,7 @@ Tabular Dataset Helpers
:toctree: generated/
:nosignatures:

chainer.dataset.tabular.DelegateDataset
chainer.dataset.tabular.from_data


Expand Down
@@ -0,0 +1,30 @@
import unittest

import chainer
from chainer.dataset import tabular
from chainer import testing


from chainer_tests.dataset_tests.tabular_tests import dummy_dataset


@testing.parameterize(
{'mode': tuple},
{'mode': dict},
{'mode': None},
)
class TestDelegateDataset(unittest.TestCase):

def test_delegate_dataset(self):
dataset = tabular.DelegateDataset(
dummy_dataset.DummyDataset(mode=self.mode))

self.assertIsInstance(dataset, chainer.dataset.TabularDataset)
self.assertEqual(len(dataset), len(dataset.dataset))
self.assertEqual(dataset.keys, dataset.dataset.keys)
self.assertEqual(dataset.mode, dataset.dataset.mode)
self.assertEqual(
dataset.get_example(3), dataset.dataset.get_example(3))


testing.run_module(__name__, __file__)

0 comments on commit 4381df8

Please sign in to comment.