From aec21403d246c3a5e5d1575da66c9a6985e9d045 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 25 Jan 2019 19:01:14 +0700 Subject: [PATCH 001/162] Put pytest as a requirement --- requirements.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/requirements.txt b/requirements.txt index f0fecb0..d4af228 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,17 @@ +atomicwrites==1.2.1 attrs==18.2.0 certifi==2018.11.29 flake8==3.6.0 flake8-mypy==17.8.0 mccabe==0.6.1 +more-itertools==5.0.0 mypy==0.660 mypy-extensions==0.4.1 +pluggy==0.8.1 +py==1.7.0 pycodestyle==2.4.0 pyflakes==2.0.0 +pytest==4.1.1 +six==1.12.0 typed-ast==1.2.0 yapf==0.25.0 From 4e30dcb253a41543fb4e296e004338fa7eaf4c75 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 25 Jan 2019 20:06:33 +0700 Subject: [PATCH 002/162] Add flake8 and yapf config files --- .flake8 | 6 ++++++ .style.yapf | 9 +++++++++ 2 files changed, 15 insertions(+) create mode 100644 .flake8 create mode 100644 .style.yapf diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..60e3acd --- /dev/null +++ b/.flake8 @@ -0,0 +1,6 @@ +# Config file for flake8 +# Symlink this to ~/.config/flake8 + +[flake8] +ignore = E,W # let yapf handle stylistic issues +show-source = True diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000..6114d2c --- /dev/null +++ b/.style.yapf @@ -0,0 +1,9 @@ +# Config file for YAPF Python formatter +# Symlink this to ~/.config/yapf/style + +[style] +based_on_style = pep8 +coalesce_brackets = True +column_limit = 96 +split_before_first_argument = True +split_complex_comprehension = True From 1172de97952fb6947756128b16923bcd194a36c2 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 25 Jan 2019 20:06:52 +0700 Subject: [PATCH 003/162] Make package installable --- setup.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 setup.py diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..5dfd37d --- /dev/null +++ b/setup.py @@ -0,0 +1,24 @@ +from pathlib import Path +from setuptools import setup, find_packages + +readme = Path(__file__).resolve().parent / 'README.rst' + +setup( + name='text2tensor', + version='0.0.1', + description='Convert your NLP text data to tensors!', + long_description=readme.read_text(), + url='https://github.com/kmkurn/text2tensor', + author='Kemal Kurniawan', + author_email='kemal@kkurniawan.com', + license='MIT', + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.6', + ], + packages=find_packages(), + python_requires='>=3.6, <4', +) From 29cf0d97a23e97629d38d4c1efb1c01383cd91ae Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 25 Jan 2019 21:57:48 +0700 Subject: [PATCH 004/162] Implement Dataset class --- tests/conftest.py | 7 +++++ tests/test_dataset.py | 64 +++++++++++++++++++++++++++++++++++++++++ text2tensor/__init__.py | 57 ++++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/test_dataset.py create mode 100644 text2tensor/__init__.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..fbdb524 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +import pytest + + +@pytest.fixture +def setup_rng(): + import random + random.seed(42) diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..ca07728 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,64 @@ +from collections.abc import Sequence + +from text2tensor import Dataset + +import pytest + + +def test_init(): + dat = Dataset(range(5)) + assert isinstance(dat, Sequence) + + +def test_init_samples_non_sequence(): + with pytest.raises(TypeError) as exc: + Dataset(10) + assert '"samples" is not a sequence' in str(exc.value) + + +@pytest.fixture +def dataset(): + return Dataset(range(5)) + + +def test_getitem(dataset): + for i in range(5): + assert dataset[i] == i + + +def test_len(dataset): + assert len(dataset) == 5 + + +def test_shuffle(setup_rng, dataset): + before = list(dataset) + retval = dataset.shuffle() + assert list(dataset) != before + assert retval is dataset + + +def test_batch(dataset): + minibatches = dataset.batch(2) + assert isinstance(minibatches, Sequence) + assert len(minibatches) == 3 + assert minibatches[0] == range(0, 2) + assert minibatches[1] == range(2, 4) + assert minibatches[2] == range(4, 5) + + +def test_batch_exactly(dataset): + minibatches = dataset.batch_exactly(2) + assert isinstance(minibatches, Sequence) + assert len(minibatches) == 2 + assert minibatches[0] == range(0, 2) + assert minibatches[1] == range(2, 4) + + +def test_batch_nonpositive_batch_size(dataset): + with pytest.raises(ValueError) as exc: + dataset.batch(0) + assert 'batch size must be greater than 0' in str(exc.value) + + with pytest.raises(ValueError) as exc: + dataset.batch_exactly(0) + assert 'batch size must be greater than 0' in str(exc.value) diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py new file mode 100644 index 0000000..348f825 --- /dev/null +++ b/text2tensor/__init__.py @@ -0,0 +1,57 @@ +from collections.abc import MutableSequence, Sequence +import random + + +class Dataset(Sequence): + def __init__(self, samples: Sequence) -> None: + if not isinstance(samples, Sequence): + raise TypeError('"samples" is not a sequence') + + self._samples = samples + + def __getitem__(self, index): + return self._samples[index] + + def __len__(self) -> int: + return len(self._samples) + + def shuffle(self) -> 'Dataset': + if isinstance(self._samples, list): + random.shuffle(self._samples) + elif isinstance(self._samples, MutableSequence): + self._shuffle_inplace() + else: + self._shuffle_copy() + return self + + def batch(self, batch_size: int) -> list: + if batch_size <= 0: + raise ValueError('batch size must be greater than 0') + + minibatches = [] + for begin in range(0, len(self._samples), batch_size): + end = begin + batch_size + minibatches.append(self._samples[begin:end]) + return minibatches + + def batch_exactly(self, batch_size: int) -> list: + minibatches = self.batch(batch_size) + if len(self._samples) % batch_size != 0: + assert len(minibatches[-1]) < batch_size + minibatches = minibatches[:-1] + return minibatches + + def _shuffle_inplace(self) -> None: + assert isinstance(self._samples, MutableSequence) + n = len(self._samples) + for i in range(n): + j = random.randrange(n) + temp = self._samples[i] + self._samples[i] = self._samples[j] + self._samples[j] = temp + + def _shuffle_copy(self) -> None: + shuf_indices = list(range(len(self._samples))) + random.shuffle(shuf_indices) + shuf_samples = [self._samples[i] for i in shuf_indices] + self._samples = shuf_samples From eb98947325fc5c5e5f9b7be0261c8c32eaf8ccf7 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 25 Jan 2019 22:46:11 +0700 Subject: [PATCH 005/162] Implement StreamDataset class --- tests/test_stream_dataset.py | 84 ++++++++++++++++++++++++++++++++++++ text2tensor/__init__.py | 41 +++++++++++++++++- 2 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 tests/test_stream_dataset.py diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py new file mode 100644 index 0000000..d1784dd --- /dev/null +++ b/tests/test_stream_dataset.py @@ -0,0 +1,84 @@ +from collections.abc import Iterable +from itertools import takewhile + +import pytest + +from text2tensor import StreamDataset + + +class Counter: + def __init__(self): + self._count = 0 + + def __iter__(self): + while True: + yield self._count + self._count += 1 + + +def test_init(): + dat = StreamDataset(Counter()) + assert isinstance(dat, Iterable) + + +def test_init_stream_non_iterable(): + with pytest.raises(TypeError) as exc: + StreamDataset(5) + assert '"stream" is not iterable' in str(exc.value) + + +@pytest.fixture +def stream_dataset(): + return StreamDataset(Counter()) + + +@pytest.fixture +def finite_stream_dataset(): + return StreamDataset(range(11)) + + +def test_iter(stream_dataset): + it = takewhile(lambda x: x < 5, stream_dataset) + assert list(it) == list(range(5)) + + +def test_batch(finite_stream_dataset): + bsize = 2 + minibatches = finite_stream_dataset.batch(bsize) + assert isinstance(minibatches, Iterable) + + it = iter(minibatches) + assert next(it) == [0, 1] + assert next(it) == [2, 3] + assert next(it) == [4, 5] + while True: + try: + assert len(next(it)) <= bsize + except StopIteration: + break + + +def test_batch_exactly(finite_stream_dataset): + bsize = 2 + minibatches = finite_stream_dataset.batch_exactly(bsize) + assert isinstance(minibatches, Iterable) + + it = iter(minibatches) + assert next(it) == [0, 1] + assert next(it) == [2, 3] + assert next(it) == [4, 5] + while True: + try: + assert len(next(it)) == bsize + except StopIteration: + break + + +def test_batch_nonpositive_batch_size(stream_dataset): + with pytest.raises(ValueError) as exc: + stream_dataset.batch(0) + assert 'batch size must be greater than 0' in str(exc.value) + + with pytest.raises(ValueError) as exc: + stream_dataset.batch_exactly(0) + assert 'batch size must be greater than 0' in str(exc.value) diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index 348f825..d188906 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -1,4 +1,4 @@ -from collections.abc import MutableSequence, Sequence +from collections.abc import Iterable, Iterator, MutableSequence, Sequence import random @@ -55,3 +55,42 @@ def _shuffle_copy(self) -> None: random.shuffle(shuf_indices) shuf_samples = [self._samples[i] for i in shuf_indices] self._samples = shuf_samples + + +class StreamDataset(Iterable): + def __init__(self, stream: Iterable) -> None: + if not isinstance(stream, Iterable): + raise TypeError('"stream" is not iterable') + + self._stream = stream + + def __iter__(self) -> Iterator: + return iter(self._stream) + + def batch(self, batch_size: int) -> Iterable: + return _Minibatches(self._stream, batch_size) + + def batch_exactly(self, batch_size: int) -> Iterable: + return _Minibatches(self._stream, batch_size, drop=True) + + +class _Minibatches(Iterable): + def __init__(self, stream: Iterable, bsize: int, drop: bool = False) -> None: + if bsize <= 0: + raise ValueError('batch size must be greater than 0') + + self._stream = stream + self._bsize = bsize + self._drop = drop + + def __iter__(self) -> Iterator: + it, exhausted = iter(self._stream), False + while not exhausted: + minibatch: list = [] + while not exhausted and len(minibatch) < self._bsize: + try: + minibatch.append(next(it)) + except StopIteration: + exhausted = True + if not self._drop or len(minibatch) == self._bsize: + yield minibatch From 94db6d6a82a3368d62d5f2ba46035d2f598a9433 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 25 Jan 2019 23:11:23 +0700 Subject: [PATCH 006/162] Make an abstract class for datasets --- text2tensor/__init__.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index d188906..dc08679 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -1,8 +1,19 @@ from collections.abc import Iterable, Iterator, MutableSequence, Sequence +import abc import random -class Dataset(Sequence): +class DatasetABC(Iterable, metaclass=abc.ABCMeta): + @abc.abstractmethod + def batch(self, batch_size: int) -> Iterable: + pass + + @abc.abstractmethod + def batch_exactly(self, batch_size: int) -> Iterable: + pass + + +class Dataset(DatasetABC, Sequence): def __init__(self, samples: Sequence) -> None: if not isinstance(samples, Sequence): raise TypeError('"samples" is not a sequence') @@ -57,7 +68,7 @@ def _shuffle_copy(self) -> None: self._samples = shuf_samples -class StreamDataset(Iterable): +class StreamDataset(DatasetABC, Iterable): def __init__(self, stream: Iterable) -> None: if not isinstance(stream, Iterable): raise TypeError('"stream" is not iterable') From a0dd1170044f115004371b2897b8a756df30104a Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 05:07:35 +0700 Subject: [PATCH 007/162] Install pytest-cov --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index d4af228..c9fe43a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ atomicwrites==1.2.1 attrs==18.2.0 certifi==2018.11.29 +coverage==4.5.2 flake8==3.6.0 flake8-mypy==17.8.0 mccabe==0.6.1 @@ -12,6 +13,7 @@ py==1.7.0 pycodestyle==2.4.0 pyflakes==2.0.0 pytest==4.1.1 +pytest-cov==2.6.1 six==1.12.0 typed-ast==1.2.0 yapf==0.25.0 From 1f809f86127e89998767f0e9ab838caf45961909 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 05:09:01 +0700 Subject: [PATCH 008/162] Add pytest config file --- pytest.ini | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 pytest.ini diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..fc62745 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --cov --cov-report term-missing \ No newline at end of file From d3f16d81a06298aff4438ea53765ff60d454c915 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 05:09:15 +0700 Subject: [PATCH 009/162] Make tests for shuffle stronger --- tests/test_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ca07728..4ac563c 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -33,8 +33,10 @@ def test_len(dataset): def test_shuffle(setup_rng, dataset): before = list(dataset) retval = dataset.shuffle() - assert list(dataset) != before assert retval is dataset + assert len(dataset) == len(before) + assert all(data in dataset for data in before) + assert list(dataset) != before def test_batch(dataset): From 9fd47866fb01c2056e5ff03f93cfb0eb597592ac Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 05:16:01 +0700 Subject: [PATCH 010/162] Ignore abstract class from coverage --- text2tensor/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index dc08679..fbffa8e 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -3,7 +3,7 @@ import random -class DatasetABC(Iterable, metaclass=abc.ABCMeta): +class DatasetABC(Iterable, metaclass=abc.ABCMeta): # pragma: no cover @abc.abstractmethod def batch(self, batch_size: int) -> Iterable: pass From cea29653550ec01b4b0e59a7f7ae045b1178d028 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 05:19:35 +0700 Subject: [PATCH 011/162] Refactor a bit --- tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4ac563c..ae7af46 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -35,7 +35,7 @@ def test_shuffle(setup_rng, dataset): retval = dataset.shuffle() assert retval is dataset assert len(dataset) == len(before) - assert all(data in dataset for data in before) + assert all(s in dataset for s in before) assert list(dataset) != before From a0c355ff0e3cc7469de253217011371271998043 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 05:42:37 +0700 Subject: [PATCH 012/162] Make coverage 100% --- tests/test_dataset.py | 39 ++++++++++++++++++++++++++------------- text2tensor/__init__.py | 4 +--- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ae7af46..a28f2a8 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -18,7 +18,7 @@ def test_init_samples_non_sequence(): @pytest.fixture def dataset(): - return Dataset(range(5)) + return Dataset(list(range(5))) def test_getitem(dataset): @@ -30,30 +30,43 @@ def test_len(dataset): assert len(dataset) == 5 -def test_shuffle(setup_rng, dataset): - before = list(dataset) - retval = dataset.shuffle() - assert retval is dataset - assert len(dataset) == len(before) - assert all(s in dataset for s in before) - assert list(dataset) != before +class TestShuffle: + @pytest.fixture + def tuple_dataset(self): + return Dataset(tuple(range(5))) + + def assert_shuffle(self, dataset): + before = list(dataset) + retval = dataset.shuffle() + after = list(dataset) + + assert retval is dataset + assert len(before) == len(after) + assert all(v in after for v in before) + assert before != after + + def test_mutable_seq(self, setup_rng, dataset): + self.assert_shuffle(dataset) + + def test_immutable_seq(self, setup_rng, tuple_dataset): + self.assert_shuffle(tuple_dataset) def test_batch(dataset): minibatches = dataset.batch(2) assert isinstance(minibatches, Sequence) assert len(minibatches) == 3 - assert minibatches[0] == range(0, 2) - assert minibatches[1] == range(2, 4) - assert minibatches[2] == range(4, 5) + assert minibatches[0] == [0, 1] + assert minibatches[1] == [2, 3] + assert minibatches[2] == [4] def test_batch_exactly(dataset): minibatches = dataset.batch_exactly(2) assert isinstance(minibatches, Sequence) assert len(minibatches) == 2 - assert minibatches[0] == range(0, 2) - assert minibatches[1] == range(2, 4) + assert minibatches[0] == [0, 1] + assert minibatches[1] == [2, 3] def test_batch_nonpositive_batch_size(dataset): diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index fbffa8e..9e97905 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -27,9 +27,7 @@ def __len__(self) -> int: return len(self._samples) def shuffle(self) -> 'Dataset': - if isinstance(self._samples, list): - random.shuffle(self._samples) - elif isinstance(self._samples, MutableSequence): + if isinstance(self._samples, MutableSequence): self._shuffle_inplace() else: self._shuffle_copy() From 5dc282077ed14cbbda249111344ec08ded5c4490 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 05:45:51 +0700 Subject: [PATCH 013/162] Fix import order --- tests/test_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a28f2a8..c05c52a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,9 +1,9 @@ from collections.abc import Sequence -from text2tensor import Dataset - import pytest +from text2tensor import Dataset + def test_init(): dat = Dataset(range(5)) From 6c9f557f8b87fbe9dedaaf333b7dc557ff48f653 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 05:55:20 +0700 Subject: [PATCH 014/162] Use Counter as finite stream --- tests/test_stream_dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index d1784dd..a242d57 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -7,13 +7,16 @@ class Counter: - def __init__(self): + def __init__(self, limit=None): self._count = 0 + self._limit = limit def __iter__(self): while True: yield self._count self._count += 1 + if self._limit is not None and self._count >= self._limit: + break def test_init(): @@ -34,7 +37,7 @@ def stream_dataset(): @pytest.fixture def finite_stream_dataset(): - return StreamDataset(range(11)) + return StreamDataset(Counter(limit=11)) def test_iter(stream_dataset): From 77f8ea6b5724d6e4f6aaf2c94eaa212b17972a02 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 06:23:55 +0700 Subject: [PATCH 015/162] Add docstrings --- text2tensor/__init__.py | 59 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index 9e97905..c683e18 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -14,6 +14,12 @@ def batch_exactly(self, batch_size: int) -> Iterable: class Dataset(DatasetABC, Sequence): + """A dataset that fits in memory (no streaming). + + Args: + samples: Sequence of samples the dataset should contain. + """ + def __init__(self, samples: Sequence) -> None: if not isinstance(samples, Sequence): raise TypeError('"samples" is not a sequence') @@ -27,6 +33,15 @@ def __len__(self) -> int: return len(self._samples) def shuffle(self) -> 'Dataset': + """Shuffle the dataset. + + This method shuffles in-place if ``samples`` is a mutable sequence. + Otherwise, a copy is made and then shuffled. This copy is a mutable + sequence, so subsequent shuffling will be done in-place. + + Returns: + The dataset object itself (useful for chaining). + """ if isinstance(self._samples, MutableSequence): self._shuffle_inplace() else: @@ -34,6 +49,14 @@ def shuffle(self) -> 'Dataset': return self def batch(self, batch_size: int) -> list: + """Group the samples in the dataset into batches. + + Args: + batch_size: Maximum number of samples in each batch. + + Returns: + The list of batches. + """ if batch_size <= 0: raise ValueError('batch size must be greater than 0') @@ -44,6 +67,17 @@ def batch(self, batch_size: int) -> list: return minibatches def batch_exactly(self, batch_size: int) -> list: + """Group the samples in the dataset into batches of exact size. + + If the length of ``samples`` is not divisible by ``batch_size``, the last + batch (whose length is less than ``batch_size``) is dropped. + + Args: + batch_size: Number of samples in each batch. + + Returns: + The list of batches. + """ minibatches = self.batch(batch_size) if len(self._samples) % batch_size != 0: assert len(minibatches[-1]) < batch_size @@ -67,6 +101,12 @@ def _shuffle_copy(self) -> None: class StreamDataset(DatasetABC, Iterable): + """A dataset that streams its samples. + + Args: + stream: Stream of examples the dataset should stream from. + """ + def __init__(self, stream: Iterable) -> None: if not isinstance(stream, Iterable): raise TypeError('"stream" is not iterable') @@ -77,9 +117,28 @@ def __iter__(self) -> Iterator: return iter(self._stream) def batch(self, batch_size: int) -> Iterable: + """Group the samples in the dataset into batches. + + Args: + batch_size: Maximum number of samples in each batch. + + Returns: + The iterable of batches. + """ return _Minibatches(self._stream, batch_size) def batch_exactly(self, batch_size: int) -> Iterable: + """Group the samples in the dataset into batches of exact size. + + If the length of ``samples`` is not divisible by ``batch_size``, the last + batch (whose length is less than ``batch_size``) is dropped. + + Args: + batch_size: Number of samples in each batch. + + Returns: + The iterable of batches. + """ return _Minibatches(self._stream, batch_size, drop=True) From eabde787c5ed2bc63168d725a6a61e6c95c937d8 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 06:26:30 +0700 Subject: [PATCH 016/162] Rename "minibatches" to "batches" --- text2tensor/__init__.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index c683e18..6085e05 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -60,11 +60,11 @@ def batch(self, batch_size: int) -> list: if batch_size <= 0: raise ValueError('batch size must be greater than 0') - minibatches = [] + batches = [] for begin in range(0, len(self._samples), batch_size): end = begin + batch_size - minibatches.append(self._samples[begin:end]) - return minibatches + batches.append(self._samples[begin:end]) + return batches def batch_exactly(self, batch_size: int) -> list: """Group the samples in the dataset into batches of exact size. @@ -78,11 +78,11 @@ def batch_exactly(self, batch_size: int) -> list: Returns: The list of batches. """ - minibatches = self.batch(batch_size) + batches = self.batch(batch_size) if len(self._samples) % batch_size != 0: - assert len(minibatches[-1]) < batch_size - minibatches = minibatches[:-1] - return minibatches + assert len(batches[-1]) < batch_size + batches = batches[:-1] + return batches def _shuffle_inplace(self) -> None: assert isinstance(self._samples, MutableSequence) @@ -125,7 +125,7 @@ def batch(self, batch_size: int) -> Iterable: Returns: The iterable of batches. """ - return _Minibatches(self._stream, batch_size) + return _Batches(self._stream, batch_size) def batch_exactly(self, batch_size: int) -> Iterable: """Group the samples in the dataset into batches of exact size. @@ -139,10 +139,10 @@ def batch_exactly(self, batch_size: int) -> Iterable: Returns: The iterable of batches. """ - return _Minibatches(self._stream, batch_size, drop=True) + return _Batches(self._stream, batch_size, drop=True) -class _Minibatches(Iterable): +class _Batches(Iterable): def __init__(self, stream: Iterable, bsize: int, drop: bool = False) -> None: if bsize <= 0: raise ValueError('batch size must be greater than 0') @@ -154,11 +154,11 @@ def __init__(self, stream: Iterable, bsize: int, drop: bool = False) -> None: def __iter__(self) -> Iterator: it, exhausted = iter(self._stream), False while not exhausted: - minibatch: list = [] - while not exhausted and len(minibatch) < self._bsize: + batch: list = [] + while not exhausted and len(batch) < self._bsize: try: - minibatch.append(next(it)) + batch.append(next(it)) except StopIteration: exhausted = True - if not self._drop or len(minibatch) == self._bsize: - yield minibatch + if not self._drop or len(batch) == self._bsize: + yield batch From 26ef7ce6693bf9b5ca9c9930b94990fdb810327f Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 06:34:05 +0700 Subject: [PATCH 017/162] Make type hints a bit more specific --- text2tensor/__init__.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index 6085e05..945ac86 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -1,15 +1,16 @@ from collections.abc import Iterable, Iterator, MutableSequence, Sequence import abc import random +import typing as ty class DatasetABC(Iterable, metaclass=abc.ABCMeta): # pragma: no cover @abc.abstractmethod - def batch(self, batch_size: int) -> Iterable: + def batch(self, batch_size: int) -> ty.Iterable[Sequence]: pass @abc.abstractmethod - def batch_exactly(self, batch_size: int) -> Iterable: + def batch_exactly(self, batch_size: int) -> ty.Iterable[Sequence]: pass @@ -48,7 +49,7 @@ def shuffle(self) -> 'Dataset': self._shuffle_copy() return self - def batch(self, batch_size: int) -> list: + def batch(self, batch_size: int) -> ty.List[Sequence]: """Group the samples in the dataset into batches. Args: @@ -66,7 +67,7 @@ def batch(self, batch_size: int) -> list: batches.append(self._samples[begin:end]) return batches - def batch_exactly(self, batch_size: int) -> list: + def batch_exactly(self, batch_size: int) -> ty.List[Sequence]: """Group the samples in the dataset into batches of exact size. If the length of ``samples`` is not divisible by ``batch_size``, the last @@ -116,7 +117,7 @@ def __init__(self, stream: Iterable) -> None: def __iter__(self) -> Iterator: return iter(self._stream) - def batch(self, batch_size: int) -> Iterable: + def batch(self, batch_size: int) -> ty.Iterable[Sequence]: """Group the samples in the dataset into batches. Args: @@ -127,7 +128,7 @@ def batch(self, batch_size: int) -> Iterable: """ return _Batches(self._stream, batch_size) - def batch_exactly(self, batch_size: int) -> Iterable: + def batch_exactly(self, batch_size: int) -> ty.Iterable[Sequence]: """Group the samples in the dataset into batches of exact size. If the length of ``samples`` is not divisible by ``batch_size``, the last @@ -142,7 +143,7 @@ def batch_exactly(self, batch_size: int) -> Iterable: return _Batches(self._stream, batch_size, drop=True) -class _Batches(Iterable): +class _Batches(ty.Iterable[Sequence]): def __init__(self, stream: Iterable, bsize: int, drop: bool = False) -> None: if bsize <= 0: raise ValueError('batch size must be greater than 0') @@ -151,7 +152,7 @@ def __init__(self, stream: Iterable, bsize: int, drop: bool = False) -> None: self._bsize = bsize self._drop = drop - def __iter__(self) -> Iterator: + def __iter__(self) -> ty.Iterator[Sequence]: it, exhausted = iter(self._stream), False while not exhausted: batch: list = [] From b8be8b32bed95fdfc88781039dce687272d81f01 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 07:14:11 +0700 Subject: [PATCH 018/162] Make the simplest kind of dataset: iterable of ints --- text2tensor/__init__.py | 46 ++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index 945ac86..5e7b13c 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -1,28 +1,32 @@ -from collections.abc import Iterable, Iterator, MutableSequence, Sequence +from collections.abc import \ + Iterable as IterableABC, MutableSequence as MutableSequenceABC, Sequence as SequenceABC +from typing import Iterable, Iterator, List, Sequence import abc import random -import typing as ty -class DatasetABC(Iterable, metaclass=abc.ABCMeta): # pragma: no cover +class DatasetABC(Iterable[int], metaclass=abc.ABCMeta): # pragma: no cover @abc.abstractmethod - def batch(self, batch_size: int) -> ty.Iterable[Sequence]: + def batch(self, batch_size: int) -> Iterable[Sequence[int]]: pass @abc.abstractmethod - def batch_exactly(self, batch_size: int) -> ty.Iterable[Sequence]: + def batch_exactly(self, batch_size: int) -> Iterable[Sequence[int]]: pass -class Dataset(DatasetABC, Sequence): +Batch = Sequence[int] + + +class Dataset(DatasetABC, Sequence[int]): """A dataset that fits in memory (no streaming). Args: samples: Sequence of samples the dataset should contain. """ - def __init__(self, samples: Sequence) -> None: - if not isinstance(samples, Sequence): + def __init__(self, samples: Sequence[int]) -> None: + if not isinstance(samples, SequenceABC): raise TypeError('"samples" is not a sequence') self._samples = samples @@ -43,13 +47,13 @@ def shuffle(self) -> 'Dataset': Returns: The dataset object itself (useful for chaining). """ - if isinstance(self._samples, MutableSequence): + if isinstance(self._samples, MutableSequenceABC): self._shuffle_inplace() else: self._shuffle_copy() return self - def batch(self, batch_size: int) -> ty.List[Sequence]: + def batch(self, batch_size: int) -> List[Batch]: """Group the samples in the dataset into batches. Args: @@ -67,7 +71,7 @@ def batch(self, batch_size: int) -> ty.List[Sequence]: batches.append(self._samples[begin:end]) return batches - def batch_exactly(self, batch_size: int) -> ty.List[Sequence]: + def batch_exactly(self, batch_size: int) -> List[Batch]: """Group the samples in the dataset into batches of exact size. If the length of ``samples`` is not divisible by ``batch_size``, the last @@ -86,7 +90,7 @@ def batch_exactly(self, batch_size: int) -> ty.List[Sequence]: return batches def _shuffle_inplace(self) -> None: - assert isinstance(self._samples, MutableSequence) + assert isinstance(self._samples, MutableSequenceABC) n = len(self._samples) for i in range(n): j = random.randrange(n) @@ -101,23 +105,23 @@ def _shuffle_copy(self) -> None: self._samples = shuf_samples -class StreamDataset(DatasetABC, Iterable): +class StreamDataset(DatasetABC, Iterable[int]): """A dataset that streams its samples. Args: stream: Stream of examples the dataset should stream from. """ - def __init__(self, stream: Iterable) -> None: - if not isinstance(stream, Iterable): + def __init__(self, stream: Iterable[int]) -> None: + if not isinstance(stream, IterableABC): raise TypeError('"stream" is not iterable') self._stream = stream - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[int]: return iter(self._stream) - def batch(self, batch_size: int) -> ty.Iterable[Sequence]: + def batch(self, batch_size: int) -> Iterable[Batch]: """Group the samples in the dataset into batches. Args: @@ -128,7 +132,7 @@ def batch(self, batch_size: int) -> ty.Iterable[Sequence]: """ return _Batches(self._stream, batch_size) - def batch_exactly(self, batch_size: int) -> ty.Iterable[Sequence]: + def batch_exactly(self, batch_size: int) -> Iterable[Batch]: """Group the samples in the dataset into batches of exact size. If the length of ``samples`` is not divisible by ``batch_size``, the last @@ -143,8 +147,8 @@ def batch_exactly(self, batch_size: int) -> ty.Iterable[Sequence]: return _Batches(self._stream, batch_size, drop=True) -class _Batches(ty.Iterable[Sequence]): - def __init__(self, stream: Iterable, bsize: int, drop: bool = False) -> None: +class _Batches(Iterable[Batch]): + def __init__(self, stream: Iterable[int], bsize: int, drop: bool = False) -> None: if bsize <= 0: raise ValueError('batch size must be greater than 0') @@ -152,7 +156,7 @@ def __init__(self, stream: Iterable, bsize: int, drop: bool = False) -> None: self._bsize = bsize self._drop = drop - def __iter__(self) -> ty.Iterator[Sequence]: + def __iter__(self) -> Iterator[Batch]: it, exhausted = iter(self._stream), False while not exhausted: batch: list = [] From f7c85f0cbe707d4e3afcd010dadd8d5cabd60d2b Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 09:01:26 +0700 Subject: [PATCH 019/162] Implement Batches class --- tests/conftest.py | 7 ++++++ tests/test_batches.py | 50 +++++++++++++++++++++++++++++++++++++++++ tests/test_dataset.py | 7 ++---- text2tensor/__init__.py | 48 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 tests/test_batches.py diff --git a/tests/conftest.py b/tests/conftest.py index fbdb524..55426b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,14 @@ import pytest +from text2tensor import Dataset + @pytest.fixture def setup_rng(): import random random.seed(42) + + +@pytest.fixture +def dataset(): + return Dataset(list(range(5))) diff --git a/tests/test_batches.py b/tests/test_batches.py new file mode 100644 index 0000000..9a61036 --- /dev/null +++ b/tests/test_batches.py @@ -0,0 +1,50 @@ +from collections.abc import Sequence + +import pytest +import torch + +from text2tensor import Batches + + +def test_init(dataset): + bs = Batches(dataset, 2) + assert bs.batch_size == 2 + assert isinstance(bs, Sequence) + assert len(bs) == 3 + assert bs[0] == [0, 1] + assert bs[1] == [2, 3] + assert bs[2] == [4] + + +def test_init_kwargs(dataset): + bs = Batches(dataset, 2, drop_last=True) + assert len(bs) == 2 + assert bs[0] == [0, 1] + assert bs[1] == [2, 3] + + +@pytest.fixture +def batches(dataset): + return Batches(dataset, 2) + + +def test_getitem_negative_index(batches): + assert batches[-1] == [4] + + +def test_getitem_index_error(batches): + with pytest.raises(IndexError) as exc: + batches[len(batches)] + assert 'index out of range' in str(exc.value) + + +def test_to_tensors(batches): + ts = batches.to_tensors() + assert isinstance(ts, Sequence) + assert len(ts) == len(batches) + for i in range(len(ts)): + t, b = ts[i], batches[i] + assert torch.is_tensor(t) + assert t.dtype == torch.long + assert t.dim() == 1 + assert t.size(0) == len(b) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c05c52a..94a10f7 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -16,16 +16,13 @@ def test_init_samples_non_sequence(): assert '"samples" is not a sequence' in str(exc.value) -@pytest.fixture -def dataset(): - return Dataset(list(range(5))) - - +# TODO put this inside test_init def test_getitem(dataset): for i in range(5): assert dataset[i] == i +# TODO put this inside test_init def test_len(dataset): assert len(dataset) == 5 diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index 5e7b13c..3994d20 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -4,6 +4,8 @@ import abc import random +import torch + class DatasetABC(Iterable[int], metaclass=abc.ABCMeta): # pragma: no cover @abc.abstractmethod @@ -167,3 +169,49 @@ def __iter__(self) -> Iterator[Batch]: exhausted = True if not self._drop or len(batch) == self._bsize: yield batch + + +class Batches(Sequence[Batch]): + """A class to represent a sequence of minibatches. + + Args: + dataset: Dataset to make batches from. + batch_size: Maximum number of samples in each batch. + drop_last (optional): Whether to drop the last batch when ``batch_size`` does not + evenly divide the length of ``dataset``. + """ + + def __init__(self, dataset: Dataset, batch_size: int, drop_last: bool = False) -> None: + self._dataset = dataset + self._bsize = batch_size + self._drop = drop_last + + @property + def batch_size(self) -> int: + return self._bsize + + def __getitem__(self, index): + if index >= len(self): + raise IndexError('index out of range') + if index < 0: + index += len(self) + + begin = index * self._bsize + end = begin + self._bsize + return self._dataset[begin:end] + + def __len__(self) -> int: + q, r = divmod(len(self._dataset), self._bsize) + return q + (1 if q > 0 and not self._drop else 0) + + def to_tensors(self) -> List[torch.LongTensor]: + """Convert each minibatch into a tensor. + + Returns: + The list of tensors. + """ + ts = [] + for b in self: + t = torch.tensor(b, dtype=torch.long) + ts.append(t) + return ts From 07655df4841abee3fd57a1afafbf2057f17ebf43 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 10:33:53 +0700 Subject: [PATCH 020/162] Test if batch size is nonpositive --- tests/test_batches.py | 35 ++++++++++++++++++++--------------- text2tensor/__init__.py | 24 ++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/tests/test_batches.py b/tests/test_batches.py index 9a61036..49850d4 100644 --- a/tests/test_batches.py +++ b/tests/test_batches.py @@ -6,21 +6,26 @@ from text2tensor import Batches -def test_init(dataset): - bs = Batches(dataset, 2) - assert bs.batch_size == 2 - assert isinstance(bs, Sequence) - assert len(bs) == 3 - assert bs[0] == [0, 1] - assert bs[1] == [2, 3] - assert bs[2] == [4] - - -def test_init_kwargs(dataset): - bs = Batches(dataset, 2, drop_last=True) - assert len(bs) == 2 - assert bs[0] == [0, 1] - assert bs[1] == [2, 3] +class TestInit: + def test_ok(self, dataset): + bs = Batches(dataset, 2) + assert bs.batch_size == 2 + assert isinstance(bs, Sequence) + assert len(bs) == 3 + assert bs[0] == [0, 1] + assert bs[1] == [2, 3] + assert bs[2] == [4] + + def test_kwargs(self, dataset): + bs = Batches(dataset, 2, drop_last=True) + assert len(bs) == 2 + assert bs[0] == [0, 1] + assert bs[1] == [2, 3] + + def test_nonpositive_batch_size(self, dataset): + with pytest.raises(ValueError) as exc: + Batches(dataset, 0) + assert 'batch size must be greater than 0' in str(exc.value) @pytest.fixture diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index 3994d20..8d998e5 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -182,6 +182,9 @@ class Batches(Sequence[Batch]): """ def __init__(self, dataset: Dataset, batch_size: int, drop_last: bool = False) -> None: + if batch_size <= 0: + raise ValueError('batch size must be greater than 0') + self._dataset = dataset self._bsize = batch_size self._drop = drop_last @@ -215,3 +218,24 @@ def to_tensors(self) -> List[torch.LongTensor]: t = torch.tensor(b, dtype=torch.long) ts.append(t) return ts + + +class StreamBatches(Iterable[Batch]): + def __init__(self, dataset: StreamDataset, batch_size: int) -> None: + self._dataset = dataset + self._bsize = batch_size + + @property + def batch_size(self) -> int: + return self._bsize + + def __iter__(self) -> Iterator[Batch]: + it, exhausted = iter(self._dataset), False + while not exhausted: + batch: list = [] + while not exhausted and len(batch) < self._bsize: + try: + batch.append(next(it)) + except StopIteration: + exhausted = True + yield batch From 5f417d5ccaf593edee12dae79742f543e10e927e Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 11:27:42 +0700 Subject: [PATCH 021/162] Implement StreamBatches class fully --- tests/conftest.py | 39 +++++++++++++++++++- tests/test_stream_batches.py | 70 ++++++++++++++++++++++++++++++++++++ tests/test_stream_dataset.py | 27 ++------------ text2tensor/__init__.py | 20 +++++++++-- 4 files changed, 128 insertions(+), 28 deletions(-) create mode 100644 tests/test_stream_batches.py diff --git a/tests/conftest.py b/tests/conftest.py index 55426b3..062e4c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,22 @@ import pytest -from text2tensor import Dataset +from text2tensor import Dataset, StreamDataset + + +class Counter: + def __init__(self, limit=None): + self._count = 0 + self._limit = limit + + def reset(self): + self._count = 0 + + def __iter__(self): + while True: + yield self._count + self._count += 1 + if self._limit is not None and self._count >= self._limit: + break @pytest.fixture @@ -12,3 +28,24 @@ def setup_rng(): @pytest.fixture def dataset(): return Dataset(list(range(5))) + + +@pytest.fixture +def counter(): + c = Counter() + yield Counter() + c.reset() + + +@pytest.fixture +def stream_dataset(): + c = Counter() + yield StreamDataset(c) + c.reset() + + +@pytest.fixture +def finite_stream_dataset(): + c = Counter(limit=11) + yield StreamDataset(c) + c.reset() diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py new file mode 100644 index 0000000..d75ab77 --- /dev/null +++ b/tests/test_stream_batches.py @@ -0,0 +1,70 @@ +from collections.abc import Iterable +from itertools import takewhile + +import pytest +import torch + +from text2tensor import StreamBatches + + +class TestInit: + def test_init(self, finite_stream_dataset): + bs = StreamBatches(finite_stream_dataset, 2) + assert bs.batch_size == 2 + assert isinstance(bs, Iterable) + it = iter(bs) + assert next(it) == [0, 1] + assert next(it) == [2, 3] + assert next(it) == [4, 5] + assert next(it) == [6, 7] + assert next(it) == [8, 9] + assert next(it) == [10] + with pytest.raises(StopIteration): + next(it) + + def test_kwargs(self, finite_stream_dataset): + bs = StreamBatches(finite_stream_dataset, 2, drop_last=True) + it = iter(bs) + assert next(it) == [0, 1] + assert next(it) == [2, 3] + assert next(it) == [4, 5] + assert next(it) == [6, 7] + assert next(it) == [8, 9] + with pytest.raises(StopIteration): + next(it) + + def test_nonpositive_batch_size(self, finite_stream_dataset): + with pytest.raises(ValueError) as exc: + StreamBatches(finite_stream_dataset, 0) + assert 'batch size must be greater than 0' in str(exc.value) + + +@pytest.fixture +def stream_batches(stream_dataset): + return StreamBatches(stream_dataset, 2) + + +@pytest.fixture +def finite_stream_batches(finite_stream_dataset): + return StreamBatches(finite_stream_dataset, 2) + + +def test_to_tensors(stream_batches): + ts = stream_batches.to_tensors() + assert isinstance(ts, Iterable) + + bs = takewhile(lambda b: sum(b) < 30, stream_batches) + for t, b in zip(ts, bs): + assert torch.is_tensor(t) + assert t.dtype == torch.long + assert t.dim() == 1 + assert t.size(0) == len(b) + + +def test_to_tensors_returns_iterable(finite_stream_batches): + ts = finite_stream_batches.to_tensors() + print(list(ts)) + ts_lst1 = list(ts) + ts_lst2 = list(ts) + assert len(ts_lst1) == len(ts_lst2) + assert len(ts_lst2) > 0 diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index a242d57..b89ee7a 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -6,21 +6,8 @@ from text2tensor import StreamDataset -class Counter: - def __init__(self, limit=None): - self._count = 0 - self._limit = limit - - def __iter__(self): - while True: - yield self._count - self._count += 1 - if self._limit is not None and self._count >= self._limit: - break - - -def test_init(): - dat = StreamDataset(Counter()) +def test_init(counter): + dat = StreamDataset(counter) assert isinstance(dat, Iterable) @@ -30,16 +17,6 @@ def test_init_stream_non_iterable(): assert '"stream" is not iterable' in str(exc.value) -@pytest.fixture -def stream_dataset(): - return StreamDataset(Counter()) - - -@pytest.fixture -def finite_stream_dataset(): - return StreamDataset(Counter(limit=11)) - - def test_iter(stream_dataset): it = takewhile(lambda x: x < 5, stream_dataset) assert list(it) == list(range(5)) diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index 8d998e5..90e1ff0 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -221,9 +221,14 @@ def to_tensors(self) -> List[torch.LongTensor]: class StreamBatches(Iterable[Batch]): - def __init__(self, dataset: StreamDataset, batch_size: int) -> None: + def __init__( + self, dataset: StreamDataset, batch_size: int, drop_last: bool = False) -> None: + if batch_size <= 0: + raise ValueError('batch size must be greater than 0') + self._dataset = dataset self._bsize = batch_size + self._drop = drop_last @property def batch_size(self) -> int: @@ -238,4 +243,15 @@ def __iter__(self) -> Iterator[Batch]: batch.append(next(it)) except StopIteration: exhausted = True - yield batch + if not self._drop or len(batch) == self._bsize: + yield batch + + def to_tensors(self) -> Iterable[torch.LongTensor]: + return self._StreamTensors(self) + + class _StreamTensors(Iterable[torch.LongTensor]): + def __init__(self, bs: 'StreamBatches') -> None: + self._bs = bs + + def __iter__(self) -> Iterator[torch.LongTensor]: + yield from (torch.tensor(b, dtype=torch.long) for b in self._bs) From c57ec2b7d24064c63956edaa4a3cc6fdf81f1141 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 11:29:11 +0700 Subject: [PATCH 022/162] Refactor a bit --- tests/conftest.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 062e4c0..416c02d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,10 +38,8 @@ def counter(): @pytest.fixture -def stream_dataset(): - c = Counter() - yield StreamDataset(c) - c.reset() +def stream_dataset(counter): + return StreamDataset(counter) @pytest.fixture From 14d4e9f5ec7b0aa3b5b0d70ab32ecd78c7e9300b Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 11:31:42 +0700 Subject: [PATCH 023/162] Add docstrings --- text2tensor/__init__.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index 90e1ff0..83f02c4 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -221,6 +221,15 @@ def to_tensors(self) -> List[torch.LongTensor]: class StreamBatches(Iterable[Batch]): + """A class to represent an iterable of minibatches. + + Args: + dataset: Dataset to make batches from. + batch_size: Maximum number of samples in each batch. + drop_last (optional): Whether to drop the last batch when ``batch_size`` does not + evenly divide the length of ``dataset``. + """ + def __init__( self, dataset: StreamDataset, batch_size: int, drop_last: bool = False) -> None: if batch_size <= 0: @@ -247,6 +256,11 @@ def __iter__(self) -> Iterator[Batch]: yield batch def to_tensors(self) -> Iterable[torch.LongTensor]: + """Convert each minibatch into a tensor. + + Returns: + The iterable of tensors. + """ return self._StreamTensors(self) class _StreamTensors(Iterable[torch.LongTensor]): From b4507c0aceaa3e966be81b140871a83ff1d44a6b Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 11:36:40 +0700 Subject: [PATCH 024/162] Create an abstract class for batches --- text2tensor/__init__.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index 83f02c4..007c65e 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -171,7 +171,18 @@ def __iter__(self) -> Iterator[Batch]: yield batch -class Batches(Sequence[Batch]): +class BatchesABC(Iterable[Batch], metaclass=abc.ABCMeta): # pragma: no cover + @property + @abc.abstractmethod + def batch_size(self) -> int: + pass + + @abc.abstractmethod + def to_tensors(self) -> Iterable[torch.LongTensor]: + pass + + +class Batches(BatchesABC, Sequence[Batch]): """A class to represent a sequence of minibatches. Args: @@ -220,7 +231,7 @@ def to_tensors(self) -> List[torch.LongTensor]: return ts -class StreamBatches(Iterable[Batch]): +class StreamBatches(BatchesABC, Iterable[Batch]): """A class to represent an iterable of minibatches. Args: From 8fd68caf805168d0044d469072c0b7d4a6dc0f68 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 12:01:14 +0700 Subject: [PATCH 025/162] Add property drop_last to batches --- tests/test_batches.py | 2 ++ tests/test_stream_batches.py | 4 ++++ text2tensor/__init__.py | 13 +++++++++++++ 3 files changed, 19 insertions(+) diff --git a/tests/test_batches.py b/tests/test_batches.py index 49850d4..c7183fc 100644 --- a/tests/test_batches.py +++ b/tests/test_batches.py @@ -10,6 +10,7 @@ class TestInit: def test_ok(self, dataset): bs = Batches(dataset, 2) assert bs.batch_size == 2 + assert not bs.drop_last assert isinstance(bs, Sequence) assert len(bs) == 3 assert bs[0] == [0, 1] @@ -18,6 +19,7 @@ def test_ok(self, dataset): def test_kwargs(self, dataset): bs = Batches(dataset, 2, drop_last=True) + assert bs.drop_last assert len(bs) == 2 assert bs[0] == [0, 1] assert bs[1] == [2, 3] diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py index d75ab77..64c1dca 100644 --- a/tests/test_stream_batches.py +++ b/tests/test_stream_batches.py @@ -11,7 +11,9 @@ class TestInit: def test_init(self, finite_stream_dataset): bs = StreamBatches(finite_stream_dataset, 2) assert bs.batch_size == 2 + assert not bs.drop_last assert isinstance(bs, Iterable) + it = iter(bs) assert next(it) == [0, 1] assert next(it) == [2, 3] @@ -24,6 +26,8 @@ def test_init(self, finite_stream_dataset): def test_kwargs(self, finite_stream_dataset): bs = StreamBatches(finite_stream_dataset, 2, drop_last=True) + assert bs.drop_last + it = iter(bs) assert next(it) == [0, 1] assert next(it) == [2, 3] diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index 007c65e..9eb35c5 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -177,6 +177,11 @@ class BatchesABC(Iterable[Batch], metaclass=abc.ABCMeta): # pragma: no cover def batch_size(self) -> int: pass + @property + @abc.abstractmethod + def drop_last(self) -> bool: + pass + @abc.abstractmethod def to_tensors(self) -> Iterable[torch.LongTensor]: pass @@ -204,6 +209,10 @@ def __init__(self, dataset: Dataset, batch_size: int, drop_last: bool = False) - def batch_size(self) -> int: return self._bsize + @property + def drop_last(self) -> bool: + return self._drop + def __getitem__(self, index): if index >= len(self): raise IndexError('index out of range') @@ -254,6 +263,10 @@ def __init__( def batch_size(self) -> int: return self._bsize + @property + def drop_last(self) -> bool: + return self._drop + def __iter__(self) -> Iterator[Batch]: it, exhausted = iter(self._dataset), False while not exhausted: From 47db7f50e7cec05d0654aec8b252333d496b5da8 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 12:10:10 +0700 Subject: [PATCH 026/162] Let batch() and batch_exactly() return Batches object --- tests/test_dataset.py | 21 ++++++------- tests/test_stream_dataset.py | 36 ++++++--------------- text2tensor/__init__.py | 61 +++++++++--------------------------- 3 files changed, 32 insertions(+), 86 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 94a10f7..2f974d1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -2,7 +2,7 @@ import pytest -from text2tensor import Dataset +from text2tensor import Batches, Dataset def test_init(): @@ -50,20 +50,17 @@ def test_immutable_seq(self, setup_rng, tuple_dataset): def test_batch(dataset): - minibatches = dataset.batch(2) - assert isinstance(minibatches, Sequence) - assert len(minibatches) == 3 - assert minibatches[0] == [0, 1] - assert minibatches[1] == [2, 3] - assert minibatches[2] == [4] + bs = dataset.batch(2) + assert isinstance(bs, Batches) + assert bs.batch_size == 2 + assert not bs.drop_last def test_batch_exactly(dataset): - minibatches = dataset.batch_exactly(2) - assert isinstance(minibatches, Sequence) - assert len(minibatches) == 2 - assert minibatches[0] == [0, 1] - assert minibatches[1] == [2, 3] + bs = dataset.batch_exactly(2) + assert isinstance(bs, Batches) + assert bs.batch_size == 2 + assert bs.drop_last def test_batch_nonpositive_batch_size(dataset): diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index b89ee7a..bcb53ff 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -3,7 +3,7 @@ import pytest -from text2tensor import StreamDataset +from text2tensor import StreamDataset, StreamBatches def test_init(counter): @@ -23,35 +23,17 @@ def test_iter(stream_dataset): def test_batch(finite_stream_dataset): - bsize = 2 - minibatches = finite_stream_dataset.batch(bsize) - assert isinstance(minibatches, Iterable) - - it = iter(minibatches) - assert next(it) == [0, 1] - assert next(it) == [2, 3] - assert next(it) == [4, 5] - while True: - try: - assert len(next(it)) <= bsize - except StopIteration: - break + bs = finite_stream_dataset.batch(2) + assert isinstance(bs, StreamBatches) + assert bs.batch_size == 2 + assert not bs.drop_last def test_batch_exactly(finite_stream_dataset): - bsize = 2 - minibatches = finite_stream_dataset.batch_exactly(bsize) - assert isinstance(minibatches, Iterable) - - it = iter(minibatches) - assert next(it) == [0, 1] - assert next(it) == [2, 3] - assert next(it) == [4, 5] - while True: - try: - assert len(next(it)) == bsize - except StopIteration: - break + bs = finite_stream_dataset.batch_exactly(2) + assert isinstance(bs, StreamBatches) + assert bs.batch_size == 2 + assert bs.drop_last def test_batch_nonpositive_batch_size(stream_dataset): diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index 9eb35c5..b008303 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -9,11 +9,11 @@ class DatasetABC(Iterable[int], metaclass=abc.ABCMeta): # pragma: no cover @abc.abstractmethod - def batch(self, batch_size: int) -> Iterable[Sequence[int]]: + def batch(self, batch_size: int) -> 'BatchesABC': pass @abc.abstractmethod - def batch_exactly(self, batch_size: int) -> Iterable[Sequence[int]]: + def batch_exactly(self, batch_size: int) -> 'BatchesABC': pass @@ -55,25 +55,18 @@ def shuffle(self) -> 'Dataset': self._shuffle_copy() return self - def batch(self, batch_size: int) -> List[Batch]: + def batch(self, batch_size: int) -> 'Batches': """Group the samples in the dataset into batches. Args: batch_size: Maximum number of samples in each batch. Returns: - The list of batches. + The batches. """ - if batch_size <= 0: - raise ValueError('batch size must be greater than 0') - - batches = [] - for begin in range(0, len(self._samples), batch_size): - end = begin + batch_size - batches.append(self._samples[begin:end]) - return batches + return Batches(self, batch_size) - def batch_exactly(self, batch_size: int) -> List[Batch]: + def batch_exactly(self, batch_size: int) -> 'Batches': """Group the samples in the dataset into batches of exact size. If the length of ``samples`` is not divisible by ``batch_size``, the last @@ -83,13 +76,9 @@ def batch_exactly(self, batch_size: int) -> List[Batch]: batch_size: Number of samples in each batch. Returns: - The list of batches. + The batches. """ - batches = self.batch(batch_size) - if len(self._samples) % batch_size != 0: - assert len(batches[-1]) < batch_size - batches = batches[:-1] - return batches + return Batches(self, batch_size, drop_last=True) def _shuffle_inplace(self) -> None: assert isinstance(self._samples, MutableSequenceABC) @@ -123,18 +112,18 @@ def __init__(self, stream: Iterable[int]) -> None: def __iter__(self) -> Iterator[int]: return iter(self._stream) - def batch(self, batch_size: int) -> Iterable[Batch]: + def batch(self, batch_size: int) -> 'StreamBatches': """Group the samples in the dataset into batches. Args: batch_size: Maximum number of samples in each batch. Returns: - The iterable of batches. + The batches. """ - return _Batches(self._stream, batch_size) + return StreamBatches(self, batch_size) - def batch_exactly(self, batch_size: int) -> Iterable[Batch]: + def batch_exactly(self, batch_size: int) -> 'StreamBatches': """Group the samples in the dataset into batches of exact size. If the length of ``samples`` is not divisible by ``batch_size``, the last @@ -144,31 +133,9 @@ def batch_exactly(self, batch_size: int) -> Iterable[Batch]: batch_size: Number of samples in each batch. Returns: - The iterable of batches. + The batches. """ - return _Batches(self._stream, batch_size, drop=True) - - -class _Batches(Iterable[Batch]): - def __init__(self, stream: Iterable[int], bsize: int, drop: bool = False) -> None: - if bsize <= 0: - raise ValueError('batch size must be greater than 0') - - self._stream = stream - self._bsize = bsize - self._drop = drop - - def __iter__(self) -> Iterator[Batch]: - it, exhausted = iter(self._stream), False - while not exhausted: - batch: list = [] - while not exhausted and len(batch) < self._bsize: - try: - batch.append(next(it)) - except StopIteration: - exhausted = True - if not self._drop or len(batch) == self._bsize: - yield batch + return StreamBatches(self, batch_size, drop_last=True) class BatchesABC(Iterable[Batch], metaclass=abc.ABCMeta): # pragma: no cover From 110f82963b03e1fa9665d518522c057a96072065 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 12:13:35 +0700 Subject: [PATCH 027/162] Combine some tests --- tests/test_dataset.py | 14 +++----------- tests/test_stream_dataset.py | 7 ++----- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 2f974d1..da41ae5 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -8,6 +8,9 @@ def test_init(): dat = Dataset(range(5)) assert isinstance(dat, Sequence) + assert len(dat) == 5 + for i in range(len(dat)): + assert dat[i] == i def test_init_samples_non_sequence(): @@ -16,17 +19,6 @@ def test_init_samples_non_sequence(): assert '"samples" is not a sequence' in str(exc.value) -# TODO put this inside test_init -def test_getitem(dataset): - for i in range(5): - assert dataset[i] == i - - -# TODO put this inside test_init -def test_len(dataset): - assert len(dataset) == 5 - - class TestShuffle: @pytest.fixture def tuple_dataset(self): diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index bcb53ff..51a9b6b 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -9,6 +9,8 @@ def test_init(counter): dat = StreamDataset(counter) assert isinstance(dat, Iterable) + it = takewhile(lambda x: x < 5, dat) + assert list(it) == list(range(5)) def test_init_stream_non_iterable(): @@ -17,11 +19,6 @@ def test_init_stream_non_iterable(): assert '"stream" is not iterable' in str(exc.value) -def test_iter(stream_dataset): - it = takewhile(lambda x: x < 5, stream_dataset) - assert list(it) == list(range(5)) - - def test_batch(finite_stream_dataset): bs = finite_stream_dataset.batch(2) assert isinstance(bs, StreamBatches) From 9a8e7ed0a022fa2faf5593a3f3dcae3c6b501e3a Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 12:25:28 +0700 Subject: [PATCH 028/162] Reorganize into modules --- text2tensor/__init__.py | 271 ++-------------------------------------- text2tensor/batches.py | 132 +++++++++++++++++++ text2tensor/datasets.py | 137 ++++++++++++++++++++ 3 files changed, 278 insertions(+), 262 deletions(-) create mode 100644 text2tensor/batches.py create mode 100644 text2tensor/datasets.py diff --git a/text2tensor/__init__.py b/text2tensor/__init__.py index b008303..5811ca8 100644 --- a/text2tensor/__init__.py +++ b/text2tensor/__init__.py @@ -1,262 +1,9 @@ -from collections.abc import \ - Iterable as IterableABC, MutableSequence as MutableSequenceABC, Sequence as SequenceABC -from typing import Iterable, Iterator, List, Sequence -import abc -import random - -import torch - - -class DatasetABC(Iterable[int], metaclass=abc.ABCMeta): # pragma: no cover - @abc.abstractmethod - def batch(self, batch_size: int) -> 'BatchesABC': - pass - - @abc.abstractmethod - def batch_exactly(self, batch_size: int) -> 'BatchesABC': - pass - - -Batch = Sequence[int] - - -class Dataset(DatasetABC, Sequence[int]): - """A dataset that fits in memory (no streaming). - - Args: - samples: Sequence of samples the dataset should contain. - """ - - def __init__(self, samples: Sequence[int]) -> None: - if not isinstance(samples, SequenceABC): - raise TypeError('"samples" is not a sequence') - - self._samples = samples - - def __getitem__(self, index): - return self._samples[index] - - def __len__(self) -> int: - return len(self._samples) - - def shuffle(self) -> 'Dataset': - """Shuffle the dataset. - - This method shuffles in-place if ``samples`` is a mutable sequence. - Otherwise, a copy is made and then shuffled. This copy is a mutable - sequence, so subsequent shuffling will be done in-place. - - Returns: - The dataset object itself (useful for chaining). - """ - if isinstance(self._samples, MutableSequenceABC): - self._shuffle_inplace() - else: - self._shuffle_copy() - return self - - def batch(self, batch_size: int) -> 'Batches': - """Group the samples in the dataset into batches. - - Args: - batch_size: Maximum number of samples in each batch. - - Returns: - The batches. - """ - return Batches(self, batch_size) - - def batch_exactly(self, batch_size: int) -> 'Batches': - """Group the samples in the dataset into batches of exact size. - - If the length of ``samples`` is not divisible by ``batch_size``, the last - batch (whose length is less than ``batch_size``) is dropped. - - Args: - batch_size: Number of samples in each batch. - - Returns: - The batches. - """ - return Batches(self, batch_size, drop_last=True) - - def _shuffle_inplace(self) -> None: - assert isinstance(self._samples, MutableSequenceABC) - n = len(self._samples) - for i in range(n): - j = random.randrange(n) - temp = self._samples[i] - self._samples[i] = self._samples[j] - self._samples[j] = temp - - def _shuffle_copy(self) -> None: - shuf_indices = list(range(len(self._samples))) - random.shuffle(shuf_indices) - shuf_samples = [self._samples[i] for i in shuf_indices] - self._samples = shuf_samples - - -class StreamDataset(DatasetABC, Iterable[int]): - """A dataset that streams its samples. - - Args: - stream: Stream of examples the dataset should stream from. - """ - - def __init__(self, stream: Iterable[int]) -> None: - if not isinstance(stream, IterableABC): - raise TypeError('"stream" is not iterable') - - self._stream = stream - - def __iter__(self) -> Iterator[int]: - return iter(self._stream) - - def batch(self, batch_size: int) -> 'StreamBatches': - """Group the samples in the dataset into batches. - - Args: - batch_size: Maximum number of samples in each batch. - - Returns: - The batches. - """ - return StreamBatches(self, batch_size) - - def batch_exactly(self, batch_size: int) -> 'StreamBatches': - """Group the samples in the dataset into batches of exact size. - - If the length of ``samples`` is not divisible by ``batch_size``, the last - batch (whose length is less than ``batch_size``) is dropped. - - Args: - batch_size: Number of samples in each batch. - - Returns: - The batches. - """ - return StreamBatches(self, batch_size, drop_last=True) - - -class BatchesABC(Iterable[Batch], metaclass=abc.ABCMeta): # pragma: no cover - @property - @abc.abstractmethod - def batch_size(self) -> int: - pass - - @property - @abc.abstractmethod - def drop_last(self) -> bool: - pass - - @abc.abstractmethod - def to_tensors(self) -> Iterable[torch.LongTensor]: - pass - - -class Batches(BatchesABC, Sequence[Batch]): - """A class to represent a sequence of minibatches. - - Args: - dataset: Dataset to make batches from. - batch_size: Maximum number of samples in each batch. - drop_last (optional): Whether to drop the last batch when ``batch_size`` does not - evenly divide the length of ``dataset``. - """ - - def __init__(self, dataset: Dataset, batch_size: int, drop_last: bool = False) -> None: - if batch_size <= 0: - raise ValueError('batch size must be greater than 0') - - self._dataset = dataset - self._bsize = batch_size - self._drop = drop_last - - @property - def batch_size(self) -> int: - return self._bsize - - @property - def drop_last(self) -> bool: - return self._drop - - def __getitem__(self, index): - if index >= len(self): - raise IndexError('index out of range') - if index < 0: - index += len(self) - - begin = index * self._bsize - end = begin + self._bsize - return self._dataset[begin:end] - - def __len__(self) -> int: - q, r = divmod(len(self._dataset), self._bsize) - return q + (1 if q > 0 and not self._drop else 0) - - def to_tensors(self) -> List[torch.LongTensor]: - """Convert each minibatch into a tensor. - - Returns: - The list of tensors. - """ - ts = [] - for b in self: - t = torch.tensor(b, dtype=torch.long) - ts.append(t) - return ts - - -class StreamBatches(BatchesABC, Iterable[Batch]): - """A class to represent an iterable of minibatches. - - Args: - dataset: Dataset to make batches from. - batch_size: Maximum number of samples in each batch. - drop_last (optional): Whether to drop the last batch when ``batch_size`` does not - evenly divide the length of ``dataset``. - """ - - def __init__( - self, dataset: StreamDataset, batch_size: int, drop_last: bool = False) -> None: - if batch_size <= 0: - raise ValueError('batch size must be greater than 0') - - self._dataset = dataset - self._bsize = batch_size - self._drop = drop_last - - @property - def batch_size(self) -> int: - return self._bsize - - @property - def drop_last(self) -> bool: - return self._drop - - def __iter__(self) -> Iterator[Batch]: - it, exhausted = iter(self._dataset), False - while not exhausted: - batch: list = [] - while not exhausted and len(batch) < self._bsize: - try: - batch.append(next(it)) - except StopIteration: - exhausted = True - if not self._drop or len(batch) == self._bsize: - yield batch - - def to_tensors(self) -> Iterable[torch.LongTensor]: - """Convert each minibatch into a tensor. - - Returns: - The iterable of tensors. - """ - return self._StreamTensors(self) - - class _StreamTensors(Iterable[torch.LongTensor]): - def __init__(self, bs: 'StreamBatches') -> None: - self._bs = bs - - def __iter__(self) -> Iterator[torch.LongTensor]: - yield from (torch.tensor(b, dtype=torch.long) for b in self._bs) +from .datasets import Dataset, StreamDataset +from .batches import Batches, StreamBatches + +__all__ = [ + Dataset, + StreamDataset, + Batches, + StreamBatches, +] diff --git a/text2tensor/batches.py b/text2tensor/batches.py new file mode 100644 index 0000000..211c9b3 --- /dev/null +++ b/text2tensor/batches.py @@ -0,0 +1,132 @@ +from typing import Iterable, Iterator, List, Sequence +import abc + +import torch + +from .datasets import Dataset, StreamDataset + +Batch = Sequence[int] + + +class BatchesABC(Iterable[Batch], metaclass=abc.ABCMeta): # pragma: no cover + @property + @abc.abstractmethod + def batch_size(self) -> int: + pass + + @property + @abc.abstractmethod + def drop_last(self) -> bool: + pass + + @abc.abstractmethod + def to_tensors(self) -> Iterable[torch.LongTensor]: + pass + + +class Batches(BatchesABC, Sequence[Batch]): + """A class to represent a sequence of minibatches. + + Args: + dataset: Dataset to make batches from. + batch_size: Maximum number of samples in each batch. + drop_last (optional): Whether to drop the last batch when ``batch_size`` does not + evenly divide the length of ``dataset``. + """ + + def __init__(self, dataset: Dataset, batch_size: int, drop_last: bool = False) -> None: + if batch_size <= 0: + raise ValueError('batch size must be greater than 0') + + self._dataset = dataset + self._bsize = batch_size + self._drop = drop_last + + @property + def batch_size(self) -> int: + return self._bsize + + @property + def drop_last(self) -> bool: + return self._drop + + def __getitem__(self, index): + if index >= len(self): + raise IndexError('index out of range') + if index < 0: + index += len(self) + + begin = index * self._bsize + end = begin + self._bsize + return self._dataset[begin:end] + + def __len__(self) -> int: + q, r = divmod(len(self._dataset), self._bsize) + return q + (1 if q > 0 and not self._drop else 0) + + def to_tensors(self) -> List[torch.LongTensor]: + """Convert each minibatch into a tensor. + + Returns: + The list of tensors. + """ + ts = [] + for b in self: + t = torch.tensor(b, dtype=torch.long) + ts.append(t) + return ts + + +class StreamBatches(BatchesABC, Iterable[Batch]): + """A class to represent an iterable of minibatches. + + Args: + dataset: Dataset to make batches from. + batch_size: Maximum number of samples in each batch. + drop_last (optional): Whether to drop the last batch when ``batch_size`` does not + evenly divide the length of ``dataset``. + """ + + def __init__( + self, dataset: StreamDataset, batch_size: int, drop_last: bool = False) -> None: + if batch_size <= 0: + raise ValueError('batch size must be greater than 0') + + self._dataset = dataset + self._bsize = batch_size + self._drop = drop_last + + @property + def batch_size(self) -> int: + return self._bsize + + @property + def drop_last(self) -> bool: + return self._drop + + def __iter__(self) -> Iterator[Batch]: + it, exhausted = iter(self._dataset), False + while not exhausted: + batch: list = [] + while not exhausted and len(batch) < self._bsize: + try: + batch.append(next(it)) + except StopIteration: + exhausted = True + if not self._drop or len(batch) == self._bsize: + yield batch + + def to_tensors(self) -> Iterable[torch.LongTensor]: + """Convert each minibatch into a tensor. + + Returns: + The iterable of tensors. + """ + return self._StreamTensors(self) + + class _StreamTensors(Iterable[torch.LongTensor]): + def __init__(self, bs: 'StreamBatches') -> None: + self._bs = bs + + def __iter__(self) -> Iterator[torch.LongTensor]: + yield from (torch.tensor(b, dtype=torch.long) for b in self._bs) diff --git a/text2tensor/datasets.py b/text2tensor/datasets.py new file mode 100644 index 0000000..d5ec334 --- /dev/null +++ b/text2tensor/datasets.py @@ -0,0 +1,137 @@ +from collections.abc import \ + Iterable as IterableABC, MutableSequence as MutableSequenceABC, Sequence as SequenceABC +from typing import Iterable, Iterator, Sequence +import abc +import random + + +class DatasetABC(Iterable[int], metaclass=abc.ABCMeta): # pragma: no cover + @abc.abstractmethod + def batch(self, batch_size: int) -> 'BatchesABC': + pass + + @abc.abstractmethod + def batch_exactly(self, batch_size: int) -> 'BatchesABC': + pass + + +class Dataset(DatasetABC, Sequence[int]): + """A dataset that fits in memory (no streaming). + + Args: + samples: Sequence of samples the dataset should contain. + """ + + def __init__(self, samples: Sequence[int]) -> None: + if not isinstance(samples, SequenceABC): + raise TypeError('"samples" is not a sequence') + + self._samples = samples + + def __getitem__(self, index): + return self._samples[index] + + def __len__(self) -> int: + return len(self._samples) + + def shuffle(self) -> 'Dataset': + """Shuffle the dataset. + + This method shuffles in-place if ``samples`` is a mutable sequence. + Otherwise, a copy is made and then shuffled. This copy is a mutable + sequence, so subsequent shuffling will be done in-place. + + Returns: + The dataset object itself (useful for chaining). + """ + if isinstance(self._samples, MutableSequenceABC): + self._shuffle_inplace() + else: + self._shuffle_copy() + return self + + def batch(self, batch_size: int) -> 'Batches': + """Group the samples in the dataset into batches. + + Args: + batch_size: Maximum number of samples in each batch. + + Returns: + The batches. + """ + return Batches(self, batch_size) + + def batch_exactly(self, batch_size: int) -> 'Batches': + """Group the samples in the dataset into batches of exact size. + + If the length of ``samples`` is not divisible by ``batch_size``, the last + batch (whose length is less than ``batch_size``) is dropped. + + Args: + batch_size: Number of samples in each batch. + + Returns: + The batches. + """ + return Batches(self, batch_size, drop_last=True) + + def _shuffle_inplace(self) -> None: + assert isinstance(self._samples, MutableSequenceABC) + n = len(self._samples) + for i in range(n): + j = random.randrange(n) + temp = self._samples[i] + self._samples[i] = self._samples[j] + self._samples[j] = temp + + def _shuffle_copy(self) -> None: + shuf_indices = list(range(len(self._samples))) + random.shuffle(shuf_indices) + shuf_samples = [self._samples[i] for i in shuf_indices] + self._samples = shuf_samples + + +class StreamDataset(DatasetABC, Iterable[int]): + """A dataset that streams its samples. + + Args: + stream: Stream of examples the dataset should stream from. + """ + + def __init__(self, stream: Iterable[int]) -> None: + if not isinstance(stream, IterableABC): + raise TypeError('"stream" is not iterable') + + self._stream = stream + + def __iter__(self) -> Iterator[int]: + return iter(self._stream) + + def batch(self, batch_size: int) -> 'StreamBatches': + """Group the samples in the dataset into batches. + + Args: + batch_size: Maximum number of samples in each batch. + + Returns: + The batches. + """ + return StreamBatches(self, batch_size) + + def batch_exactly(self, batch_size: int) -> 'StreamBatches': + """Group the samples in the dataset into batches of exact size. + + If the length of ``samples`` is not divisible by ``batch_size``, the last + batch (whose length is less than ``batch_size``) is dropped. + + Args: + batch_size: Number of samples in each batch. + + Returns: + The batches. + """ + return StreamBatches(self, batch_size, drop_last=True) + + +# Need to import here to avoid circular dependency +from .batches import BatchesABC, Batches, StreamBatches From ca588cc577b881b3e9b98b261198fc16b1dc19e6 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 13:54:19 +0700 Subject: [PATCH 029/162] Fix Counter not reset --- tests/conftest.py | 16 +++++++++------- tests/test_stream_batches.py | 1 - 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 416c02d..b03523c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ def reset(self): self._count = 0 def __iter__(self): + self._count = 0 while True: yield self._count self._count += 1 @@ -32,9 +33,12 @@ def dataset(): @pytest.fixture def counter(): - c = Counter() - yield Counter() - c.reset() + return Counter() + + +@pytest.fixture +def finite_counter(): + return Counter(limit=11) @pytest.fixture @@ -43,7 +47,5 @@ def stream_dataset(counter): @pytest.fixture -def finite_stream_dataset(): - c = Counter(limit=11) - yield StreamDataset(c) - c.reset() +def finite_stream_dataset(finite_counter): + return StreamDataset(finite_counter) diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py index 64c1dca..2d376b7 100644 --- a/tests/test_stream_batches.py +++ b/tests/test_stream_batches.py @@ -67,7 +67,6 @@ def test_to_tensors(stream_batches): def test_to_tensors_returns_iterable(finite_stream_batches): ts = finite_stream_batches.to_tensors() - print(list(ts)) ts_lst1 = list(ts) ts_lst2 = list(ts) assert len(ts_lst1) == len(ts_lst2) From d6a224f5007eebc6cc23cd9031269a47098fae93 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 13:58:35 +0700 Subject: [PATCH 030/162] Add test to check StreamDataset can be iterated > 1x --- tests/test_stream_dataset.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index 51a9b6b..8d05f9f 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -13,6 +13,14 @@ def test_init(counter): assert list(it) == list(range(5)) +def test_init_returns_iterable(finite_counter): + dat = StreamDataset(finite_counter) + dat_lst1 = list(dat) + dat_lst2 = list(dat) + assert len(dat_lst1) == len(dat_lst2) + assert len(dat_lst2) > 0 + + def test_init_stream_non_iterable(): with pytest.raises(TypeError) as exc: StreamDataset(5) From 29fa8a457bb9e9695107665c75cadff52caa2d57 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 14:10:06 +0700 Subject: [PATCH 031/162] Fix when dataset size is divisible by batch size --- tests/test_stream_batches.py | 4 ++++ text2tensor/batches.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py index 2d376b7..e9a2103 100644 --- a/tests/test_stream_batches.py +++ b/tests/test_stream_batches.py @@ -37,6 +37,10 @@ def test_kwargs(self, finite_stream_dataset): with pytest.raises(StopIteration): next(it) + def test_divisible_length(self, finite_stream_dataset): + bs = StreamBatches(finite_stream_dataset, 1) + assert list(bs) == [[i] for i, _ in enumerate(finite_stream_dataset)] + def test_nonpositive_batch_size(self, finite_stream_dataset): with pytest.raises(ValueError) as exc: StreamBatches(finite_stream_dataset, 0) diff --git a/text2tensor/batches.py b/text2tensor/batches.py index 211c9b3..e0ebce4 100644 --- a/text2tensor/batches.py +++ b/text2tensor/batches.py @@ -113,7 +113,7 @@ def __iter__(self) -> Iterator[Batch]: batch.append(next(it)) except StopIteration: exhausted = True - if not self._drop or len(batch) == self._bsize: + if len(batch) == self._bsize or (batch and not self._drop): yield batch def to_tensors(self) -> Iterable[torch.LongTensor]: From c3b6f1d0fbfea45a0170b288f166517ba9f863a7 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 14:17:30 +0700 Subject: [PATCH 032/162] Refactor some tests --- tests/conftest.py | 3 --- tests/test_stream_batches.py | 21 ++------------------- 2 files changed, 2 insertions(+), 22 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b03523c..5da81da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,9 +8,6 @@ def __init__(self, limit=None): self._count = 0 self._limit = limit - def reset(self): - self._count = 0 - def __iter__(self): self._count = 0 while True: diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py index e9a2103..5f2e44c 100644 --- a/tests/test_stream_batches.py +++ b/tests/test_stream_batches.py @@ -13,29 +13,12 @@ def test_init(self, finite_stream_dataset): assert bs.batch_size == 2 assert not bs.drop_last assert isinstance(bs, Iterable) - - it = iter(bs) - assert next(it) == [0, 1] - assert next(it) == [2, 3] - assert next(it) == [4, 5] - assert next(it) == [6, 7] - assert next(it) == [8, 9] - assert next(it) == [10] - with pytest.raises(StopIteration): - next(it) + assert list(bs) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]] def test_kwargs(self, finite_stream_dataset): bs = StreamBatches(finite_stream_dataset, 2, drop_last=True) assert bs.drop_last - - it = iter(bs) - assert next(it) == [0, 1] - assert next(it) == [2, 3] - assert next(it) == [4, 5] - assert next(it) == [6, 7] - assert next(it) == [8, 9] - with pytest.raises(StopIteration): - next(it) + assert list(bs) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] def test_divisible_length(self, finite_stream_dataset): bs = StreamBatches(finite_stream_dataset, 1) From ca4b91a12f98bb79f5215499a13abed49bdeee85 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 18:59:39 +0700 Subject: [PATCH 033/162] Convert to numpy's ndarray instead of torch's tensor --- requirements.txt | 1 + tests/test_batches.py | 10 +++++----- tests/test_stream_batches.py | 10 +++++----- text2tensor/batches.py | 16 ++++++++-------- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/requirements.txt b/requirements.txt index c9fe43a..48581a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ mccabe==0.6.1 more-itertools==5.0.0 mypy==0.660 mypy-extensions==0.4.1 +numpy==1.16.0 pluggy==0.8.1 py==1.7.0 pycodestyle==2.4.0 diff --git a/tests/test_batches.py b/tests/test_batches.py index c7183fc..a4559be 100644 --- a/tests/test_batches.py +++ b/tests/test_batches.py @@ -1,7 +1,7 @@ from collections.abc import Sequence +import numpy as np import pytest -import torch from text2tensor import Batches @@ -51,7 +51,7 @@ def test_to_tensors(batches): assert len(ts) == len(batches) for i in range(len(ts)): t, b = ts[i], batches[i] - assert torch.is_tensor(t) - assert t.dtype == torch.long - assert t.dim() == 1 - assert t.size(0) == len(b) + assert isinstance(t, np.ndarray) + assert t.dtype == np.int32 + assert t.ndim == 1 + assert t.shape[0] == len(b) diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py index 5f2e44c..d848ea5 100644 --- a/tests/test_stream_batches.py +++ b/tests/test_stream_batches.py @@ -1,8 +1,8 @@ from collections.abc import Iterable from itertools import takewhile +import numpy as np import pytest -import torch from text2tensor import StreamBatches @@ -46,10 +46,10 @@ def test_to_tensors(stream_batches): bs = takewhile(lambda b: sum(b) < 30, stream_batches) for t, b in zip(ts, bs): - assert torch.is_tensor(t) - assert t.dtype == torch.long - assert t.dim() == 1 - assert t.size(0) == len(b) + assert isinstance(t, np.ndarray) + assert t.dtype == np.int32 + assert t.ndim == 1 + assert t.shape[0] == len(b) def test_to_tensors_returns_iterable(finite_stream_batches): diff --git a/text2tensor/batches.py b/text2tensor/batches.py index e0ebce4..2975b6f 100644 --- a/text2tensor/batches.py +++ b/text2tensor/batches.py @@ -1,7 +1,7 @@ from typing import Iterable, Iterator, List, Sequence import abc -import torch +import numpy as np from .datasets import Dataset, StreamDataset @@ -20,7 +20,7 @@ def drop_last(self) -> bool: pass @abc.abstractmethod - def to_tensors(self) -> Iterable[torch.LongTensor]: + def to_tensors(self) -> Iterable[np.ndarray]: pass @@ -64,7 +64,7 @@ def __len__(self) -> int: q, r = divmod(len(self._dataset), self._bsize) return q + (1 if q > 0 and not self._drop else 0) - def to_tensors(self) -> List[torch.LongTensor]: + def to_tensors(self) -> List[np.ndarray]: """Convert each minibatch into a tensor. Returns: @@ -72,7 +72,7 @@ def to_tensors(self) -> List[torch.LongTensor]: """ ts = [] for b in self: - t = torch.tensor(b, dtype=torch.long) + t = np.array(b, np.int32) ts.append(t) return ts @@ -116,7 +116,7 @@ def __iter__(self) -> Iterator[Batch]: if len(batch) == self._bsize or (batch and not self._drop): yield batch - def to_tensors(self) -> Iterable[torch.LongTensor]: + def to_tensors(self) -> Iterable[np.ndarray]: """Convert each minibatch into a tensor. Returns: @@ -124,9 +124,9 @@ def to_tensors(self) -> Iterable[torch.LongTensor]: """ return self._StreamTensors(self) - class _StreamTensors(Iterable[torch.LongTensor]): + class _StreamTensors(Iterable[np.ndarray]): def __init__(self, bs: 'StreamBatches') -> None: self._bs = bs - def __iter__(self) -> Iterator[torch.LongTensor]: - yield from (torch.tensor(b, dtype=torch.long) for b in self._bs) + def __iter__(self) -> Iterator[np.ndarray]: + yield from (np.array(b, np.int32) for b in self._bs) From fb5fe78719859f8f02bdf2212f5ab2e082e95de5 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 19:07:01 +0700 Subject: [PATCH 034/162] Change name to text2array --- README.rst | 6 +++--- setup.py | 6 +++--- tests/conftest.py | 2 +- tests/test_batches.py | 6 +++--- tests/test_dataset.py | 2 +- tests/test_stream_batches.py | 10 +++++----- tests/test_stream_dataset.py | 2 +- {text2tensor => text2array}/__init__.py | 0 {text2tensor => text2array}/batches.py | 18 +++++++++--------- {text2tensor => text2array}/datasets.py | 0 10 files changed, 26 insertions(+), 26 deletions(-) rename {text2tensor => text2array}/__init__.py (100%) rename {text2tensor => text2array}/batches.py (89%) rename {text2tensor => text2array}/datasets.py (100%) diff --git a/README.rst b/README.rst index 908da3f..3373cfe 100644 --- a/README.rst +++ b/README.rst @@ -1,4 +1,4 @@ -text2tensor -^^^^^^^^^^^ +text2array +^^^^^^^^^^ -Convert your NLP text dataset to (batched) tensors! +Convert your NLP text dataset to arrays! diff --git a/setup.py b/setup.py index 5dfd37d..1826cf4 100644 --- a/setup.py +++ b/setup.py @@ -4,11 +4,11 @@ readme = Path(__file__).resolve().parent / 'README.rst' setup( - name='text2tensor', + name='text2array', version='0.0.1', - description='Convert your NLP text data to tensors!', + description='Convert your NLP text data to arrays!', long_description=readme.read_text(), - url='https://github.com/kmkurn/text2tensor', + url='https://github.com/kmkurn/text2array', author='Kemal Kurniawan', author_email='kemal@kkurniawan.com', license='MIT', diff --git a/tests/conftest.py b/tests/conftest.py index 5da81da..3fbcd5b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import pytest -from text2tensor import Dataset, StreamDataset +from text2array import Dataset, StreamDataset class Counter: diff --git a/tests/test_batches.py b/tests/test_batches.py index a4559be..32f7d78 100644 --- a/tests/test_batches.py +++ b/tests/test_batches.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from text2tensor import Batches +from text2array import Batches class TestInit: @@ -45,8 +45,8 @@ def test_getitem_index_error(batches): assert 'index out of range' in str(exc.value) -def test_to_tensors(batches): - ts = batches.to_tensors() +def test_to_arrays(batches): + ts = batches.to_arrays() assert isinstance(ts, Sequence) assert len(ts) == len(batches) for i in range(len(ts)): diff --git a/tests/test_dataset.py b/tests/test_dataset.py index da41ae5..91b6f3a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -2,7 +2,7 @@ import pytest -from text2tensor import Batches, Dataset +from text2array import Batches, Dataset def test_init(): diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py index d848ea5..12f09a4 100644 --- a/tests/test_stream_batches.py +++ b/tests/test_stream_batches.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from text2tensor import StreamBatches +from text2array import StreamBatches class TestInit: @@ -40,8 +40,8 @@ def finite_stream_batches(finite_stream_dataset): return StreamBatches(finite_stream_dataset, 2) -def test_to_tensors(stream_batches): - ts = stream_batches.to_tensors() +def test_to_arrays(stream_batches): + ts = stream_batches.to_arrays() assert isinstance(ts, Iterable) bs = takewhile(lambda b: sum(b) < 30, stream_batches) @@ -52,8 +52,8 @@ def test_to_tensors(stream_batches): assert t.shape[0] == len(b) -def test_to_tensors_returns_iterable(finite_stream_batches): - ts = finite_stream_batches.to_tensors() +def test_to_arrays_returns_iterable(finite_stream_batches): + ts = finite_stream_batches.to_arrays() ts_lst1 = list(ts) ts_lst2 = list(ts) assert len(ts_lst1) == len(ts_lst2) diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index 8d05f9f..8f0f722 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -3,7 +3,7 @@ import pytest -from text2tensor import StreamDataset, StreamBatches +from text2array import StreamDataset, StreamBatches def test_init(counter): diff --git a/text2tensor/__init__.py b/text2array/__init__.py similarity index 100% rename from text2tensor/__init__.py rename to text2array/__init__.py diff --git a/text2tensor/batches.py b/text2array/batches.py similarity index 89% rename from text2tensor/batches.py rename to text2array/batches.py index 2975b6f..ad03ed3 100644 --- a/text2tensor/batches.py +++ b/text2array/batches.py @@ -20,7 +20,7 @@ def drop_last(self) -> bool: pass @abc.abstractmethod - def to_tensors(self) -> Iterable[np.ndarray]: + def to_arrays(self) -> Iterable[np.ndarray]: pass @@ -64,11 +64,11 @@ def __len__(self) -> int: q, r = divmod(len(self._dataset), self._bsize) return q + (1 if q > 0 and not self._drop else 0) - def to_tensors(self) -> List[np.ndarray]: - """Convert each minibatch into a tensor. + def to_arrays(self) -> List[np.ndarray]: + """Convert each minibatch into an ndarray. Returns: - The list of tensors. + The list of arrays. """ ts = [] for b in self: @@ -116,15 +116,15 @@ def __iter__(self) -> Iterator[Batch]: if len(batch) == self._bsize or (batch and not self._drop): yield batch - def to_tensors(self) -> Iterable[np.ndarray]: - """Convert each minibatch into a tensor. + def to_arrays(self) -> Iterable[np.ndarray]: + """Convert each minibatch into an ndarray. Returns: - The iterable of tensors. + The iterable of arrays. """ - return self._StreamTensors(self) + return self._StreamArrays(self) - class _StreamTensors(Iterable[np.ndarray]): + class _StreamArrays(Iterable[np.ndarray]): def __init__(self, bs: 'StreamBatches') -> None: self._bs = bs diff --git a/text2tensor/datasets.py b/text2array/datasets.py similarity index 100% rename from text2tensor/datasets.py rename to text2array/datasets.py From ec423b937191aa48f4c895497e073dc9922d440b Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 19:25:15 +0700 Subject: [PATCH 035/162] Make sure to report only the project's files --- pytest.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index fc62745..467972f 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,2 @@ [pytest] -addopts = --cov --cov-report term-missing \ No newline at end of file +addopts = --cov text2array --cov-report term-missing From cf57cfa0ea338dc9c017b20d6f2cb2f5063e8268 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 19:27:15 +0700 Subject: [PATCH 036/162] Also report coverage as HTML --- pytest.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index 467972f..43e269b 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,2 @@ [pytest] -addopts = --cov text2array --cov-report term-missing +addopts = --cov text2array --cov-report term-missing --cov-report html From d369e4e6583290d0155e7861fbc7c98ed2ae6e14 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 20:51:13 +0700 Subject: [PATCH 037/162] Simplify some tests --- tests/test_batches.py | 3 +-- tests/test_stream_batches.py | 14 +++++--------- text2array/batches.py | 2 +- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/test_batches.py b/tests/test_batches.py index 32f7d78..0fd2f98 100644 --- a/tests/test_batches.py +++ b/tests/test_batches.py @@ -53,5 +53,4 @@ def test_to_arrays(batches): t, b = ts[i], batches[i] assert isinstance(t, np.ndarray) assert t.dtype == np.int32 - assert t.ndim == 1 - assert t.shape[0] == len(b) + assert t.tolist() == list(b) diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py index 12f09a4..89e1d66 100644 --- a/tests/test_stream_batches.py +++ b/tests/test_stream_batches.py @@ -1,5 +1,4 @@ from collections.abc import Iterable -from itertools import takewhile import numpy as np import pytest @@ -40,16 +39,13 @@ def finite_stream_batches(finite_stream_dataset): return StreamBatches(finite_stream_dataset, 2) -def test_to_arrays(stream_batches): - ts = stream_batches.to_arrays() +def test_to_arrays(finite_stream_batches): + ts = finite_stream_batches.to_arrays() assert isinstance(ts, Iterable) - bs = takewhile(lambda b: sum(b) < 30, stream_batches) - for t, b in zip(ts, bs): - assert isinstance(t, np.ndarray) - assert t.dtype == np.int32 - assert t.ndim == 1 - assert t.shape[0] == len(b) + assert all(isinstance(t, np.ndarray) for t in ts) + assert all(t.dtype == np.int32 for t in ts) + assert [t.tolist() for t in ts] == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]] def test_to_arrays_returns_iterable(finite_stream_batches): diff --git a/text2array/batches.py b/text2array/batches.py index ad03ed3..ce2bc06 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -129,4 +129,4 @@ def __init__(self, bs: 'StreamBatches') -> None: self._bs = bs def __iter__(self) -> Iterator[np.ndarray]: - yield from (np.array(b, np.int32) for b in self._bs) + return (np.array(b, np.int32) for b in self._bs) From 2279ab3b5a47a5e021e9d0198be69fca64aa5eb9 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 21:18:57 +0700 Subject: [PATCH 038/162] Strengthen and refactor tests --- tests/conftest.py | 12 ++++++------ tests/test_batches.py | 8 ++++++++ tests/test_dataset.py | 4 +--- tests/test_stream_batches.py | 19 +++++++++---------- tests/test_stream_dataset.py | 23 ++++++++++------------- text2array/batches.py | 2 +- 6 files changed, 35 insertions(+), 33 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3fbcd5b..f634e90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,15 +5,15 @@ class Counter: def __init__(self, limit=None): - self._count = 0 - self._limit = limit + self.count = 0 + self.limit = limit def __iter__(self): - self._count = 0 + self.count = 0 while True: - yield self._count - self._count += 1 - if self._limit is not None and self._count >= self._limit: + yield self.count + self.count += 1 + if self.limit is not None and self.count >= self.limit: break diff --git a/tests/test_batches.py b/tests/test_batches.py index 0fd2f98..1b1168f 100644 --- a/tests/test_batches.py +++ b/tests/test_batches.py @@ -37,13 +37,21 @@ def batches(dataset): def test_getitem_negative_index(batches): assert batches[-1] == [4] + assert batches[-2] == [2, 3] + assert batches[-3] == [0, 1] def test_getitem_index_error(batches): + # index too large with pytest.raises(IndexError) as exc: batches[len(batches)] assert 'index out of range' in str(exc.value) + # index too small + with pytest.raises(IndexError) as exc: + batches[-len(batches) - 1] + assert 'index out of range' in str(exc.value) + def test_to_arrays(batches): ts = batches.to_arrays() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 91b6f3a..56c18c6 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -30,9 +30,7 @@ def assert_shuffle(self, dataset): after = list(dataset) assert retval is dataset - assert len(before) == len(after) - assert all(v in after for v in before) - assert before != after + assert before != after and sorted(before) == sorted(after) def test_mutable_seq(self, setup_rng, dataset): self.assert_shuffle(dataset) diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py index 89e1d66..3251a7b 100644 --- a/tests/test_stream_batches.py +++ b/tests/test_stream_batches.py @@ -7,7 +7,7 @@ class TestInit: - def test_init(self, finite_stream_dataset): + def test_ok(self, finite_stream_dataset): bs = StreamBatches(finite_stream_dataset, 2) assert bs.batch_size == 2 assert not bs.drop_last @@ -21,7 +21,7 @@ def test_kwargs(self, finite_stream_dataset): def test_divisible_length(self, finite_stream_dataset): bs = StreamBatches(finite_stream_dataset, 1) - assert list(bs) == [[i] for i, _ in enumerate(finite_stream_dataset)] + assert list(bs) == [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10]] def test_nonpositive_batch_size(self, finite_stream_dataset): with pytest.raises(ValueError) as exc: @@ -39,6 +39,13 @@ def finite_stream_batches(finite_stream_dataset): return StreamBatches(finite_stream_dataset, 2) +def test_can_be_iterated_twice(finite_stream_batches): + bs_lst1 = list(finite_stream_batches) + bs_lst2 = list(finite_stream_batches) + assert len(bs_lst1) == len(bs_lst2) + assert len(bs_lst2) > 0 + + def test_to_arrays(finite_stream_batches): ts = finite_stream_batches.to_arrays() assert isinstance(ts, Iterable) @@ -46,11 +53,3 @@ def test_to_arrays(finite_stream_batches): assert all(isinstance(t, np.ndarray) for t in ts) assert all(t.dtype == np.int32 for t in ts) assert [t.tolist() for t in ts] == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]] - - -def test_to_arrays_returns_iterable(finite_stream_batches): - ts = finite_stream_batches.to_arrays() - ts_lst1 = list(ts) - ts_lst2 = list(ts) - assert len(ts_lst1) == len(ts_lst2) - assert len(ts_lst2) > 0 diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index 8f0f722..ed42953 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -1,24 +1,14 @@ from collections.abc import Iterable -from itertools import takewhile import pytest from text2array import StreamDataset, StreamBatches -def test_init(counter): - dat = StreamDataset(counter) - assert isinstance(dat, Iterable) - it = takewhile(lambda x: x < 5, dat) - assert list(it) == list(range(5)) - - -def test_init_returns_iterable(finite_counter): +def test_init(finite_counter): dat = StreamDataset(finite_counter) - dat_lst1 = list(dat) - dat_lst2 = list(dat) - assert len(dat_lst1) == len(dat_lst2) - assert len(dat_lst2) > 0 + assert isinstance(dat, Iterable) + assert list(dat) == list(range(finite_counter.limit)) def test_init_stream_non_iterable(): @@ -27,6 +17,13 @@ def test_init_stream_non_iterable(): assert '"stream" is not iterable' in str(exc.value) +def test_can_be_iterated_twice(finite_stream_dataset): + dat_lst1 = list(finite_stream_dataset) + dat_lst2 = list(finite_stream_dataset) + assert len(dat_lst1) == len(dat_lst2) + assert len(dat_lst2) > 0 + + def test_batch(finite_stream_dataset): bs = finite_stream_dataset.batch(2) assert isinstance(bs, StreamBatches) diff --git a/text2array/batches.py b/text2array/batches.py index ce2bc06..670ba0e 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -51,7 +51,7 @@ def drop_last(self) -> bool: return self._drop def __getitem__(self, index): - if index >= len(self): + if index >= len(self) or index < -len(self): raise IndexError('index out of range') if index < 0: index += len(self) From 16d2195297aedeb53002bfa9fde19062bb56191b Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 21:22:31 +0700 Subject: [PATCH 039/162] Make all stream fixtures finite --- tests/conftest.py | 10 ---------- tests/test_stream_batches.py | 31 +++++++++++++------------------ tests/test_stream_dataset.py | 20 ++++++++++---------- 3 files changed, 23 insertions(+), 38 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f634e90..995820b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,19 +30,9 @@ def dataset(): @pytest.fixture def counter(): - return Counter() - - -@pytest.fixture -def finite_counter(): return Counter(limit=11) @pytest.fixture def stream_dataset(counter): return StreamDataset(counter) - - -@pytest.fixture -def finite_stream_dataset(finite_counter): - return StreamDataset(finite_counter) diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py index 3251a7b..6931314 100644 --- a/tests/test_stream_batches.py +++ b/tests/test_stream_batches.py @@ -7,25 +7,25 @@ class TestInit: - def test_ok(self, finite_stream_dataset): - bs = StreamBatches(finite_stream_dataset, 2) + def test_ok(self, stream_dataset): + bs = StreamBatches(stream_dataset, 2) assert bs.batch_size == 2 assert not bs.drop_last assert isinstance(bs, Iterable) assert list(bs) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]] - def test_kwargs(self, finite_stream_dataset): - bs = StreamBatches(finite_stream_dataset, 2, drop_last=True) + def test_kwargs(self, stream_dataset): + bs = StreamBatches(stream_dataset, 2, drop_last=True) assert bs.drop_last assert list(bs) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] - def test_divisible_length(self, finite_stream_dataset): - bs = StreamBatches(finite_stream_dataset, 1) + def test_divisible_length(self, stream_dataset): + bs = StreamBatches(stream_dataset, 1) assert list(bs) == [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10]] - def test_nonpositive_batch_size(self, finite_stream_dataset): + def test_nonpositive_batch_size(self, stream_dataset): with pytest.raises(ValueError) as exc: - StreamBatches(finite_stream_dataset, 0) + StreamBatches(stream_dataset, 0) assert 'batch size must be greater than 0' in str(exc.value) @@ -34,20 +34,15 @@ def stream_batches(stream_dataset): return StreamBatches(stream_dataset, 2) -@pytest.fixture -def finite_stream_batches(finite_stream_dataset): - return StreamBatches(finite_stream_dataset, 2) - - -def test_can_be_iterated_twice(finite_stream_batches): - bs_lst1 = list(finite_stream_batches) - bs_lst2 = list(finite_stream_batches) +def test_can_be_iterated_twice(stream_batches): + bs_lst1 = list(stream_batches) + bs_lst2 = list(stream_batches) assert len(bs_lst1) == len(bs_lst2) assert len(bs_lst2) > 0 -def test_to_arrays(finite_stream_batches): - ts = finite_stream_batches.to_arrays() +def test_to_arrays(stream_batches): + ts = stream_batches.to_arrays() assert isinstance(ts, Iterable) assert all(isinstance(t, np.ndarray) for t in ts) diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index ed42953..7bf5ad5 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -5,10 +5,10 @@ from text2array import StreamDataset, StreamBatches -def test_init(finite_counter): - dat = StreamDataset(finite_counter) +def test_init(counter): + dat = StreamDataset(counter) assert isinstance(dat, Iterable) - assert list(dat) == list(range(finite_counter.limit)) + assert list(dat) == list(range(counter.limit)) def test_init_stream_non_iterable(): @@ -17,22 +17,22 @@ def test_init_stream_non_iterable(): assert '"stream" is not iterable' in str(exc.value) -def test_can_be_iterated_twice(finite_stream_dataset): - dat_lst1 = list(finite_stream_dataset) - dat_lst2 = list(finite_stream_dataset) +def test_can_be_iterated_twice(stream_dataset): + dat_lst1 = list(stream_dataset) + dat_lst2 = list(stream_dataset) assert len(dat_lst1) == len(dat_lst2) assert len(dat_lst2) > 0 -def test_batch(finite_stream_dataset): - bs = finite_stream_dataset.batch(2) +def test_batch(stream_dataset): + bs = stream_dataset.batch(2) assert isinstance(bs, StreamBatches) assert bs.batch_size == 2 assert not bs.drop_last -def test_batch_exactly(finite_stream_dataset): - bs = finite_stream_dataset.batch_exactly(2) +def test_batch_exactly(stream_dataset): + bs = stream_dataset.batch_exactly(2) assert isinstance(bs, StreamBatches) assert bs.batch_size == 2 assert bs.drop_last From 8609841d5f0b8382bac8dabcdd90b67ca17e954d Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 21:24:22 +0700 Subject: [PATCH 040/162] Shorten streams --- tests/conftest.py | 6 +++--- tests/test_stream_batches.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 995820b..8e495b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,9 +4,9 @@ class Counter: - def __init__(self, limit=None): + def __init__(self): self.count = 0 - self.limit = limit + self.limit = 5 def __iter__(self): self.count = 0 @@ -30,7 +30,7 @@ def dataset(): @pytest.fixture def counter(): - return Counter(limit=11) + return Counter() @pytest.fixture diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py index 6931314..b72f873 100644 --- a/tests/test_stream_batches.py +++ b/tests/test_stream_batches.py @@ -12,16 +12,16 @@ def test_ok(self, stream_dataset): assert bs.batch_size == 2 assert not bs.drop_last assert isinstance(bs, Iterable) - assert list(bs) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]] + assert list(bs) == [[0, 1], [2, 3], [4]] def test_kwargs(self, stream_dataset): bs = StreamBatches(stream_dataset, 2, drop_last=True) assert bs.drop_last - assert list(bs) == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] + assert list(bs) == [[0, 1], [2, 3]] def test_divisible_length(self, stream_dataset): bs = StreamBatches(stream_dataset, 1) - assert list(bs) == [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10]] + assert list(bs) == [[0], [1], [2], [3], [4]] def test_nonpositive_batch_size(self, stream_dataset): with pytest.raises(ValueError) as exc: @@ -47,4 +47,4 @@ def test_to_arrays(stream_batches): assert all(isinstance(t, np.ndarray) for t in ts) assert all(t.dtype == np.int32 for t in ts) - assert [t.tolist() for t in ts] == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]] + assert [t.tolist() for t in ts] == [[0, 1], [2, 3], [4]] From 7694af1998b89fce75d4354e5e8da0b685e2518a Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 21:38:16 +0700 Subject: [PATCH 041/162] Mention that samples must support what kinds of indexing --- text2array/datasets.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/text2array/datasets.py b/text2array/datasets.py index d5ec334..eae949c 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -19,7 +19,9 @@ class Dataset(DatasetABC, Sequence[int]): """A dataset that fits in memory (no streaming). Args: - samples: Sequence of samples the dataset should contain. + samples: Sequence of samples the dataset should contain. This sequence should + support indexing by a positive/negative index of type :obj:`int` or a + :obj:`slice` object. """ def __init__(self, samples: Sequence[int]) -> None: From e0950861de3a4bd74f9131a54a82f046997a5643 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 21:49:08 +0700 Subject: [PATCH 042/162] Make Batches.to_array() return type more abstract --- text2array/batches.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/text2array/batches.py b/text2array/batches.py index 670ba0e..63f6da3 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -1,4 +1,4 @@ -from typing import Iterable, Iterator, List, Sequence +from typing import Iterable, Iterator, Sequence import abc import numpy as np @@ -64,11 +64,11 @@ def __len__(self) -> int: q, r = divmod(len(self._dataset), self._bsize) return q + (1 if q > 0 and not self._drop else 0) - def to_arrays(self) -> List[np.ndarray]: + def to_arrays(self) -> Sequence[np.ndarray]: """Convert each minibatch into an ndarray. Returns: - The list of arrays. + The sequence of arrays. """ ts = [] for b in self: From 0f883b191dea1e4c05ce02bbc905d86411b4ebc3 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sat, 26 Jan 2019 22:42:48 +0700 Subject: [PATCH 043/162] Try to make samples/stream contains objects --- tests/conftest.py | 48 +++++++++++++++++++++--------------- tests/test_batches.py | 28 ++++++++++++--------- tests/test_dataset.py | 10 ++++---- tests/test_stream_batches.py | 22 ++++++++++++----- tests/test_stream_dataset.py | 6 ++--- 5 files changed, 69 insertions(+), 45 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 8e495b7..d3d5896 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,20 +3,6 @@ from text2array import Dataset, StreamDataset -class Counter: - def __init__(self): - self.count = 0 - self.limit = 5 - - def __iter__(self): - self.count = 0 - while True: - yield self.count - self.count += 1 - if self.limit is not None and self.count >= self.limit: - break - - @pytest.fixture def setup_rng(): import random @@ -24,15 +10,37 @@ def setup_rng(): @pytest.fixture -def dataset(): - return Dataset(list(range(5))) +def samples(): + return [Sample(i, i * i) for i in range(5)] + + +@pytest.fixture +def dataset(samples): + return Dataset(samples) @pytest.fixture -def counter(): - return Counter() +def stream(samples): + return Stream(samples) @pytest.fixture -def stream_dataset(counter): - return StreamDataset(counter) +def stream_dataset(stream): + return StreamDataset(stream) + + +class Sample: + def __init__(self, x, y): + self.x = x + self.y = y + + def __lt__(self, s): + return self.x < s.x or (self.x == s.x and self.y < s.y) + + +class Stream: + def __init__(self, samples): + self.samples = samples + + def __iter__(self): + yield from self.samples diff --git a/tests/test_batches.py b/tests/test_batches.py index 1b1168f..a424226 100644 --- a/tests/test_batches.py +++ b/tests/test_batches.py @@ -13,16 +13,16 @@ def test_ok(self, dataset): assert not bs.drop_last assert isinstance(bs, Sequence) assert len(bs) == 3 - assert bs[0] == [0, 1] - assert bs[1] == [2, 3] - assert bs[2] == [4] + assert bs[0] == [dataset[0], dataset[1]] + assert bs[1] == [dataset[2], dataset[3]] + assert bs[2] == [dataset[4]] def test_kwargs(self, dataset): bs = Batches(dataset, 2, drop_last=True) assert bs.drop_last assert len(bs) == 2 - assert bs[0] == [0, 1] - assert bs[1] == [2, 3] + assert bs[0] == [dataset[0], dataset[1]] + assert bs[1] == [dataset[2], dataset[3]] def test_nonpositive_batch_size(self, dataset): with pytest.raises(ValueError) as exc: @@ -36,9 +36,9 @@ def batches(dataset): def test_getitem_negative_index(batches): - assert batches[-1] == [4] - assert batches[-2] == [2, 3] - assert batches[-3] == [0, 1] + n = len(batches) + for i in range(n): + assert batches[-i - 1] == batches[n - i - 1] def test_getitem_index_error(batches): @@ -53,12 +53,18 @@ def test_getitem_index_error(batches): assert 'index out of range' in str(exc.value) +@pytest.mark.skip def test_to_arrays(batches): ts = batches.to_arrays() assert isinstance(ts, Sequence) assert len(ts) == len(batches) for i in range(len(ts)): t, b = ts[i], batches[i] - assert isinstance(t, np.ndarray) - assert t.dtype == np.int32 - assert t.tolist() == list(b) + + assert isinstance(t.x, np.ndarray) + assert t.x.dtype == np.int32 + assert t.x.tolist() == list(b.x) + + assert isinstance(t.y, np.ndarray) + assert t.y.dtype == np.int32 + assert t.y.tolist() == list(b.y) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 56c18c6..c664c7b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -5,12 +5,12 @@ from text2array import Batches, Dataset -def test_init(): - dat = Dataset(range(5)) +def test_init(samples): + dat = Dataset(samples) assert isinstance(dat, Sequence) assert len(dat) == 5 for i in range(len(dat)): - assert dat[i] == i + assert dat[i] == samples[i] def test_init_samples_non_sequence(): @@ -21,8 +21,8 @@ def test_init_samples_non_sequence(): class TestShuffle: @pytest.fixture - def tuple_dataset(self): - return Dataset(tuple(range(5))) + def tuple_dataset(self, samples): + return Dataset(tuple(samples)) def assert_shuffle(self, dataset): before = list(dataset) diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py index b72f873..3c01dac 100644 --- a/tests/test_stream_batches.py +++ b/tests/test_stream_batches.py @@ -12,16 +12,19 @@ def test_ok(self, stream_dataset): assert bs.batch_size == 2 assert not bs.drop_last assert isinstance(bs, Iterable) - assert list(bs) == [[0, 1], [2, 3], [4]] + dat = list(stream_dataset) + assert list(bs) == [[dat[0], dat[1]], [dat[2], dat[3]], [dat[4]]] def test_kwargs(self, stream_dataset): bs = StreamBatches(stream_dataset, 2, drop_last=True) assert bs.drop_last - assert list(bs) == [[0, 1], [2, 3]] + dat = list(stream_dataset) + assert list(bs) == [[dat[0], dat[1]], [dat[2], dat[3]]] def test_divisible_length(self, stream_dataset): bs = StreamBatches(stream_dataset, 1) - assert list(bs) == [[0], [1], [2], [3], [4]] + dat = list(stream_dataset) + assert list(bs) == [[dat[0]], [dat[1]], [dat[2]], [dat[3]], [dat[4]]] def test_nonpositive_batch_size(self, stream_dataset): with pytest.raises(ValueError) as exc: @@ -41,10 +44,17 @@ def test_can_be_iterated_twice(stream_batches): assert len(bs_lst2) > 0 +@pytest.mark.skip def test_to_arrays(stream_batches): ts = stream_batches.to_arrays() assert isinstance(ts, Iterable) - assert all(isinstance(t, np.ndarray) for t in ts) - assert all(t.dtype == np.int32 for t in ts) - assert [t.tolist() for t in ts] == [[0, 1], [2, 3], [4]] + bs = list(stream_batches) + for t, b in zip(ts, bs): + assert isinstance(t.x, np.ndarray) + assert t.x.dtype == np.int32 + assert t.x.tolist() == list(b.x) + + assert isinstance(t.y, np.ndarray) + assert t.y.dtype == np.int32 + assert t.y.tolist() == list(b.y) diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index 7bf5ad5..ac80f69 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -5,10 +5,10 @@ from text2array import StreamDataset, StreamBatches -def test_init(counter): - dat = StreamDataset(counter) +def test_init(stream): + dat = StreamDataset(stream) assert isinstance(dat, Iterable) - assert list(dat) == list(range(counter.limit)) + assert list(dat) == list(stream) def test_init_stream_non_iterable(): From 54281868f2082d6c20a7f12b100c9ccbd7bd21b7 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 07:46:11 +0700 Subject: [PATCH 044/162] Try to make custom Batch class --- tests/conftest.py | 10 +++++++--- tests/test_batch.py | 36 ++++++++++++++++++++++++++++++++++++ text2array/__init__.py | 5 ++++- text2array/batches.py | 18 +++++++++++++++++- text2array/samples.py | 12 ++++++++++++ 5 files changed, 76 insertions(+), 5 deletions(-) create mode 100644 tests/test_batch.py create mode 100644 text2array/samples.py diff --git a/tests/conftest.py b/tests/conftest.py index d3d5896..5b2df53 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import pytest -from text2array import Dataset, StreamDataset +from text2array import Dataset, Sample, StreamDataset @pytest.fixture @@ -11,7 +11,7 @@ def setup_rng(): @pytest.fixture def samples(): - return [Sample(i, i * i) for i in range(5)] + return [TestSample(i, i * i) for i in range(5)] @pytest.fixture @@ -29,7 +29,7 @@ def stream_dataset(stream): return StreamDataset(stream) -class Sample: +class TestSample(Sample): def __init__(self, x, y): self.x = x self.y = y @@ -37,6 +37,10 @@ def __init__(self, x, y): def __lt__(self, s): return self.x < s.x or (self.x == s.x and self.y < s.y) + @property + def fields(self): + return {'x': self.x, 'y': self.y} + class Stream: def __init__(self, samples): diff --git a/tests/test_batch.py b/tests/test_batch.py new file mode 100644 index 0000000..03013a3 --- /dev/null +++ b/tests/test_batch.py @@ -0,0 +1,36 @@ +from collections.abc import Sequence + +import pytest + +from text2array import Batch + + +def test_init(samples): + b = Batch(samples) + assert isinstance(b, Sequence) + assert len(b) == len(samples) + for i in range(len(b)): + assert b[i] == samples[i] + + +@pytest.fixture +def batch(samples): + return Batch(samples) + + +def test_getattr(batch): + assert isinstance(batch.x, Sequence) + assert len(batch.x) == len(batch) + for i in range(len(batch)): + assert batch.x[i] == batch[i].x + + assert isinstance(batch.y, Sequence) + assert len(batch.y) == len(batch) + for i in range(len(batch)): + assert batch.y[i] == batch[i].y + + +def test_getattr_invalid_name(batch): + with pytest.raises(AttributeError) as exc: + batch.z + assert "some samples have no field 'z'" in str(exc.value) diff --git a/text2array/__init__.py b/text2array/__init__.py index 5811ca8..a28768f 100644 --- a/text2array/__init__.py +++ b/text2array/__init__.py @@ -1,9 +1,12 @@ from .datasets import Dataset, StreamDataset -from .batches import Batches, StreamBatches +from .batches import Batch, Batches, StreamBatches +from .samples import Sample __all__ = [ + Sample, Dataset, StreamDataset, + Batch, Batches, StreamBatches, ] diff --git a/text2array/batches.py b/text2array/batches.py index 63f6da3..d9e05d0 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -4,8 +4,24 @@ import numpy as np from .datasets import Dataset, StreamDataset +from .samples import FieldValue, Sample -Batch = Sequence[int] + +class Batch(Sequence[Sample]): + def __init__(self, samples: Sequence[Sample]) -> None: + self._samples = samples + + def __getitem__(self, index): + return self._samples[index] + + def __len__(self) -> int: + return len(self._samples) + + def __getattr__(self, name: str) -> Sequence[FieldValue]: + try: + return [s.fields[name] for s in self._samples] + except KeyError: + raise AttributeError(f"some samples have no field '{name}'") class BatchesABC(Iterable[Batch], metaclass=abc.ABCMeta): # pragma: no cover diff --git a/text2array/samples.py b/text2array/samples.py new file mode 100644 index 0000000..f02b351 --- /dev/null +++ b/text2array/samples.py @@ -0,0 +1,12 @@ +from typing import Mapping, Union +import abc + +FieldName = str +FieldValue = Union[str, float, int] + + +class Sample(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def fields(self) -> Mapping[FieldName, FieldValue]: + pass From 5ac1e51f3e4e42598641b24256b2e9d0ed15aaac Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 13:09:43 +0700 Subject: [PATCH 045/162] Let DatasetABC.batch() return iterator of Batch --- tests/test_dataset.py | 27 +++++++----- tests/test_stream_dataset.py | 32 +++++++++----- text2array/datasets.py | 85 ++++++++++++++++++------------------ 3 files changed, 82 insertions(+), 62 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c664c7b..727a0bd 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,8 +1,8 @@ -from collections.abc import Sequence +from collections.abc import Iterator, Sequence import pytest -from text2array import Batches, Dataset +from text2array import Batch, Dataset def test_init(samples): @@ -41,23 +41,30 @@ def test_immutable_seq(self, setup_rng, tuple_dataset): def test_batch(dataset): bs = dataset.batch(2) - assert isinstance(bs, Batches) - assert bs.batch_size == 2 - assert not bs.drop_last + assert isinstance(bs, Iterator) + bs_lst = list(bs) + assert len(bs_lst) == 3 + assert all(isinstance(b, Batch) for b in bs_lst) + assert list(bs_lst[0]) == [dataset[0], dataset[1]] + assert list(bs_lst[1]) == [dataset[2], dataset[3]] + assert list(bs_lst[2]) == [dataset[4]] def test_batch_exactly(dataset): bs = dataset.batch_exactly(2) - assert isinstance(bs, Batches) - assert bs.batch_size == 2 - assert bs.drop_last + assert isinstance(bs, Iterator) + bs_lst = list(bs) + assert len(bs_lst) == 2 + assert all(isinstance(b, Batch) for b in bs_lst) + assert list(bs_lst[0]) == [dataset[0], dataset[1]] + assert list(bs_lst[1]) == [dataset[2], dataset[3]] def test_batch_nonpositive_batch_size(dataset): with pytest.raises(ValueError) as exc: - dataset.batch(0) + next(dataset.batch(0)) assert 'batch size must be greater than 0' in str(exc.value) with pytest.raises(ValueError) as exc: - dataset.batch_exactly(0) + next(dataset.batch_exactly(0)) assert 'batch size must be greater than 0' in str(exc.value) diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index ac80f69..f121886 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -1,8 +1,8 @@ -from collections.abc import Iterable +from collections.abc import Iterable, Iterator import pytest -from text2array import StreamDataset, StreamBatches +from text2array import Batch, StreamDataset def test_init(stream): @@ -26,23 +26,35 @@ def test_can_be_iterated_twice(stream_dataset): def test_batch(stream_dataset): bs = stream_dataset.batch(2) - assert isinstance(bs, StreamBatches) - assert bs.batch_size == 2 - assert not bs.drop_last + assert isinstance(bs, Iterator) + bs_lst = list(bs) + assert len(bs_lst) == 3 + assert all(isinstance(b, Batch) for b in bs_lst) + dat = list(stream_dataset) + assert list(bs_lst[0]) == [dat[0], dat[1]] + assert list(bs_lst[1]) == [dat[2], dat[3]] + assert list(bs_lst[2]) == [dat[4]] def test_batch_exactly(stream_dataset): bs = stream_dataset.batch_exactly(2) - assert isinstance(bs, StreamBatches) - assert bs.batch_size == 2 - assert bs.drop_last + assert isinstance(bs, Iterator) + bs_lst = list(bs) + assert len(bs_lst) == 2 + assert all(isinstance(b, Batch) for b in bs_lst) + dat = list(stream_dataset) + assert list(bs_lst[0]) == [dat[0], dat[1]] + assert list(bs_lst[1]) == [dat[2], dat[3]] + + +# TODO add test for when batch size evenly divides length of samples def test_batch_nonpositive_batch_size(stream_dataset): with pytest.raises(ValueError) as exc: - stream_dataset.batch(0) + next(stream_dataset.batch(0)) assert 'batch size must be greater than 0' in str(exc.value) with pytest.raises(ValueError) as exc: - stream_dataset.batch_exactly(0) + next(stream_dataset.batch_exactly(0)) assert 'batch size must be greater than 0' in str(exc.value) diff --git a/text2array/datasets.py b/text2array/datasets.py index eae949c..b5872e5 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -4,18 +4,30 @@ import abc import random +from .samples import Sample -class DatasetABC(Iterable[int], metaclass=abc.ABCMeta): # pragma: no cover - @abc.abstractmethod - def batch(self, batch_size: int) -> 'BatchesABC': - pass +class DatasetABC(Iterable[Sample], metaclass=abc.ABCMeta): # pragma: no cover @abc.abstractmethod - def batch_exactly(self, batch_size: int) -> 'BatchesABC': + def batch(self, batch_size: int) -> Iterator['Batch']: pass + def batch_exactly(self, batch_size: int) -> Iterator['Batch']: + """Group the samples in the dataset into batches of exact size. + + If the number of samples is not divisible by ``batch_size``, the last + batch (whose length is less than ``batch_size``) is dropped. + + Args: + batch_size: Number of samples in each batch. -class Dataset(DatasetABC, Sequence[int]): + Returns: + The iterator of batches. + """ + return (b for b in self.batch(batch_size) if len(b) == batch_size) + + +class Dataset(DatasetABC, Sequence[Sample]): """A dataset that fits in memory (no streaming). Args: @@ -24,7 +36,7 @@ class Dataset(DatasetABC, Sequence[int]): :obj:`slice` object. """ - def __init__(self, samples: Sequence[int]) -> None: + def __init__(self, samples: Sequence[Sample]) -> None: if not isinstance(samples, SequenceABC): raise TypeError('"samples" is not a sequence') @@ -52,30 +64,21 @@ def shuffle(self) -> 'Dataset': self._shuffle_copy() return self - def batch(self, batch_size: int) -> 'Batches': + def batch(self, batch_size: int) -> Iterator['Batch']: """Group the samples in the dataset into batches. Args: batch_size: Maximum number of samples in each batch. Returns: - The batches. + The iterator of batches. """ - return Batches(self, batch_size) - - def batch_exactly(self, batch_size: int) -> 'Batches': - """Group the samples in the dataset into batches of exact size. - - If the length of ``samples`` is not divisible by ``batch_size``, the last - batch (whose length is less than ``batch_size``) is dropped. - - Args: - batch_size: Number of samples in each batch. + if batch_size <= 0: + raise ValueError('batch size must be greater than 0') - Returns: - The batches. - """ - return Batches(self, batch_size, drop_last=True) + for begin in range(0, len(self._samples), batch_size): + end = begin + batch_size + yield Batch(self._samples[begin:end]) def _shuffle_inplace(self) -> None: assert isinstance(self._samples, MutableSequenceABC) @@ -93,14 +96,14 @@ def _shuffle_copy(self) -> None: self._samples = shuf_samples -class StreamDataset(DatasetABC, Iterable[int]): +class StreamDataset(DatasetABC): """A dataset that streams its samples. Args: stream: Stream of examples the dataset should stream from. """ - def __init__(self, stream: Iterable[int]) -> None: + def __init__(self, stream: Iterable[Sample]) -> None: if not isinstance(stream, IterableABC): raise TypeError('"stream" is not iterable') @@ -109,31 +112,29 @@ def __init__(self, stream: Iterable[int]) -> None: def __iter__(self) -> Iterator[int]: return iter(self._stream) - def batch(self, batch_size: int) -> 'StreamBatches': + def batch(self, batch_size: int) -> Iterator['Batch']: """Group the samples in the dataset into batches. Args: batch_size: Maximum number of samples in each batch. Returns: - The batches. + The iterator of batches. """ - return StreamBatches(self, batch_size) - - def batch_exactly(self, batch_size: int) -> 'StreamBatches': - """Group the samples in the dataset into batches of exact size. + if batch_size <= 0: + raise ValueError('batch size must be greater than 0') - If the length of ``samples`` is not divisible by ``batch_size``, the last - batch (whose length is less than ``batch_size``) is dropped. - - Args: - batch_size: Number of samples in each batch. - - Returns: - The batches. - """ - return StreamBatches(self, batch_size, drop_last=True) + it, exhausted = iter(self._stream), False + while not exhausted: + batch: list = [] + while not exhausted and len(batch) < batch_size: + try: + batch.append(next(it)) + except StopIteration: + exhausted = True + if batch: + yield Batch(batch) # Need to import here to avoid circular dependency -from .batches import BatchesABC, Batches, StreamBatches +from .batches import Batch From f50422f1860b0e8b1de82cc4b6d9c861e64d25db Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 13:25:04 +0700 Subject: [PATCH 046/162] Remove Batches stuff and refactor --- tests/test_batches.py | 70 ------------------- tests/test_stream_batches.py | 60 ---------------- text2array/__init__.py | 6 +- text2array/batches.py | 132 +---------------------------------- text2array/datasets.py | 17 ++--- text2array/samples.py | 1 + 6 files changed, 12 insertions(+), 274 deletions(-) delete mode 100644 tests/test_batches.py delete mode 100644 tests/test_stream_batches.py diff --git a/tests/test_batches.py b/tests/test_batches.py deleted file mode 100644 index a424226..0000000 --- a/tests/test_batches.py +++ /dev/null @@ -1,70 +0,0 @@ -from collections.abc import Sequence - -import numpy as np -import pytest - -from text2array import Batches - - -class TestInit: - def test_ok(self, dataset): - bs = Batches(dataset, 2) - assert bs.batch_size == 2 - assert not bs.drop_last - assert isinstance(bs, Sequence) - assert len(bs) == 3 - assert bs[0] == [dataset[0], dataset[1]] - assert bs[1] == [dataset[2], dataset[3]] - assert bs[2] == [dataset[4]] - - def test_kwargs(self, dataset): - bs = Batches(dataset, 2, drop_last=True) - assert bs.drop_last - assert len(bs) == 2 - assert bs[0] == [dataset[0], dataset[1]] - assert bs[1] == [dataset[2], dataset[3]] - - def test_nonpositive_batch_size(self, dataset): - with pytest.raises(ValueError) as exc: - Batches(dataset, 0) - assert 'batch size must be greater than 0' in str(exc.value) - - -@pytest.fixture -def batches(dataset): - return Batches(dataset, 2) - - -def test_getitem_negative_index(batches): - n = len(batches) - for i in range(n): - assert batches[-i - 1] == batches[n - i - 1] - - -def test_getitem_index_error(batches): - # index too large - with pytest.raises(IndexError) as exc: - batches[len(batches)] - assert 'index out of range' in str(exc.value) - - # index too small - with pytest.raises(IndexError) as exc: - batches[-len(batches) - 1] - assert 'index out of range' in str(exc.value) - - -@pytest.mark.skip -def test_to_arrays(batches): - ts = batches.to_arrays() - assert isinstance(ts, Sequence) - assert len(ts) == len(batches) - for i in range(len(ts)): - t, b = ts[i], batches[i] - - assert isinstance(t.x, np.ndarray) - assert t.x.dtype == np.int32 - assert t.x.tolist() == list(b.x) - - assert isinstance(t.y, np.ndarray) - assert t.y.dtype == np.int32 - assert t.y.tolist() == list(b.y) diff --git a/tests/test_stream_batches.py b/tests/test_stream_batches.py deleted file mode 100644 index 3c01dac..0000000 --- a/tests/test_stream_batches.py +++ /dev/null @@ -1,60 +0,0 @@ -from collections.abc import Iterable - -import numpy as np -import pytest - -from text2array import StreamBatches - - -class TestInit: - def test_ok(self, stream_dataset): - bs = StreamBatches(stream_dataset, 2) - assert bs.batch_size == 2 - assert not bs.drop_last - assert isinstance(bs, Iterable) - dat = list(stream_dataset) - assert list(bs) == [[dat[0], dat[1]], [dat[2], dat[3]], [dat[4]]] - - def test_kwargs(self, stream_dataset): - bs = StreamBatches(stream_dataset, 2, drop_last=True) - assert bs.drop_last - dat = list(stream_dataset) - assert list(bs) == [[dat[0], dat[1]], [dat[2], dat[3]]] - - def test_divisible_length(self, stream_dataset): - bs = StreamBatches(stream_dataset, 1) - dat = list(stream_dataset) - assert list(bs) == [[dat[0]], [dat[1]], [dat[2]], [dat[3]], [dat[4]]] - - def test_nonpositive_batch_size(self, stream_dataset): - with pytest.raises(ValueError) as exc: - StreamBatches(stream_dataset, 0) - assert 'batch size must be greater than 0' in str(exc.value) - - -@pytest.fixture -def stream_batches(stream_dataset): - return StreamBatches(stream_dataset, 2) - - -def test_can_be_iterated_twice(stream_batches): - bs_lst1 = list(stream_batches) - bs_lst2 = list(stream_batches) - assert len(bs_lst1) == len(bs_lst2) - assert len(bs_lst2) > 0 - - -@pytest.mark.skip -def test_to_arrays(stream_batches): - ts = stream_batches.to_arrays() - assert isinstance(ts, Iterable) - - bs = list(stream_batches) - for t, b in zip(ts, bs): - assert isinstance(t.x, np.ndarray) - assert t.x.dtype == np.int32 - assert t.x.tolist() == list(b.x) - - assert isinstance(t.y, np.ndarray) - assert t.y.dtype == np.int32 - assert t.y.tolist() == list(b.y) diff --git a/text2array/__init__.py b/text2array/__init__.py index a28768f..92ebb73 100644 --- a/text2array/__init__.py +++ b/text2array/__init__.py @@ -1,12 +1,10 @@ +from .batches import Batch from .datasets import Dataset, StreamDataset -from .batches import Batch, Batches, StreamBatches from .samples import Sample __all__ = [ Sample, + Batch, Dataset, StreamDataset, - Batch, - Batches, - StreamBatches, ] diff --git a/text2array/batches.py b/text2array/batches.py index d9e05d0..189e8d0 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -1,9 +1,5 @@ -from typing import Iterable, Iterator, Sequence -import abc +from typing import Sequence -import numpy as np - -from .datasets import Dataset, StreamDataset from .samples import FieldValue, Sample @@ -11,7 +7,7 @@ class Batch(Sequence[Sample]): def __init__(self, samples: Sequence[Sample]) -> None: self._samples = samples - def __getitem__(self, index): + def __getitem__(self, index) -> Sample: return self._samples[index] def __len__(self) -> int: @@ -22,127 +18,3 @@ def __getattr__(self, name: str) -> Sequence[FieldValue]: return [s.fields[name] for s in self._samples] except KeyError: raise AttributeError(f"some samples have no field '{name}'") - - -class BatchesABC(Iterable[Batch], metaclass=abc.ABCMeta): # pragma: no cover - @property - @abc.abstractmethod - def batch_size(self) -> int: - pass - - @property - @abc.abstractmethod - def drop_last(self) -> bool: - pass - - @abc.abstractmethod - def to_arrays(self) -> Iterable[np.ndarray]: - pass - - -class Batches(BatchesABC, Sequence[Batch]): - """A class to represent a sequence of minibatches. - - Args: - dataset: Dataset to make batches from. - batch_size: Maximum number of samples in each batch. - drop_last (optional): Whether to drop the last batch when ``batch_size`` does not - evenly divide the length of ``dataset``. - """ - - def __init__(self, dataset: Dataset, batch_size: int, drop_last: bool = False) -> None: - if batch_size <= 0: - raise ValueError('batch size must be greater than 0') - - self._dataset = dataset - self._bsize = batch_size - self._drop = drop_last - - @property - def batch_size(self) -> int: - return self._bsize - - @property - def drop_last(self) -> bool: - return self._drop - - def __getitem__(self, index): - if index >= len(self) or index < -len(self): - raise IndexError('index out of range') - if index < 0: - index += len(self) - - begin = index * self._bsize - end = begin + self._bsize - return self._dataset[begin:end] - - def __len__(self) -> int: - q, r = divmod(len(self._dataset), self._bsize) - return q + (1 if q > 0 and not self._drop else 0) - - def to_arrays(self) -> Sequence[np.ndarray]: - """Convert each minibatch into an ndarray. - - Returns: - The sequence of arrays. - """ - ts = [] - for b in self: - t = np.array(b, np.int32) - ts.append(t) - return ts - - -class StreamBatches(BatchesABC, Iterable[Batch]): - """A class to represent an iterable of minibatches. - - Args: - dataset: Dataset to make batches from. - batch_size: Maximum number of samples in each batch. - drop_last (optional): Whether to drop the last batch when ``batch_size`` does not - evenly divide the length of ``dataset``. - """ - - def __init__( - self, dataset: StreamDataset, batch_size: int, drop_last: bool = False) -> None: - if batch_size <= 0: - raise ValueError('batch size must be greater than 0') - - self._dataset = dataset - self._bsize = batch_size - self._drop = drop_last - - @property - def batch_size(self) -> int: - return self._bsize - - @property - def drop_last(self) -> bool: - return self._drop - - def __iter__(self) -> Iterator[Batch]: - it, exhausted = iter(self._dataset), False - while not exhausted: - batch: list = [] - while not exhausted and len(batch) < self._bsize: - try: - batch.append(next(it)) - except StopIteration: - exhausted = True - if len(batch) == self._bsize or (batch and not self._drop): - yield batch - - def to_arrays(self) -> Iterable[np.ndarray]: - """Convert each minibatch into an ndarray. - - Returns: - The iterable of arrays. - """ - return self._StreamArrays(self) - - class _StreamArrays(Iterable[np.ndarray]): - def __init__(self, bs: 'StreamBatches') -> None: - self._bs = bs - - def __iter__(self) -> Iterator[np.ndarray]: - return (np.array(b, np.int32) for b in self._bs) diff --git a/text2array/datasets.py b/text2array/datasets.py index b5872e5..14b0fb4 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -4,15 +4,16 @@ import abc import random +from .batches import Batch from .samples import Sample class DatasetABC(Iterable[Sample], metaclass=abc.ABCMeta): # pragma: no cover @abc.abstractmethod - def batch(self, batch_size: int) -> Iterator['Batch']: + def batch(self, batch_size: int) -> Iterator[Batch]: pass - def batch_exactly(self, batch_size: int) -> Iterator['Batch']: + def batch_exactly(self, batch_size: int) -> Iterator[Batch]: """Group the samples in the dataset into batches of exact size. If the number of samples is not divisible by ``batch_size``, the last @@ -42,7 +43,7 @@ def __init__(self, samples: Sequence[Sample]) -> None: self._samples = samples - def __getitem__(self, index): + def __getitem__(self, index) -> Sample: return self._samples[index] def __len__(self) -> int: @@ -64,7 +65,7 @@ def shuffle(self) -> 'Dataset': self._shuffle_copy() return self - def batch(self, batch_size: int) -> Iterator['Batch']: + def batch(self, batch_size: int) -> Iterator[Batch]: """Group the samples in the dataset into batches. Args: @@ -109,10 +110,10 @@ def __init__(self, stream: Iterable[Sample]) -> None: self._stream = stream - def __iter__(self) -> Iterator[int]: + def __iter__(self) -> Iterator[Sample]: return iter(self._stream) - def batch(self, batch_size: int) -> Iterator['Batch']: + def batch(self, batch_size: int) -> Iterator[Batch]: """Group the samples in the dataset into batches. Args: @@ -134,7 +135,3 @@ def batch(self, batch_size: int) -> Iterator['Batch']: exhausted = True if batch: yield Batch(batch) - - -# Need to import here to avoid circular dependency -from .batches import Batch diff --git a/text2array/samples.py b/text2array/samples.py index f02b351..4e42f57 100644 --- a/text2array/samples.py +++ b/text2array/samples.py @@ -5,6 +5,7 @@ FieldValue = Union[str, float, int] +# TODO properly ignore abstractmethod from coverage class Sample(metaclass=abc.ABCMeta): @property @abc.abstractmethod From d3fcf77d9b4cc948ca5e7999d21df909ef57eaa7 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 13:29:24 +0700 Subject: [PATCH 047/162] Add tests for when batch size evenly divides number of samples --- tests/test_dataset.py | 8 ++++++++ tests/test_stream_dataset.py | 12 +++++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 727a0bd..1f0e73a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -50,6 +50,14 @@ def test_batch(dataset): assert list(bs_lst[2]) == [dataset[4]] +def test_batch_size_evenly_divides(dataset): + bs = dataset.batch(1) + bs_lst = list(bs) + assert len(bs_lst) == len(dataset) + for i in range(len(bs_lst)): + assert list(bs_lst[i]) == [dataset[i]] + + def test_batch_exactly(dataset): bs = dataset.batch_exactly(2) assert isinstance(bs, Iterator) diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index f121886..5dc8ee2 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -36,6 +36,15 @@ def test_batch(stream_dataset): assert list(bs_lst[2]) == [dat[4]] +def test_batch_size_evenly_divides(stream_dataset): + bs = stream_dataset.batch(1) + dat = list(stream_dataset) + bs_lst = list(bs) + assert len(bs_lst) == len(dat) + for i in range(len(bs_lst)): + assert list(bs_lst[i]) == [dat[i]] + + def test_batch_exactly(stream_dataset): bs = stream_dataset.batch_exactly(2) assert isinstance(bs, Iterator) @@ -47,9 +56,6 @@ def test_batch_exactly(stream_dataset): assert list(bs_lst[1]) == [dat[2], dat[3]] -# TODO add test for when batch size evenly divides length of samples - - def test_batch_nonpositive_batch_size(stream_dataset): with pytest.raises(ValueError) as exc: next(stream_dataset.batch(0)) From c4e255a8c38ad4fbbb3d3f4c2414d95c2b14feaa Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 13:36:51 +0700 Subject: [PATCH 048/162] Properly ignore abstract methods from coverage --- .coveragerc | 6 ++++++ text2array/datasets.py | 2 +- text2array/samples.py | 1 - 3 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..db4d06e --- /dev/null +++ b/.coveragerc @@ -0,0 +1,6 @@ +[report] +exclude_lines = + # Re-enable standard pragma + pragma: no cover + # abstract method + abstractmethod \ No newline at end of file diff --git a/text2array/datasets.py b/text2array/datasets.py index 14b0fb4..209e24c 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -8,7 +8,7 @@ from .samples import Sample -class DatasetABC(Iterable[Sample], metaclass=abc.ABCMeta): # pragma: no cover +class DatasetABC(Iterable[Sample], metaclass=abc.ABCMeta): @abc.abstractmethod def batch(self, batch_size: int) -> Iterator[Batch]: pass diff --git a/text2array/samples.py b/text2array/samples.py index 4e42f57..f02b351 100644 --- a/text2array/samples.py +++ b/text2array/samples.py @@ -5,7 +5,6 @@ FieldValue = Union[str, float, int] -# TODO properly ignore abstractmethod from coverage class Sample(metaclass=abc.ABCMeta): @property @abc.abstractmethod From adfc724a17d77f0a8ac8f453c4b95bb78a528dc7 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 13:40:00 +0700 Subject: [PATCH 049/162] Measure branch coverage as well --- .coveragerc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.coveragerc b/.coveragerc index db4d06e..2a69529 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,6 @@ +[run] +branch = True + [report] exclude_lines = # Re-enable standard pragma From 436886bf6d975d3dc273f552f1e0aa301cba43ae Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 13:43:10 +0700 Subject: [PATCH 050/162] Change Sample class name to SampleABC for consistency --- tests/conftest.py | 4 ++-- text2array/__init__.py | 4 ++-- text2array/batches.py | 8 ++++---- text2array/datasets.py | 14 +++++++------- text2array/samples.py | 2 +- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5b2df53..0257b6a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import pytest -from text2array import Dataset, Sample, StreamDataset +from text2array import Dataset, SampleABC, StreamDataset @pytest.fixture @@ -29,7 +29,7 @@ def stream_dataset(stream): return StreamDataset(stream) -class TestSample(Sample): +class TestSample(SampleABC): def __init__(self, x, y): self.x = x self.y = y diff --git a/text2array/__init__.py b/text2array/__init__.py index 92ebb73..2d2eaef 100644 --- a/text2array/__init__.py +++ b/text2array/__init__.py @@ -1,9 +1,9 @@ from .batches import Batch from .datasets import Dataset, StreamDataset -from .samples import Sample +from .samples import SampleABC __all__ = [ - Sample, + SampleABC, Batch, Dataset, StreamDataset, diff --git a/text2array/batches.py b/text2array/batches.py index 189e8d0..3f6e6db 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -1,13 +1,13 @@ from typing import Sequence -from .samples import FieldValue, Sample +from .samples import FieldValue, SampleABC -class Batch(Sequence[Sample]): - def __init__(self, samples: Sequence[Sample]) -> None: +class Batch(Sequence[SampleABC]): + def __init__(self, samples: Sequence[SampleABC]) -> None: self._samples = samples - def __getitem__(self, index) -> Sample: + def __getitem__(self, index) -> SampleABC: return self._samples[index] def __len__(self) -> int: diff --git a/text2array/datasets.py b/text2array/datasets.py index 209e24c..9f73873 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -5,10 +5,10 @@ import random from .batches import Batch -from .samples import Sample +from .samples import SampleABC -class DatasetABC(Iterable[Sample], metaclass=abc.ABCMeta): +class DatasetABC(Iterable[SampleABC], metaclass=abc.ABCMeta): @abc.abstractmethod def batch(self, batch_size: int) -> Iterator[Batch]: pass @@ -28,7 +28,7 @@ def batch_exactly(self, batch_size: int) -> Iterator[Batch]: return (b for b in self.batch(batch_size) if len(b) == batch_size) -class Dataset(DatasetABC, Sequence[Sample]): +class Dataset(DatasetABC, Sequence[SampleABC]): """A dataset that fits in memory (no streaming). Args: @@ -37,13 +37,13 @@ class Dataset(DatasetABC, Sequence[Sample]): :obj:`slice` object. """ - def __init__(self, samples: Sequence[Sample]) -> None: + def __init__(self, samples: Sequence[SampleABC]) -> None: if not isinstance(samples, SequenceABC): raise TypeError('"samples" is not a sequence') self._samples = samples - def __getitem__(self, index) -> Sample: + def __getitem__(self, index) -> SampleABC: return self._samples[index] def __len__(self) -> int: @@ -104,13 +104,13 @@ class StreamDataset(DatasetABC): stream: Stream of examples the dataset should stream from. """ - def __init__(self, stream: Iterable[Sample]) -> None: + def __init__(self, stream: Iterable[SampleABC]) -> None: if not isinstance(stream, IterableABC): raise TypeError('"stream" is not iterable') self._stream = stream - def __iter__(self) -> Iterator[Sample]: + def __iter__(self) -> Iterator[SampleABC]: return iter(self._stream) def batch(self, batch_size: int) -> Iterator[Batch]: diff --git a/text2array/samples.py b/text2array/samples.py index f02b351..1fee2ac 100644 --- a/text2array/samples.py +++ b/text2array/samples.py @@ -5,7 +5,7 @@ FieldValue = Union[str, float, int] -class Sample(metaclass=abc.ABCMeta): +class SampleABC(metaclass=abc.ABCMeta): @property @abc.abstractmethod def fields(self) -> Mapping[FieldName, FieldValue]: From 65fefaf0bac2151c842ae23440f102304f20ebe3 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 13:45:36 +0700 Subject: [PATCH 051/162] Add docstring to Batch class --- text2array/batches.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/text2array/batches.py b/text2array/batches.py index 3f6e6db..195581d 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -4,6 +4,12 @@ class Batch(Sequence[SampleABC]): + """A class to represent a single batch.i + + Args: + samples: Sequence of samples this batch should contain. + """ + def __init__(self, samples: Sequence[SampleABC]) -> None: self._samples = samples From 77d3099a3d4f9ff71ebd50940da98ead224909be Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 14:25:50 +0700 Subject: [PATCH 052/162] Implement Batch.to_array() --- tests/test_batch.py | 29 ++++++++++++++++++++++++++++- text2array/__init__.py | 3 ++- text2array/batches.py | 34 ++++++++++++++++++++++++++++++++-- text2array/samples.py | 4 ++-- 4 files changed, 64 insertions(+), 6 deletions(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index 03013a3..38b6e2a 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -1,8 +1,9 @@ from collections.abc import Sequence +import numpy as np import pytest -from text2array import Batch +from text2array import Batch, BatchArray def test_init(samples): @@ -34,3 +35,29 @@ def test_getattr_invalid_name(batch): with pytest.raises(AttributeError) as exc: batch.z assert "some samples have no field 'z'" in str(exc.value) + + +def test_to_array(batch): + arr = batch.to_array() + assert isinstance(arr, BatchArray) + + assert isinstance(arr.x, np.ndarray) + assert arr.x.tolist() == list(batch.x) + assert isinstance(arr.y, np.ndarray) + assert arr.y.tolist() == list(batch.y) + + +def test_to_array_no_common_field_names(samples): + from text2array import SampleABC + + class FooSample(SampleABC): + @property + def fields(self): + return {'z': 10} + + samples.append(FooSample()) + batch = Batch(samples) + + with pytest.raises(RuntimeError) as exc: + batch.to_array() + assert 'some samples have no common field names with the others' in str(exc.value) diff --git a/text2array/__init__.py b/text2array/__init__.py index 2d2eaef..754319e 100644 --- a/text2array/__init__.py +++ b/text2array/__init__.py @@ -1,10 +1,11 @@ -from .batches import Batch +from .batches import Batch, BatchArray from .datasets import Dataset, StreamDataset from .samples import SampleABC __all__ = [ SampleABC, Batch, + BatchArray, Dataset, StreamDataset, ] diff --git a/text2array/batches.py b/text2array/batches.py index 195581d..c6b990f 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -1,6 +1,8 @@ -from typing import Sequence +from typing import Sequence, Set -from .samples import FieldValue, SampleABC +import numpy as np + +from .samples import FieldName, FieldValue, SampleABC class Batch(Sequence[SampleABC]): @@ -24,3 +26,31 @@ def __getattr__(self, name: str) -> Sequence[FieldValue]: return [s.fields[name] for s in self._samples] except KeyError: raise AttributeError(f"some samples have no field '{name}'") + + def to_array(self) -> 'BatchArray': + """Convert the batch into numpy array. + + Returns: + A :class:`BatchArray` object that has attribute names matching those of + the field names in ``samples``. The value of such attribute is an array + whose first dimension corresponds to the batch size as returned by + ``__len__``. + """ + common: Set[FieldName] = set() + for s in self._samples: + if common: + common.intersection_update(s.fields) + else: + common = set(s.fields) + + if not common: + raise RuntimeError('some samples have no common field names with the others') + + arr = BatchArray() + for name in common: + setattr(arr, name, np.array(getattr(self, name))) + return arr + + +class BatchArray: + pass diff --git a/text2array/samples.py b/text2array/samples.py index 1fee2ac..22dfece 100644 --- a/text2array/samples.py +++ b/text2array/samples.py @@ -1,8 +1,8 @@ -from typing import Mapping, Union +from typing import Mapping import abc FieldName = str -FieldValue = Union[str, float, int] +FieldValue = int class SampleABC(metaclass=abc.ABCMeta): From bdce90a474ff7f43158e19800be32e80f5c9bc50 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 14:28:00 +0700 Subject: [PATCH 053/162] Fix typo --- text2array/batches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text2array/batches.py b/text2array/batches.py index c6b990f..aca47d2 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -6,7 +6,7 @@ class Batch(Sequence[SampleABC]): - """A class to represent a single batch.i + """A class to represent a single batch. Args: samples: Sequence of samples this batch should contain. From b82c82550a64317701cd60b3c44ca26269762fd2 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 14:48:45 +0700 Subject: [PATCH 054/162] Allow FieldValue to be float --- tests/conftest.py | 2 +- tests/test_batch.py | 6 ++++-- text2array/samples.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0257b6a..014ee3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ def setup_rng(): @pytest.fixture def samples(): - return [TestSample(i, i * i) for i in range(5)] + return [TestSample(i, (i + 1) / 3) for i in range(5)] @pytest.fixture diff --git a/tests/test_batch.py b/tests/test_batch.py index 38b6e2a..6f86551 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -28,7 +28,7 @@ def test_getattr(batch): assert isinstance(batch.y, Sequence) assert len(batch.y) == len(batch) for i in range(len(batch)): - assert batch.y[i] == batch[i].y + assert batch.y[i] == pytest.approx(batch[i].y) def test_getattr_invalid_name(batch): @@ -44,7 +44,9 @@ def test_to_array(batch): assert isinstance(arr.x, np.ndarray) assert arr.x.tolist() == list(batch.x) assert isinstance(arr.y, np.ndarray) - assert arr.y.tolist() == list(batch.y) + assert arr.y.shape[0] == len(batch) + for i in range(len(batch)): + assert arr.y[i] == pytest.approx(batch[i].y) def test_to_array_no_common_field_names(samples): diff --git a/text2array/samples.py b/text2array/samples.py index 22dfece..c87994a 100644 --- a/text2array/samples.py +++ b/text2array/samples.py @@ -1,8 +1,8 @@ -from typing import Mapping +from typing import Mapping, Union import abc FieldName = str -FieldValue = int +FieldValue = Union[float, int] class SampleABC(metaclass=abc.ABCMeta): From 371776660a140982e2f55496ebb7ccb142209e25 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 15:17:22 +0700 Subject: [PATCH 055/162] Initial attempt to let FieldValue be str --- tests/conftest.py | 13 +++++++++---- tests/test_batch.py | 18 +++++++++++++++--- text2array/batches.py | 10 +++++++--- text2array/samples.py | 2 +- 4 files changed, 32 insertions(+), 11 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 014ee3a..0a1ca54 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ def setup_rng(): @pytest.fixture def samples(): - return [TestSample(i, (i + 1) / 3) for i in range(5)] + return [TestSample(i, (i + 1) / 3, f'word-{i}') for i in range(5)] @pytest.fixture @@ -30,16 +30,21 @@ def stream_dataset(stream): class TestSample(SampleABC): - def __init__(self, x, y): + def __init__(self, x, y, z): self.x = x self.y = y + self.z = z def __lt__(self, s): - return self.x < s.x or (self.x == s.x and self.y < s.y) + if self.x != s.x: + return self.x < s.x + if self.y != s.y: + return self.y < s.y + return self.z < s.z @property def fields(self): - return {'x': self.x, 'y': self.y} + return {'x': self.x, 'y': self.y, 'z': self.z} class Stream: diff --git a/tests/test_batch.py b/tests/test_batch.py index 6f86551..5195474 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -33,8 +33,8 @@ def test_getattr(batch): def test_getattr_invalid_name(batch): with pytest.raises(AttributeError) as exc: - batch.z - assert "some samples have no field 'z'" in str(exc.value) + batch.foo + assert "some samples have no field 'foo'" in str(exc.value) def test_to_array(batch): @@ -47,6 +47,18 @@ def test_to_array(batch): assert arr.y.shape[0] == len(batch) for i in range(len(batch)): assert arr.y[i] == pytest.approx(batch[i].y) + assert isinstance(arr.z, np.ndarray) + assert arr.z.tolist() == list(batch.z) + + +def test_to_array_with_stoi(batch): + stoi = {'word-0': 0, 'word-1': 1, 'word-2': 2, 'word-3': 3, 'word-4': 4} + arr = batch.to_array(stoi=stoi) + assert arr.z.dtype.name.startswith('int') + assert arr.z.tolist() == [stoi[s.z] for s in batch] + + +# TODO what if more than one fields are str? def test_to_array_no_common_field_names(samples): @@ -55,7 +67,7 @@ def test_to_array_no_common_field_names(samples): class FooSample(SampleABC): @property def fields(self): - return {'z': 10} + return {'foo': 10} samples.append(FooSample()) batch = Batch(samples) diff --git a/text2array/batches.py b/text2array/batches.py index aca47d2..25c11e3 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -1,4 +1,4 @@ -from typing import Sequence, Set +from typing import Mapping, Optional, Sequence, Set import numpy as np @@ -27,7 +27,7 @@ def __getattr__(self, name: str) -> Sequence[FieldValue]: except KeyError: raise AttributeError(f"some samples have no field '{name}'") - def to_array(self) -> 'BatchArray': + def to_array(self, stoi: Optional[Mapping[str, int]] = None) -> 'BatchArray': """Convert the batch into numpy array. Returns: @@ -48,7 +48,11 @@ def to_array(self) -> 'BatchArray': arr = BatchArray() for name in common: - setattr(arr, name, np.array(getattr(self, name))) + if type(getattr(self._samples[0], name)) == str and stoi is not None: + vs = [stoi[v] for v in getattr(self, name)] + else: + vs = getattr(self, name) + setattr(arr, name, np.array(vs)) return arr diff --git a/text2array/samples.py b/text2array/samples.py index c87994a..c147f5c 100644 --- a/text2array/samples.py +++ b/text2array/samples.py @@ -2,7 +2,7 @@ import abc FieldName = str -FieldValue = Union[float, int] +FieldValue = Union[float, int, str] class SampleABC(metaclass=abc.ABCMeta): From 2bdd1b1eefba8da4ecb5155cb47fce95f9a1a9d0 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 15:32:47 +0700 Subject: [PATCH 056/162] Handle when there are more than one str fields --- tests/conftest.py | 11 +++++++---- tests/test_batch.py | 18 +++++++++++------- text2array/batches.py | 14 +++++++++++--- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0a1ca54..cf26b44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ def setup_rng(): @pytest.fixture def samples(): - return [TestSample(i, (i + 1) / 3, f'word-{i}') for i in range(5)] + return [TestSample(i, (i + 1) / 3, f'word-{i}', f'token-{i}') for i in range(5)] @pytest.fixture @@ -30,21 +30,24 @@ def stream_dataset(stream): class TestSample(SampleABC): - def __init__(self, x, y, z): + def __init__(self, x, y, z, w): self.x = x self.y = y self.z = z + self.w = w def __lt__(self, s): if self.x != s.x: return self.x < s.x if self.y != s.y: return self.y < s.y - return self.z < s.z + if self.z != s.z: + return self.z < s.z + return self.w < s.w @property def fields(self): - return {'x': self.x, 'y': self.y, 'z': self.z} + return {'x': self.x, 'y': self.y, 'z': self.z, 'w': self.w} class Stream: diff --git a/tests/test_batch.py b/tests/test_batch.py index 5195474..6c798ca 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -51,14 +51,18 @@ def test_to_array(batch): assert arr.z.tolist() == list(batch.z) -def test_to_array_with_stoi(batch): - stoi = {'word-0': 0, 'word-1': 1, 'word-2': 2, 'word-3': 3, 'word-4': 4} - arr = batch.to_array(stoi=stoi) +def test_to_array_with_vocab(batch): + vocab = { + 'z': {v: i + for i, v in enumerate(batch.z)}, + 'w': {v: i + for i, v in enumerate(batch.w)}, + } + arr = batch.to_array(vocab=vocab) assert arr.z.dtype.name.startswith('int') - assert arr.z.tolist() == [stoi[s.z] for s in batch] - - -# TODO what if more than one fields are str? + assert arr.z.tolist() == [vocab['z'][s.z] for s in batch] + assert arr.w.dtype.name.startswith('int') + assert arr.w.tolist() == [vocab['w'][s.w] for s in batch] def test_to_array_no_common_field_names(samples): diff --git a/text2array/batches.py b/text2array/batches.py index 25c11e3..780261b 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -27,7 +27,7 @@ def __getattr__(self, name: str) -> Sequence[FieldValue]: except KeyError: raise AttributeError(f"some samples have no field '{name}'") - def to_array(self, stoi: Optional[Mapping[str, int]] = None) -> 'BatchArray': + def to_array(self, vocab: Optional['Vocab'] = None) -> 'BatchArray': """Convert the batch into numpy array. Returns: @@ -45,16 +45,24 @@ def to_array(self, stoi: Optional[Mapping[str, int]] = None) -> 'BatchArray': if not common: raise RuntimeError('some samples have no common field names with the others') + assert self._samples # if `common` isn't empty, neither is `samples` arr = BatchArray() for name in common: - if type(getattr(self._samples[0], name)) == str and stoi is not None: - vs = [stoi[v] for v in getattr(self, name)] + if type(getattr(self._samples[0], name)) == str and vocab is not None: + stoi = vocab.get(name) + if stoi is not None: + vs = [stoi[v] for v in getattr(self, name)] + else: + vs = getattr(self, name) else: vs = getattr(self, name) setattr(arr, name, np.array(vs)) return arr +Vocab = Mapping[FieldName, Mapping[str, int]] + + class BatchArray: pass From afe221f36d3795bab560c016e06f7eec05815495 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 15:50:09 +0700 Subject: [PATCH 057/162] Add a todo --- text2array/batches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/text2array/batches.py b/text2array/batches.py index 780261b..469ba46 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -28,6 +28,7 @@ def __getattr__(self, name: str) -> Sequence[FieldValue]: raise AttributeError(f"some samples have no field '{name}'") def to_array(self, vocab: Optional['Vocab'] = None) -> 'BatchArray': + # TODO fix docstring """Convert the batch into numpy array. Returns: From 1ef946f6877219e0fc9973531cf8efabdc0d19db Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 15:54:33 +0700 Subject: [PATCH 058/162] Use clearer field names --- tests/conftest.py | 24 ++++++++++++------------ tests/test_batch.py | 42 +++++++++++++++++++++--------------------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index cf26b44..56fec79 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,23 +31,23 @@ def stream_dataset(stream): class TestSample(SampleABC): def __init__(self, x, y, z, w): - self.x = x - self.y = y - self.z = z - self.w = w + self.int_ = x + self.float_ = y + self.str1 = z + self.str2 = w def __lt__(self, s): - if self.x != s.x: - return self.x < s.x - if self.y != s.y: - return self.y < s.y - if self.z != s.z: - return self.z < s.z - return self.w < s.w + if self.int_ != s.int_: + return self.int_ < s.int_ + if self.float_ != s.float_: + return self.float_ < s.float_ + if self.str1 != s.str1: + return self.str1 < s.str1 + return self.str2 < s.str2 @property def fields(self): - return {'x': self.x, 'y': self.y, 'z': self.z, 'w': self.w} + return {'int_': self.int_, 'float_': self.float_, 'str1': self.str1, 'str2': self.str2} class Stream: diff --git a/tests/test_batch.py b/tests/test_batch.py index 6c798ca..933738b 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -20,15 +20,15 @@ def batch(samples): def test_getattr(batch): - assert isinstance(batch.x, Sequence) - assert len(batch.x) == len(batch) + assert isinstance(batch.int_, Sequence) + assert len(batch.int_) == len(batch) for i in range(len(batch)): - assert batch.x[i] == batch[i].x + assert batch.int_[i] == batch[i].int_ - assert isinstance(batch.y, Sequence) - assert len(batch.y) == len(batch) + assert isinstance(batch.float_, Sequence) + assert len(batch.float_) == len(batch) for i in range(len(batch)): - assert batch.y[i] == pytest.approx(batch[i].y) + assert batch.float_[i] == pytest.approx(batch[i].float_) def test_getattr_invalid_name(batch): @@ -41,28 +41,28 @@ def test_to_array(batch): arr = batch.to_array() assert isinstance(arr, BatchArray) - assert isinstance(arr.x, np.ndarray) - assert arr.x.tolist() == list(batch.x) - assert isinstance(arr.y, np.ndarray) - assert arr.y.shape[0] == len(batch) + assert isinstance(arr.int_, np.ndarray) + assert arr.int_.tolist() == list(batch.int_) + assert isinstance(arr.float_, np.ndarray) + assert arr.float_.shape[0] == len(batch) for i in range(len(batch)): - assert arr.y[i] == pytest.approx(batch[i].y) - assert isinstance(arr.z, np.ndarray) - assert arr.z.tolist() == list(batch.z) + assert arr.float_[i] == pytest.approx(batch[i].float_) + assert isinstance(arr.str1, np.ndarray) + assert arr.str1.tolist() == list(batch.str1) def test_to_array_with_vocab(batch): vocab = { - 'z': {v: i - for i, v in enumerate(batch.z)}, - 'w': {v: i - for i, v in enumerate(batch.w)}, + 'str1': {v: i + for i, v in enumerate(batch.str1)}, + 'str2': {v: i + for i, v in enumerate(batch.str2)}, } arr = batch.to_array(vocab=vocab) - assert arr.z.dtype.name.startswith('int') - assert arr.z.tolist() == [vocab['z'][s.z] for s in batch] - assert arr.w.dtype.name.startswith('int') - assert arr.w.tolist() == [vocab['w'][s.w] for s in batch] + assert arr.str1.dtype.name.startswith('int') + assert arr.str1.tolist() == [vocab['str1'][s.str1] for s in batch] + assert arr.str2.dtype.name.startswith('int') + assert arr.str2.tolist() == [vocab['str2'][s.str2] for s in batch] def test_to_array_no_common_field_names(samples): From f83719b854826fe7501854311282bb52f7db7678 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 16:39:04 +0700 Subject: [PATCH 059/162] Don't do vocab in to_array() and use Mapping instead of BatchArray --- tests/conftest.py | 24 +++++++++--------- tests/test_batch.py | 55 +++++++++++++++--------------------------- text2array/__init__.py | 3 +-- text2array/batches.py | 36 ++++++--------------------- 4 files changed, 40 insertions(+), 78 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 56fec79..dd9f155 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,23 +31,23 @@ def stream_dataset(stream): class TestSample(SampleABC): def __init__(self, x, y, z, w): - self.int_ = x - self.float_ = y - self.str1 = z - self.str2 = w + self.x = x + self.y = y + self.z = z + self.w = w def __lt__(self, s): - if self.int_ != s.int_: - return self.int_ < s.int_ - if self.float_ != s.float_: - return self.float_ < s.float_ - if self.str1 != s.str1: - return self.str1 < s.str1 - return self.str2 < s.str2 + if self.x != s.x: + return self.x < s.x + if self.y != s.y: + return self.y < s.y + if self.z != s.z: + return self.z < s.z + return self.w < s.w @property def fields(self): - return {'int_': self.int_, 'float_': self.float_, 'str1': self.str1, 'str2': self.str2} + return {'i': self.x, 'f': self.y, 's1': self.z, 's2': self.w} class Stream: diff --git a/tests/test_batch.py b/tests/test_batch.py index 933738b..cf69e44 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -1,9 +1,9 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence import numpy as np import pytest -from text2array import Batch, BatchArray +from text2array import Batch def test_init(samples): @@ -19,50 +19,34 @@ def batch(samples): return Batch(samples) -def test_getattr(batch): - assert isinstance(batch.int_, Sequence) - assert len(batch.int_) == len(batch) +def test_get(batch): + assert isinstance(batch.get('i'), Sequence) + assert len(batch.get('i')) == len(batch) for i in range(len(batch)): - assert batch.int_[i] == batch[i].int_ + assert batch.get('i')[i] == batch[i].fields['i'] - assert isinstance(batch.float_, Sequence) - assert len(batch.float_) == len(batch) + assert isinstance(batch.get('f'), Sequence) + assert len(batch.get('f')) == len(batch) for i in range(len(batch)): - assert batch.float_[i] == pytest.approx(batch[i].float_) + assert batch.get('f')[i] == pytest.approx(batch[i].fields['f']) -def test_getattr_invalid_name(batch): +def test_get_invalid_name(batch): with pytest.raises(AttributeError) as exc: - batch.foo + batch.get('foo') assert "some samples have no field 'foo'" in str(exc.value) def test_to_array(batch): arr = batch.to_array() - assert isinstance(arr, BatchArray) + assert isinstance(arr, Mapping) - assert isinstance(arr.int_, np.ndarray) - assert arr.int_.tolist() == list(batch.int_) - assert isinstance(arr.float_, np.ndarray) - assert arr.float_.shape[0] == len(batch) + assert isinstance(arr['i'], np.ndarray) + assert arr['i'].tolist() == list(batch.get('i')) + assert isinstance(arr['f'], np.ndarray) + assert arr['f'].shape[0] == len(batch) for i in range(len(batch)): - assert arr.float_[i] == pytest.approx(batch[i].float_) - assert isinstance(arr.str1, np.ndarray) - assert arr.str1.tolist() == list(batch.str1) - - -def test_to_array_with_vocab(batch): - vocab = { - 'str1': {v: i - for i, v in enumerate(batch.str1)}, - 'str2': {v: i - for i, v in enumerate(batch.str2)}, - } - arr = batch.to_array(vocab=vocab) - assert arr.str1.dtype.name.startswith('int') - assert arr.str1.tolist() == [vocab['str1'][s.str1] for s in batch] - assert arr.str2.dtype.name.startswith('int') - assert arr.str2.tolist() == [vocab['str2'][s.str2] for s in batch] + assert arr['f'][i] == pytest.approx(batch[i].fields['f']) def test_to_array_no_common_field_names(samples): @@ -73,8 +57,9 @@ class FooSample(SampleABC): def fields(self): return {'foo': 10} - samples.append(FooSample()) - batch = Batch(samples) + samples_ = list(samples) + samples_.append(FooSample()) + batch = Batch(samples_) with pytest.raises(RuntimeError) as exc: batch.to_array() diff --git a/text2array/__init__.py b/text2array/__init__.py index 754319e..2d2eaef 100644 --- a/text2array/__init__.py +++ b/text2array/__init__.py @@ -1,11 +1,10 @@ -from .batches import Batch, BatchArray +from .batches import Batch from .datasets import Dataset, StreamDataset from .samples import SampleABC __all__ = [ SampleABC, Batch, - BatchArray, Dataset, StreamDataset, ] diff --git a/text2array/batches.py b/text2array/batches.py index 469ba46..05c4f12 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -1,4 +1,4 @@ -from typing import Mapping, Optional, Sequence, Set +from typing import Mapping, Sequence, Set import numpy as np @@ -21,21 +21,18 @@ def __getitem__(self, index) -> SampleABC: def __len__(self) -> int: return len(self._samples) - def __getattr__(self, name: str) -> Sequence[FieldValue]: + def get(self, name: str) -> Sequence[FieldValue]: try: return [s.fields[name] for s in self._samples] except KeyError: raise AttributeError(f"some samples have no field '{name}'") - def to_array(self, vocab: Optional['Vocab'] = None) -> 'BatchArray': - # TODO fix docstring - """Convert the batch into numpy array. + def to_array(self) -> Mapping[FieldName, np.ndarray]: + """Convert the batch into :class:`np.ndarray`. Returns: - A :class:`BatchArray` object that has attribute names matching those of - the field names in ``samples``. The value of such attribute is an array - whose first dimension corresponds to the batch size as returned by - ``__len__``. + A mapping from field names to :class:`np.ndarray`s whose first + dimension corresponds to the batch size as returned by ``__len__``. """ common: Set[FieldName] = set() for s in self._samples: @@ -46,24 +43,5 @@ def to_array(self, vocab: Optional['Vocab'] = None) -> 'BatchArray': if not common: raise RuntimeError('some samples have no common field names with the others') - assert self._samples # if `common` isn't empty, neither is `samples` - arr = BatchArray() - for name in common: - if type(getattr(self._samples[0], name)) == str and vocab is not None: - stoi = vocab.get(name) - if stoi is not None: - vs = [stoi[v] for v in getattr(self, name)] - else: - vs = getattr(self, name) - else: - vs = getattr(self, name) - setattr(arr, name, np.array(vs)) - return arr - - -Vocab = Mapping[FieldName, Mapping[str, int]] - - -class BatchArray: - pass + return {name: np.array(self.get(name)) for name in common} From 595a9cfcc12730bd84a9a2bcdf36747457ad242c Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 17:00:45 +0700 Subject: [PATCH 060/162] Sample is just a mapping from field name to field value --- tests/conftest.py | 25 ++----------------------- tests/test_batch.py | 15 ++++----------- tests/test_dataset.py | 2 +- text2array/__init__.py | 4 ++-- text2array/batches.py | 14 +++++++------- text2array/datasets.py | 14 +++++++------- text2array/samples.py | 9 +-------- 7 files changed, 24 insertions(+), 59 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index dd9f155..150dcad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import pytest -from text2array import Dataset, SampleABC, StreamDataset +from text2array import Dataset, StreamDataset @pytest.fixture @@ -11,7 +11,7 @@ def setup_rng(): @pytest.fixture def samples(): - return [TestSample(i, (i + 1) / 3, f'word-{i}', f'token-{i}') for i in range(5)] + return [{'i': i, 'f': (i + 1) / 3} for i in range(5)] @pytest.fixture @@ -29,27 +29,6 @@ def stream_dataset(stream): return StreamDataset(stream) -class TestSample(SampleABC): - def __init__(self, x, y, z, w): - self.x = x - self.y = y - self.z = z - self.w = w - - def __lt__(self, s): - if self.x != s.x: - return self.x < s.x - if self.y != s.y: - return self.y < s.y - if self.z != s.z: - return self.z < s.z - return self.w < s.w - - @property - def fields(self): - return {'i': self.x, 'f': self.y, 's1': self.z, 's2': self.w} - - class Stream: def __init__(self, samples): self.samples = samples diff --git a/tests/test_batch.py b/tests/test_batch.py index cf69e44..3606cb4 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -23,12 +23,12 @@ def test_get(batch): assert isinstance(batch.get('i'), Sequence) assert len(batch.get('i')) == len(batch) for i in range(len(batch)): - assert batch.get('i')[i] == batch[i].fields['i'] + assert batch.get('i')[i] == batch[i]['i'] assert isinstance(batch.get('f'), Sequence) assert len(batch.get('f')) == len(batch) for i in range(len(batch)): - assert batch.get('f')[i] == pytest.approx(batch[i].fields['f']) + assert batch.get('f')[i] == pytest.approx(batch[i]['f']) def test_get_invalid_name(batch): @@ -46,19 +46,12 @@ def test_to_array(batch): assert isinstance(arr['f'], np.ndarray) assert arr['f'].shape[0] == len(batch) for i in range(len(batch)): - assert arr['f'][i] == pytest.approx(batch[i].fields['f']) + assert arr['f'][i] == pytest.approx(batch[i]['f']) def test_to_array_no_common_field_names(samples): - from text2array import SampleABC - - class FooSample(SampleABC): - @property - def fields(self): - return {'foo': 10} - samples_ = list(samples) - samples_.append(FooSample()) + samples_.append({'foo': 10}) batch = Batch(samples_) with pytest.raises(RuntimeError) as exc: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 1f0e73a..aa0abc3 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -30,7 +30,7 @@ def assert_shuffle(self, dataset): after = list(dataset) assert retval is dataset - assert before != after and sorted(before) == sorted(after) + assert before != after and len(before) == len(after) and all(x in after for x in before) def test_mutable_seq(self, setup_rng, dataset): self.assert_shuffle(dataset) diff --git a/text2array/__init__.py b/text2array/__init__.py index 2d2eaef..92ebb73 100644 --- a/text2array/__init__.py +++ b/text2array/__init__.py @@ -1,9 +1,9 @@ from .batches import Batch from .datasets import Dataset, StreamDataset -from .samples import SampleABC +from .samples import Sample __all__ = [ - SampleABC, + Sample, Batch, Dataset, StreamDataset, diff --git a/text2array/batches.py b/text2array/batches.py index 05c4f12..fea21ad 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -2,20 +2,20 @@ import numpy as np -from .samples import FieldName, FieldValue, SampleABC +from .samples import FieldName, FieldValue, Sample -class Batch(Sequence[SampleABC]): +class Batch(Sequence[Sample]): """A class to represent a single batch. Args: samples: Sequence of samples this batch should contain. """ - def __init__(self, samples: Sequence[SampleABC]) -> None: + def __init__(self, samples: Sequence[Sample]) -> None: self._samples = samples - def __getitem__(self, index) -> SampleABC: + def __getitem__(self, index) -> Sample: return self._samples[index] def __len__(self) -> int: @@ -23,7 +23,7 @@ def __len__(self) -> int: def get(self, name: str) -> Sequence[FieldValue]: try: - return [s.fields[name] for s in self._samples] + return [s[name] for s in self._samples] except KeyError: raise AttributeError(f"some samples have no field '{name}'") @@ -37,9 +37,9 @@ def to_array(self) -> Mapping[FieldName, np.ndarray]: common: Set[FieldName] = set() for s in self._samples: if common: - common.intersection_update(s.fields) + common.intersection_update(s) else: - common = set(s.fields) + common = set(s) if not common: raise RuntimeError('some samples have no common field names with the others') diff --git a/text2array/datasets.py b/text2array/datasets.py index 9f73873..209e24c 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -5,10 +5,10 @@ import random from .batches import Batch -from .samples import SampleABC +from .samples import Sample -class DatasetABC(Iterable[SampleABC], metaclass=abc.ABCMeta): +class DatasetABC(Iterable[Sample], metaclass=abc.ABCMeta): @abc.abstractmethod def batch(self, batch_size: int) -> Iterator[Batch]: pass @@ -28,7 +28,7 @@ def batch_exactly(self, batch_size: int) -> Iterator[Batch]: return (b for b in self.batch(batch_size) if len(b) == batch_size) -class Dataset(DatasetABC, Sequence[SampleABC]): +class Dataset(DatasetABC, Sequence[Sample]): """A dataset that fits in memory (no streaming). Args: @@ -37,13 +37,13 @@ class Dataset(DatasetABC, Sequence[SampleABC]): :obj:`slice` object. """ - def __init__(self, samples: Sequence[SampleABC]) -> None: + def __init__(self, samples: Sequence[Sample]) -> None: if not isinstance(samples, SequenceABC): raise TypeError('"samples" is not a sequence') self._samples = samples - def __getitem__(self, index) -> SampleABC: + def __getitem__(self, index) -> Sample: return self._samples[index] def __len__(self) -> int: @@ -104,13 +104,13 @@ class StreamDataset(DatasetABC): stream: Stream of examples the dataset should stream from. """ - def __init__(self, stream: Iterable[SampleABC]) -> None: + def __init__(self, stream: Iterable[Sample]) -> None: if not isinstance(stream, IterableABC): raise TypeError('"stream" is not iterable') self._stream = stream - def __iter__(self) -> Iterator[SampleABC]: + def __iter__(self) -> Iterator[Sample]: return iter(self._stream) def batch(self, batch_size: int) -> Iterator[Batch]: diff --git a/text2array/samples.py b/text2array/samples.py index c147f5c..cf99b88 100644 --- a/text2array/samples.py +++ b/text2array/samples.py @@ -1,12 +1,5 @@ from typing import Mapping, Union -import abc FieldName = str FieldValue = Union[float, int, str] - - -class SampleABC(metaclass=abc.ABCMeta): - @property - @abc.abstractmethod - def fields(self) -> Mapping[FieldName, FieldValue]: - pass +Sample = Mapping[FieldName, FieldValue] From 13970ac25c4de73741658b53e52c90708956911e Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 18:14:02 +0700 Subject: [PATCH 061/162] Allow FieldValue to be a sequence of floats or ints --- tests/conftest.py | 7 ++++++- tests/test_batch.py | 30 +++++++++++++++++++++++++++++- text2array/batches.py | 29 +++++++++++++++++++++++++++-- text2array/samples.py | 4 ++-- 4 files changed, 64 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 150dcad..6489d0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,12 @@ def setup_rng(): @pytest.fixture def samples(): - return [{'i': i, 'f': (i + 1) / 3} for i in range(5)] + return [{ + 'i': i, + 'f': (i + 1) / 3, + 'is': list(range(i + 1)), + 'fs': [0.6 * (j + 1) for j in range(i + 1)] + } for i in range(5)] @pytest.fixture diff --git a/tests/test_batch.py b/tests/test_batch.py index 3606cb4..383f9a9 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -30,6 +30,17 @@ def test_get(batch): for i in range(len(batch)): assert batch.get('f')[i] == pytest.approx(batch[i]['f']) + assert isinstance(batch.get('is'), Sequence) + assert len(batch.get('is')) == len(batch) + for i in range(len(batch)): + assert list(batch.get('is')[i]) == list(batch[i]['is']) + + assert isinstance(batch.get('fs'), Sequence) + assert len(batch.get('fs')) == len(batch) + for i in range(len(batch)): + for f1, f2 in zip(batch.get('fs')[i], batch[i]['fs']): + assert f1 == pytest.approx(f2) + def test_get_invalid_name(batch): with pytest.raises(AttributeError) as exc: @@ -42,12 +53,29 @@ def test_to_array(batch): assert isinstance(arr, Mapping) assert isinstance(arr['i'], np.ndarray) + assert arr['i'].shape == (len(batch), ) assert arr['i'].tolist() == list(batch.get('i')) + assert isinstance(arr['f'], np.ndarray) - assert arr['f'].shape[0] == len(batch) + assert arr['f'].shape == (len(batch), ) for i in range(len(batch)): assert arr['f'][i] == pytest.approx(batch[i]['f']) + assert isinstance(arr['is'], np.ndarray) + maxlen = max(len(x) for x in batch.get('is')) + assert arr['is'].shape == (len(batch), maxlen) + for r, s in zip(arr['is'], batch): + assert r[:len(s['is'])].tolist() == list(s['is']) + assert all(c == 0 for c in r[len(s['is']):]) + + assert isinstance(arr['fs'], np.ndarray) + maxlen = max(len(x) for x in batch.get('fs')) + assert arr['fs'].shape == (len(batch), maxlen) + for r, s in zip(arr['fs'], batch): + for c, f in zip(r, s['fs']): + assert c == pytest.approx(f) + assert all(c == pytest.approx(0, abs=1e-7) for c in r[len(s['fs']):]) + def test_to_array_no_common_field_names(samples): samples_ = list(samples) diff --git a/text2array/batches.py b/text2array/batches.py index fea21ad..0fa4ed2 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -1,4 +1,5 @@ -from typing import Mapping, Sequence, Set +from collections.abc import Sequence as SequenceABC +from typing import Mapping, Sequence, Set, Union, cast import numpy as np @@ -43,5 +44,29 @@ def to_array(self) -> Mapping[FieldName, np.ndarray]: if not common: raise RuntimeError('some samples have no common field names with the others') + assert self._samples # if `common` isn't empty, neither is `_samples` - return {name: np.array(self.get(name)) for name in common} + arrs = {} + for name in common: + vs = self.get(name) + if isinstance(vs[0], SequenceABC): + vs = cast(Union[Sequence[Sequence[float]], Sequence[Sequence[int]]], vs) + maxlen = max(len(v) for v in vs) + vs = self._pad(vs, maxlen) + arrs[name] = np.array(vs) + + return arrs + + # TODO customize padding token + @staticmethod + def _pad( + vs: Union[Sequence[Sequence[float]], Sequence[Sequence[int]]], + maxlen: int, + ) -> Union[Sequence[Sequence[float]], Sequence[Sequence[int]]]: + res = [] + for v in vs: + v, n = list(v), len(v) + for _ in range(maxlen - n): + v.append(0) + res.append(v) + return res diff --git a/text2array/samples.py b/text2array/samples.py index cf99b88..81ad02c 100644 --- a/text2array/samples.py +++ b/text2array/samples.py @@ -1,5 +1,5 @@ -from typing import Mapping, Union +from typing import Mapping, Sequence, Union FieldName = str -FieldValue = Union[float, int, str] +FieldValue = Union[float, int, Sequence[float], Sequence[int]] Sample = Mapping[FieldName, FieldValue] From ea420905e064fdcff4f3275391927a47b2e65da8 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 18:24:19 +0700 Subject: [PATCH 062/162] Allow custom padding value --- tests/test_batch.py | 8 ++++++++ text2array/batches.py | 11 +++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index 383f9a9..7b773d6 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -77,6 +77,14 @@ def test_to_array(batch): assert all(c == pytest.approx(0, abs=1e-7) for c in r[len(s['fs']):]) +def test_to_array_custom_padding(batch): + arr = batch.to_array(pad_with=1) + for r, s in zip(arr['is'], batch): + assert all(c == 1 for c in r[len(s['is']):]) + for r, s in zip(arr['fs'], batch): + assert all(c == pytest.approx(1) for c in r[len(s['fs']):]) + + def test_to_array_no_common_field_names(samples): samples_ = list(samples) samples_.append({'foo': 10}) diff --git a/text2array/batches.py b/text2array/batches.py index 0fa4ed2..aad621c 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -28,9 +28,12 @@ def get(self, name: str) -> Sequence[FieldValue]: except KeyError: raise AttributeError(f"some samples have no field '{name}'") - def to_array(self) -> Mapping[FieldName, np.ndarray]: + def to_array(self, pad_with: int = 0) -> Mapping[FieldName, np.ndarray]: """Convert the batch into :class:`np.ndarray`. + Args: + pad_with: Pad sequential field values with this number. + Returns: A mapping from field names to :class:`np.ndarray`s whose first dimension corresponds to the batch size as returned by ``__len__``. @@ -52,21 +55,21 @@ def to_array(self) -> Mapping[FieldName, np.ndarray]: if isinstance(vs[0], SequenceABC): vs = cast(Union[Sequence[Sequence[float]], Sequence[Sequence[int]]], vs) maxlen = max(len(v) for v in vs) - vs = self._pad(vs, maxlen) + vs = self._pad(vs, maxlen, with_=pad_with) arrs[name] = np.array(vs) return arrs - # TODO customize padding token @staticmethod def _pad( vs: Union[Sequence[Sequence[float]], Sequence[Sequence[int]]], maxlen: int, + with_: int = 0, ) -> Union[Sequence[Sequence[float]], Sequence[Sequence[int]]]: res = [] for v in vs: v, n = list(v), len(v) for _ in range(maxlen - n): - v.append(0) + v.append(with_) res.append(v) return res From ca2ebeb32dfbf789bced93babd35b1f98ba226ab Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Sun, 27 Jan 2019 20:42:41 +0700 Subject: [PATCH 063/162] Add a todo --- text2array/samples.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/text2array/samples.py b/text2array/samples.py index 81ad02c..c7a79d5 100644 --- a/text2array/samples.py +++ b/text2array/samples.py @@ -3,3 +3,5 @@ FieldName = str FieldValue = Union[float, int, Sequence[float], Sequence[int]] Sample = Mapping[FieldName, FieldValue] + +# TODO handle when field value is a seq of seq of float/int From 752f9b8ac2572a0e89f5d5815fef834304f8fed9 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 06:59:27 +0700 Subject: [PATCH 064/162] Refactor tests for Batch.to_array() into a class --- tests/test_batch.py | 92 +++++++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 45 deletions(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index 7b773d6..4bb7a06 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -48,48 +48,50 @@ def test_get_invalid_name(batch): assert "some samples have no field 'foo'" in str(exc.value) -def test_to_array(batch): - arr = batch.to_array() - assert isinstance(arr, Mapping) - - assert isinstance(arr['i'], np.ndarray) - assert arr['i'].shape == (len(batch), ) - assert arr['i'].tolist() == list(batch.get('i')) - - assert isinstance(arr['f'], np.ndarray) - assert arr['f'].shape == (len(batch), ) - for i in range(len(batch)): - assert arr['f'][i] == pytest.approx(batch[i]['f']) - - assert isinstance(arr['is'], np.ndarray) - maxlen = max(len(x) for x in batch.get('is')) - assert arr['is'].shape == (len(batch), maxlen) - for r, s in zip(arr['is'], batch): - assert r[:len(s['is'])].tolist() == list(s['is']) - assert all(c == 0 for c in r[len(s['is']):]) - - assert isinstance(arr['fs'], np.ndarray) - maxlen = max(len(x) for x in batch.get('fs')) - assert arr['fs'].shape == (len(batch), maxlen) - for r, s in zip(arr['fs'], batch): - for c, f in zip(r, s['fs']): - assert c == pytest.approx(f) - assert all(c == pytest.approx(0, abs=1e-7) for c in r[len(s['fs']):]) - - -def test_to_array_custom_padding(batch): - arr = batch.to_array(pad_with=1) - for r, s in zip(arr['is'], batch): - assert all(c == 1 for c in r[len(s['is']):]) - for r, s in zip(arr['fs'], batch): - assert all(c == pytest.approx(1) for c in r[len(s['fs']):]) - - -def test_to_array_no_common_field_names(samples): - samples_ = list(samples) - samples_.append({'foo': 10}) - batch = Batch(samples_) - - with pytest.raises(RuntimeError) as exc: - batch.to_array() - assert 'some samples have no common field names with the others' in str(exc.value) +class TestToArray: + def test_ok(self, batch): + arr = batch.to_array() + assert isinstance(arr, Mapping) + + assert isinstance(arr['i'], np.ndarray) + assert arr['i'].shape == (len(batch), ) + assert arr['i'].tolist() == list(batch.get('i')) + + assert isinstance(arr['f'], np.ndarray) + assert arr['f'].shape == (len(batch), ) + for i in range(len(batch)): + assert arr['f'][i] == pytest.approx(batch[i]['f']) + + def test_seq(self, batch): + arr = batch.to_array() + + assert isinstance(arr['is'], np.ndarray) + maxlen = max(len(x) for x in batch.get('is')) + assert arr['is'].shape == (len(batch), maxlen) + for r, s in zip(arr['is'], batch): + assert r[:len(s['is'])].tolist() == list(s['is']) + assert all(c == 0 for c in r[len(s['is']):]) + + assert isinstance(arr['fs'], np.ndarray) + maxlen = max(len(x) for x in batch.get('fs')) + assert arr['fs'].shape == (len(batch), maxlen) + for r, s in zip(arr['fs'], batch): + for c, f in zip(r, s['fs']): + assert c == pytest.approx(f) + assert all(c == pytest.approx(0, abs=1e-7) for c in r[len(s['fs']):]) + + def test_custom_padding(self, batch): + arr = batch.to_array(pad_with=1) + for r, s in zip(arr['is'], batch): + assert all(c == 1 for c in r[len(s['is']):]) + for r, s in zip(arr['fs'], batch): + assert all(c == pytest.approx(1) for c in r[len(s['fs']):]) + + def test_no_common_field_names(self, samples): + samples_ = list(samples) + samples_.append({'foo': 10}) + batch = Batch(samples_) + + with pytest.raises(RuntimeError) as exc: + batch.to_array() + assert 'some samples have no common field names with the others' in str(exc.value) From 01e6c55da3358cf8856227d0adba66da8a96651d Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 10:13:07 +0700 Subject: [PATCH 065/162] Support recursive sequences for field values --- tests/test_batch.py | 144 ++++++++++++++++++++++++++++++++++++++++++ text2array/batches.py | 81 +++++++++++++++++------- text2array/samples.py | 8 +-- 3 files changed, 207 insertions(+), 26 deletions(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index 4bb7a06..42587a7 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -80,6 +80,150 @@ def test_seq(self, batch): assert c == pytest.approx(f) assert all(c == pytest.approx(0, abs=1e-7) for c in r[len(s['fs']):]) + def test_seq_of_seq(self): + ss = [ + { + 'iss': [ + [1], + ] + }, + { + 'iss': [ + [1], + [1, 2], + ] + }, + { + 'iss': [ + [1], + [1, 2, 3], + [1, 2], + ] + }, + { + 'iss': [ + [1], + [1, 2], + [1, 2, 3], + [1], + ] + }, + { + 'iss': [ + [1], + [1, 2], + [1, 2, 3], + ] + }, + ] + b = Batch(ss) + arr = b.to_array() + + assert isinstance(arr['iss'], np.ndarray) + assert arr['iss'].shape == (5, 4, 3) + assert arr['iss'][0].tolist() == [ + [1, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + ] + assert arr['iss'][1].tolist() == [ + [1, 0, 0], + [1, 2, 0], + [0, 0, 0], + [0, 0, 0], + ] + assert arr['iss'][2].tolist() == [ + [1, 0, 0], + [1, 2, 3], + [1, 2, 0], + [0, 0, 0], + ] + assert arr['iss'][3].tolist() == [ + [1, 0, 0], + [1, 2, 0], + [1, 2, 3], + [1, 0, 0], + ] + assert arr['iss'][4].tolist() == [ + [1, 0, 0], + [1, 2, 0], + [1, 2, 3], + [0, 0, 0], + ] + + def test_seq_of_seq_of_seq(self): + ss = [ + { + 'isss': [ + [[1], [1, 2]], + [[1], [1, 2], [1, 2]], + ] + }, + { + 'isss': [ + [[1, 2], [1]], + ] + }, + { + 'isss': [ + [[1, 2], [1, 2]], + [[1], [1, 2], [1]], + [[1, 2], [1]], + ] + }, + { + 'isss': [ + [[1]], + [[1], [1, 2]], + [[1], [1, 2]], + ] + }, + { + 'isss': [ + [[1]], + [[1], [1, 2]], + [[1], [1, 2], [1, 2]], + [[1], [1, 2], [1, 2]], + ] + }, + ] + b = Batch(ss) + arr = b.to_array() + + assert isinstance(arr['isss'], np.ndarray) + assert arr['isss'].shape == (5, 4, 3, 2) + assert arr['isss'][0].tolist() == [ + [[1, 0], [1, 2], [0, 0]], + [[1, 0], [1, 2], [1, 2]], + [[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + ] + assert arr['isss'][1].tolist() == [ + [[1, 2], [1, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + ] + assert arr['isss'][2].tolist() == [ + [[1, 2], [1, 2], [0, 0]], + [[1, 0], [1, 2], [1, 0]], + [[1, 2], [1, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + ] + assert arr['isss'][3].tolist() == [ + [[1, 0], [0, 0], [0, 0]], + [[1, 0], [1, 2], [0, 0]], + [[1, 0], [1, 2], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + ] + assert arr['isss'][4].tolist() == [ + [[1, 0], [0, 0], [0, 0]], + [[1, 0], [1, 2], [0, 0]], + [[1, 0], [1, 2], [1, 2]], + [[1, 0], [1, 2], [1, 2]], + ] + def test_custom_padding(self, batch): arr = batch.to_array(pad_with=1) for r, s in zip(arr['is'], batch): diff --git a/text2array/batches.py b/text2array/batches.py index aad621c..31f5922 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -1,5 +1,6 @@ from collections.abc import Sequence as SequenceABC -from typing import Mapping, Sequence, Set, Union, cast +from functools import reduce +from typing import Any, List, Mapping, Sequence, Set, Union import numpy as np @@ -49,27 +50,63 @@ def to_array(self, pad_with: int = 0) -> Mapping[FieldName, np.ndarray]: raise RuntimeError('some samples have no common field names with the others') assert self._samples # if `common` isn't empty, neither is `_samples` - arrs = {} + arr = {} for name in common: - vs = self.get(name) - if isinstance(vs[0], SequenceABC): - vs = cast(Union[Sequence[Sequence[float]], Sequence[Sequence[int]]], vs) - maxlen = max(len(v) for v in vs) - vs = self._pad(vs, maxlen, with_=pad_with) - arrs[name] = np.array(vs) - - return arrs + data = self.get(name) + # Get max length for all depths, 1st elem is batch size + maxlens = self._get_maxlens(data) + # Get padding for all depths + paddings = self._get_paddings(maxlens, pad_with) + # Pad the data + data = self._pad(data, maxlens, paddings, 0) + + arr[name] = np.array(data) + + return arr + + @classmethod + def _get_maxlens(cls, data: Sequence[Any]) -> List[int]: + assert data + + # Base case + if not isinstance(data[0], SequenceABC): + return [len(data)] + + # Recursive case + maxlenss = [cls._get_maxlens(x) for x in data] + assert all(len(x) == len(maxlenss[0]) for x in maxlenss) + + maxlens = reduce(lambda ml1, ml2: [max(l1, l2) for l1, l2 in zip(ml1, ml2)], maxlenss) + maxlens.insert(0, len(data)) + return maxlens + + @classmethod + def _get_paddings(cls, maxlens: List[int], with_: int) -> List[Union[int, List[int]]]: + res: list = [with_] + for maxlen in reversed(maxlens[1:]): + res.append([res[-1] for _ in range(maxlen)]) + res.reverse() + return res - @staticmethod + @classmethod def _pad( - vs: Union[Sequence[Sequence[float]], Sequence[Sequence[int]]], - maxlen: int, - with_: int = 0, - ) -> Union[Sequence[Sequence[float]], Sequence[Sequence[int]]]: - res = [] - for v in vs: - v, n = list(v), len(v) - for _ in range(maxlen - n): - v.append(with_) - res.append(v) - return res + cls, + data: Sequence[Any], + maxlens: List[int], + paddings: List[Union[int, List[int]]], + depth: int, + ) -> Sequence[Any]: + assert data + assert len(maxlens) == len(paddings) + assert depth < len(maxlens) + + # Base case + if not isinstance(data[0], SequenceABC): + data_ = list(data) + # Recursive case + else: + data_ = [cls._pad(x, maxlens, paddings, depth + 1) for x in data] + + for _ in range(maxlens[depth] - len(data)): + data_.append(paddings[depth]) + return data_ diff --git a/text2array/samples.py b/text2array/samples.py index c7a79d5..79f8fd1 100644 --- a/text2array/samples.py +++ b/text2array/samples.py @@ -1,7 +1,7 @@ from typing import Mapping, Sequence, Union +# TODO remove these "type ignore" once mypy supports recursive types +# see: https://github.com/python/mypy/issues/731 FieldName = str -FieldValue = Union[float, int, Sequence[float], Sequence[int]] -Sample = Mapping[FieldName, FieldValue] - -# TODO handle when field value is a seq of seq of float/int +FieldValue = Union[float, int, Sequence['FieldValue']] # type: ignore +Sample = Mapping[FieldName, FieldValue] # type: ignore From 46923711d280386b2218fed80816a9405a737fa2 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 10:25:38 +0700 Subject: [PATCH 066/162] Explicitly setup sequences in test_batch.py --- tests/conftest.py | 7 +------ tests/test_batch.py | 47 +++++++++++++++------------------------------ 2 files changed, 16 insertions(+), 38 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6489d0d..150dcad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,12 +11,7 @@ def setup_rng(): @pytest.fixture def samples(): - return [{ - 'i': i, - 'f': (i + 1) / 3, - 'is': list(range(i + 1)), - 'fs': [0.6 * (j + 1) for j in range(i + 1)] - } for i in range(5)] + return [{'i': i, 'f': (i + 1) / 3} for i in range(5)] @pytest.fixture diff --git a/tests/test_batch.py b/tests/test_batch.py index 42587a7..1769f43 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -30,17 +30,6 @@ def test_get(batch): for i in range(len(batch)): assert batch.get('f')[i] == pytest.approx(batch[i]['f']) - assert isinstance(batch.get('is'), Sequence) - assert len(batch.get('is')) == len(batch) - for i in range(len(batch)): - assert list(batch.get('is')[i]) == list(batch[i]['is']) - - assert isinstance(batch.get('fs'), Sequence) - assert len(batch.get('fs')) == len(batch) - for i in range(len(batch)): - for f1, f2 in zip(batch.get('fs')[i], batch[i]['fs']): - assert f1 == pytest.approx(f2) - def test_get_invalid_name(batch): with pytest.raises(AttributeError) as exc: @@ -62,23 +51,13 @@ def test_ok(self, batch): for i in range(len(batch)): assert arr['f'][i] == pytest.approx(batch[i]['f']) - def test_seq(self, batch): - arr = batch.to_array() + def test_seq(self): + ss = [{'is': [1, 2]}, {'is': [1]}, {'is': [1, 2, 3]}, {'is': [1, 2]}] + b = Batch(ss) + arr = b.to_array() assert isinstance(arr['is'], np.ndarray) - maxlen = max(len(x) for x in batch.get('is')) - assert arr['is'].shape == (len(batch), maxlen) - for r, s in zip(arr['is'], batch): - assert r[:len(s['is'])].tolist() == list(s['is']) - assert all(c == 0 for c in r[len(s['is']):]) - - assert isinstance(arr['fs'], np.ndarray) - maxlen = max(len(x) for x in batch.get('fs')) - assert arr['fs'].shape == (len(batch), maxlen) - for r, s in zip(arr['fs'], batch): - for c, f in zip(r, s['fs']): - assert c == pytest.approx(f) - assert all(c == pytest.approx(0, abs=1e-7) for c in r[len(s['fs']):]) + assert arr['is'].tolist() == [[1, 2, 0], [1, 0, 0], [1, 2, 3], [1, 2, 0]] def test_seq_of_seq(self): ss = [ @@ -224,12 +203,16 @@ def test_seq_of_seq_of_seq(self): [[1, 0], [1, 2], [1, 2]], ] - def test_custom_padding(self, batch): - arr = batch.to_array(pad_with=1) - for r, s in zip(arr['is'], batch): - assert all(c == 1 for c in r[len(s['is']):]) - for r, s in zip(arr['fs'], batch): - assert all(c == pytest.approx(1) for c in r[len(s['fs']):]) + def test_custom_padding(self): + ss = [{'is': [1]}, {'is': [1, 2]}] + b = Batch(ss) + arr = b.to_array(pad_with=9) + assert arr['is'].tolist() == [[1, 9], [1, 2]] + + ss = [{'iss': [[1, 2], [1]]}, {'iss': [[1]]}] + b = Batch(ss) + arr = b.to_array(pad_with=9) + assert arr['iss'].tolist() == [[[1, 2], [1, 9]], [[1, 9], [9, 9]]] def test_no_common_field_names(self, samples): samples_ = list(samples) From bc0383ac7ca5c87b88200cf86b60050269e8c8b3 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 12:17:22 +0700 Subject: [PATCH 067/162] Add todos --- text2array/datasets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/text2array/datasets.py b/text2array/datasets.py index 209e24c..5590033 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -28,6 +28,7 @@ def batch_exactly(self, batch_size: int) -> Iterator[Batch]: return (b for b in self.batch(batch_size) if len(b) == batch_size) +# TODO implement Vocab class class Dataset(DatasetABC, Sequence[Sample]): """A dataset that fits in memory (no streaming). @@ -49,6 +50,7 @@ def __getitem__(self, index) -> Sample: def __len__(self) -> int: return len(self._samples) + # TODO implement shuffle_by def shuffle(self) -> 'Dataset': """Shuffle the dataset. From 6d911dabbbeb307e6f4214df90f601d0bcde1f39 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 13:08:37 +0700 Subject: [PATCH 068/162] Let Dataset._shuffle_copy() use Dataset._shuffle_inplace() --- text2array/datasets.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/text2array/datasets.py b/text2array/datasets.py index 5590033..286359a 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -93,10 +93,8 @@ def _shuffle_inplace(self) -> None: self._samples[j] = temp def _shuffle_copy(self) -> None: - shuf_indices = list(range(len(self._samples))) - random.shuffle(shuf_indices) - shuf_samples = [self._samples[i] for i in shuf_indices] - self._samples = shuf_samples + self._samples = list(self._samples) + self._shuffle_inplace() class StreamDataset(DatasetABC): From bf637c8653e62445361dde59c356520b6c26d8c1 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 13:59:44 +0700 Subject: [PATCH 069/162] Implement Dataset.shuffle_by() method --- tests/test_dataset.py | 68 ++++++++++++++++++++++++++++++++++-------- text2array/datasets.py | 32 ++++++++++++++++++-- 2 files changed, 85 insertions(+), 15 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index aa0abc3..f1df538 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -20,23 +20,65 @@ def test_init_samples_non_sequence(): class TestShuffle: - @pytest.fixture - def tuple_dataset(self, samples): - return Dataset(tuple(samples)) - - def assert_shuffle(self, dataset): + def test_mutable_seq(self, setup_rng, dataset): before = list(dataset) retval = dataset.shuffle() after = list(dataset) - assert retval is dataset - assert before != after and len(before) == len(after) and all(x in after for x in before) - - def test_mutable_seq(self, setup_rng, dataset): - self.assert_shuffle(dataset) - - def test_immutable_seq(self, setup_rng, tuple_dataset): - self.assert_shuffle(tuple_dataset) + assert_shuffled(before, after) + + def test_immutable_seq(self, setup_rng, samples): + dat = Dataset(tuple(samples)) + before = list(dat) + retval = dat.shuffle() + after = list(dat) + assert retval is dat + assert_shuffled(before, after) + + +class TestShuffleBy: + @staticmethod + def make_dataset(): + return Dataset([{ + 'is': [1, 2, 3] + }, { + 'is': [1] + }, { + 'is': [1, 2] + }, { + 'is': [1, 2, 3, 4, 5] + }, { + 'is': [1, 2, 3, 4] + }]) + + @staticmethod + def key(sample): + return len(sample['is']) + + def test_ok(self, setup_rng): + dat = self.make_dataset() + before = list(dat) + retval = dat.shuffle_by(self.key) + after = list(dat) + assert retval is dat + assert_shuffled(before, after) + + def test_zero_scale(self, setup_rng): + dat = self.make_dataset() + before = list(dat) + dat.shuffle_by(self.key, scale=0.) + after = list(dat) + assert sorted(before, key=self.key) == after + + def test_negative_scale(self, setup_rng): + dat = self.make_dataset() + with pytest.raises(ValueError) as exc: + dat.shuffle_by(self.key, scale=-1) + assert 'scale cannot be less than 0' in str(exc.value) + + +def assert_shuffled(before, after): + assert before != after and len(before) == len(after) and all(x in after for x in before) def test_batch(dataset): diff --git a/text2array/datasets.py b/text2array/datasets.py index 286359a..2e9cc23 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -1,8 +1,9 @@ from collections.abc import \ Iterable as IterableABC, MutableSequence as MutableSequenceABC, Sequence as SequenceABC -from typing import Iterable, Iterator, Sequence +from typing import Callable, Iterable, Iterator, Sequence import abc import random +import statistics as stat from .batches import Batch from .samples import Sample @@ -50,7 +51,6 @@ def __getitem__(self, index) -> Sample: def __len__(self) -> int: return len(self._samples) - # TODO implement shuffle_by def shuffle(self) -> 'Dataset': """Shuffle the dataset. @@ -67,6 +67,34 @@ def shuffle(self) -> 'Dataset': self._shuffle_copy() return self + def shuffle_by(self, key: Callable[[Sample], int], scale: float = 1.) -> 'Dataset': + """Shuffle the dataset by the given key. + + This method essentially performs noisy sorting. The samples in the dataset are + sorted ascending by the value of the given key, plus some random noise. This random + noise is drawn from ``Uniform(-z, z)``, where ``z`` equals ``scale`` times the + standard deviation of key values. This formulation means that ``scale`` regulates + how noisy the sorting is. The larger it is, the more noisy the sorting becomes, i.e. + it resembles random shuffling more closely. In an extreme case where ``scale=0``, + this method just sorts the samples by ``key``. This method is useful when working + with text data, where we want to shuffle the dataset and also minimize padding by + ensuring that sentences of similar lengths are not too far apart. + + Args: + by: Callable to get the key value of a given sample. + scale: Value to regulate the noise of the sorting. Must not be negative. + + Returns: + The dataset object itself (useful for chaining). + """ + if scale < 0: + raise ValueError('scale cannot be less than 0') + + std = stat.stdev(key(s) for s in self._samples) + z = scale * std + self._samples = sorted(self._samples, key=lambda s: key(s) + random.uniform(-z, z)) + return self + def batch(self, batch_size: int) -> Iterator[Batch]: """Group the samples in the dataset into batches. From 53bce6ef91dc1d9b1636cc706cc167a985c5b4c1 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 14:35:39 +0700 Subject: [PATCH 070/162] Remove unnecessary comments in config files --- .flake8 | 1 - .style.yapf | 1 - 2 files changed, 2 deletions(-) diff --git a/.flake8 b/.flake8 index 60e3acd..dfec9d6 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,4 @@ # Config file for flake8 -# Symlink this to ~/.config/flake8 [flake8] ignore = E,W # let yapf handle stylistic issues diff --git a/.style.yapf b/.style.yapf index 6114d2c..13b6369 100644 --- a/.style.yapf +++ b/.style.yapf @@ -1,5 +1,4 @@ # Config file for YAPF Python formatter -# Symlink this to ~/.config/yapf/style [style] based_on_style = pep8 From 43e5edbb5fef7ca945f36f11b896f117cbdced13 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 16:20:55 +0700 Subject: [PATCH 071/162] Add a todo --- text2array/batches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/text2array/batches.py b/text2array/batches.py index 31f5922..2491be9 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -23,6 +23,7 @@ def __getitem__(self, index) -> Sample: def __len__(self) -> int: return len(self._samples) + # TODO rethink if this needs to be public def get(self, name: str) -> Sequence[FieldValue]: try: return [s[name] for s in self._samples] From 51ae4c0c6246e4e41e0829f42fd5ba29a436c1ca Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 16:27:01 +0700 Subject: [PATCH 072/162] Initial implementation of Vocab class --- tests/test_vocab.py | 37 ++++++++++++++++++++++ text2array/__init__.py | 2 ++ text2array/vocab.py | 72 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 tests/test_vocab.py create mode 100644 text2array/vocab.py diff --git a/tests/test_vocab.py b/tests/test_vocab.py new file mode 100644 index 0000000..4c69062 --- /dev/null +++ b/tests/test_vocab.py @@ -0,0 +1,37 @@ +from collections.abc import Mapping, Sequence + +from text2array import Dataset, Vocab + + +def test_vocab_from_dataset(samples): + ss = [{ + 'w': 'three' + }, { + 'w': 'two' + }, { + 'w': 'one' + }, { + 'w': 'two' + }, { + 'w': 'three' + }, { + 'w': 'three' + }] + for s, s_ in zip(ss, samples): + s.update(s_) + + vocab = Vocab.from_dataset(Dataset(ss)) + itos = ' three two'.split() + + assert isinstance(vocab.of('w'), Sequence) + assert len(vocab.of('w')) == len(itos) + for i in range(len(itos)): + vocab.of('w')[i] == itos[i] + + assert isinstance(vocab.of('w').stoi, Mapping) + assert len(vocab.of('w').stoi) == len(vocab.of('w')) + assert set(vocab.of('w').stoi) == set(vocab.of('w')) + for i, s in enumerate(vocab.of('w')): + assert vocab.of('w').stoi[s] == i + assert vocab.of('w').stoi['foo'] == vocab.of('w').stoi[''] + assert vocab.of('w').stoi['bar'] == vocab.of('w').stoi[''] diff --git a/text2array/__init__.py b/text2array/__init__.py index 92ebb73..0fefa6d 100644 --- a/text2array/__init__.py +++ b/text2array/__init__.py @@ -1,10 +1,12 @@ from .batches import Batch from .datasets import Dataset, StreamDataset from .samples import Sample +from .vocab import Vocab __all__ = [ Sample, Batch, Dataset, StreamDataset, + Vocab, ] diff --git a/text2array/vocab.py b/text2array/vocab.py new file mode 100644 index 0000000..25d7847 --- /dev/null +++ b/text2array/vocab.py @@ -0,0 +1,72 @@ +from collections import Counter +from typing import Iterable, Mapping, Sequence + +from .datasets import Dataset +from .samples import FieldName, FieldValue + + +class Vocab: + def __init__(self, mapping: Mapping[FieldName, 'VocabEntry']) -> None: + self._map = mapping + + @classmethod + def from_dataset(cls, dataset: Dataset) -> 'Vocab': + vals = cls._get_values(dataset, 'w') + return cls({'w': VocabEntry.from_iterable(vals)}) + + def of(self, name: str) -> 'VocabEntry': + return self._map[name] + + @staticmethod + def _get_values(dat: Dataset, name: FieldName) -> Sequence[FieldValue]: + return [s[name] for s in dat] + + +# TODO think if this class needs separate test cases +class VocabEntry(Sequence[str]): + def __init__(self, strings: Sequence[str]) -> None: + self._itos = strings + self._stoi = _StringStore.from_itos(strings) + + def __len__(self) -> int: + return len(self._itos) + + def __getitem__(self, index) -> str: + return self._itos[index] + + @property + def stoi(self) -> Mapping[str, int]: + return self._stoi + + @classmethod + def from_iterable(cls, iterable: Iterable[str]) -> 'VocabEntry': + itos = ['', ''] + c = Counter(iterable) + for v, f in c.most_common(): + if f < 2: + break + itos.append(v) + return cls(itos) + + +class _StringStore(Mapping[str, int]): + def __init__(self, mapping: Mapping[str, int]) -> None: + self._map = mapping + + def __len__(self): + return len(self._map) + + def __iter__(self): + return iter(self._map) + + def __getitem__(self, s): + try: + return self._map[s] + except KeyError: + return 1 + + @classmethod + def from_itos(cls, itos: Sequence[str]) -> '_StringStore': + assert len(set(itos)) == len(itos), 'itos cannot have duplicate strings' + stoi = {s: i for i, s in enumerate(itos)} + return cls(stoi) From 1a44a3118bc428d3052785bddb9842162c3a8976 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 17:51:25 +0700 Subject: [PATCH 073/162] Handle several fields need vocab --- tests/test_vocab.py | 37 +++++++++++++++++++++++++++---------- text2array/batches.py | 1 + text2array/vocab.py | 18 +++++++++++++++--- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 4c69062..52327d6 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -1,25 +1,36 @@ from collections.abc import Mapping, Sequence +import pytest + from text2array import Dataset, Vocab -def test_vocab_from_dataset(samples): +def test_vocab_from_dataset(): ss = [{ - 'w': 'three' + 'i': 1, + 'w': 'three', + 't': 'three', }, { - 'w': 'two' + 'i': 2, + 'w': 'two', + 't': 'two', }, { - 'w': 'one' + 'i': 3, + 'w': 'one', + 't': 'one', }, { - 'w': 'two' + 'i': 4, + 'w': 'two', + 't': 'two', }, { - 'w': 'three' + 'i': 5, + 'w': 'three', + 't': 'three', }, { - 'w': 'three' + 'i': 6, + 'w': 'three', + 't': 'three', }] - for s, s_ in zip(ss, samples): - s.update(s_) - vocab = Vocab.from_dataset(Dataset(ss)) itos = ' three two'.split() @@ -35,3 +46,9 @@ def test_vocab_from_dataset(samples): assert vocab.of('w').stoi[s] == i assert vocab.of('w').stoi['foo'] == vocab.of('w').stoi[''] assert vocab.of('w').stoi['bar'] == vocab.of('w').stoi[''] + + assert isinstance(vocab.of('t'), Sequence) + assert isinstance(vocab.of('t').stoi, Mapping) + with pytest.raises(RuntimeError) as exc: + vocab.of('i') + assert "no vocabulary found for field name 'i'" in str(exc.value) diff --git a/text2array/batches.py b/text2array/batches.py index 2491be9..9c75c62 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -40,6 +40,7 @@ def to_array(self, pad_with: int = 0) -> Mapping[FieldName, np.ndarray]: A mapping from field names to :class:`np.ndarray`s whose first dimension corresponds to the batch size as returned by ``__len__``. """ + # TODO just assume all samples have the same keys common: Set[FieldName] = set() for s in self._samples: if common: diff --git a/text2array/vocab.py b/text2array/vocab.py index 25d7847..b6ea7c5 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -11,16 +11,28 @@ def __init__(self, mapping: Mapping[FieldName, 'VocabEntry']) -> None: @classmethod def from_dataset(cls, dataset: Dataset) -> 'Vocab': - vals = cls._get_values(dataset, 'w') - return cls({'w': VocabEntry.from_iterable(vals)}) + assert len(dataset) > 0 + m = { + name: VocabEntry.from_iterable(cls._get_values(dataset, name)) + for name, value in dataset[0].items() + if cls._needs_vocab(value) + } + return cls(m) def of(self, name: str) -> 'VocabEntry': - return self._map[name] + try: + return self._map[name] + except KeyError: + raise RuntimeError(f"no vocabulary found for field name '{name}'") @staticmethod def _get_values(dat: Dataset, name: FieldName) -> Sequence[FieldValue]: return [s[name] for s in dat] + @classmethod + def _needs_vocab(cls, val: FieldValue) -> bool: + return isinstance(val, str) + # TODO think if this class needs separate test cases class VocabEntry(Sequence[str]): From 1e56fe5d674faa057c133182d9068b958b9d3b9c Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 18:08:20 +0700 Subject: [PATCH 074/162] Refactor tests --- tests/test_vocab.py | 78 ++++++++++++++++++--------------------------- 1 file changed, 31 insertions(+), 47 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 52327d6..90881c5 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -5,50 +5,34 @@ from text2array import Dataset, Vocab -def test_vocab_from_dataset(): - ss = [{ - 'i': 1, - 'w': 'three', - 't': 'three', - }, { - 'i': 2, - 'w': 'two', - 't': 'two', - }, { - 'i': 3, - 'w': 'one', - 't': 'one', - }, { - 'i': 4, - 'w': 'two', - 't': 'two', - }, { - 'i': 5, - 'w': 'three', - 't': 'three', - }, { - 'i': 6, - 'w': 'three', - 't': 'three', - }] - vocab = Vocab.from_dataset(Dataset(ss)) - itos = ' three two'.split() - - assert isinstance(vocab.of('w'), Sequence) - assert len(vocab.of('w')) == len(itos) - for i in range(len(itos)): - vocab.of('w')[i] == itos[i] - - assert isinstance(vocab.of('w').stoi, Mapping) - assert len(vocab.of('w').stoi) == len(vocab.of('w')) - assert set(vocab.of('w').stoi) == set(vocab.of('w')) - for i, s in enumerate(vocab.of('w')): - assert vocab.of('w').stoi[s] == i - assert vocab.of('w').stoi['foo'] == vocab.of('w').stoi[''] - assert vocab.of('w').stoi['bar'] == vocab.of('w').stoi[''] - - assert isinstance(vocab.of('t'), Sequence) - assert isinstance(vocab.of('t').stoi, Mapping) - with pytest.raises(RuntimeError) as exc: - vocab.of('i') - assert "no vocabulary found for field name 'i'" in str(exc.value) +class TestFromDataset(): + def test_ok(self): + dat = Dataset([{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}]) + vocab = Vocab.from_dataset(dat) + itos = ' c b'.split() + + assert isinstance(vocab.of('w'), Sequence) + assert len(vocab.of('w')) == len(itos) + for i in range(len(itos)): + assert vocab.of('w')[i] == itos[i] + + assert isinstance(vocab.of('w').stoi, Mapping) + assert len(vocab.of('w').stoi) == len(vocab.of('w')) + assert set(vocab.of('w').stoi) == set(vocab.of('w')) + for i, s in enumerate(vocab.of('w')): + assert vocab.of('w').stoi[s] == i + + assert vocab.of('w').stoi['foo'] == vocab.of('w').stoi[''] + assert vocab.of('w').stoi['bar'] == vocab.of('w').stoi[''] + + def test_has_vocab_for_all_str_fields(self): + dat = Dataset([{'w': 'b', 't': 'b'}, {'w': 'b', 't': 'b'}]) + vocab = Vocab.from_dataset(dat) + assert isinstance(vocab.of('t'), Sequence) + assert isinstance(vocab.of('t').stoi, Mapping) + + def test_no_vocab_for_non_str(self): + vocab = Vocab.from_dataset(Dataset([{'i': 10}, {'i': 20}])) + with pytest.raises(RuntimeError) as exc: + vocab.of('i') + assert "no vocabulary found for field name 'i'" in str(exc.value) From 13d753471ac414b0aaa58b381219c64c6ab4592e Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 18:13:57 +0700 Subject: [PATCH 075/162] Make Vocab a mapping --- tests/test_vocab.py | 28 +++++++++++++++------------- text2array/vocab.py | 22 ++++++++++++++-------- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 90881c5..1339f4c 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -9,30 +9,32 @@ class TestFromDataset(): def test_ok(self): dat = Dataset([{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}]) vocab = Vocab.from_dataset(dat) + assert isinstance(vocab, Mapping) + itos = ' c b'.split() - assert isinstance(vocab.of('w'), Sequence) - assert len(vocab.of('w')) == len(itos) + assert isinstance(vocab['w'], Sequence) + assert len(vocab['w']) == len(itos) for i in range(len(itos)): - assert vocab.of('w')[i] == itos[i] + assert vocab['w'][i] == itos[i] - assert isinstance(vocab.of('w').stoi, Mapping) - assert len(vocab.of('w').stoi) == len(vocab.of('w')) - assert set(vocab.of('w').stoi) == set(vocab.of('w')) - for i, s in enumerate(vocab.of('w')): - assert vocab.of('w').stoi[s] == i + assert isinstance(vocab['w'].stoi, Mapping) + assert len(vocab['w'].stoi) == len(vocab['w']) + assert set(vocab['w'].stoi) == set(vocab['w']) + for i, s in enumerate(vocab['w']): + assert vocab['w'].stoi[s] == i - assert vocab.of('w').stoi['foo'] == vocab.of('w').stoi[''] - assert vocab.of('w').stoi['bar'] == vocab.of('w').stoi[''] + assert vocab['w'].stoi['foo'] == vocab['w'].stoi[''] + assert vocab['w'].stoi['bar'] == vocab['w'].stoi[''] def test_has_vocab_for_all_str_fields(self): dat = Dataset([{'w': 'b', 't': 'b'}, {'w': 'b', 't': 'b'}]) vocab = Vocab.from_dataset(dat) - assert isinstance(vocab.of('t'), Sequence) - assert isinstance(vocab.of('t').stoi, Mapping) + assert isinstance(vocab['t'], Sequence) + assert isinstance(vocab['t'].stoi, Mapping) def test_no_vocab_for_non_str(self): vocab = Vocab.from_dataset(Dataset([{'i': 10}, {'i': 20}])) with pytest.raises(RuntimeError) as exc: - vocab.of('i') + vocab['i'] assert "no vocabulary found for field name 'i'" in str(exc.value) diff --git a/text2array/vocab.py b/text2array/vocab.py index b6ea7c5..65e8aba 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -1,14 +1,26 @@ from collections import Counter -from typing import Iterable, Mapping, Sequence +from typing import Iterable, Iterator, Mapping, Sequence from .datasets import Dataset from .samples import FieldName, FieldValue -class Vocab: +class Vocab(Mapping[FieldName, 'VocabEntry']): def __init__(self, mapping: Mapping[FieldName, 'VocabEntry']) -> None: self._map = mapping + def __len__(self) -> int: + return len(self._map) + + def __iter__(self) -> Iterator[FieldName]: + return iter(self._map) + + def __getitem__(self, name: FieldName) -> 'VocabEntry': + try: + return self._map[name] + except KeyError: + raise RuntimeError(f"no vocabulary found for field name '{name}'") + @classmethod def from_dataset(cls, dataset: Dataset) -> 'Vocab': assert len(dataset) > 0 @@ -19,12 +31,6 @@ def from_dataset(cls, dataset: Dataset) -> 'Vocab': } return cls(m) - def of(self, name: str) -> 'VocabEntry': - try: - return self._map[name] - except KeyError: - raise RuntimeError(f"no vocabulary found for field name '{name}'") - @staticmethod def _get_values(dat: Dataset, name: FieldName) -> Sequence[FieldValue]: return [s[name] for s in dat] From 50f7d134bede9fa672ce3a410c8b32e6cd11017c Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 18:15:58 +0700 Subject: [PATCH 076/162] Shorten variable names when the type is clear --- text2array/vocab.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/text2array/vocab.py b/text2array/vocab.py index 65e8aba..55611cd 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -6,18 +6,18 @@ class Vocab(Mapping[FieldName, 'VocabEntry']): - def __init__(self, mapping: Mapping[FieldName, 'VocabEntry']) -> None: - self._map = mapping + def __init__(self, m: Mapping[FieldName, 'VocabEntry']) -> None: + self._m = m def __len__(self) -> int: - return len(self._map) + return len(self._m) def __iter__(self) -> Iterator[FieldName]: - return iter(self._map) + return iter(self._m) def __getitem__(self, name: FieldName) -> 'VocabEntry': try: - return self._map[name] + return self._m[name] except KeyError: raise RuntimeError(f"no vocabulary found for field name '{name}'") @@ -68,18 +68,18 @@ def from_iterable(cls, iterable: Iterable[str]) -> 'VocabEntry': class _StringStore(Mapping[str, int]): - def __init__(self, mapping: Mapping[str, int]) -> None: - self._map = mapping + def __init__(self, m: Mapping[str, int]) -> None: + self._m = m def __len__(self): - return len(self._map) + return len(self._m) def __iter__(self): - return iter(self._map) + return iter(self._m) def __getitem__(self, s): try: - return self._map[s] + return self._m[s] except KeyError: return 1 From 6da51384956dde5f7533d0137ce66a59d38ab7a6 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 18:17:17 +0700 Subject: [PATCH 077/162] Complete type hints in _StringStore class --- text2array/vocab.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/text2array/vocab.py b/text2array/vocab.py index 55611cd..1167cf3 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -71,13 +71,13 @@ class _StringStore(Mapping[str, int]): def __init__(self, m: Mapping[str, int]) -> None: self._m = m - def __len__(self): + def __len__(self) -> int: return len(self._m) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._m) - def __getitem__(self, s): + def __getitem__(self, s: str) -> int: try: return self._m[s] except KeyError: From 99fdf074dc79203765142a4a28ebcf66365ad3da Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 19:23:52 +0700 Subject: [PATCH 078/162] Handle sequence of str field values --- tests/test_vocab.py | 17 +++++++++++++++++ text2array/vocab.py | 15 ++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 1339f4c..a9ef36d 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -38,3 +38,20 @@ def test_no_vocab_for_non_str(self): with pytest.raises(RuntimeError) as exc: vocab['i'] assert "no vocabulary found for field name 'i'" in str(exc.value) + + def test_seq(self): + dat = Dataset([{'ws': ['a', 'c', 'c']}, {'ws': ['b', 'c']}, {'ws': ['b']}]) + vocab = Vocab.from_dataset(dat) + + itos = ' c b'.split() + + assert isinstance(vocab['ws'], Sequence) + assert len(vocab['ws']) == len(itos) + for i in range(len(itos)): + assert vocab['ws'][i] == itos[i] + + assert isinstance(vocab['ws'].stoi, Mapping) + assert len(vocab['ws'].stoi) == len(vocab['ws']) + assert set(vocab['ws'].stoi) == set(vocab['ws']) + for i, s in enumerate(vocab['ws']): + assert vocab['ws'].stoi[s] == i diff --git a/text2array/vocab.py b/text2array/vocab.py index 1167cf3..6e99003 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -1,5 +1,6 @@ from collections import Counter -from typing import Iterable, Iterator, Mapping, Sequence +from collections.abc import Sequence as SequenceABC +from typing import Any, Iterable, Iterator, Mapping, Sequence from .datasets import Dataset from .samples import FieldName, FieldValue @@ -25,7 +26,7 @@ def __getitem__(self, name: FieldName) -> 'VocabEntry': def from_dataset(cls, dataset: Dataset) -> 'Vocab': assert len(dataset) > 0 m = { - name: VocabEntry.from_iterable(cls._get_values(dataset, name)) + name: VocabEntry.from_iterable(cls._flatten(cls._get_values(dataset, name))) for name, value in dataset[0].items() if cls._needs_vocab(value) } @@ -37,7 +38,15 @@ def _get_values(dat: Dataset, name: FieldName) -> Sequence[FieldValue]: @classmethod def _needs_vocab(cls, val: FieldValue) -> bool: - return isinstance(val, str) + if not isinstance(val, SequenceABC): + return isinstance(val, str) + assert len(val) > 0 + return isinstance(val[0], str) + + @staticmethod + def _flatten(xss: Sequence[Sequence[Any]]) -> Iterator[Any]: + for xs in xss: + yield from xs # TODO think if this class needs separate test cases From 5f02c654bedeaad77e5d432fdaf06a8ac7e1317e Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 20:22:42 +0700 Subject: [PATCH 079/162] Handle sequence of sequence field values --- tests/test_vocab.py | 23 +++++++++++++++++++++++ text2array/vocab.py | 32 +++++++++++++++++++++++--------- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index a9ef36d..a29f5f7 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -55,3 +55,26 @@ def test_seq(self): assert set(vocab['ws'].stoi) == set(vocab['ws']) for i, s in enumerate(vocab['ws']): assert vocab['ws'].stoi[s] == i + + def test_seq_of_seq(self): + dat = Dataset([{ + 'cs': [['c', 'd'], ['a', 'd']] + }, { + 'cs': [['c'], ['b'], ['b', 'd']] + }, { + 'cs': [['d', 'c']] + }]) + vocab = Vocab.from_dataset(dat) + + itos = ' d c b'.split() + + assert isinstance(vocab['cs'], Sequence) + assert len(vocab['cs']) == len(itos) + for i in range(len(itos)): + assert vocab['cs'][i] == itos[i] + + assert isinstance(vocab['cs'].stoi, Mapping) + assert len(vocab['cs'].stoi) == len(vocab['cs']) + assert set(vocab['cs'].stoi) == set(vocab['cs']) + for i, s in enumerate(vocab['cs']): + assert vocab['cs'].stoi[s] == i diff --git a/text2array/vocab.py b/text2array/vocab.py index 6e99003..189bc4d 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -1,6 +1,6 @@ from collections import Counter from collections.abc import Sequence as SequenceABC -from typing import Any, Iterable, Iterator, Mapping, Sequence +from typing import Iterable, Iterator, Mapping, Sequence from .datasets import Dataset from .samples import FieldName, FieldValue @@ -38,15 +38,26 @@ def _get_values(dat: Dataset, name: FieldName) -> Sequence[FieldValue]: @classmethod def _needs_vocab(cls, val: FieldValue) -> bool: - if not isinstance(val, SequenceABC): - return isinstance(val, str) - assert len(val) > 0 - return isinstance(val[0], str) + if isinstance(val, str): + return True + if isinstance(val, SequenceABC): + assert len(val) > 0 + return cls._needs_vocab(val[0]) + return False - @staticmethod - def _flatten(xss: Sequence[Sequence[Any]]) -> Iterator[Any]: - for xs in xss: - yield from xs + @classmethod + def _flatten(cls, xs): + if isinstance(xs, str): + yield xs + return + + try: + iter(xs) + except TypeError: + yield xs + else: + for x in xs: + yield from cls._flatten(x) # TODO think if this class needs separate test cases @@ -75,6 +86,9 @@ def from_iterable(cls, iterable: Iterable[str]) -> 'VocabEntry': itos.append(v) return cls(itos) + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self._itos!r})' + class _StringStore(Mapping[str, int]): def __init__(self, m: Mapping[str, int]) -> None: From 8206507c3e240a3aba73309a898d4362b187c7ea Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 20:25:58 +0700 Subject: [PATCH 080/162] Improve coverage --- .coveragerc | 4 +++- tests/test_vocab.py | 4 +++- text2array/vocab.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.coveragerc b/.coveragerc index 2a69529..f77eabf 100644 --- a/.coveragerc +++ b/.coveragerc @@ -6,4 +6,6 @@ exclude_lines = # Re-enable standard pragma pragma: no cover # abstract method - abstractmethod \ No newline at end of file + abstractmethod + # debugging stuff + __repr__ diff --git a/tests/test_vocab.py b/tests/test_vocab.py index a29f5f7..836914a 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -9,10 +9,12 @@ class TestFromDataset(): def test_ok(self): dat = Dataset([{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}]) vocab = Vocab.from_dataset(dat) + assert isinstance(vocab, Mapping) + assert len(vocab) == 1 + assert list(vocab) == ['w'] itos = ' c b'.split() - assert isinstance(vocab['w'], Sequence) assert len(vocab['w']) == len(itos) for i in range(len(itos)): diff --git a/text2array/vocab.py b/text2array/vocab.py index 189bc4d..9e6ccb0 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -53,7 +53,7 @@ def _flatten(cls, xs): try: iter(xs) - except TypeError: + except TypeError: # pragma: no cover yield xs else: for x in xs: From 7a5215e6f0c151b3d30f4d93a330e73de7205e1b Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 20:48:58 +0700 Subject: [PATCH 081/162] Relax so vocab can be built from from just iterable of samples --- tests/test_vocab.py | 24 ++++++++++++------------ text2array/vocab.py | 21 +++++++++++++-------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 836914a..4b264ea 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -2,13 +2,13 @@ import pytest -from text2array import Dataset, Vocab +from text2array import Vocab -class TestFromDataset(): +class TestFromSamples(): def test_ok(self): - dat = Dataset([{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}]) - vocab = Vocab.from_dataset(dat) + ss = [{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}] + vocab = Vocab.from_samples(ss) assert isinstance(vocab, Mapping) assert len(vocab) == 1 @@ -30,20 +30,20 @@ def test_ok(self): assert vocab['w'].stoi['bar'] == vocab['w'].stoi[''] def test_has_vocab_for_all_str_fields(self): - dat = Dataset([{'w': 'b', 't': 'b'}, {'w': 'b', 't': 'b'}]) - vocab = Vocab.from_dataset(dat) + ss = [{'w': 'b', 't': 'b'}, {'w': 'b', 't': 'b'}] + vocab = Vocab.from_samples(ss) assert isinstance(vocab['t'], Sequence) assert isinstance(vocab['t'].stoi, Mapping) def test_no_vocab_for_non_str(self): - vocab = Vocab.from_dataset(Dataset([{'i': 10}, {'i': 20}])) + vocab = Vocab.from_samples([{'i': 10}, {'i': 20}]) with pytest.raises(RuntimeError) as exc: vocab['i'] assert "no vocabulary found for field name 'i'" in str(exc.value) def test_seq(self): - dat = Dataset([{'ws': ['a', 'c', 'c']}, {'ws': ['b', 'c']}, {'ws': ['b']}]) - vocab = Vocab.from_dataset(dat) + ss = [{'ws': ['a', 'c', 'c']}, {'ws': ['b', 'c']}, {'ws': ['b']}] + vocab = Vocab.from_samples(ss) itos = ' c b'.split() @@ -59,14 +59,14 @@ def test_seq(self): assert vocab['ws'].stoi[s] == i def test_seq_of_seq(self): - dat = Dataset([{ + ss = [{ 'cs': [['c', 'd'], ['a', 'd']] }, { 'cs': [['c'], ['b'], ['b', 'd']] }, { 'cs': [['d', 'c']] - }]) - vocab = Vocab.from_dataset(dat) + }] + vocab = Vocab.from_samples(ss) itos = ' d c b'.split() diff --git a/text2array/vocab.py b/text2array/vocab.py index 9e6ccb0..e921d1b 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -2,8 +2,7 @@ from collections.abc import Sequence as SequenceABC from typing import Iterable, Iterator, Mapping, Sequence -from .datasets import Dataset -from .samples import FieldName, FieldValue +from .samples import FieldName, FieldValue, Sample class Vocab(Mapping[FieldName, 'VocabEntry']): @@ -22,19 +21,25 @@ def __getitem__(self, name: FieldName) -> 'VocabEntry': except KeyError: raise RuntimeError(f"no vocabulary found for field name '{name}'") + # TODO mention in docstring that samples must be able to be iterated twice @classmethod - def from_dataset(cls, dataset: Dataset) -> 'Vocab': - assert len(dataset) > 0 + def from_samples(cls, samples: Iterable[Sample]) -> 'Vocab': + # TODO handle when samples is empty m = { - name: VocabEntry.from_iterable(cls._flatten(cls._get_values(dataset, name))) - for name, value in dataset[0].items() + name: VocabEntry.from_iterable(cls._flatten(cls._get_values(samples, name))) + for name, value in cls._head(samples).items() if cls._needs_vocab(value) } return cls(m) @staticmethod - def _get_values(dat: Dataset, name: FieldName) -> Sequence[FieldValue]: - return [s[name] for s in dat] + def _head(x): + return next(iter(x)) + + # TODO return iterable + @staticmethod + def _get_values(ss: Iterable[Sample], name: FieldName) -> Sequence[FieldValue]: + return [s[name] for s in ss] @classmethod def _needs_vocab(cls, val: FieldValue) -> bool: From ca5e4b3bf8641a56cd4d4c79b730ea8b804393d5 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 20:52:26 +0700 Subject: [PATCH 082/162] Let Vocab._get_values return iterator to conserve memory --- text2array/vocab.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/text2array/vocab.py b/text2array/vocab.py index e921d1b..a926e4d 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -36,10 +36,9 @@ def from_samples(cls, samples: Iterable[Sample]) -> 'Vocab': def _head(x): return next(iter(x)) - # TODO return iterable @staticmethod - def _get_values(ss: Iterable[Sample], name: FieldName) -> Sequence[FieldValue]: - return [s[name] for s in ss] + def _get_values(ss: Iterable[Sample], name: FieldName) -> Iterator[FieldValue]: + return (s[name] for s in ss) @classmethod def _needs_vocab(cls, val: FieldValue) -> bool: From 0e20f67477ad5321f65f9d8db9f87bc5e574e9cd Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 20:54:29 +0700 Subject: [PATCH 083/162] Fill some type hints --- text2array/vocab.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/text2array/vocab.py b/text2array/vocab.py index a926e4d..c5ce8d3 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -33,7 +33,7 @@ def from_samples(cls, samples: Iterable[Sample]) -> 'Vocab': return cls(m) @staticmethod - def _head(x): + def _head(x: Iterable[Sample]) -> Sample: return next(iter(x)) @staticmethod @@ -50,7 +50,7 @@ def _needs_vocab(cls, val: FieldValue) -> bool: return False @classmethod - def _flatten(cls, xs): + def _flatten(cls, xs) -> Iterator[str]: if isinstance(xs, str): yield xs return From 1217173d7ccca22eaeb31c6db66e3d2b0f2b3d98 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 20:56:02 +0700 Subject: [PATCH 084/162] Simplify some code --- text2array/vocab.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/text2array/vocab.py b/text2array/vocab.py index c5ce8d3..7b8115d 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -55,13 +55,9 @@ def _flatten(cls, xs) -> Iterator[str]: yield xs return - try: - iter(xs) - except TypeError: # pragma: no cover - yield xs - else: - for x in xs: - yield from cls._flatten(x) + # must be an iterable, due to how we use this function + for x in xs: + yield from cls._flatten(x) # TODO think if this class needs separate test cases From 9860fac7ff5c96a86910a19879318740a89bb92e Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 21:00:03 +0700 Subject: [PATCH 085/162] Handle when samples is empty --- tests/test_vocab.py | 4 ++++ text2array/vocab.py | 8 ++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 4b264ea..34369d4 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -80,3 +80,7 @@ def test_seq_of_seq(self): assert set(vocab['cs'].stoi) == set(vocab['cs']) for i, s in enumerate(vocab['cs']): assert vocab['cs'].stoi[s] == i + + def test_empty_samples(self): + vocab = Vocab.from_samples([]) + assert len(vocab) == 0 diff --git a/text2array/vocab.py b/text2array/vocab.py index 7b8115d..ff5932f 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -24,10 +24,14 @@ def __getitem__(self, name: FieldName) -> 'VocabEntry': # TODO mention in docstring that samples must be able to be iterated twice @classmethod def from_samples(cls, samples: Iterable[Sample]) -> 'Vocab': - # TODO handle when samples is empty + try: + first = cls._head(samples) + except StopIteration: + return cls({}) + m = { name: VocabEntry.from_iterable(cls._flatten(cls._get_values(samples, name))) - for name, value in cls._head(samples).items() + for name, value in first.items() if cls._needs_vocab(value) } return cls(m) From 0a0d3f31aa76e7246bb3801ae78b8cf41900a937 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 21:01:37 +0700 Subject: [PATCH 086/162] Add todos --- text2array/vocab.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/text2array/vocab.py b/text2array/vocab.py index ff5932f..60cc493 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -85,6 +85,7 @@ def from_iterable(cls, iterable: Iterable[str]) -> 'VocabEntry': itos = ['', ''] c = Counter(iterable) for v, f in c.most_common(): + # TODO customize this min count if f < 2: break itos.append(v) @@ -108,6 +109,7 @@ def __getitem__(self, s: str) -> int: try: return self._m[s] except KeyError: + # TODO customize unk id return 1 @classmethod From 31a8fdacc05c94bf6b11a1b38b9fed03f1c3d3f3 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 21:06:23 +0700 Subject: [PATCH 087/162] Turns an assertion into runtime check --- tests/test_vocab.py | 5 +++++ text2array/vocab.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 34369d4..c1e9536 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -84,3 +84,8 @@ def test_seq_of_seq(self): def test_empty_samples(self): vocab = Vocab.from_samples([]) assert len(vocab) == 0 + + def test_empty_field_values(self): + with pytest.raises(ValueError) as exc: + Vocab.from_samples([{'w': []}]) + assert 'field values must not be an empty sequence' in str(exc.value) diff --git a/text2array/vocab.py b/text2array/vocab.py index 60cc493..390f1ed 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -49,7 +49,8 @@ def _needs_vocab(cls, val: FieldValue) -> bool: if isinstance(val, str): return True if isinstance(val, SequenceABC): - assert len(val) > 0 + if not val: + raise ValueError('field values must not be an empty sequence') return cls._needs_vocab(val[0]) return False From d8ca75221ff499afd32c07b188fd0812413d90cb Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 21:14:38 +0700 Subject: [PATCH 088/162] Add todos --- text2array/vocab.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/text2array/vocab.py b/text2array/vocab.py index 390f1ed..28bddc9 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -68,6 +68,7 @@ def _flatten(cls, xs) -> Iterator[str]: # TODO think if this class needs separate test cases class VocabEntry(Sequence[str]): def __init__(self, strings: Sequence[str]) -> None: + # TODO maybe force strings to have no duplicates? self._itos = strings self._stoi = _StringStore.from_itos(strings) @@ -83,6 +84,7 @@ def stoi(self) -> Mapping[str, int]: @classmethod def from_iterable(cls, iterable: Iterable[str]) -> 'VocabEntry': + # TODO customize these tokens itos = ['', ''] c = Counter(iterable) for v, f in c.most_common(): From bafff97ec8e23fcb56a7bd09c3f882c5dcc875e3 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 21:39:34 +0700 Subject: [PATCH 089/162] Write docstrings --- text2array/vocab.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/text2array/vocab.py b/text2array/vocab.py index 28bddc9..9743ede 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -6,6 +6,17 @@ class Vocab(Mapping[FieldName, 'VocabEntry']): + """Vocabulary containing the mapping from string field values to their integer indices. + + A vocabulary does not hold the mapping directly, but rather it stores a mapping from + field names to :class:`VocabEntry` objects. These objects are the one actually holding + the str-to-int mapping for that particular field name. In other words, the actual + vocabulary is stored in :class:`VocabEntry` and namespaced by this :class:`Vocab` object. + + Args: + m: Mapping from :obj:`FieldName` to its :class:`VocabEntry`. + """ + def __init__(self, m: Mapping[FieldName, 'VocabEntry']) -> None: self._m = m @@ -21,9 +32,20 @@ def __getitem__(self, name: FieldName) -> 'VocabEntry': except KeyError: raise RuntimeError(f"no vocabulary found for field name '{name}'") - # TODO mention in docstring that samples must be able to be iterated twice @classmethod def from_samples(cls, samples: Iterable[Sample]) -> 'Vocab': + """Make an instance of this class from an iterable of samples. + + A vocabulary is only made for fields whose value is a string or a (nested) + sequence of strings. It is important that ``samples`` be a true iterable, i.e. + it can be iterated more than once. No exception is raised when this is violated. + + Args: + samples: Iterable of samples. + + Returns: + Vocabulary instance. + """ try: first = cls._head(samples) except StopIteration: @@ -67,6 +89,12 @@ def _flatten(cls, xs) -> Iterator[str]: # TODO think if this class needs separate test cases class VocabEntry(Sequence[str]): + """Vocabulary entry that holds the actual str-to-int/int-to-str mapping. + + Args: + strings: Sequence of distinct strings that serves as the int-to-str mapping. + """ + def __init__(self, strings: Sequence[str]) -> None: # TODO maybe force strings to have no duplicates? self._itos = strings @@ -80,10 +108,19 @@ def __getitem__(self, index) -> str: @property def stoi(self) -> Mapping[str, int]: + """The str-to-int mapping.""" return self._stoi @classmethod def from_iterable(cls, iterable: Iterable[str]) -> 'VocabEntry': + """Make an instance of this class from an iterable of strings. + + Args: + iterable: Iterable of strings. + + Returns: + Vocab entry instance. + """ # TODO customize these tokens itos = ['', ''] c = Counter(iterable) From b2113f10d2d4ea0e034dca88b20ccb05925ddaef Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 21:49:54 +0700 Subject: [PATCH 090/162] Customize min count --- tests/test_vocab.py | 5 +++++ text2array/vocab.py | 15 ++++++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index c1e9536..34242c3 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -89,3 +89,8 @@ def test_empty_field_values(self): with pytest.raises(ValueError) as exc: Vocab.from_samples([{'w': []}]) assert 'field values must not be an empty sequence' in str(exc.value) + + def test_custom_min_count(self): + ss = [{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}] + vocab = Vocab.from_samples(ss, min_count=3) + assert list(vocab['w']) == ['', '', 'c'] diff --git a/text2array/vocab.py b/text2array/vocab.py index 9743ede..07e748a 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -32,8 +32,9 @@ def __getitem__(self, name: FieldName) -> 'VocabEntry': except KeyError: raise RuntimeError(f"no vocabulary found for field name '{name}'") + # TODO limit vocab size @classmethod - def from_samples(cls, samples: Iterable[Sample]) -> 'Vocab': + def from_samples(cls, samples: Iterable[Sample], min_count: int = 2) -> 'Vocab': """Make an instance of this class from an iterable of samples. A vocabulary is only made for fields whose value is a string or a (nested) @@ -42,6 +43,8 @@ def from_samples(cls, samples: Iterable[Sample]) -> 'Vocab': Args: samples: Iterable of samples. + min_count: Remove from the vocabulary string field values occurring fewer + than this number of times. Returns: Vocabulary instance. @@ -52,7 +55,8 @@ def from_samples(cls, samples: Iterable[Sample]) -> 'Vocab': return cls({}) m = { - name: VocabEntry.from_iterable(cls._flatten(cls._get_values(samples, name))) + name: VocabEntry.from_iterable( + cls._flatten(cls._get_values(samples, name)), min_count=min_count) for name, value in first.items() if cls._needs_vocab(value) } @@ -112,11 +116,13 @@ def stoi(self) -> Mapping[str, int]: return self._stoi @classmethod - def from_iterable(cls, iterable: Iterable[str]) -> 'VocabEntry': + def from_iterable(cls, iterable: Iterable[str], min_count: int = 2) -> 'VocabEntry': """Make an instance of this class from an iterable of strings. Args: iterable: Iterable of strings. + min_count: Remove from the vocabulary strings occurring fewer than this number + of times. Returns: Vocab entry instance. @@ -125,8 +131,7 @@ def from_iterable(cls, iterable: Iterable[str]) -> 'VocabEntry': itos = ['', ''] c = Counter(iterable) for v, f in c.most_common(): - # TODO customize this min count - if f < 2: + if f < min_count: break itos.append(v) return cls(itos) From f4a1a1d69faae674d0308ba80d6ac31a54dd95de Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 22:04:21 +0700 Subject: [PATCH 091/162] Fix customizing min count for each field names --- tests/test_vocab.py | 25 ++++++++++++++++++++++--- text2array/vocab.py | 20 ++++++++++++++------ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 34242c3..5e057fc 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -90,7 +90,26 @@ def test_empty_field_values(self): Vocab.from_samples([{'w': []}]) assert 'field values must not be an empty sequence' in str(exc.value) - def test_custom_min_count(self): - ss = [{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}] - vocab = Vocab.from_samples(ss, min_count=3) + def test_min_count(self): + ss = [{ + 'w': 'c', + 't': 'c' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'a', + 't': 'a' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'c', + 't': 'c' + }, { + 'w': 'c', + 't': 'c' + }] + vocab = Vocab.from_samples(ss, ve_kwargs={'w': dict(min_count=3)}) assert list(vocab['w']) == ['', '', 'c'] + assert list(vocab['t']) == ['', '', 'c', 'b'] diff --git a/text2array/vocab.py b/text2array/vocab.py index 07e748a..930d6e7 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -1,6 +1,6 @@ from collections import Counter from collections.abc import Sequence as SequenceABC -from typing import Iterable, Iterator, Mapping, Sequence +from typing import Iterable, Iterator, Mapping, Optional, Sequence from .samples import FieldName, FieldValue, Sample @@ -32,9 +32,13 @@ def __getitem__(self, name: FieldName) -> 'VocabEntry': except KeyError: raise RuntimeError(f"no vocabulary found for field name '{name}'") - # TODO limit vocab size + # TODO limit vocab size for each field name @classmethod - def from_samples(cls, samples: Iterable[Sample], min_count: int = 2) -> 'Vocab': + def from_samples( + cls, + samples: Iterable[Sample], + ve_kwargs: Optional[Mapping[FieldName, dict]] = None, + ) -> 'Vocab': """Make an instance of this class from an iterable of samples. A vocabulary is only made for fields whose value is a string or a (nested) @@ -43,8 +47,9 @@ def from_samples(cls, samples: Iterable[Sample], min_count: int = 2) -> 'Vocab': Args: samples: Iterable of samples. - min_count: Remove from the vocabulary string field values occurring fewer - than this number of times. + ve_kwargs: Mapping from field names to dictionaries. Each dictionary is passed + as keyword arguments to the corresponding :meth:`VocabEntry.from_iterable` + call. Returns: Vocabulary instance. @@ -54,9 +59,12 @@ def from_samples(cls, samples: Iterable[Sample], min_count: int = 2) -> 'Vocab': except StopIteration: return cls({}) + if ve_kwargs is None: + ve_kwargs = {} + m = { name: VocabEntry.from_iterable( - cls._flatten(cls._get_values(samples, name)), min_count=min_count) + cls._flatten(cls._get_values(samples, name)), **ve_kwargs.get(name, {})) for name, value in first.items() if cls._needs_vocab(value) } From eb4032190046fda68abfeb2448e536648cd27d80 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Mon, 28 Jan 2019 22:23:31 +0700 Subject: [PATCH 092/162] Add some todos --- tests/test_dataset.py | 1 + tests/test_vocab.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index f1df538..4d3131c 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -37,6 +37,7 @@ def test_immutable_seq(self, setup_rng, samples): class TestShuffleBy: + # TODO make this a class variable @staticmethod def make_dataset(): return Dataset([{ diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 5e057fc..14ad8b2 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -5,6 +5,7 @@ from text2array import Vocab +# TODO change so vocab['w'] returns the mapping, not sequence class TestFromSamples(): def test_ok(self): ss = [{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}] From 5e9b0b2d63f36d175199f7b34fd2dd2fdfae2de3 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 05:37:52 +0700 Subject: [PATCH 093/162] Change the Vocab API; now it's a mapping to str-to-int mapping --- tests/test_vocab.py | 56 ++++++++---------------- text2array/vocab.py | 103 +++++++++++++------------------------------- 2 files changed, 50 insertions(+), 109 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 14ad8b2..95a32cf 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -1,11 +1,10 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Mapping import pytest from text2array import Vocab -# TODO change so vocab['w'] returns the mapping, not sequence class TestFromSamples(): def test_ok(self): ss = [{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}] @@ -16,29 +15,24 @@ def test_ok(self): assert list(vocab) == ['w'] itos = ' c b'.split() - assert isinstance(vocab['w'], Sequence) + assert isinstance(vocab['w'], Mapping) assert len(vocab['w']) == len(itos) - for i in range(len(itos)): - assert vocab['w'][i] == itos[i] + assert list(vocab['w']) == itos + for i, w in enumerate(itos): + assert vocab['w'][w] == i - assert isinstance(vocab['w'].stoi, Mapping) - assert len(vocab['w'].stoi) == len(vocab['w']) - assert set(vocab['w'].stoi) == set(vocab['w']) - for i, s in enumerate(vocab['w']): - assert vocab['w'].stoi[s] == i - - assert vocab['w'].stoi['foo'] == vocab['w'].stoi[''] - assert vocab['w'].stoi['bar'] == vocab['w'].stoi[''] + assert vocab['w']['foo'] == vocab['w'][''] + assert vocab['w']['bar'] == vocab['w'][''] def test_has_vocab_for_all_str_fields(self): ss = [{'w': 'b', 't': 'b'}, {'w': 'b', 't': 'b'}] vocab = Vocab.from_samples(ss) - assert isinstance(vocab['t'], Sequence) - assert isinstance(vocab['t'].stoi, Mapping) + assert 't' in vocab def test_no_vocab_for_non_str(self): vocab = Vocab.from_samples([{'i': 10}, {'i': 20}]) - with pytest.raises(RuntimeError) as exc: + assert 'i' not in vocab + with pytest.raises(KeyError) as exc: vocab['i'] assert "no vocabulary found for field name 'i'" in str(exc.value) @@ -47,17 +41,11 @@ def test_seq(self): vocab = Vocab.from_samples(ss) itos = ' c b'.split() - - assert isinstance(vocab['ws'], Sequence) + assert isinstance(vocab['ws'], Mapping) assert len(vocab['ws']) == len(itos) - for i in range(len(itos)): - assert vocab['ws'][i] == itos[i] - - assert isinstance(vocab['ws'].stoi, Mapping) - assert len(vocab['ws'].stoi) == len(vocab['ws']) - assert set(vocab['ws'].stoi) == set(vocab['ws']) - for i, s in enumerate(vocab['ws']): - assert vocab['ws'].stoi[s] == i + assert list(vocab['ws']) == itos + for i, w in enumerate(itos): + assert vocab['ws'][w] == i def test_seq_of_seq(self): ss = [{ @@ -70,17 +58,11 @@ def test_seq_of_seq(self): vocab = Vocab.from_samples(ss) itos = ' d c b'.split() - - assert isinstance(vocab['cs'], Sequence) + assert isinstance(vocab['cs'], Mapping) assert len(vocab['cs']) == len(itos) - for i in range(len(itos)): - assert vocab['cs'][i] == itos[i] - - assert isinstance(vocab['cs'].stoi, Mapping) - assert len(vocab['cs'].stoi) == len(vocab['cs']) - assert set(vocab['cs'].stoi) == set(vocab['cs']) - for i, s in enumerate(vocab['cs']): - assert vocab['cs'].stoi[s] == i + assert list(vocab['cs']) == itos + for i, w in enumerate(itos): + assert vocab['cs'][w] == i def test_empty_samples(self): vocab = Vocab.from_samples([]) @@ -111,6 +93,6 @@ def test_min_count(self): 'w': 'c', 't': 'c' }] - vocab = Vocab.from_samples(ss, ve_kwargs={'w': dict(min_count=3)}) + vocab = Vocab.from_samples(ss, options={'w': dict(min_count=3)}) assert list(vocab['w']) == ['', '', 'c'] assert list(vocab['t']) == ['', '', 'c', 'b'] diff --git a/text2array/vocab.py b/text2array/vocab.py index 930d6e7..23c6519 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -1,23 +1,24 @@ -from collections import Counter +from collections import Counter, OrderedDict from collections.abc import Sequence as SequenceABC -from typing import Iterable, Iterator, Mapping, Optional, Sequence +from typing import Iterable, Iterator, Mapping, Optional from .samples import FieldName, FieldValue, Sample -class Vocab(Mapping[FieldName, 'VocabEntry']): - """Vocabulary containing the mapping from string field values to their integer indices. +class Vocab(Mapping[FieldName, Mapping[str, int]]): + """Namespaced vocabulary storing the mapping from field names to their actual vocabulary. - A vocabulary does not hold the mapping directly, but rather it stores a mapping from - field names to :class:`VocabEntry` objects. These objects are the one actually holding - the str-to-int mapping for that particular field name. In other words, the actual - vocabulary is stored in :class:`VocabEntry` and namespaced by this :class:`Vocab` object. + A vocabulary does not hold the str-to-int mapping directly, but rather it stores a mapping + from field names to the corresponding str-to-int mappings. These mappings are the actual + vocabulary for that particular field name. In other words, the actual vocabulary for each + field name is namespaced by the field name and all of them are handled this :class:`Vocab` + object. Args: - m: Mapping from :obj:`FieldName` to its :class:`VocabEntry`. + m: Mapping from :obj:`FieldName` to its str-to-int mapping. """ - def __init__(self, m: Mapping[FieldName, 'VocabEntry']) -> None: + def __init__(self, m: Mapping[FieldName, Mapping[str, int]]) -> None: self._m = m def __len__(self) -> int: @@ -26,18 +27,18 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[FieldName]: return iter(self._m) - def __getitem__(self, name: FieldName) -> 'VocabEntry': + def __getitem__(self, name: FieldName) -> Mapping[str, int]: try: return self._m[name] except KeyError: - raise RuntimeError(f"no vocabulary found for field name '{name}'") + raise KeyError(f"no vocabulary found for field name '{name}'") # TODO limit vocab size for each field name @classmethod def from_samples( cls, samples: Iterable[Sample], - ve_kwargs: Optional[Mapping[FieldName, dict]] = None, + options: Optional[Mapping[FieldName, dict]] = None, ) -> 'Vocab': """Make an instance of this class from an iterable of samples. @@ -47,9 +48,11 @@ def from_samples( Args: samples: Iterable of samples. - ve_kwargs: Mapping from field names to dictionaries. Each dictionary is passed - as keyword arguments to the corresponding :meth:`VocabEntry.from_iterable` - call. + options: Mapping from field names to dictionaries to control the creation of + the str-to-int mapping. Allowed dictionary keys are: + + * ``min_count`` - Exclude strings occurring fewer than this number of times + from the vocabulary. Returns: Vocabulary instance. @@ -59,12 +62,12 @@ def from_samples( except StopIteration: return cls({}) - if ve_kwargs is None: - ve_kwargs = {} + if options is None: + options = {} m = { - name: VocabEntry.from_iterable( - cls._flatten(cls._get_values(samples, name)), **ve_kwargs.get(name, {})) + name: _StringStore._from_iterable( + cls._flatten(cls._get_values(samples, name)), **options.get(name, {})) for name, value in first.items() if cls._needs_vocab(value) } @@ -99,55 +102,6 @@ def _flatten(cls, xs) -> Iterator[str]: yield from cls._flatten(x) -# TODO think if this class needs separate test cases -class VocabEntry(Sequence[str]): - """Vocabulary entry that holds the actual str-to-int/int-to-str mapping. - - Args: - strings: Sequence of distinct strings that serves as the int-to-str mapping. - """ - - def __init__(self, strings: Sequence[str]) -> None: - # TODO maybe force strings to have no duplicates? - self._itos = strings - self._stoi = _StringStore.from_itos(strings) - - def __len__(self) -> int: - return len(self._itos) - - def __getitem__(self, index) -> str: - return self._itos[index] - - @property - def stoi(self) -> Mapping[str, int]: - """The str-to-int mapping.""" - return self._stoi - - @classmethod - def from_iterable(cls, iterable: Iterable[str], min_count: int = 2) -> 'VocabEntry': - """Make an instance of this class from an iterable of strings. - - Args: - iterable: Iterable of strings. - min_count: Remove from the vocabulary strings occurring fewer than this number - of times. - - Returns: - Vocab entry instance. - """ - # TODO customize these tokens - itos = ['', ''] - c = Counter(iterable) - for v, f in c.most_common(): - if f < min_count: - break - itos.append(v) - return cls(itos) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}({self._itos!r})' - - class _StringStore(Mapping[str, int]): def __init__(self, m: Mapping[str, int]) -> None: self._m = m @@ -166,7 +120,12 @@ def __getitem__(self, s: str) -> int: return 1 @classmethod - def from_itos(cls, itos: Sequence[str]) -> '_StringStore': - assert len(set(itos)) == len(itos), 'itos cannot have duplicate strings' - stoi = {s: i for i, s in enumerate(itos)} + def _from_iterable(cls, iterable: Iterable[str], min_count: int = 2) -> '_StringStore': + # TODO customize these tokens + stoi = OrderedDict([('', 0), ('', 1)]) + c = Counter(iterable) + for s, f in c.most_common(): + if f < min_count: + break + stoi[s] = len(stoi) return cls(stoi) From aff3a570820bae375e014fd5dc4044a11d8c8f0c Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 06:01:30 +0700 Subject: [PATCH 094/162] Allow customizing unknown token --- tests/test_vocab.py | 27 +++++++++++++++++++++++++++ text2array/vocab.py | 32 ++++++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 95a32cf..f6e2bb6 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -96,3 +96,30 @@ def test_min_count(self): vocab = Vocab.from_samples(ss, options={'w': dict(min_count=3)}) assert list(vocab['w']) == ['', '', 'c'] assert list(vocab['t']) == ['', '', 'c', 'b'] + + def test_no_unk(self): + ss = [{ + 'w': 'c', + 't': 'c' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'a', + 't': 'a' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'c', + 't': 'c' + }, { + 'w': 'c', + 't': 'c' + }] + vocab = Vocab.from_samples(ss, options={'w': dict(unk=None)}) + assert list(vocab['w']) == ['', 'c', 'b'] + assert list(vocab['t']) == ['', '', 'c', 'b'] + with pytest.raises(KeyError) as exc: + vocab['w']['foo'] + assert "'foo' not found in vocabulary" in str(exc.value) diff --git a/text2array/vocab.py b/text2array/vocab.py index 23c6519..ecb0085 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -51,8 +51,11 @@ def from_samples( options: Mapping from field names to dictionaries to control the creation of the str-to-int mapping. Allowed dictionary keys are: - * ``min_count`` - Exclude strings occurring fewer than this number of times - from the vocabulary. + * min_count(:obj:`int`): Exclude strings occurring fewer than this number of + times from the vocabulary. + * unk(:obj:`str`): String to represent unknown strings with. If ``None``, + no unknown strings are expected. This means when querying the vocabulary + with such string, an error is raised. Returns: Vocabulary instance. @@ -103,8 +106,10 @@ def _flatten(cls, xs) -> Iterator[str]: class _StringStore(Mapping[str, int]): - def __init__(self, m: Mapping[str, int]) -> None: + def __init__(self, m: Mapping[str, int], unk_id: Optional[int] = None) -> None: + assert unk_id is None or unk_id >= 0 self._m = m + self._unk_id = unk_id def __len__(self) -> int: return len(self._m) @@ -116,16 +121,27 @@ def __getitem__(self, s: str) -> int: try: return self._m[s] except KeyError: - # TODO customize unk id - return 1 + if self._unk_id is not None: + return self._unk_id + raise KeyError(f"'{s}' not found in vocabulary") @classmethod - def _from_iterable(cls, iterable: Iterable[str], min_count: int = 2) -> '_StringStore': + def _from_iterable( + cls, + iterable: Iterable[str], + min_count: int = 2, + unk: Optional[str] = '', + ) -> '_StringStore': # TODO customize these tokens - stoi = OrderedDict([('', 0), ('', 1)]) + stoi = OrderedDict([('', 0)]) + if unk is not None: + stoi[unk] = len(stoi) + c = Counter(iterable) for s, f in c.most_common(): if f < min_count: break stoi[s] = len(stoi) - return cls(stoi) + + unk_id = None if unk is None else stoi[unk] + return cls(stoi, unk_id=unk_id) From 0f49d30764efd9e25402bfa64e43385218bb9e89 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 06:12:02 +0700 Subject: [PATCH 095/162] Allow customizing padding token --- tests/test_vocab.py | 24 ++++++++++++++++++++++++ text2array/vocab.py | 25 ++++++++++++++++--------- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index f6e2bb6..7c5a838 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -123,3 +123,27 @@ def test_no_unk(self): with pytest.raises(KeyError) as exc: vocab['w']['foo'] assert "'foo' not found in vocabulary" in str(exc.value) + + def test_no_pad(self): + ss = [{ + 'w': 'c', + 't': 'c' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'a', + 't': 'a' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'c', + 't': 'c' + }, { + 'w': 'c', + 't': 'c' + }] + vocab = Vocab.from_samples(ss, options={'w': dict(pad=None)}) + assert list(vocab['w']) == ['', 'c', 'b'] + assert list(vocab['t']) == ['', '', 'c', 'b'] diff --git a/text2array/vocab.py b/text2array/vocab.py index ecb0085..6b9f509 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -1,6 +1,6 @@ from collections import Counter, OrderedDict from collections.abc import Sequence as SequenceABC -from typing import Iterable, Iterator, Mapping, Optional +from typing import Iterable, Iterator, Mapping, MutableMapping, Optional from .samples import FieldName, FieldValue, Sample @@ -42,8 +42,8 @@ def from_samples( ) -> 'Vocab': """Make an instance of this class from an iterable of samples. - A vocabulary is only made for fields whose value is a string or a (nested) - sequence of strings. It is important that ``samples`` be a true iterable, i.e. + A vocabulary is only made for fields whose value is a string token or a (nested) + sequence of string tokens. It is important that ``samples`` be a true iterable, i.e. it can be iterated more than once. No exception is raised when this is violated. Args: @@ -51,11 +51,16 @@ def from_samples( options: Mapping from field names to dictionaries to control the creation of the str-to-int mapping. Allowed dictionary keys are: - * min_count(:obj:`int`): Exclude strings occurring fewer than this number of + * min_count(:obj:`int`): Exclude tokens occurring fewer than this number of times from the vocabulary. - * unk(:obj:`str`): String to represent unknown strings with. If ``None``, - no unknown strings are expected. This means when querying the vocabulary - with such string, an error is raised. + * pad(:obj:`str`): String token to represent padding tokens. If ``None``, + no padding token is added to the vocabulary. Otherwise, it is the + first entry in the vocabulary (index is 0). + * unk(:obj:`str`): String token to represent unknown tokens with. If ``None``, + no unknown token is added to the vocabulary. This means when querying + the vocabulary with such token, an error is raised. Otherwise, it is the + first entry in the vocabulary *after* the padding token, if any (index is + either 0 or 1). Returns: Vocabulary instance. @@ -131,9 +136,11 @@ def _from_iterable( iterable: Iterable[str], min_count: int = 2, unk: Optional[str] = '', + pad: Optional[str] = '', ) -> '_StringStore': - # TODO customize these tokens - stoi = OrderedDict([('', 0)]) + stoi: MutableMapping[str, int] = OrderedDict() + if pad is not None: + stoi[pad] = len(stoi) if unk is not None: stoi[unk] = len(stoi) From 81ea45bedeb513651c7d56a63adbe0e913ee66ed Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 06:27:43 +0700 Subject: [PATCH 096/162] Allow customizing max vocab size --- tests/test_vocab.py | 24 ++++++++++++++++++++++++ text2array/vocab.py | 9 +++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 7c5a838..d871131 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -147,3 +147,27 @@ def test_no_pad(self): vocab = Vocab.from_samples(ss, options={'w': dict(pad=None)}) assert list(vocab['w']) == ['', 'c', 'b'] assert list(vocab['t']) == ['', '', 'c', 'b'] + + def test_max_size(self): + ss = [{ + 'w': 'c', + 't': 'c' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'a', + 't': 'a' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'c', + 't': 'c' + }, { + 'w': 'c', + 't': 'c' + }] + vocab = Vocab.from_samples(ss, options={'w': dict(max_size=1)}) + assert list(vocab['w']) == ['', '', 'c'] + assert list(vocab['t']) == ['', '', 'c', 'b'] diff --git a/text2array/vocab.py b/text2array/vocab.py index 6b9f509..34a3751 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -33,7 +33,6 @@ def __getitem__(self, name: FieldName) -> Mapping[str, int]: except KeyError: raise KeyError(f"no vocabulary found for field name '{name}'") - # TODO limit vocab size for each field name @classmethod def from_samples( cls, @@ -61,6 +60,10 @@ def from_samples( the vocabulary with such token, an error is raised. Otherwise, it is the first entry in the vocabulary *after* the padding token, if any (index is either 0 or 1). + * max_size(:obj:`int`): Maximum size of the vocabulary, excluding the padding + and unknown tokens. If ``None``, no limit on the vocabulary size. + Otherwise, at most, only this number of most frequent tokens are included + in the vocabulary. Returns: Vocabulary instance. @@ -137,6 +140,7 @@ def _from_iterable( min_count: int = 2, unk: Optional[str] = '', pad: Optional[str] = '', + max_size: Optional[int] = None, ) -> '_StringStore': stoi: MutableMapping[str, int] = OrderedDict() if pad is not None: @@ -144,9 +148,10 @@ def _from_iterable( if unk is not None: stoi[unk] = len(stoi) + n = len(stoi) c = Counter(iterable) for s, f in c.most_common(): - if f < min_count: + if f < min_count or (max_size is not None and len(stoi) - n >= max_size): break stoi[s] = len(stoi) From c0982ef7b41c15dff5c84f6d3e2c9be02dcce51d Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 06:34:45 +0700 Subject: [PATCH 097/162] Make better docstrings for options --- text2array/vocab.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/text2array/vocab.py b/text2array/vocab.py index 34a3751..884a9a5 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -50,20 +50,21 @@ def from_samples( options: Mapping from field names to dictionaries to control the creation of the str-to-int mapping. Allowed dictionary keys are: - * min_count(:obj:`int`): Exclude tokens occurring fewer than this number of - times from the vocabulary. - * pad(:obj:`str`): String token to represent padding tokens. If ``None``, + * ``min_count``(:obj:`int`): Exclude tokens occurring fewer than this number + of times from the vocabulary (default: 2). + * ``pad``(:obj:`str`): String token to represent padding tokens. If ``None``, no padding token is added to the vocabulary. Otherwise, it is the - first entry in the vocabulary (index is 0). - * unk(:obj:`str`): String token to represent unknown tokens with. If ``None``, - no unknown token is added to the vocabulary. This means when querying - the vocabulary with such token, an error is raised. Otherwise, it is the - first entry in the vocabulary *after* the padding token, if any (index is - either 0 or 1). - * max_size(:obj:`int`): Maximum size of the vocabulary, excluding the padding - and unknown tokens. If ``None``, no limit on the vocabulary size. - Otherwise, at most, only this number of most frequent tokens are included - in the vocabulary. + first entry in the vocabulary (index is 0) (default: ````). + * ``unk``(:obj:`str`): String token to represent unknown tokens with. If + ``None``, no unknown token is added to the vocabulary. This means when + querying the vocabulary with such token, an error is raised. Otherwise, + it is the first entry in the vocabulary *after* ``pad``, if any (index is + either 0 or 1) (default: ````). + * ``max_size``(:obj:`int`): Maximum size of the vocabulary, excluding ``pad`` + and ``unk``. If ``None``, no limit on the vocabulary size. Otherwise, at + most, only this number of most frequent tokens are included in the + vocabulary. Note that ``min_count`` also sets the maximum size implicitly. + So, the size is limited by whichever is smaller. (default: ``None``). Returns: Vocabulary instance. From e61b4657520ff208e9212e46d8e59362fa394983 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 06:37:25 +0700 Subject: [PATCH 098/162] Fix a todo --- text2array/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text2array/datasets.py b/text2array/datasets.py index 2e9cc23..2ae4606 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -29,7 +29,7 @@ def batch_exactly(self, batch_size: int) -> Iterator[Batch]: return (b for b in self.batch(batch_size) if len(b) == batch_size) -# TODO implement Vocab class +# TODO implement mapping with a vocab class Dataset(DatasetABC, Sequence[Sample]): """A dataset that fits in memory (no streaming). From 3ba1ea081225992b8f397843156aba3d508bc993 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 07:21:46 +0700 Subject: [PATCH 099/162] Implement Dataset.apply_vocab() --- tests/test_dataset.py | 47 ++++++++++++++++++++++++++++++++++++++++++ text2array/datasets.py | 37 +++++++++++++++++++++++++++++++-- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4d3131c..bfebb30 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -119,3 +119,50 @@ def test_batch_nonpositive_batch_size(dataset): with pytest.raises(ValueError) as exc: next(dataset.batch_exactly(0)) assert 'batch size must be greater than 0' in str(exc.value) + + +def test_apply_vocab(): + dat = Dataset([{ + 'w': 'a', + 'ws': ['a', 'b'], + 'cs': [['a', 'b'], ['b', 'a']], + 'i': 10, + 'j': 20 + }, { + 'w': 'b', + 'ws': ['a', 'a'], + 'cs': [['b', 'b'], ['b', 'a']], + 'i': 10, + 'j': 20 + }]) + vocab = { + 'w': { + 'a': 0, + 'b': 1 + }, + 'ws': { + 'a': 2, + 'b': 3 + }, + 'cs': { + 'a': 4, + 'b': 5 + }, + 'j': { + 20: 2 + } + } + dat = dat.apply_vocab(vocab) + assert list(dat) == [{ + 'w': 0, + 'ws': [2, 3], + 'cs': [[4, 5], [5, 4]], + 'i': 10, + 'j': 2 + }, { + 'w': 1, + 'ws': [2, 2], + 'cs': [[5, 5], [5, 4]], + 'i': 10, + 'j': 2 + }] diff --git a/text2array/datasets.py b/text2array/datasets.py index 2ae4606..7c53382 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -1,12 +1,12 @@ from collections.abc import \ Iterable as IterableABC, MutableSequence as MutableSequenceABC, Sequence as SequenceABC -from typing import Callable, Iterable, Iterator, Sequence +from typing import Callable, Iterable, Iterator, Mapping, Sequence import abc import random import statistics as stat from .batches import Batch -from .samples import Sample +from .samples import FieldName, FieldValue, Sample class DatasetABC(Iterable[Sample], metaclass=abc.ABCMeta): @@ -111,6 +111,30 @@ def batch(self, batch_size: int) -> Iterator[Batch]: end = begin + batch_size yield Batch(self._samples[begin:end]) + def apply_vocab( + self, + vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]], + ) -> 'Dataset': + """Apply a vocabulary to this dataset. + + Applying a vocabulary means mapping all the (nested) field values to the corresponding + values according to the mapping specified by the vocabulary. Field names that have + no entry in the vocabulary are ignored. + + Args: + vocab: Vocabulary to apply. + + Returns: + Dataset after application. + """ + samples = [] + for s in self._samples: + s_ = {} + for name, val in s.items(): + s_[name] = self._apply(vocab[name], val) if name in vocab else val + samples.append(s_) + return Dataset(samples) + def _shuffle_inplace(self) -> None: assert isinstance(self._samples, MutableSequenceABC) n = len(self._samples) @@ -124,6 +148,15 @@ def _shuffle_copy(self) -> None: self._samples = list(self._samples) self._shuffle_inplace() + @classmethod + def _apply(cls, vocab: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: + # TODO handle when val not in vocab + if isinstance(val, str): + return vocab[val] + if isinstance(val, SequenceABC): + return [cls._apply(vocab, v) for v in val] + return vocab[val] + class StreamDataset(DatasetABC): """A dataset that streams its samples. From 6342feec21339f33bfc51f87400422ce705a0ad4 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 07:28:09 +0700 Subject: [PATCH 100/162] Handle when value is not in vocab --- tests/test_dataset.py | 7 +++++++ text2array/datasets.py | 13 ++++++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index bfebb30..3cb7064 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -166,3 +166,10 @@ def test_apply_vocab(): 'i': 10, 'j': 2 }] + + +def test_apply_vocab_value_not_in_vocab(): + dat = Dataset([{'w': 'a', 'i': 10}, {'w': 'b', 'i': 11}]) + vocab = {'w': {'a': 0}, 'i': {10: 1}} + dat = dat.apply_vocab(vocab) + assert list(dat) == [{'w': 0, 'i': 1}, {'w': 'b', 'i': 11}] diff --git a/text2array/datasets.py b/text2array/datasets.py index 7c53382..ad7c4b2 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -150,12 +150,19 @@ def _shuffle_copy(self) -> None: @classmethod def _apply(cls, vocab: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: - # TODO handle when val not in vocab if isinstance(val, str): - return vocab[val] + try: + return vocab[val] + except KeyError: + return val + if isinstance(val, SequenceABC): return [cls._apply(vocab, v) for v in val] - return vocab[val] + + try: + return vocab[val] + except KeyError: + return val class StreamDataset(DatasetABC): From 885da681255ee85f0bbfda9187f02e4c5795bb66 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 07:29:20 +0700 Subject: [PATCH 101/162] Add a todo --- tests/test_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 3cb7064..1487523 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -173,3 +173,6 @@ def test_apply_vocab_value_not_in_vocab(): vocab = {'w': {'a': 0}, 'i': {10: 1}} dat = dat.apply_vocab(vocab) assert list(dat) == [{'w': 0, 'i': 1}, {'w': 'b', 'i': 11}] + + +# TODO add tests with actual vocabulary object From bf5c96dbb2ad100160bf1ba9c161890029bed532 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 21:26:14 +0700 Subject: [PATCH 102/162] Raise error when value is not in vocab --- tests/test_dataset.py | 15 +++++++++++---- text2array/datasets.py | 12 +++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 1487523..9ce0da0 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -169,10 +169,17 @@ def test_apply_vocab(): def test_apply_vocab_value_not_in_vocab(): - dat = Dataset([{'w': 'a', 'i': 10}, {'w': 'b', 'i': 11}]) - vocab = {'w': {'a': 0}, 'i': {10: 1}} - dat = dat.apply_vocab(vocab) - assert list(dat) == [{'w': 0, 'i': 1}, {'w': 'b', 'i': 11}] + dat = Dataset([{'w': 'a'}]) + vocab = {'w': {'b': 0}} + with pytest.raises(KeyError) as exc: + dat.apply_vocab(vocab) + assert "value 'a' not found in vocab" in str(exc.value) + + dat = Dataset([{'w': 10}]) + vocab = {'w': {11: 0}} + with pytest.raises(KeyError) as exc: + dat.apply_vocab(vocab) + assert "value 10 not found in vocab" in str(exc.value) # TODO add tests with actual vocabulary object diff --git a/text2array/datasets.py b/text2array/datasets.py index ad7c4b2..e9f7585 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -150,19 +150,13 @@ def _shuffle_copy(self) -> None: @classmethod def _apply(cls, vocab: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: - if isinstance(val, str): + if isinstance(val, str) or not isinstance(val, SequenceABC): try: return vocab[val] except KeyError: - return val + raise KeyError(f'value {val!r} not found in vocab') - if isinstance(val, SequenceABC): - return [cls._apply(vocab, v) for v in val] - - try: - return vocab[val] - except KeyError: - return val + return [cls._apply(vocab, v) for v in val] class StreamDataset(DatasetABC): From 315fd499286e145f05bc82d66e50da2fc0e6a3da Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 21:26:55 +0700 Subject: [PATCH 103/162] Say "key error" instead of "value not in vocab" --- tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9ce0da0..fe2037e 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -168,7 +168,7 @@ def test_apply_vocab(): }] -def test_apply_vocab_value_not_in_vocab(): +def test_apply_vocab_key_error(): dat = Dataset([{'w': 'a'}]) vocab = {'w': {'b': 0}} with pytest.raises(KeyError) as exc: From d6ae6586586529e0eecb770df6f785fa8d1e73bc Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 21:40:20 +0700 Subject: [PATCH 104/162] Add a test with actual vocab object --- tests/test_dataset.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index fe2037e..beffcdc 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -2,7 +2,7 @@ import pytest -from text2array import Batch, Dataset +from text2array import Batch, Dataset, Vocab def test_init(samples): @@ -182,4 +182,20 @@ def test_apply_vocab_key_error(): assert "value 10 not found in vocab" in str(exc.value) -# TODO add tests with actual vocabulary object +def test_apply_vocab_with_vocab_object(): + dat = Dataset([{ + 'ws': ['a', 'b'], + 'cs': [['a', 'c'], ['c', 'b', 'c']] + }, { + 'ws': ['b'], + 'cs': [['b']] + }]) + v = Vocab.from_samples(dat) + dat = dat.apply_vocab(v) + assert list(dat) == [{ + 'ws': [v['ws']['a'], v['ws']['b']], + 'cs': [[v['cs']['a'], v['cs']['c']], [v['cs']['c'], v['cs']['b'], v['cs']['c']]] + }, { + 'ws': [v['ws']['b']], + 'cs': [[v['cs']['b']]] + }] From 11ff9b84a8d95191a8a393352f2144a73c555f73 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Tue, 29 Jan 2019 21:53:15 +0700 Subject: [PATCH 105/162] Organize tests for Dataset.apply_vocab() in a class --- tests/test_dataset.py | 153 +++++++++++++++++++++--------------------- 1 file changed, 76 insertions(+), 77 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index beffcdc..d1f47cd 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -121,81 +121,80 @@ def test_batch_nonpositive_batch_size(dataset): assert 'batch size must be greater than 0' in str(exc.value) -def test_apply_vocab(): - dat = Dataset([{ - 'w': 'a', - 'ws': ['a', 'b'], - 'cs': [['a', 'b'], ['b', 'a']], - 'i': 10, - 'j': 20 - }, { - 'w': 'b', - 'ws': ['a', 'a'], - 'cs': [['b', 'b'], ['b', 'a']], - 'i': 10, - 'j': 20 - }]) - vocab = { - 'w': { - 'a': 0, - 'b': 1 - }, - 'ws': { - 'a': 2, - 'b': 3 - }, - 'cs': { - 'a': 4, - 'b': 5 - }, - 'j': { - 20: 2 +class TestApplyVocab: + def test_ok(self): + dat = Dataset([{ + 'w': 'a', + 'ws': ['a', 'b'], + 'cs': [['a', 'b'], ['b', 'a']], + 'i': 10, + 'j': 20 + }, { + 'w': 'b', + 'ws': ['a', 'a'], + 'cs': [['b', 'b'], ['b', 'a']], + 'i': 10, + 'j': 20 + }]) + vocab = { + 'w': { + 'a': 0, + 'b': 1 + }, + 'ws': { + 'a': 2, + 'b': 3 + }, + 'cs': { + 'a': 4, + 'b': 5 + }, + 'j': { + 20: 2 + } } - } - dat = dat.apply_vocab(vocab) - assert list(dat) == [{ - 'w': 0, - 'ws': [2, 3], - 'cs': [[4, 5], [5, 4]], - 'i': 10, - 'j': 2 - }, { - 'w': 1, - 'ws': [2, 2], - 'cs': [[5, 5], [5, 4]], - 'i': 10, - 'j': 2 - }] - - -def test_apply_vocab_key_error(): - dat = Dataset([{'w': 'a'}]) - vocab = {'w': {'b': 0}} - with pytest.raises(KeyError) as exc: - dat.apply_vocab(vocab) - assert "value 'a' not found in vocab" in str(exc.value) - - dat = Dataset([{'w': 10}]) - vocab = {'w': {11: 0}} - with pytest.raises(KeyError) as exc: - dat.apply_vocab(vocab) - assert "value 10 not found in vocab" in str(exc.value) - - -def test_apply_vocab_with_vocab_object(): - dat = Dataset([{ - 'ws': ['a', 'b'], - 'cs': [['a', 'c'], ['c', 'b', 'c']] - }, { - 'ws': ['b'], - 'cs': [['b']] - }]) - v = Vocab.from_samples(dat) - dat = dat.apply_vocab(v) - assert list(dat) == [{ - 'ws': [v['ws']['a'], v['ws']['b']], - 'cs': [[v['cs']['a'], v['cs']['c']], [v['cs']['c'], v['cs']['b'], v['cs']['c']]] - }, { - 'ws': [v['ws']['b']], - 'cs': [[v['cs']['b']]] - }] + dat = dat.apply_vocab(vocab) + assert list(dat) == [{ + 'w': 0, + 'ws': [2, 3], + 'cs': [[4, 5], [5, 4]], + 'i': 10, + 'j': 2 + }, { + 'w': 1, + 'ws': [2, 2], + 'cs': [[5, 5], [5, 4]], + 'i': 10, + 'j': 2 + }] + + def test_key_error(self): + dat = Dataset([{'w': 'a'}]) + vocab = {'w': {'b': 0}} + with pytest.raises(KeyError) as exc: + dat.apply_vocab(vocab) + assert "value 'a' not found in vocab" in str(exc.value) + + dat = Dataset([{'w': 10}]) + vocab = {'w': {11: 0}} + with pytest.raises(KeyError) as exc: + dat.apply_vocab(vocab) + assert "value 10 not found in vocab" in str(exc.value) + + def test_with_vocab_object(self): + dat = Dataset([{ + 'ws': ['a', 'b'], + 'cs': [['a', 'c'], ['c', 'b', 'c']] + }, { + 'ws': ['b'], + 'cs': [['b']] + }]) + v = Vocab.from_samples(dat) + dat = dat.apply_vocab(v) + assert list(dat) == [{ + 'ws': [v['ws']['a'], v['ws']['b']], + 'cs': [[v['cs']['a'], v['cs']['c']], [v['cs']['c'], v['cs']['b'], v['cs']['c']]] + }, { + 'ws': [v['ws']['b']], + 'cs': [[v['cs']['b']]] + }] From 58ed68eb65db0dbe9ea942097a15e94fcd04ea40 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 05:21:49 +0700 Subject: [PATCH 106/162] Refactor and make Dataset.apply_vocab() modify the dataset --- tests/test_dataset.py | 4 ++-- text2array/datasets.py | 24 +++++++++++++----------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d1f47cd..08116c7 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -153,7 +153,7 @@ def test_ok(self): 20: 2 } } - dat = dat.apply_vocab(vocab) + dat.apply_vocab(vocab) assert list(dat) == [{ 'w': 0, 'ws': [2, 3], @@ -190,7 +190,7 @@ def test_with_vocab_object(self): 'cs': [['b']] }]) v = Vocab.from_samples(dat) - dat = dat.apply_vocab(v) + dat.apply_vocab(v) assert list(dat) == [{ 'ws': [v['ws']['a'], v['ws']['b']], 'cs': [[v['cs']['a'], v['cs']['c']], [v['cs']['c'], v['cs']['b'], v['cs']['c']]] diff --git a/text2array/datasets.py b/text2array/datasets.py index e9f7585..22ef765 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -114,7 +114,7 @@ def batch(self, batch_size: int) -> Iterator[Batch]: def apply_vocab( self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]], - ) -> 'Dataset': + ) -> None: """Apply a vocabulary to this dataset. Applying a vocabulary means mapping all the (nested) field values to the corresponding @@ -123,17 +123,19 @@ def apply_vocab( Args: vocab: Vocabulary to apply. - - Returns: - Dataset after application. """ - samples = [] - for s in self._samples: - s_ = {} - for name, val in s.items(): - s_[name] = self._apply(vocab[name], val) if name in vocab else val - samples.append(s_) - return Dataset(samples) + ss = [] + for s_ in self._samples: + s = {} + for name, val in s_.items(): + try: + vb = vocab[name] + except KeyError: + s[name] = val + else: + s[name] = self._apply(vb, val) + ss.append(s) + self._samples = ss def _shuffle_inplace(self) -> None: assert isinstance(self._samples, MutableSequenceABC) From 083ba6a4089926704bb4baaae0e0dec71a1127db Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 06:03:15 +0700 Subject: [PATCH 107/162] Do apply vocab inplace if possible --- tests/test_dataset.py | 15 +++++++++++++++ text2array/datasets.py | 33 +++++++++++++++++++++------------ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 08116c7..d42226b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -198,3 +198,18 @@ def test_with_vocab_object(self): 'ws': [v['ws']['b']], 'cs': [[v['cs']['b']]] }] + + def test_immutable_seq(self): + ss = [{ + 'ws': ['a', 'b'], + 'cs': [['a', 'c'], ['c', 'b', 'c']] + }, { + 'ws': ['b'], + 'cs': [['b']] + }] + lstdat = Dataset(ss) + tpldat = Dataset(tuple(ss)) + v = Vocab.from_samples(ss) + lstdat.apply_vocab(v) + tpldat.apply_vocab(v) + assert list(lstdat) == list(tpldat) diff --git a/text2array/datasets.py b/text2array/datasets.py index 22ef765..c1d495f 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -124,18 +124,9 @@ def apply_vocab( Args: vocab: Vocabulary to apply. """ - ss = [] - for s_ in self._samples: - s = {} - for name, val in s_.items(): - try: - vb = vocab[name] - except KeyError: - s[name] = val - else: - s[name] = self._apply(vb, val) - ss.append(s) - self._samples = ss + if not isinstance(self._samples, MutableSequenceABC): + self._samples = list(self._samples) + self._apply_vocab_inplace(vocab) def _shuffle_inplace(self) -> None: assert isinstance(self._samples, MutableSequenceABC) @@ -150,6 +141,24 @@ def _shuffle_copy(self) -> None: self._samples = list(self._samples) self._shuffle_inplace() + def _apply_vocab_inplace( + self, + vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]], + ) -> None: + assert isinstance(self._samples, MutableSequenceABC) + + for i in range(len(self._samples)): + s = {} + for name, val in self._samples[i].items(): + try: + vb = vocab[name] + except KeyError: + s[name] = val + else: + s[name] = self._apply(vb, self._samples[i][name]) + + self._samples[i] = s + @classmethod def _apply(cls, vocab: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: if isinstance(val, str) or not isinstance(val, SequenceABC): From 085c884e1feb1e524c08a061f8a0d3c97e83ac29 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 06:04:30 +0700 Subject: [PATCH 108/162] Remove unnecessary Dataset._shuffle_copy() method --- text2array/datasets.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/text2array/datasets.py b/text2array/datasets.py index c1d495f..6cda107 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -61,10 +61,9 @@ def shuffle(self) -> 'Dataset': Returns: The dataset object itself (useful for chaining). """ - if isinstance(self._samples, MutableSequenceABC): - self._shuffle_inplace() - else: - self._shuffle_copy() + if not isinstance(self._samples, MutableSequenceABC): + self._samples = list(self._samples) + self._shuffle_inplace() return self def shuffle_by(self, key: Callable[[Sample], int], scale: float = 1.) -> 'Dataset': @@ -137,10 +136,6 @@ def _shuffle_inplace(self) -> None: self._samples[i] = self._samples[j] self._samples[j] = temp - def _shuffle_copy(self) -> None: - self._samples = list(self._samples) - self._shuffle_inplace() - def _apply_vocab_inplace( self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]], From aa667c2e0e65384b8ba0f354d466af68f98ff422 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 06:05:25 +0700 Subject: [PATCH 109/162] Refactor a bit --- text2array/datasets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/text2array/datasets.py b/text2array/datasets.py index 6cda107..4bb57e1 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -155,14 +155,14 @@ def _apply_vocab_inplace( self._samples[i] = s @classmethod - def _apply(cls, vocab: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: + def _apply(cls, vb: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: if isinstance(val, str) or not isinstance(val, SequenceABC): try: - return vocab[val] + return vb[val] except KeyError: raise KeyError(f'value {val!r} not found in vocab') - return [cls._apply(vocab, v) for v in val] + return [cls._apply(vb, v) for v in val] class StreamDataset(DatasetABC): From 63c910e01de641da3a09419c36478d40575ca8ba Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 06:13:39 +0700 Subject: [PATCH 110/162] Update docstring --- text2array/datasets.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/text2array/datasets.py b/text2array/datasets.py index 4bb57e1..3262a8f 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -118,7 +118,9 @@ def apply_vocab( Applying a vocabulary means mapping all the (nested) field values to the corresponding values according to the mapping specified by the vocabulary. Field names that have - no entry in the vocabulary are ignored. + no entry in the vocabulary are ignored. This method applies the vocabulary in-place + when the dataset holds a mutable sequence of samples. Otherwise, a mutable copy of + samples is made. Args: vocab: Vocabulary to apply. From 22493d87070b35216c5f5fbcda669ecb1c01419d Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 06:37:01 +0700 Subject: [PATCH 111/162] Implement StreamDataset.apply_vocab() --- tests/conftest.py | 5 +++ tests/test_stream_dataset.py | 85 +++++++++++++++++++++++++++++++++++- text2array/datasets.py | 43 +++++++++++++++++- 3 files changed, 131 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 150dcad..9465082 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,6 +29,11 @@ def stream_dataset(stream): return StreamDataset(stream) +@pytest.fixture +def stream_cls(): + return Stream + + class Stream: def __init__(self, samples): self.samples = samples diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index 5dc8ee2..e12f745 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -2,7 +2,7 @@ import pytest -from text2array import Batch, StreamDataset +from text2array import Batch, StreamDataset, Vocab def test_init(stream): @@ -64,3 +64,86 @@ def test_batch_nonpositive_batch_size(stream_dataset): with pytest.raises(ValueError) as exc: next(stream_dataset.batch_exactly(0)) assert 'batch size must be greater than 0' in str(exc.value) + + +class TestApplyVocab: + def test_ok(self, stream_cls): + dat = StreamDataset( + stream_cls([{ + 'w': 'a', + 'ws': ['a', 'b'], + 'cs': [['a', 'b'], ['b', 'a']], + 'i': 10, + 'j': 20 + }, { + 'w': 'b', + 'ws': ['a', 'a'], + 'cs': [['b', 'b'], ['b', 'a']], + 'i': 10, + 'j': 20 + }])) + vocab = { + 'w': { + 'a': 0, + 'b': 1 + }, + 'ws': { + 'a': 2, + 'b': 3 + }, + 'cs': { + 'a': 4, + 'b': 5 + }, + 'j': { + 20: 2 + } + } + dat.apply_vocab(vocab) + assert list(dat) == [{ + 'w': 0, + 'ws': [2, 3], + 'cs': [[4, 5], [5, 4]], + 'i': 10, + 'j': 2 + }, { + 'w': 1, + 'ws': [2, 2], + 'cs': [[5, 5], [5, 4]], + 'i': 10, + 'j': 2 + }] + + def test_key_error(self, stream_cls): + dat = StreamDataset(stream_cls([{'w': 'a'}])) + vocab = {'w': {'b': 0}} + dat.apply_vocab(vocab) + with pytest.raises(KeyError) as exc: + list(dat) + assert "value 'a' not found in vocab" in str(exc.value) + + dat = StreamDataset(stream_cls([{'w': 10}])) + vocab = {'w': {11: 0}} + dat.apply_vocab(vocab) + with pytest.raises(KeyError) as exc: + list(dat) + assert "value 10 not found in vocab" in str(exc.value) + + def test_with_vocab_object(self, stream_cls): + dat = StreamDataset( + stream_cls([{ + 'ws': ['a', 'b'], + 'cs': [['a', 'c'], ['c', 'b', 'c']] + }, { + 'ws': ['b'], + 'cs': [['b']] + }])) + v = Vocab.from_samples(dat) + dat.apply_vocab(v) + assert list(dat) == [{ + 'ws': [v['ws']['a'], v['ws']['b']], + 'cs': [[v['cs']['a'], v['cs']['c']], [v['cs']['c'], v['cs']['b'], v['cs']['c']]] + }, { + 'ws': [v['ws']['b']], + 'cs': [[v['cs']['b']]] + }] diff --git a/text2array/datasets.py b/text2array/datasets.py index 3262a8f..4227b95 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -181,7 +181,23 @@ def __init__(self, stream: Iterable[Sample]) -> None: self._stream = stream def __iter__(self) -> Iterator[Sample]: - return iter(self._stream) + try: + vocab = self._vocab + except AttributeError: + yield from iter(self._stream) + return + + for s_ in self._stream: + # TODO these lines occur in Dataset as well, refactor? + s = {} + for name, val in s_.items(): + try: + vb = vocab[name] + except KeyError: + s[name] = val + else: + s[name] = self._apply(vb, val) + yield s def batch(self, batch_size: int) -> Iterator[Batch]: """Group the samples in the dataset into batches. @@ -205,3 +221,28 @@ def batch(self, batch_size: int) -> Iterator[Batch]: exhausted = True if batch: yield Batch(batch) + + def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]]) -> None: + """Apply a vocabulary to this dataset. + + Applying a vocabulary means mapping all the (nested) field values to the corresponding + values according to the mapping specified by the vocabulary. Field names that have + no entry in the vocabulary are ignored. Note that since the dataset holds a stream of + samples, the actual application is delayed until the dataset is iterated. Therefore, + ``vocab`` must still exist when that happens. + + Args: + vocab: Vocabulary to apply. + """ + self._vocab = vocab + + # TODO put this in DatasetABC maybe? + @classmethod + def _apply(cls, vb: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: + if isinstance(val, str) or not isinstance(val, SequenceABC): + try: + return vb[val] + except KeyError: + raise KeyError(f'value {val!r} not found in vocab') + + return [cls._apply(vb, v) for v in val] From b171f5959a43b0e7058067317693c90afc323511 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 06:43:30 +0700 Subject: [PATCH 112/162] Move _apply() classmethod to DatasetABC._app_vb_to_val() --- text2array/datasets.py | 45 ++++++++++++++++-------------------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/text2array/datasets.py b/text2array/datasets.py index 4227b95..e13d1a8 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -28,8 +28,21 @@ def batch_exactly(self, batch_size: int) -> Iterator[Batch]: """ return (b for b in self.batch(batch_size) if len(b) == batch_size) + @abc.abstractmethod + def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]]) -> None: + pass + + @classmethod + def _app_vb_to_val(cls, vb: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: + if isinstance(val, str) or not isinstance(val, SequenceABC): + try: + return vb[val] + except KeyError: + raise KeyError(f'value {val!r} not found in vocab') + + return [cls._app_vb_to_val(vb, v) for v in val] + -# TODO implement mapping with a vocab class Dataset(DatasetABC, Sequence[Sample]): """A dataset that fits in memory (no streaming). @@ -110,10 +123,7 @@ def batch(self, batch_size: int) -> Iterator[Batch]: end = begin + batch_size yield Batch(self._samples[begin:end]) - def apply_vocab( - self, - vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]], - ) -> None: + def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]]) -> None: """Apply a vocabulary to this dataset. Applying a vocabulary means mapping all the (nested) field values to the corresponding @@ -152,20 +162,10 @@ def _apply_vocab_inplace( except KeyError: s[name] = val else: - s[name] = self._apply(vb, self._samples[i][name]) + s[name] = self._app_vb_to_val(vb, val) self._samples[i] = s - @classmethod - def _apply(cls, vb: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: - if isinstance(val, str) or not isinstance(val, SequenceABC): - try: - return vb[val] - except KeyError: - raise KeyError(f'value {val!r} not found in vocab') - - return [cls._apply(vb, v) for v in val] - class StreamDataset(DatasetABC): """A dataset that streams its samples. @@ -196,7 +196,7 @@ def __iter__(self) -> Iterator[Sample]: except KeyError: s[name] = val else: - s[name] = self._apply(vb, val) + s[name] = self._app_vb_to_val(vb, val) yield s def batch(self, batch_size: int) -> Iterator[Batch]: @@ -235,14 +235,3 @@ def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]] vocab: Vocabulary to apply. """ self._vocab = vocab - - # TODO put this in DatasetABC maybe? - @classmethod - def _apply(cls, vb: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: - if isinstance(val, str) or not isinstance(val, SequenceABC): - try: - return vb[val] - except KeyError: - raise KeyError(f'value {val!r} not found in vocab') - - return [cls._apply(vb, v) for v in val] From 5d5454bdd7be8535f371460aa82b82ebdbafd50c Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 06:53:27 +0700 Subject: [PATCH 113/162] Refactor applying vocab to a sample --- text2array/datasets.py | 41 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/text2array/datasets.py b/text2array/datasets.py index e13d1a8..d87548d 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -32,6 +32,22 @@ def batch_exactly(self, batch_size: int) -> Iterator[Batch]: def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]]) -> None: pass + @classmethod + def _app_vocab_to_sample( + cls, + vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]], + sample: Sample, + ) -> Sample: + s = {} + for name, val in sample.items(): + try: + vb = vocab[name] + except KeyError: + s[name] = val + else: + s[name] = cls._app_vb_to_val(vb, val) + return s + @classmethod def _app_vb_to_val(cls, vb: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: if isinstance(val, str) or not isinstance(val, SequenceABC): @@ -153,18 +169,8 @@ def _apply_vocab_inplace( vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]], ) -> None: assert isinstance(self._samples, MutableSequenceABC) - for i in range(len(self._samples)): - s = {} - for name, val in self._samples[i].items(): - try: - vb = vocab[name] - except KeyError: - s[name] = val - else: - s[name] = self._app_vb_to_val(vb, val) - - self._samples[i] = s + self._samples[i] = self._app_vocab_to_sample(vocab, self._samples[i]) class StreamDataset(DatasetABC): @@ -187,17 +193,8 @@ def __iter__(self) -> Iterator[Sample]: yield from iter(self._stream) return - for s_ in self._stream: - # TODO these lines occur in Dataset as well, refactor? - s = {} - for name, val in s_.items(): - try: - vb = vocab[name] - except KeyError: - s[name] = val - else: - s[name] = self._app_vb_to_val(vb, val) - yield s + for s in self._stream: + yield self._app_vocab_to_sample(vocab, s) def batch(self, batch_size: int) -> Iterator[Batch]: """Group the samples in the dataset into batches. From ac7ad66d6b5f45c8a71e511b77993d66f733e344 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 06:58:53 +0700 Subject: [PATCH 114/162] Refactor name a bit --- text2array/datasets.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/text2array/datasets.py b/text2array/datasets.py index d87548d..4480632 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -33,7 +33,7 @@ def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]] pass @classmethod - def _app_vocab_to_sample( + def _apply_vocab_to_sample( cls, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]], sample: Sample, @@ -45,18 +45,22 @@ def _app_vocab_to_sample( except KeyError: s[name] = val else: - s[name] = cls._app_vb_to_val(vb, val) + s[name] = cls._apply_vb_to_val(vb, val) return s @classmethod - def _app_vb_to_val(cls, vb: Mapping[FieldValue, FieldValue], val: FieldValue) -> FieldValue: + def _apply_vb_to_val( + cls, + vb: Mapping[FieldValue, FieldValue], + val: FieldValue, + ) -> FieldValue: if isinstance(val, str) or not isinstance(val, SequenceABC): try: return vb[val] except KeyError: raise KeyError(f'value {val!r} not found in vocab') - return [cls._app_vb_to_val(vb, v) for v in val] + return [cls._apply_vb_to_val(vb, v) for v in val] class Dataset(DatasetABC, Sequence[Sample]): @@ -170,7 +174,7 @@ def _apply_vocab_inplace( ) -> None: assert isinstance(self._samples, MutableSequenceABC) for i in range(len(self._samples)): - self._samples[i] = self._app_vocab_to_sample(vocab, self._samples[i]) + self._samples[i] = self._apply_vocab_to_sample(vocab, self._samples[i]) class StreamDataset(DatasetABC): @@ -194,7 +198,7 @@ def __iter__(self) -> Iterator[Sample]: return for s in self._stream: - yield self._app_vocab_to_sample(vocab, s) + yield self._apply_vocab_to_sample(vocab, s) def batch(self, batch_size: int) -> Iterator[Batch]: """Group the samples in the dataset into batches. From faddc9826e2c0f993bccd90f400bbc21782e3281 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 07:02:30 +0700 Subject: [PATCH 115/162] Add a todo --- text2array/vocab.py | 1 + 1 file changed, 1 insertion(+) diff --git a/text2array/vocab.py b/text2array/vocab.py index 884a9a5..02ca44d 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -69,6 +69,7 @@ def from_samples( Returns: Vocabulary instance. """ + # TODO don't waste the first try: first = cls._head(samples) except StopIteration: From 67e4e0fe00c157a0ad9c77aa3aad11aab064d8a9 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 07:05:24 +0700 Subject: [PATCH 116/162] Update a todo --- text2array/batches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text2array/batches.py b/text2array/batches.py index 9c75c62..5592482 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -23,7 +23,7 @@ def __getitem__(self, index) -> Sample: def __len__(self) -> int: return len(self._samples) - # TODO rethink if this needs to be public + # TODO make this private def get(self, name: str) -> Sequence[FieldValue]: try: return [s[name] for s in self._samples] From bb3dd8889db7febeb4b8c0a1f0fc97ed41a89dab Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 20:50:12 +0700 Subject: [PATCH 117/162] Make Batch.get() private --- tests/test_batch.py | 20 +------------------- text2array/batches.py | 5 ++--- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index 1769f43..9c6b360 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -19,24 +19,6 @@ def batch(samples): return Batch(samples) -def test_get(batch): - assert isinstance(batch.get('i'), Sequence) - assert len(batch.get('i')) == len(batch) - for i in range(len(batch)): - assert batch.get('i')[i] == batch[i]['i'] - - assert isinstance(batch.get('f'), Sequence) - assert len(batch.get('f')) == len(batch) - for i in range(len(batch)): - assert batch.get('f')[i] == pytest.approx(batch[i]['f']) - - -def test_get_invalid_name(batch): - with pytest.raises(AttributeError) as exc: - batch.get('foo') - assert "some samples have no field 'foo'" in str(exc.value) - - class TestToArray: def test_ok(self, batch): arr = batch.to_array() @@ -44,7 +26,7 @@ def test_ok(self, batch): assert isinstance(arr['i'], np.ndarray) assert arr['i'].shape == (len(batch), ) - assert arr['i'].tolist() == list(batch.get('i')) + assert arr['i'].tolist() == [s['i'] for s in batch] assert isinstance(arr['f'], np.ndarray) assert arr['f'].shape == (len(batch), ) diff --git a/text2array/batches.py b/text2array/batches.py index 5592482..391b396 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -23,8 +23,7 @@ def __getitem__(self, index) -> Sample: def __len__(self) -> int: return len(self._samples) - # TODO make this private - def get(self, name: str) -> Sequence[FieldValue]: + def _get(self, name: str) -> Sequence[FieldValue]: try: return [s[name] for s in self._samples] except KeyError: @@ -54,7 +53,7 @@ def to_array(self, pad_with: int = 0) -> Mapping[FieldName, np.ndarray]: arr = {} for name in common: - data = self.get(name) + data = self._get(name) # Get max length for all depths, 1st elem is batch size maxlens = self._get_maxlens(data) # Get padding for all depths From 8557fe39a8ceb55c61a3113e75a05b8ce3fb534f Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 20:54:55 +0700 Subject: [PATCH 118/162] Assume all samples have the same fields --- tests/test_batch.py | 13 ++++--------- text2array/batches.py | 15 +++------------ 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index 9c6b360..8f1ab36 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -33,6 +33,10 @@ def test_ok(self, batch): for i in range(len(batch)): assert arr['f'][i] == pytest.approx(batch[i]['f']) + def test_empty(self): + b = Batch([]) + assert not b.to_array() + def test_seq(self): ss = [{'is': [1, 2]}, {'is': [1]}, {'is': [1, 2, 3]}, {'is': [1, 2]}] b = Batch(ss) @@ -195,12 +199,3 @@ def test_custom_padding(self): b = Batch(ss) arr = b.to_array(pad_with=9) assert arr['iss'].tolist() == [[[1, 2], [1, 9]], [[1, 9], [9, 9]]] - - def test_no_common_field_names(self, samples): - samples_ = list(samples) - samples_.append({'foo': 10}) - batch = Batch(samples_) - - with pytest.raises(RuntimeError) as exc: - batch.to_array() - assert 'some samples have no common field names with the others' in str(exc.value) diff --git a/text2array/batches.py b/text2array/batches.py index 391b396..8676f32 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -39,20 +39,11 @@ def to_array(self, pad_with: int = 0) -> Mapping[FieldName, np.ndarray]: A mapping from field names to :class:`np.ndarray`s whose first dimension corresponds to the batch size as returned by ``__len__``. """ - # TODO just assume all samples have the same keys - common: Set[FieldName] = set() - for s in self._samples: - if common: - common.intersection_update(s) - else: - common = set(s) - - if not common: - raise RuntimeError('some samples have no common field names with the others') - assert self._samples # if `common` isn't empty, neither is `_samples` + if not self._samples: + return {} arr = {} - for name in common: + for name in self._samples[0].keys(): data = self._get(name) # Get max length for all depths, 1st elem is batch size maxlens = self._get_maxlens(data) From e668913091a4693259d050314485a08cca3af87a Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 20:58:30 +0700 Subject: [PATCH 119/162] Refactor a bit and add more test --- tests/test_batch.py | 6 ++++++ text2array/batches.py | 14 +++++++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index 8f1ab36..7c5bc6b 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -199,3 +199,9 @@ def test_custom_padding(self): b = Batch(ss) arr = b.to_array(pad_with=9) assert arr['iss'].tolist() == [[[1, 2], [1, 9]], [[1, 9], [9, 9]]] + + def test_missing_field(self): + b = Batch([{'a': 10}, {'b': 20}]) + with pytest.raises(KeyError) as exc: + b.to_array() + assert "some samples have no field 'a'" in str(exc.value) diff --git a/text2array/batches.py b/text2array/batches.py index 8676f32..375753a 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -23,12 +23,6 @@ def __getitem__(self, index) -> Sample: def __len__(self) -> int: return len(self._samples) - def _get(self, name: str) -> Sequence[FieldValue]: - try: - return [s[name] for s in self._samples] - except KeyError: - raise AttributeError(f"some samples have no field '{name}'") - def to_array(self, pad_with: int = 0) -> Mapping[FieldName, np.ndarray]: """Convert the batch into :class:`np.ndarray`. @@ -44,7 +38,7 @@ def to_array(self, pad_with: int = 0) -> Mapping[FieldName, np.ndarray]: arr = {} for name in self._samples[0].keys(): - data = self._get(name) + data = self._get_values(name) # Get max length for all depths, 1st elem is batch size maxlens = self._get_maxlens(data) # Get padding for all depths @@ -56,6 +50,12 @@ def to_array(self, pad_with: int = 0) -> Mapping[FieldName, np.ndarray]: return arr + def _get_values(self, name: str) -> Sequence[FieldValue]: + try: + return [s[name] for s in self._samples] + except KeyError: + raise KeyError(f"some samples have no field '{name}'") + @classmethod def _get_maxlens(cls, data: Sequence[Any]) -> List[int]: assert data From 564721c7f04a1899b6c8a640bb18f6fc0d81a27e Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 21:00:40 +0700 Subject: [PATCH 120/162] Delete unused import --- text2array/batches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text2array/batches.py b/text2array/batches.py index 375753a..72c8392 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -1,6 +1,6 @@ from collections.abc import Sequence as SequenceABC from functools import reduce -from typing import Any, List, Mapping, Sequence, Set, Union +from typing import Any, List, Mapping, Sequence, Union import numpy as np From 8500181831f98c173a33bb8ebcda1ed54bdca27c Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 21:34:20 +0700 Subject: [PATCH 121/162] Add contains test for vocab --- tests/test_vocab.py | 2 ++ text2array/vocab.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index d871131..2fff116 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -21,7 +21,9 @@ def test_ok(self): for i, w in enumerate(itos): assert vocab['w'][w] == i + assert 'foo' not in vocab['w'] assert vocab['w']['foo'] == vocab['w'][''] + assert 'bar' not in vocab['w'] assert vocab['w']['bar'] == vocab['w'][''] def test_has_vocab_for_all_str_fields(self): diff --git a/text2array/vocab.py b/text2array/vocab.py index 02ca44d..d3c4174 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -135,6 +135,9 @@ def __getitem__(self, s: str) -> int: return self._unk_id raise KeyError(f"'{s}' not found in vocabulary") + def __contains__(self, s) -> bool: + return s in self._m + @classmethod def _from_iterable( cls, From 74807b441ce0f1a9f4fd432bc8523432a4e61e8d Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 21:35:04 +0700 Subject: [PATCH 122/162] Add a todo --- tests/test_vocab.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 2fff116..2a57f2f 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -6,6 +6,7 @@ class TestFromSamples(): + # TODO check all possible methods of a mapping? def test_ok(self): ss = [{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}] vocab = Vocab.from_samples(ss) From 4e6f0f4531a6ace93fff38b8f988747be1879a5e Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 22:26:20 +0700 Subject: [PATCH 123/162] Make Vocab.from_samples() accept iterator --- tests/test_vocab.py | 5 +++ text2array/vocab.py | 81 ++++++++++++++++++--------------------------- 2 files changed, 38 insertions(+), 48 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 2a57f2f..8361abe 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -174,3 +174,8 @@ def test_max_size(self): vocab = Vocab.from_samples(ss, options={'w': dict(max_size=1)}) assert list(vocab['w']) == ['', '', 'c'] assert list(vocab['t']) == ['', '', 'c', 'b'] + + def test_iterator_is_passed(self): + ss = [{'ws': ['a', 'a']}, {'ws': ['b', 'b']}] + vocab = Vocab.from_samples(iter(ss)) + assert 'a' in vocab['ws'] diff --git a/text2array/vocab.py b/text2array/vocab.py index d3c4174..a2bfd7d 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -1,6 +1,6 @@ -from collections import Counter, OrderedDict +from collections import Counter, OrderedDict, defaultdict from collections.abc import Sequence as SequenceABC -from typing import Iterable, Iterator, Mapping, MutableMapping, Optional +from typing import Counter as CounterT, Dict, Iterable, Iterator, Mapping, Optional from .samples import FieldName, FieldValue, Sample @@ -48,7 +48,7 @@ def from_samples( Args: samples: Iterable of samples. options: Mapping from field names to dictionaries to control the creation of - the str-to-int mapping. Allowed dictionary keys are: + the str-to-int mapping. Recognized dictionary keys are: * ``min_count``(:obj:`int`): Exclude tokens occurring fewer than this number of times from the vocabulary (default: 2). @@ -69,30 +69,40 @@ def from_samples( Returns: Vocabulary instance. """ - # TODO don't waste the first - try: - first = cls._head(samples) - except StopIteration: - return cls({}) - if options is None: options = {} - m = { - name: _StringStore._from_iterable( - cls._flatten(cls._get_values(samples, name)), **options.get(name, {})) - for name, value in first.items() - if cls._needs_vocab(value) - } - return cls(m) - - @staticmethod - def _head(x: Iterable[Sample]) -> Sample: - return next(iter(x)) + counter: Dict[FieldName, CounterT[str]] = defaultdict(Counter) + for s in samples: + for name, value in s.items(): + if cls._needs_vocab(value): + counter[name].update(cls._flatten(value)) + + m = {} + for name, c in counter.items(): + store: dict = OrderedDict() + opts = options.get(name, {}) + + # Padding and unknown tokens + pad = opts.get('pad', '') + unk = opts.get('unk', '') + if pad is not None: + store[pad] = len(store) + if unk is not None: + store[unk] = len(store) + + min_count = opts.get('min_count', 2) + max_size = opts.get('max_size') + n = len(store) + for tok, freq in c.most_common(): + if freq < min_count or (max_size is not None and len(store) - n >= max_size): + break + store[tok] = len(store) + + unk_id = None if unk is None else store[unk] + m[name] = _StringStore(store, unk_id=unk_id) - @staticmethod - def _get_values(ss: Iterable[Sample], name: FieldName) -> Iterator[FieldValue]: - return (s[name] for s in ss) + return cls(m) @classmethod def _needs_vocab(cls, val: FieldValue) -> bool: @@ -137,28 +147,3 @@ def __getitem__(self, s: str) -> int: def __contains__(self, s) -> bool: return s in self._m - - @classmethod - def _from_iterable( - cls, - iterable: Iterable[str], - min_count: int = 2, - unk: Optional[str] = '', - pad: Optional[str] = '', - max_size: Optional[int] = None, - ) -> '_StringStore': - stoi: MutableMapping[str, int] = OrderedDict() - if pad is not None: - stoi[pad] = len(stoi) - if unk is not None: - stoi[unk] = len(stoi) - - n = len(stoi) - c = Counter(iterable) - for s, f in c.most_common(): - if f < min_count or (max_size is not None and len(stoi) - n >= max_size): - break - stoi[s] = len(stoi) - - unk_id = None if unk is None else stoi[unk] - return cls(stoi, unk_id=unk_id) From 33fbe3864bbc5948c47630848580601364b4aeab Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 22:52:46 +0700 Subject: [PATCH 124/162] Add more tests and refactor a bit --- tests/test_vocab.py | 55 +++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 8361abe..18f182e 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -6,7 +6,6 @@ class TestFromSamples(): - # TODO check all possible methods of a mapping? def test_ok(self): ss = [{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}] vocab = Vocab.from_samples(ss) @@ -14,12 +13,15 @@ def test_ok(self): assert isinstance(vocab, Mapping) assert len(vocab) == 1 assert list(vocab) == ['w'] + with pytest.raises(KeyError): + vocab['ws'] itos = ' c b'.split() assert isinstance(vocab['w'], Mapping) assert len(vocab['w']) == len(itos) assert list(vocab['w']) == itos for i, w in enumerate(itos): + assert w in vocab['w'] assert vocab['w'][w] == i assert 'foo' not in vocab['w'] @@ -30,11 +32,11 @@ def test_ok(self): def test_has_vocab_for_all_str_fields(self): ss = [{'w': 'b', 't': 'b'}, {'w': 'b', 't': 'b'}] vocab = Vocab.from_samples(ss) + assert 'w' in vocab assert 't' in vocab def test_no_vocab_for_non_str(self): vocab = Vocab.from_samples([{'i': 10}, {'i': 20}]) - assert 'i' not in vocab with pytest.raises(KeyError) as exc: vocab['i'] assert "no vocabulary found for field name 'i'" in str(exc.value) @@ -42,13 +44,7 @@ def test_no_vocab_for_non_str(self): def test_seq(self): ss = [{'ws': ['a', 'c', 'c']}, {'ws': ['b', 'c']}, {'ws': ['b']}] vocab = Vocab.from_samples(ss) - - itos = ' c b'.split() - assert isinstance(vocab['ws'], Mapping) - assert len(vocab['ws']) == len(itos) - assert list(vocab['ws']) == itos - for i, w in enumerate(itos): - assert vocab['ws'][w] == i + assert list(vocab['ws']) == ' c b'.split() def test_seq_of_seq(self): ss = [{ @@ -59,17 +55,11 @@ def test_seq_of_seq(self): 'cs': [['d', 'c']] }] vocab = Vocab.from_samples(ss) - - itos = ' d c b'.split() - assert isinstance(vocab['cs'], Mapping) - assert len(vocab['cs']) == len(itos) - assert list(vocab['cs']) == itos - for i, w in enumerate(itos): - assert vocab['cs'][w] == i + assert list(vocab['cs']) == ' d c b'.split() def test_empty_samples(self): vocab = Vocab.from_samples([]) - assert len(vocab) == 0 + assert not vocab def test_empty_field_values(self): with pytest.raises(ValueError) as exc: @@ -97,8 +87,8 @@ def test_min_count(self): 't': 'c' }] vocab = Vocab.from_samples(ss, options={'w': dict(min_count=3)}) - assert list(vocab['w']) == ['', '', 'c'] - assert list(vocab['t']) == ['', '', 'c', 'b'] + assert list(vocab['w']) == ' c'.split() + assert list(vocab['t']) == ' c b'.split() def test_no_unk(self): ss = [{ @@ -121,8 +111,8 @@ def test_no_unk(self): 't': 'c' }] vocab = Vocab.from_samples(ss, options={'w': dict(unk=None)}) - assert list(vocab['w']) == ['', 'c', 'b'] - assert list(vocab['t']) == ['', '', 'c', 'b'] + assert list(vocab['w']) == ' c b'.split() + assert list(vocab['t']) == ' c b'.split() with pytest.raises(KeyError) as exc: vocab['w']['foo'] assert "'foo' not found in vocabulary" in str(exc.value) @@ -148,8 +138,8 @@ def test_no_pad(self): 't': 'c' }] vocab = Vocab.from_samples(ss, options={'w': dict(pad=None)}) - assert list(vocab['w']) == ['', 'c', 'b'] - assert list(vocab['t']) == ['', '', 'c', 'b'] + assert list(vocab['w']) == ' c b'.split() + assert list(vocab['t']) == ' c b'.split() def test_max_size(self): ss = [{ @@ -172,10 +162,21 @@ def test_max_size(self): 't': 'c' }] vocab = Vocab.from_samples(ss, options={'w': dict(max_size=1)}) - assert list(vocab['w']) == ['', '', 'c'] - assert list(vocab['t']) == ['', '', 'c', 'b'] + assert list(vocab['w']) == ' c'.split() + assert list(vocab['t']) == ' c b'.split() def test_iterator_is_passed(self): - ss = [{'ws': ['a', 'a']}, {'ws': ['b', 'b']}] + ss = [{ + 'ws': ['b', 'c'], + 'w': 'c' + }, { + 'ws': ['c', 'b'], + 'w': 'c' + }, { + 'ws': ['c'], + 'w': 'c' + }] vocab = Vocab.from_samples(iter(ss)) - assert 'a' in vocab['ws'] + assert 'b' in vocab['ws'] + assert 'c' in vocab['ws'] + assert 'c' in vocab['w'] From eb82df2df858f1582ed5c7a5b9d73bda2503a956 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Wed, 30 Jan 2019 22:53:46 +0700 Subject: [PATCH 125/162] Add more todos --- tests/test_vocab.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 18f182e..7b7cd0a 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -6,6 +6,7 @@ class TestFromSamples(): + # TODO think about what to test in each test case def test_ok(self): ss = [{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}] vocab = Vocab.from_samples(ss) @@ -16,6 +17,7 @@ def test_ok(self): with pytest.raises(KeyError): vocab['ws'] + # TODO for non-sequential field, maybe no padding? itos = ' c b'.split() assert isinstance(vocab['w'], Mapping) assert len(vocab['w']) == len(itos) From e677c658cb2894c1f62c0d2850f775f805d70415 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 05:31:27 +0700 Subject: [PATCH 126/162] Refactor tests to only check what's necessary --- tests/test_vocab.py | 63 +++++++++------------------------------------ 1 file changed, 12 insertions(+), 51 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 7b7cd0a..3247de2 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -6,7 +6,6 @@ class TestFromSamples(): - # TODO think about what to test in each test case def test_ok(self): ss = [{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}] vocab = Vocab.from_samples(ss) @@ -34,8 +33,8 @@ def test_ok(self): def test_has_vocab_for_all_str_fields(self): ss = [{'w': 'b', 't': 'b'}, {'w': 'b', 't': 'b'}] vocab = Vocab.from_samples(ss) - assert 'w' in vocab - assert 't' in vocab + assert vocab.get('w') is not None + assert vocab.get('t') is not None def test_no_vocab_for_non_str(self): vocab = Vocab.from_samples([{'i': 10}, {'i': 20}]) @@ -89,59 +88,21 @@ def test_min_count(self): 't': 'c' }] vocab = Vocab.from_samples(ss, options={'w': dict(min_count=3)}) - assert list(vocab['w']) == ' c'.split() - assert list(vocab['t']) == ' c b'.split() + assert 'b' not in vocab['w'] + assert 'b' in vocab['t'] def test_no_unk(self): - ss = [{ - 'w': 'c', - 't': 'c' - }, { - 'w': 'b', - 't': 'b' - }, { - 'w': 'a', - 't': 'a' - }, { - 'w': 'b', - 't': 'b' - }, { - 'w': 'c', - 't': 'c' - }, { - 'w': 'c', - 't': 'c' - }] - vocab = Vocab.from_samples(ss, options={'w': dict(unk=None)}) - assert list(vocab['w']) == ' c b'.split() - assert list(vocab['t']) == ' c b'.split() + vocab = Vocab.from_samples([{'w': 'a', 't': 'a'}], options={'w': dict(unk=None)}) + assert '' not in vocab['w'] + assert '' in vocab['t'] with pytest.raises(KeyError) as exc: vocab['w']['foo'] assert "'foo' not found in vocabulary" in str(exc.value) def test_no_pad(self): - ss = [{ - 'w': 'c', - 't': 'c' - }, { - 'w': 'b', - 't': 'b' - }, { - 'w': 'a', - 't': 'a' - }, { - 'w': 'b', - 't': 'b' - }, { - 'w': 'c', - 't': 'c' - }, { - 'w': 'c', - 't': 'c' - }] - vocab = Vocab.from_samples(ss, options={'w': dict(pad=None)}) - assert list(vocab['w']) == ' c b'.split() - assert list(vocab['t']) == ' c b'.split() + vocab = Vocab.from_samples([{'w': 'a', 't': 'a'}], options={'w': dict(pad=None)}) + assert '' not in vocab['w'] + assert '' in vocab['t'] def test_max_size(self): ss = [{ @@ -164,8 +125,8 @@ def test_max_size(self): 't': 'c' }] vocab = Vocab.from_samples(ss, options={'w': dict(max_size=1)}) - assert list(vocab['w']) == ' c'.split() - assert list(vocab['t']) == ' c b'.split() + assert 'b' not in vocab['w'] + assert 'b' in vocab['t'] def test_iterator_is_passed(self): ss = [{ From 7d299e1fa2d4634b0a8624af96d1f84ab6fee18f Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 05:44:44 +0700 Subject: [PATCH 127/162] Add padding only for sequential fields --- tests/test_vocab.py | 5 ++--- text2array/vocab.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 3247de2..3691901 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -16,8 +16,7 @@ def test_ok(self): with pytest.raises(KeyError): vocab['ws'] - # TODO for non-sequential field, maybe no padding? - itos = ' c b'.split() + itos = ' c b'.split() assert isinstance(vocab['w'], Mapping) assert len(vocab['w']) == len(itos) assert list(vocab['w']) == itos @@ -100,7 +99,7 @@ def test_no_unk(self): assert "'foo' not found in vocabulary" in str(exc.value) def test_no_pad(self): - vocab = Vocab.from_samples([{'w': 'a', 't': 'a'}], options={'w': dict(pad=None)}) + vocab = Vocab.from_samples([{'w': ['a'], 't': ['a']}], options={'w': dict(pad=None)}) assert '' not in vocab['w'] assert '' in vocab['t'] diff --git a/text2array/vocab.py b/text2array/vocab.py index a2bfd7d..35b62bc 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -1,6 +1,6 @@ from collections import Counter, OrderedDict, defaultdict from collections.abc import Sequence as SequenceABC -from typing import Counter as CounterT, Dict, Iterable, Iterator, Mapping, Optional +from typing import Counter as CounterT, Dict, Iterable, Iterator, Mapping, Optional, Set from .samples import FieldName, FieldValue, Sample @@ -54,7 +54,9 @@ def from_samples( of times from the vocabulary (default: 2). * ``pad``(:obj:`str`): String token to represent padding tokens. If ``None``, no padding token is added to the vocabulary. Otherwise, it is the - first entry in the vocabulary (index is 0) (default: ````). + first entry in the vocabulary (index is 0). Note that if the field has no + sequential values, no padding is added. String field values are *not* + considered sequential (default: ````). * ``unk``(:obj:`str`): String token to represent unknown tokens with. If ``None``, no unknown token is added to the vocabulary. This means when querying the vocabulary with such token, an error is raised. Otherwise, @@ -73,10 +75,13 @@ def from_samples( options = {} counter: Dict[FieldName, CounterT[str]] = defaultdict(Counter) + seqfield: Set[FieldName] = set() for s in samples: for name, value in s.items(): if cls._needs_vocab(value): counter[name].update(cls._flatten(value)) + if isinstance(value, SequenceABC) and not isinstance(value, str): + seqfield.add(name) m = {} for name, c in counter.items(): @@ -86,7 +91,7 @@ def from_samples( # Padding and unknown tokens pad = opts.get('pad', '') unk = opts.get('unk', '') - if pad is not None: + if name in seqfield and pad is not None: store[pad] = len(store) if unk is not None: store[unk] = len(store) From 7266d61579e51a7a21fa6d3ed6c54fdf136e1145 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 05:46:04 +0700 Subject: [PATCH 128/162] Add a todo --- text2array/vocab.py | 1 + 1 file changed, 1 insertion(+) diff --git a/text2array/vocab.py b/text2array/vocab.py index 35b62bc..a2e3729 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -5,6 +5,7 @@ from .samples import FieldName, FieldValue, Sample +# TODO use typing classes for isinstance check class Vocab(Mapping[FieldName, Mapping[str, int]]): """Namespaced vocabulary storing the mapping from field names to their actual vocabulary. From abd54e8f2931ebbf1add6902536112597a59566b Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 05:53:22 +0700 Subject: [PATCH 129/162] Use typing classes for isinstance check --- text2array/batches.py | 5 ++--- text2array/datasets.py | 18 ++++++++---------- text2array/vocab.py | 9 ++++----- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/text2array/batches.py b/text2array/batches.py index 72c8392..49c2eb6 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -1,4 +1,3 @@ -from collections.abc import Sequence as SequenceABC from functools import reduce from typing import Any, List, Mapping, Sequence, Union @@ -61,7 +60,7 @@ def _get_maxlens(cls, data: Sequence[Any]) -> List[int]: assert data # Base case - if not isinstance(data[0], SequenceABC): + if not isinstance(data[0], Sequence): return [len(data)] # Recursive case @@ -93,7 +92,7 @@ def _pad( assert depth < len(maxlens) # Base case - if not isinstance(data[0], SequenceABC): + if not isinstance(data[0], Sequence): data_ = list(data) # Recursive case else: diff --git a/text2array/datasets.py b/text2array/datasets.py index 4480632..0ec7455 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -1,6 +1,4 @@ -from collections.abc import \ - Iterable as IterableABC, MutableSequence as MutableSequenceABC, Sequence as SequenceABC -from typing import Callable, Iterable, Iterator, Mapping, Sequence +from typing import Callable, Iterable, Iterator, Mapping, MutableSequence, Sequence import abc import random import statistics as stat @@ -54,7 +52,7 @@ def _apply_vb_to_val( vb: Mapping[FieldValue, FieldValue], val: FieldValue, ) -> FieldValue: - if isinstance(val, str) or not isinstance(val, SequenceABC): + if isinstance(val, str) or not isinstance(val, Sequence): try: return vb[val] except KeyError: @@ -73,7 +71,7 @@ class Dataset(DatasetABC, Sequence[Sample]): """ def __init__(self, samples: Sequence[Sample]) -> None: - if not isinstance(samples, SequenceABC): + if not isinstance(samples, Sequence): raise TypeError('"samples" is not a sequence') self._samples = samples @@ -94,7 +92,7 @@ def shuffle(self) -> 'Dataset': Returns: The dataset object itself (useful for chaining). """ - if not isinstance(self._samples, MutableSequenceABC): + if not isinstance(self._samples, MutableSequence): self._samples = list(self._samples) self._shuffle_inplace() return self @@ -155,12 +153,12 @@ def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]] Args: vocab: Vocabulary to apply. """ - if not isinstance(self._samples, MutableSequenceABC): + if not isinstance(self._samples, MutableSequence): self._samples = list(self._samples) self._apply_vocab_inplace(vocab) def _shuffle_inplace(self) -> None: - assert isinstance(self._samples, MutableSequenceABC) + assert isinstance(self._samples, MutableSequence) n = len(self._samples) for i in range(n): j = random.randrange(n) @@ -172,7 +170,7 @@ def _apply_vocab_inplace( self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]], ) -> None: - assert isinstance(self._samples, MutableSequenceABC) + assert isinstance(self._samples, MutableSequence) for i in range(len(self._samples)): self._samples[i] = self._apply_vocab_to_sample(vocab, self._samples[i]) @@ -185,7 +183,7 @@ class StreamDataset(DatasetABC): """ def __init__(self, stream: Iterable[Sample]) -> None: - if not isinstance(stream, IterableABC): + if not isinstance(stream, Iterable): raise TypeError('"stream" is not iterable') self._stream = stream diff --git a/text2array/vocab.py b/text2array/vocab.py index a2e3729..755f3a4 100644 --- a/text2array/vocab.py +++ b/text2array/vocab.py @@ -1,11 +1,10 @@ from collections import Counter, OrderedDict, defaultdict -from collections.abc import Sequence as SequenceABC -from typing import Counter as CounterT, Dict, Iterable, Iterator, Mapping, Optional, Set +from typing import Counter as CounterT, Dict, Iterable, Iterator, Mapping, \ + Optional, Sequence, Set from .samples import FieldName, FieldValue, Sample -# TODO use typing classes for isinstance check class Vocab(Mapping[FieldName, Mapping[str, int]]): """Namespaced vocabulary storing the mapping from field names to their actual vocabulary. @@ -81,7 +80,7 @@ def from_samples( for name, value in s.items(): if cls._needs_vocab(value): counter[name].update(cls._flatten(value)) - if isinstance(value, SequenceABC) and not isinstance(value, str): + if isinstance(value, Sequence) and not isinstance(value, str): seqfield.add(name) m = {} @@ -114,7 +113,7 @@ def from_samples( def _needs_vocab(cls, val: FieldValue) -> bool: if isinstance(val, str): return True - if isinstance(val, SequenceABC): + if isinstance(val, Sequence): if not val: raise ValueError('field values must not be an empty sequence') return cls._needs_vocab(val[0]) From 6a3137d6e0ca4fff76d661be79ed9aff72e8381f Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 05:58:13 +0700 Subject: [PATCH 130/162] Change to class variable --- tests/test_dataset.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d42226b..ac17a27 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -37,27 +37,24 @@ def test_immutable_seq(self, setup_rng, samples): class TestShuffleBy: - # TODO make this a class variable - @staticmethod - def make_dataset(): - return Dataset([{ - 'is': [1, 2, 3] - }, { - 'is': [1] - }, { - 'is': [1, 2] - }, { - 'is': [1, 2, 3, 4, 5] - }, { - 'is': [1, 2, 3, 4] - }]) + dataset = Dataset([{ + 'is': [1, 2, 3] + }, { + 'is': [1] + }, { + 'is': [1, 2] + }, { + 'is': [1, 2, 3, 4, 5] + }, { + 'is': [1, 2, 3, 4] + }]) @staticmethod def key(sample): return len(sample['is']) def test_ok(self, setup_rng): - dat = self.make_dataset() + dat = self.dataset before = list(dat) retval = dat.shuffle_by(self.key) after = list(dat) @@ -65,16 +62,15 @@ def test_ok(self, setup_rng): assert_shuffled(before, after) def test_zero_scale(self, setup_rng): - dat = self.make_dataset() + dat = self.dataset before = list(dat) dat.shuffle_by(self.key, scale=0.) after = list(dat) assert sorted(before, key=self.key) == after def test_negative_scale(self, setup_rng): - dat = self.make_dataset() with pytest.raises(ValueError) as exc: - dat.shuffle_by(self.key, scale=-1) + self.dataset.shuffle_by(self.key, scale=-1) assert 'scale cannot be less than 0' in str(exc.value) From 1b7b5a258a156c1c2942df65c302caf60d9692e6 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 06:22:36 +0700 Subject: [PATCH 131/162] Refactor dataset tests to use fixtures only when makes sense --- tests/test_dataset.py | 50 +++++++++++++----------------------- tests/test_stream_dataset.py | 16 +++++++----- 2 files changed, 27 insertions(+), 39 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ac17a27..c3ecb62 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,4 +1,4 @@ -from collections.abc import Iterator, Sequence +from typing import Iterator, Sequence import pytest @@ -8,7 +8,7 @@ def test_init(samples): dat = Dataset(samples) assert isinstance(dat, Sequence) - assert len(dat) == 5 + assert len(dat) == len(samples) for i in range(len(dat)): assert dat[i] == samples[i] @@ -37,21 +37,11 @@ def test_immutable_seq(self, setup_rng, samples): class TestShuffleBy: - dataset = Dataset([{ - 'is': [1, 2, 3] - }, { - 'is': [1] - }, { - 'is': [1, 2] - }, { - 'is': [1, 2, 3, 4, 5] - }, { - 'is': [1, 2, 3, 4] - }]) + dataset = Dataset([{'i': 3}, {'i': 1}, {'i': 2}, {'i': 5}, {'i': 4}]) @staticmethod def key(sample): - return len(sample['is']) + return sample['i'] def test_ok(self, setup_rng): dat = self.dataset @@ -78,15 +68,16 @@ def assert_shuffled(before, after): assert before != after and len(before) == len(after) and all(x in after for x in before) -def test_batch(dataset): - bs = dataset.batch(2) +def test_batch(): + dat = Dataset([{'i': 3}, {'i': 1}, {'i': 2}, {'i': 5}, {'i': 4}]) + bs = dat.batch(2) assert isinstance(bs, Iterator) bs_lst = list(bs) assert len(bs_lst) == 3 assert all(isinstance(b, Batch) for b in bs_lst) - assert list(bs_lst[0]) == [dataset[0], dataset[1]] - assert list(bs_lst[1]) == [dataset[2], dataset[3]] - assert list(bs_lst[2]) == [dataset[4]] + assert list(bs_lst[0]) == [dat[0], dat[1]] + assert list(bs_lst[1]) == [dat[2], dat[3]] + assert list(bs_lst[2]) == [dat[4]] def test_batch_size_evenly_divides(dataset): @@ -97,14 +88,15 @@ def test_batch_size_evenly_divides(dataset): assert list(bs_lst[i]) == [dataset[i]] -def test_batch_exactly(dataset): - bs = dataset.batch_exactly(2) +def test_batch_exactly(): + dat = Dataset([{'i': 3}, {'i': 1}, {'i': 2}, {'i': 5}, {'i': 4}]) + bs = dat.batch_exactly(2) assert isinstance(bs, Iterator) bs_lst = list(bs) assert len(bs_lst) == 2 assert all(isinstance(b, Batch) for b in bs_lst) - assert list(bs_lst[0]) == [dataset[0], dataset[1]] - assert list(bs_lst[1]) == [dataset[2], dataset[3]] + assert list(bs_lst[0]) == [dat[0], dat[1]] + assert list(bs_lst[1]) == [dat[2], dat[3]] def test_batch_nonpositive_batch_size(dataset): @@ -164,7 +156,7 @@ def test_ok(self): 'j': 2 }] - def test_key_error(self): + def test_value_not_in_vocab(self): dat = Dataset([{'w': 'a'}]) vocab = {'w': {'b': 0}} with pytest.raises(KeyError) as exc: @@ -195,14 +187,8 @@ def test_with_vocab_object(self): 'cs': [[v['cs']['b']]] }] - def test_immutable_seq(self): - ss = [{ - 'ws': ['a', 'b'], - 'cs': [['a', 'c'], ['c', 'b', 'c']] - }, { - 'ws': ['b'], - 'cs': [['b']] - }] + def test_immutable_seq(self, samples): + ss = samples lstdat = Dataset(ss) tpldat = Dataset(tuple(ss)) v = Vocab.from_samples(ss) diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py index e12f745..5da076b 100644 --- a/tests/test_stream_dataset.py +++ b/tests/test_stream_dataset.py @@ -1,4 +1,4 @@ -from collections.abc import Iterable, Iterator +from typing import Iterable, Iterator import pytest @@ -24,13 +24,14 @@ def test_can_be_iterated_twice(stream_dataset): assert len(dat_lst2) > 0 -def test_batch(stream_dataset): - bs = stream_dataset.batch(2) +def test_batch(stream_cls): + stream_dat = StreamDataset(stream_cls([{'i': 3}, {'i': 1}, {'i': 2}, {'i': 5}, {'i': 4}])) + bs = stream_dat.batch(2) assert isinstance(bs, Iterator) bs_lst = list(bs) assert len(bs_lst) == 3 assert all(isinstance(b, Batch) for b in bs_lst) - dat = list(stream_dataset) + dat = list(stream_dat) assert list(bs_lst[0]) == [dat[0], dat[1]] assert list(bs_lst[1]) == [dat[2], dat[3]] assert list(bs_lst[2]) == [dat[4]] @@ -45,13 +46,14 @@ def test_batch_size_evenly_divides(stream_dataset): assert list(bs_lst[i]) == [dat[i]] -def test_batch_exactly(stream_dataset): - bs = stream_dataset.batch_exactly(2) +def test_batch_exactly(stream_cls): + stream_dat = StreamDataset(stream_cls([{'i': 3}, {'i': 1}, {'i': 2}, {'i': 5}, {'i': 4}])) + bs = stream_dat.batch_exactly(2) assert isinstance(bs, Iterator) bs_lst = list(bs) assert len(bs_lst) == 2 assert all(isinstance(b, Batch) for b in bs_lst) - dat = list(stream_dataset) + dat = list(stream_dat) assert list(bs_lst[0]) == [dat[0], dat[1]] assert list(bs_lst[1]) == [dat[2], dat[3]] From 2978228e0a64822d6be801e156201c7c2ea22d99 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 06:33:09 +0700 Subject: [PATCH 132/162] Refactor test_batch.py --- tests/test_batch.py | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index 7c5bc6b..2763739 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping, Sequence +from typing import Mapping, Sequence import numpy as np import pytest @@ -14,24 +14,37 @@ def test_init(samples): assert b[i] == samples[i] -@pytest.fixture -def batch(samples): - return Batch(samples) - - class TestToArray: - def test_ok(self, batch): - arr = batch.to_array() + def test_ok(self): + ss = [{ + 'i': 4, + 'f': 0.67 + }, { + 'i': 2, + 'f': 0.89 + }, { + 'i': 3, + 'f': 0.23 + }, { + 'i': 5, + 'f': 0.11 + }, { + 'i': 3, + 'f': 0.22 + }] + b = Batch(ss) + arr = b.to_array() assert isinstance(arr, Mapping) + assert len(arr) == 2 + assert set(arr) == set(['i', 'f']) assert isinstance(arr['i'], np.ndarray) - assert arr['i'].shape == (len(batch), ) - assert arr['i'].tolist() == [s['i'] for s in batch] + assert arr['i'].shape == (len(b), ) + assert arr['i'].tolist() == [s['i'] for s in b] assert isinstance(arr['f'], np.ndarray) - assert arr['f'].shape == (len(batch), ) - for i in range(len(batch)): - assert arr['f'][i] == pytest.approx(batch[i]['f']) + assert arr['f'].shape == (len(b), ) + assert arr['f'].tolist() == [pytest.approx(s['f']) for s in b] def test_empty(self): b = Batch([]) From f8f69128c3fc21a76b43fd20c92c306a145a8656 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 06:34:14 +0700 Subject: [PATCH 133/162] Use classes from typing instead of collections.abc --- tests/test_vocab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 3691901..71310bb 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from typing import Mapping import pytest From 29eee983c60f69bc76591cc3cbca011d2c1c5815 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 06:52:19 +0700 Subject: [PATCH 134/162] Refactor datasets.py a bit --- text2array/datasets.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/text2array/datasets.py b/text2array/datasets.py index 0ec7455..c47f89c 100644 --- a/text2array/datasets.py +++ b/text2array/datasets.py @@ -62,7 +62,7 @@ def _apply_vb_to_val( class Dataset(DatasetABC, Sequence[Sample]): - """A dataset that fits in memory (no streaming). + """Dataset that fits all in memory (no streaming). Args: samples: Sequence of samples the dataset should contain. This sequence should @@ -90,7 +90,7 @@ def shuffle(self) -> 'Dataset': sequence, so subsequent shuffling will be done in-place. Returns: - The dataset object itself (useful for chaining). + This dataset object (useful for chaining). """ if not isinstance(self._samples, MutableSequence): self._samples = list(self._samples) @@ -115,7 +115,7 @@ def shuffle_by(self, key: Callable[[Sample], int], scale: float = 1.) -> 'Datase scale: Value to regulate the noise of the sorting. Must not be negative. Returns: - The dataset object itself (useful for chaining). + This dataset object (useful for chaining). """ if scale < 0: raise ValueError('scale cannot be less than 0') @@ -148,10 +148,10 @@ def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]] values according to the mapping specified by the vocabulary. Field names that have no entry in the vocabulary are ignored. This method applies the vocabulary in-place when the dataset holds a mutable sequence of samples. Otherwise, a mutable copy of - samples is made. + samples is made and the vocabulary is applied on it. Args: - vocab: Vocabulary to apply. + vocab: The vocabulary to apply. """ if not isinstance(self._samples, MutableSequence): self._samples = list(self._samples) @@ -176,10 +176,10 @@ def _apply_vocab_inplace( class StreamDataset(DatasetABC): - """A dataset that streams its samples. + """Dataset that streams its samples. Args: - stream: Stream of examples the dataset should stream from. + stream: Stream of samples the dataset should stream from. """ def __init__(self, stream: Iterable[Sample]) -> None: @@ -192,7 +192,7 @@ def __iter__(self) -> Iterator[Sample]: try: vocab = self._vocab except AttributeError: - yield from iter(self._stream) + yield from self._stream return for s in self._stream: @@ -231,6 +231,6 @@ def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]] ``vocab`` must still exist when that happens. Args: - vocab: Vocabulary to apply. + vocab: The vocabulary to apply. """ self._vocab = vocab From e29092426af341b29cfd850f375745245ed24aae Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 07:32:46 +0700 Subject: [PATCH 135/162] Handle when nesting depth is inconsistent --- tests/test_batch.py | 6 ++++++ text2array/batches.py | 19 ++++++++++++++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index 2763739..64a388c 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -218,3 +218,9 @@ def test_missing_field(self): with pytest.raises(KeyError) as exc: b.to_array() assert "some samples have no field 'a'" in str(exc.value) + + def test_inconsistent_depth(self): + b = Batch([{'ws': [1, 2]}, {'ws': [[1, 2], [3, 4]]}]) + with pytest.raises(ValueError) as exc: + b.to_array() + assert "field 'ws' has inconsistent nesting depth" in str(exc.value) diff --git a/text2array/batches.py b/text2array/batches.py index 49c2eb6..30a047d 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -37,13 +37,18 @@ def to_array(self, pad_with: int = 0) -> Mapping[FieldName, np.ndarray]: arr = {} for name in self._samples[0].keys(): - data = self._get_values(name) + values = self._get_values(name) + # Get max length for all depths, 1st elem is batch size - maxlens = self._get_maxlens(data) + try: + maxlens = self._get_maxlens(values) + except self._InconsistentDepthError: + raise ValueError(f"field '{name}' has inconsistent nesting depth") + # Get padding for all depths paddings = self._get_paddings(maxlens, pad_with) - # Pad the data - data = self._pad(data, maxlens, paddings, 0) + # Pad the values + data = self._pad(values, maxlens, paddings, 0) arr[name] = np.array(data) @@ -65,7 +70,8 @@ def _get_maxlens(cls, data: Sequence[Any]) -> List[int]: # Recursive case maxlenss = [cls._get_maxlens(x) for x in data] - assert all(len(x) == len(maxlenss[0]) for x in maxlenss) + if not all(len(x) == len(maxlenss[0]) for x in maxlenss): + raise cls._InconsistentDepthError maxlens = reduce(lambda ml1, ml2: [max(l1, l2) for l1, l2 in zip(ml1, ml2)], maxlenss) maxlens.insert(0, len(data)) @@ -101,3 +107,6 @@ def _pad( for _ in range(maxlens[depth] - len(data)): data_.append(paddings[depth]) return data_ + + class _InconsistentDepthError(Exception): + pass From da00bf6a2902c6415611736c9f67402bc60d68a2 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 07:40:14 +0700 Subject: [PATCH 136/162] Handle when some values are strings --- tests/test_batch.py | 6 ++++++ text2array/batches.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index 64a388c..f7e333c 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -224,3 +224,9 @@ def test_inconsistent_depth(self): with pytest.raises(ValueError) as exc: b.to_array() assert "field 'ws' has inconsistent nesting depth" in str(exc.value) + + def test_str(self): + b = Batch([{'w': 'a'}, {'w': 'b'}]) + arr = b.to_array() + assert isinstance(arr['w'], np.ndarray) + assert arr['w'].tolist() == ['a', 'b'] diff --git a/text2array/batches.py b/text2array/batches.py index 30a047d..8d60522 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -65,7 +65,7 @@ def _get_maxlens(cls, data: Sequence[Any]) -> List[int]: assert data # Base case - if not isinstance(data[0], Sequence): + if isinstance(data[0], str) or not isinstance(data[0], Sequence): return [len(data)] # Recursive case @@ -98,7 +98,7 @@ def _pad( assert depth < len(maxlens) # Base case - if not isinstance(data[0], Sequence): + if isinstance(data[0], str) or not isinstance(data[0], Sequence): data_ = list(data) # Recursive case else: From 4b4f261b3e63158c2b911fb171bfa73d17303ffe Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 07:48:26 +0700 Subject: [PATCH 137/162] Refactor Batch.to_array()'s types and var names --- text2array/batches.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/text2array/batches.py b/text2array/batches.py index 8d60522..7eba416 100644 --- a/text2array/batches.py +++ b/text2array/batches.py @@ -1,5 +1,5 @@ from functools import reduce -from typing import Any, List, Mapping, Sequence, Union +from typing import List, Mapping, Sequence, Union import numpy as np @@ -61,20 +61,20 @@ def _get_values(self, name: str) -> Sequence[FieldValue]: raise KeyError(f"some samples have no field '{name}'") @classmethod - def _get_maxlens(cls, data: Sequence[Any]) -> List[int]: - assert data + def _get_maxlens(cls, values: Sequence[FieldValue]) -> List[int]: + assert values # Base case - if isinstance(data[0], str) or not isinstance(data[0], Sequence): - return [len(data)] + if isinstance(values[0], str) or not isinstance(values[0], Sequence): + return [len(values)] # Recursive case - maxlenss = [cls._get_maxlens(x) for x in data] + maxlenss = [cls._get_maxlens(x) for x in values] if not all(len(x) == len(maxlenss[0]) for x in maxlenss): raise cls._InconsistentDepthError maxlens = reduce(lambda ml1, ml2: [max(l1, l2) for l1, l2 in zip(ml1, ml2)], maxlenss) - maxlens.insert(0, len(data)) + maxlens.insert(0, len(values)) return maxlens @classmethod @@ -88,25 +88,25 @@ def _get_paddings(cls, maxlens: List[int], with_: int) -> List[Union[int, List[i @classmethod def _pad( cls, - data: Sequence[Any], + values: Sequence[FieldValue], maxlens: List[int], paddings: List[Union[int, List[int]]], depth: int, - ) -> Sequence[Any]: - assert data + ) -> Sequence[FieldValue]: + assert values assert len(maxlens) == len(paddings) assert depth < len(maxlens) # Base case - if isinstance(data[0], str) or not isinstance(data[0], Sequence): - data_ = list(data) + if isinstance(values[0], str) or not isinstance(values[0], Sequence): + values_ = list(values) # Recursive case else: - data_ = [cls._pad(x, maxlens, paddings, depth + 1) for x in data] + values_ = [cls._pad(x, maxlens, paddings, depth + 1) for x in values] - for _ in range(maxlens[depth] - len(data)): - data_.append(paddings[depth]) - return data_ + for _ in range(maxlens[depth] - len(values)): + values_.append(paddings[depth]) + return values_ class _InconsistentDepthError(Exception): pass From 108e450d497fbe2ea7623c1b8769e2b315dfe021 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 08:01:42 +0700 Subject: [PATCH 138/162] Add a todo --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 1826cf4..35c2048 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ readme = Path(__file__).resolve().parent / 'README.rst' +# TODO release v0.0.1 setup( name='text2array', version='0.0.1', From e4cd29e68b5040d6bd5c09fc8cebb05185f02c79 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 22:06:42 +0700 Subject: [PATCH 139/162] Add numpy as dependency --- setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.py b/setup.py index 35c2048..518e758 100644 --- a/setup.py +++ b/setup.py @@ -21,5 +21,8 @@ 'Programming Language :: Python :: 3.6', ], packages=find_packages(), + install_requires=[ + 'numpy ~=1.16.0', + ], python_requires='>=3.6, <4', ) From 7436afa0074dd9dfdfb7c767e301764957589e98 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 22:06:54 +0700 Subject: [PATCH 140/162] Write longer readme --- README.rst | 244 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 242 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 3373cfe..35592e4 100644 --- a/README.rst +++ b/README.rst @@ -1,4 +1,244 @@ text2array -^^^^^^^^^^ +========== -Convert your NLP text dataset to arrays! +*Convert your NLP text dataset to arrays!* + +**text2array** helps you process your NLP text dataset into Numpy ndarray objects that are +ready to use for e.g. neural network inputs. **text2array** handles data shuffling, +batching, padding, converting into arrays. Say goodbye to these tedious works! + +Installation +------------ + +**text2array** requires at least Python 3.6 and can be installed via pip:: + + pip install text2array + +Usage overview +-------------- + +.. code-block:: python + + >>> from text2array import Dataset, Vocab + >>> + >>> samples = [ + ... {'ws': ['john', 'talks']}, + ... {'ws': ['john', 'loves', 'mary']}, + ... {'ws': ['mary']} + ... ] + >>> + >>> # Create a Dataset + >>> dataset = Dataset(samples) + >>> len(dataset) + 3 + >>> dataset[1] + {'ws': ['john', 'loves', 'mary']} + >>> + >>> # Create a Vocab + >>> vocab = Vocab.from_samples(dataset) + >>> list(vocab['ws']) + ['', '', 'john', 'mary'] + >>> # 'talks' and 'loves' are out-of-vocabulary because they occur only once + >>> vocab['ws']['john'] + 2 + >>> vocab['ws']['talks'] # unknown word is mapped to '' + 1 + >>> + >>> # Applying vocab to the dataset + >>> list(dataset) + [{'ws': ['john', 'talks']}, {'ws': ['john', 'loves', 'mary']}, {'ws': ['mary']}] + >>> dataset.apply_vocab(vocab) + >>> list(dataset) + [{'ws': [2, 1]}, {'ws': [2, 1, 3]}, {'ws': [3]}] + >>> + >>> # Shuffle, create batches of size 2, convert to array + >>> batches = dataset.shuffle().batch(2) + >>> batch = next(batches) + >>> arr = batch.to_array() + >>> arr['ws'] + array([[2, 1], + [3, 0]]) + >>> batch = next(batches) + >>> arr = batch.to_array() + >>> arr['ws'] + array([[2, 1, 3]]) + +Detailed tutorial +----------------- + +Sample +++++++ + +``Sample`` is just a ``Mapping[FieldName, FieldValue]``, where ``FieldName = str`` and +``FieldValue = Union[float, int, Sequence['FieldValue']``. It is easiest to use a ``dict`` +to represent a sample, but you can essentially use any object you like as long as it +implements ``Mapping[FieldName, FieldValue]`` (which can be ensured by subclassing from +this type). + +Dataset ++++++++ + +There are actually 2 classes for a dataset. ``Dataset`` is what you'd use normally. It +implements ``Sequence[Sample]`` so it requires all the samples to fit in the memory. To +create a ``Dataset`` object, pass an object of type ``Sequence[Sample]`` as an argument. + +.. code-block:: python + + >>> from text2array import Dataset + >>> samples = [ + ... {'ws': ['john', 'talks']}, + ... {'ws': ['john', 'loves', 'mary']}, + ... {'ws': ['mary']} + ... ] + >>> + >>> # Create a Dataset + >>> dataset = Dataset(samples) + >>> len(dataset) + 3 + >>> dataset[1] + {'ws': ['john', 'loves', 'mary']} + +If the samples can't fit in the memory, use ``StreamDataset`` instead. It implements +``Iterable[Sample]`` and streams the samples one by one, only when iterated over. To +instantiate, pass an ``Iterable[Sample]`` object. + +.. code-block:: python + + >>> from text2array import StreamDataset + >>> samples = [ + ... {'ws': ['john', 'talks']}, + ... {'ws': ['john', 'loves', 'mary']}, + ... {'ws': ['mary']} + ... ] + >>> class Stream: + ... def __init__(self, seq): + ... self.seq = seq + ... def __iter__(self): + ... return iter(self.seq) + ... + >>> dataset = StreamDataset(Stream(samples)) # simulate a stream of samples + >>> list(dataset) + [{'ws': ['john', 'talks']}, {'ws': ['john', 'loves', 'mary']}, {'ws': ['mary']}] + +Note that because ``StreamDataset`` is an iterable, you can't ask for its length nor access +by index, but it can be iterated over. + +Shuffling dataset +^^^^^^^^^^^^^^^^^ + +``StreamDataset`` cannot be shuffled because shuffling requires all the elements to be +accessible by index. So, only ``Dataset`` can be shuffled. There are 2 ways to shuffle. +First, using ``shuffle`` method, which shuffles the dataset randomly without any +constraints. Second, using ``shuffle_by`` which accepts a ``Callable[[Sample], int]`` +and use that to shuffle by performing a noisy sorting:: + +.. code-block:: python + + >>> from text2array import Dataset + >>> samples = [ + ... {'ws': ['john', 'talks']}, + ... {'ws': ['john', 'loves', 'mary']}, + ... {'ws': ['mary']} + ... ] + >>> dataset = Dataset(samples) + >>> dataset.shuffle_by(lambda s: len(s['ws'])) + +The example above shuffles the dataset but also tries to keep samples with similar lengths +closer. This is useful for NLP where we want to shuffle but also minimize padding in each +batch. If a very short sample ends up in the same batch as a very long one, there would be +a lot of wasted entries for padding. Sorting noisily by length can help mitigate this issue. +This approach is inspired by `AllenNLP `_. Note that +both ``shuffle`` and ``shuffle_by`` returns the dataset object itself so method chaining +is possible. + +Batching dataset +^^^^^^^^^^^^^^^^ + +To split up a dataset into batches, use the ``batch`` method, which takes the batch size +as an argument. :: + +.. code-block:: python + + >>> from text2array import Dataset + >>> samples = [ + ... {'ws': ['john', 'talks']}, + ... {'ws': ['john', 'loves', 'mary']}, + ... {'ws': ['mary']} + ... ] + >>> dataset = Dataset(samples) + >>> for batch in dataset.batch(2): + ... print('batch:', list(batch)) + ... + batch: [{'ws': ['john', 'talks']}, {'ws': ['john', 'loves', 'mary']}] + batch: [{'ws': ['mary']}] + +The method returns an ``Iterator[Batch]`` object so it can be iterated only once. If you want +the batches to have exactly the same size, i.e. dropping the last one if it's smaller than +batch size, use ``batch_exactly`` instead. The two methods are also available for +``StreamDataset``. Note that before batching, you might want to map all those strings +into integer IDs first, which is explained in the next section. + +Applying vocabulary +^^^^^^^^^^^^^^^^^^^ + +A vocabulary should implement ``Mapping[FieldName, Mapping[FieldValue, FieldValue]]``. +Then, call ``apply_vocab`` method with the vocabulary as an argument. This is best +explained with an example. :: + +.. code-block:: python + + >>> from pprint import pprint + >>> from text2array import Dataset + >>> samples = [ + ... {'ws': ['john', 'talks'], 'i': 10, 'label': 'pos'}, + ... {'ws': ['john', 'loves', 'mary'], 'i': 20, 'label': 'pos'}, + ... {'ws': ['mary'], 'i': 30, 'label': 'neg'} + ... ] + >>> dataset = Dataset(samples) + >>> vocab = { + ... 'ws': {'john': 0, 'talks': 1, 'loves': 2, 'mary': 3}, + ... 'i': {10: 5, 20: 10, 30: 15} + ... } + >>> dataset.apply_vocab(vocab) + >>> pprint(list(dataset)) + [{'i': 5, 'label': 'pos', 'ws': [0, 1]}, + {'i': 10, 'label': 'pos', 'ws': [0, 2, 3]}, + {'i': 15, 'label': 'neg', 'ws': [3]}] + +Note that the vocabulary is only applied to fields whose name is contained in the +vocabulary. Although not shown above, the vocabulary application still works even if +the field value is a deeply nested sequence. + +Vocabulary +++++++++++ + +Creating a vocabulary object from scratch is tedious. So, it's common to learn the vocabulary +from a dataset. The ``Vocab`` class can be used for this purpose. + +.. code-block:: python + + >>> samples = [ + ... {'ws': ['john', 'talks'], 'i': 10, 'label': 'pos'}, + ... {'ws': ['john', 'loves', 'mary'], 'i': 20, 'label': 'pos'}, + ... {'ws': ['mary'], 'i': 30, 'label': 'neg'} + ... ] + >>> vocab = Vocab.from_samples(samples) + >>> list(vocab.keys()) + ['ws', 'label'] + >>> dict(vocab['ws']) + {'': 0, '': 1, 'john': 2, 'mary': 3} + >>> dict(vocab['label']) + {'': 0, 'pos': 1} + >>> vocab['ws']['john'], vocab['ws']['talks'] + (2, 1) + +There are several things to note: + +#. Vocabularies are only created for fields which contain ``str`` values. +#. Words that occur only once are not included in the vocabulary. +#. Non-sequence fields do not have a padding token in the vocabulary. +#. Out-of-vocabulary words are assigned a single ID for unknown words. + +``Vocab.from_samples`` actually accepts an ``Iterable[Sample]``, which means a ``Dataset`` +or a ``StreamDataset`` can be passed as well. See the docstring to see other arguments +that it accepts to customize vocabulary creation. From 87cad0ce7d153c3554510f4e2f9d93251f888d7c Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 22:31:21 +0700 Subject: [PATCH 141/162] Finish tutorial stuff on readme --- README.rst | 53 +++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/README.rst b/README.rst index 35592e4..eb92005 100644 --- a/README.rst +++ b/README.rst @@ -130,7 +130,7 @@ Shuffling dataset accessible by index. So, only ``Dataset`` can be shuffled. There are 2 ways to shuffle. First, using ``shuffle`` method, which shuffles the dataset randomly without any constraints. Second, using ``shuffle_by`` which accepts a ``Callable[[Sample], int]`` -and use that to shuffle by performing a noisy sorting:: +and use that to shuffle by performing a noisy sorting. .. code-block:: python @@ -155,7 +155,7 @@ Batching dataset ^^^^^^^^^^^^^^^^ To split up a dataset into batches, use the ``batch`` method, which takes the batch size -as an argument. :: +as an argument. .. code-block:: python @@ -183,7 +183,7 @@ Applying vocabulary A vocabulary should implement ``Mapping[FieldName, Mapping[FieldValue, FieldValue]]``. Then, call ``apply_vocab`` method with the vocabulary as an argument. This is best -explained with an example. :: +explained with an example. .. code-block:: python @@ -207,7 +207,8 @@ explained with an example. :: Note that the vocabulary is only applied to fields whose name is contained in the vocabulary. Although not shown above, the vocabulary application still works even if -the field value is a deeply nested sequence. +the field value is a deeply nested sequence. Note that ``apply_vocab`` is available +for ``StreamDataset`` as well. Vocabulary ++++++++++ @@ -223,8 +224,8 @@ from a dataset. The ``Vocab`` class can be used for this purpose. ... {'ws': ['mary'], 'i': 30, 'label': 'neg'} ... ] >>> vocab = Vocab.from_samples(samples) - >>> list(vocab.keys()) - ['ws', 'label'] + >>> vocab.keys() + dict_keys(['ws', 'label']) >>> dict(vocab['ws']) {'': 0, '': 1, 'john': 2, 'mary': 3} >>> dict(vocab['label']) @@ -242,3 +243,43 @@ There are several things to note: ``Vocab.from_samples`` actually accepts an ``Iterable[Sample]``, which means a ``Dataset`` or a ``StreamDataset`` can be passed as well. See the docstring to see other arguments that it accepts to customize vocabulary creation. + +Batch ++++++ + +Both ``batch`` and ``batch_exactly`` methods return ``Iterator[Batch]`` where ``Batch`` +implements ``Sequence[Sample]``. This is true even for ``StreamDataset``. So, although +all samples may not all fit in the memory, a batch of them should. Given a ``Batch`` +object, it can be converted into Numpy's ndarray by ``to_array`` method. Note that normally +you'd want to apply the vocabulary beforehand to ensure all values contain only ints or floats. + +.. code-block:: python + + >>> from text2array import Dataset + >>> samples = [ + ... {'ws': ['john', 'talks'], 'i': 10, 'label': 'pos'}, + ... {'ws': ['john', 'loves', 'mary'], 'i': 20, 'label': 'pos'}, + ... {'ws': ['mary'], 'i': 30, 'label': 'neg'} + ... ] + >>> dataset = Dataset(samples) + >>> vocab = Vocab.from_samples(dataset) + >>> dict(vocab['ws']) + {'': 0, '': 1, 'john': 2, 'mary': 3} + >>> dict(vocab['label']) + {'': 0, 'pos': 1} + >>> dataset.apply_vocab(vocab) + >>> batches = dataset.batch(2) + >>> batch = next(batches) + >>> arr = batch.to_array() + >>> arr.keys() + dict_keys(['ws', 'i', 'label']) + >>> arr['ws'] + array([[2, 1, 0], + [2, 1, 3]]) + >>> arr['i'] + array([10, 20]) + >>> arr['label'] + array([1, 1]) + +Note that ``to_array`` returns a ``Mapping[FieldName, np.ndarray]`` object, and sequential +fields are automatically padded. From e192126c769352ad0d1c3c77a908db4ac539df82 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 22:44:59 +0700 Subject: [PATCH 142/162] Add str as FieldValue as well --- text2array/samples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text2array/samples.py b/text2array/samples.py index 79f8fd1..7e6f010 100644 --- a/text2array/samples.py +++ b/text2array/samples.py @@ -3,5 +3,5 @@ # TODO remove these "type ignore" once mypy supports recursive types # see: https://github.com/python/mypy/issues/731 FieldName = str -FieldValue = Union[float, int, Sequence['FieldValue']] # type: ignore +FieldValue = Union[float, int, str, Sequence['FieldValue']] # type: ignore Sample = Mapping[FieldName, FieldValue] # type: ignore From 22f8b6ec6537788c0967fedebb892bd7d881f1ef Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Thu, 31 Jan 2019 22:45:13 +0700 Subject: [PATCH 143/162] Fix readme --- README.rst | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/README.rst b/README.rst index eb92005..6d37c60 100644 --- a/README.rst +++ b/README.rst @@ -5,7 +5,7 @@ text2array **text2array** helps you process your NLP text dataset into Numpy ndarray objects that are ready to use for e.g. neural network inputs. **text2array** handles data shuffling, -batching, padding, converting into arrays. Say goodbye to these tedious works! +batching, padding, and converting into arrays. Say goodbye to these tedious works! Installation ------------ @@ -70,15 +70,15 @@ Sample ++++++ ``Sample`` is just a ``Mapping[FieldName, FieldValue]``, where ``FieldName = str`` and -``FieldValue = Union[float, int, Sequence['FieldValue']``. It is easiest to use a ``dict`` -to represent a sample, but you can essentially use any object you like as long as it -implements ``Mapping[FieldName, FieldValue]`` (which can be ensured by subclassing from -this type). +``FieldValue = Union[float, int, str, Sequence['FieldValue']``. It is easiest to use a +``dict`` to represent a sample, but you can essentially use any object you like as long +as it implements ``Mapping[FieldName, FieldValue]`` (which can be ensured by subclassing +from this type). Dataset +++++++ -There are actually 2 classes for a dataset. ``Dataset`` is what you'd use normally. It +There are actually 2 classes for a dataset. ``Dataset`` is what you would normally use. It implements ``Sequence[Sample]`` so it requires all the samples to fit in the memory. To create a ``Dataset`` object, pass an object of type ``Sequence[Sample]`` as an argument. @@ -120,7 +120,7 @@ instantiate, pass an ``Iterable[Sample]`` object. >>> list(dataset) [{'ws': ['john', 'talks']}, {'ws': ['john', 'loves', 'mary']}, {'ws': ['mary']}] -Note that because ``StreamDataset`` is an iterable, you can't ask for its length nor access +Because ``StreamDataset`` is an iterable, you can't ask for its length nor access by index, but it can be iterated over. Shuffling dataset @@ -147,9 +147,9 @@ The example above shuffles the dataset but also tries to keep samples with simil closer. This is useful for NLP where we want to shuffle but also minimize padding in each batch. If a very short sample ends up in the same batch as a very long one, there would be a lot of wasted entries for padding. Sorting noisily by length can help mitigate this issue. -This approach is inspired by `AllenNLP `_. Note that -both ``shuffle`` and ``shuffle_by`` returns the dataset object itself so method chaining -is possible. +This approach is inspired by `AllenNLP `_. Both +``shuffle`` and ``shuffle_by`` returns the dataset object itself so method chaining +is possible. See the docstring for more details. Batching dataset ^^^^^^^^^^^^^^^^ @@ -174,8 +174,8 @@ as an argument. The method returns an ``Iterator[Batch]`` object so it can be iterated only once. If you want the batches to have exactly the same size, i.e. dropping the last one if it's smaller than -batch size, use ``batch_exactly`` instead. The two methods are also available for -``StreamDataset``. Note that before batching, you might want to map all those strings +the batch size, use ``batch_exactly`` instead. The two methods are also available for +``StreamDataset``. Before batching, you might want to map all those strings into integer IDs first, which is explained in the next section. Applying vocabulary @@ -207,7 +207,7 @@ explained with an example. Note that the vocabulary is only applied to fields whose name is contained in the vocabulary. Although not shown above, the vocabulary application still works even if -the field value is a deeply nested sequence. Note that ``apply_vocab`` is available +the field value is a deeply nested sequence. Method ``apply_vocab`` is available for ``StreamDataset`` as well. Vocabulary @@ -218,6 +218,7 @@ from a dataset. The ``Vocab`` class can be used for this purpose. .. code-block:: python + >>> from text2array import Vocab >>> samples = [ ... {'ws': ['john', 'talks'], 'i': 10, 'label': 'pos'}, ... {'ws': ['john', 'loves', 'mary'], 'i': 20, 'label': 'pos'}, @@ -249,13 +250,13 @@ Batch Both ``batch`` and ``batch_exactly`` methods return ``Iterator[Batch]`` where ``Batch`` implements ``Sequence[Sample]``. This is true even for ``StreamDataset``. So, although -all samples may not all fit in the memory, a batch of them should. Given a ``Batch`` -object, it can be converted into Numpy's ndarray by ``to_array`` method. Note that normally +samples may not all fit in the memory, a batch of them should. Given a ``Batch`` +object, it can be converted into Numpy's ndarray by ``to_array`` method. Normally, you'd want to apply the vocabulary beforehand to ensure all values contain only ints or floats. .. code-block:: python - >>> from text2array import Dataset + >>> from text2array import Dataset, Vocab >>> samples = [ ... {'ws': ['john', 'talks'], 'i': 10, 'label': 'pos'}, ... {'ws': ['john', 'loves', 'mary'], 'i': 20, 'label': 'pos'}, From 411110d06784a7627e4a708f16743242c021b30b Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 06:26:57 +0700 Subject: [PATCH 144/162] Complete readme --- README.rst | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/README.rst b/README.rst index 6d37c60..6d99701 100644 --- a/README.rst +++ b/README.rst @@ -12,10 +12,10 @@ Installation **text2array** requires at least Python 3.6 and can be installed via pip:: - pip install text2array + $ pip install text2array -Usage overview --------------- +Overview +-------- .. code-block:: python @@ -63,8 +63,8 @@ Usage overview >>> arr['ws'] array([[2, 1, 3]]) -Detailed tutorial ------------------ +Tutorial +-------- Sample ++++++ @@ -284,3 +284,26 @@ you'd want to apply the vocabulary beforehand to ensure all values contain only Note that ``to_array`` returns a ``Mapping[FieldName, np.ndarray]`` object, and sequential fields are automatically padded. + +Contributing +------------ + +Pull requests are welcome! To start contributing, make sure to install all the dependencies. + +:: + + $ pip install -r requirements.txt + +Next, setup the pre-commit hook. + +:: + + $ ln -s ../../pre-commit.sh .git/hooks/pre-commit + +Tests and the linter can be run with ``pytest`` and ``flake8`` respectively. The latter also +runs ``mypy`` for type checking. + +License +------- + +MIT From 250c93812a2903afcaeb82ef37332a70113844f2 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 06:27:06 +0700 Subject: [PATCH 145/162] Add license --- LICENSE.txt | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 LICENSE.txt diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..4bc1693 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Kemal Kurniawan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. From 042eba75bce9fe987e0d8f7ff22640444875bd4e Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 06:32:18 +0700 Subject: [PATCH 146/162] Add Makefile --- Makefile | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d7e169f --- /dev/null +++ b/Makefile @@ -0,0 +1,11 @@ +.PHONY := build upload test-upload + +build: + python setup.py bdist_wheel + python setup.py sdist + +upload: build + twine upload --skip-existing dist/* + +test-upload: build + twine upload -r testpypi --skip-existing dist/* From 91a650dda5e0fe272d2fb7932274318da59e47fa Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 06:32:25 +0700 Subject: [PATCH 147/162] Add manifest file --- MANIFEST.in | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..408f26f --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include README.rst +include LICENSE.txt \ No newline at end of file From 87e00862dda206a1d5f700754ab2839eb90ff3bb Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 06:46:51 +0700 Subject: [PATCH 148/162] Delete unnecessary manifest file --- MANIFEST.in | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 408f26f..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,2 +0,0 @@ -include README.rst -include LICENSE.txt \ No newline at end of file From b71ebd8f22087ccbb72ea71d5b899a4a8fa4165b Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 06:47:07 +0700 Subject: [PATCH 149/162] Fix .PHONY target --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index d7e169f..652948b 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY := build upload test-upload +.PHONY: build upload test-upload build: python setup.py bdist_wheel From 70c9aab7434ea7fa31478cab1ceeb42e9da9644d Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 06:47:45 +0700 Subject: [PATCH 150/162] Remove todo --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 518e758..106dc94 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,6 @@ readme = Path(__file__).resolve().parent / 'README.rst' -# TODO release v0.0.1 setup( name='text2array', version='0.0.1', From 0ddb42e90922805cf69062082aa74fbbe1b84c72 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 07:19:55 +0700 Subject: [PATCH 151/162] Improve examples on readme and add nested fields for padding --- README.rst | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index 6d99701..790064e 100644 --- a/README.rst +++ b/README.rst @@ -39,8 +39,12 @@ Overview >>> list(vocab['ws']) ['', '', 'john', 'mary'] >>> # 'talks' and 'loves' are out-of-vocabulary because they occur only once + >>> 'john' in vocab['ws'] + True >>> vocab['ws']['john'] 2 + >>> 'talks' in vocab['ws'] + False >>> vocab['ws']['talks'] # unknown word is mapped to '' 1 >>> @@ -56,12 +60,12 @@ Overview >>> batch = next(batches) >>> arr = batch.to_array() >>> arr['ws'] - array([[2, 1], - [3, 0]]) + array([[3, 0, 0], + [2, 1, 3]]) >>> batch = next(batches) >>> arr = batch.to_array() >>> arr['ws'] - array([[2, 1, 3]]) + array([[2, 1]]) Tutorial -------- @@ -231,6 +235,8 @@ from a dataset. The ``Vocab`` class can be used for this purpose. {'': 0, '': 1, 'john': 2, 'mary': 3} >>> dict(vocab['label']) {'': 0, 'pos': 1} + >>> 'john' in vocab['ws'], 'talks' in vocab['ws'] + (True, False) >>> vocab['ws']['john'], vocab['ws']['talks'] (2, 1) @@ -283,7 +289,53 @@ you'd want to apply the vocabulary beforehand to ensure all values contain only array([1, 1]) Note that ``to_array`` returns a ``Mapping[FieldName, np.ndarray]`` object, and sequential -fields are automatically padded. +fields are automatically padded. One of the nice things is that the field can be deeply +nested and the padding just works! + +.. code-block:: + + >>> from pprint import pprint + >>> from text2array import Dataset, Vocab + >>> samples = [ + ... {'ws': ['john', 'talks'], 'cs': [list('john'), list('talks')]}, + ... {'ws': ['john', 'loves', 'mary'], 'cs': [list('john'), list('loves'), list('mary')]}, + ... {'ws': ['mary'], 'cs': [list('mary')]} + ... ] + >>> dataset = Dataset(samples) + >>> vocab = Vocab.from_samples(dataset) + >>> dataset.apply_vocab(vocab) + >>> dict(vocab['ws']) + {'': 0, '': 1, 'john': 2, 'mary': 3} + >>> pprint(dict(vocab['cs'])) + {'': 0, + '': 1, + 'a': 3, + 'h': 5, + 'j': 4, + 'l': 7, + 'm': 9, + 'n': 6, + 'o': 2, + 'r': 10, + 's': 8, + 'y': 11} + >>> batches = dat.batch(2) + >>> batch = next(batches) + >>> arr = batch.to_array() + >>> arr['ws'] + array([[2, 1, 0], + [2, 1, 3]]) + >>> arr['cs'] + array([[[ 4, 2, 5, 6, 0], + [ 1, 3, 7, 1, 8], + [ 0, 0, 0, 0, 0]], + + [[ 4, 2, 5, 6, 0], + [ 7, 2, 1, 1, 8], + [ 9, 3, 10, 11, 0]]]) + +So, you can go crazy and have a field representing a document hierarchically as paragraphs, +sentences, words, and characters, and it will be padded correctly. Contributing ------------ From 87c72f8a580a0ebe061372c0a62db014316e483b Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 07:22:31 +0700 Subject: [PATCH 152/162] Fix unnamed language code block --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 790064e..27849fc 100644 --- a/README.rst +++ b/README.rst @@ -292,7 +292,7 @@ Note that ``to_array`` returns a ``Mapping[FieldName, np.ndarray]`` object, and fields are automatically padded. One of the nice things is that the field can be deeply nested and the padding just works! -.. code-block:: +.. code-block:: python >>> from pprint import pprint >>> from text2array import Dataset, Vocab From 93d22e2ee4d46187908f199abc26f7a77c5b75a8 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 07:26:45 +0700 Subject: [PATCH 153/162] Add todos --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 106dc94..38a9287 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,8 @@ readme = Path(__file__).resolve().parent / 'README.rst' +# TODO setup travis +# TODO coveralls.io setup( name='text2array', version='0.0.1', From 2a7aa07980bac2cfeff121ec985f15367828f324 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 07:27:41 +0700 Subject: [PATCH 154/162] Add spacemacs badge --- README.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.rst b/README.rst index 27849fc..aa35f80 100644 --- a/README.rst +++ b/README.rst @@ -3,6 +3,9 @@ text2array *Convert your NLP text dataset to arrays!* +.. image:: https://cdn.rawgit.com/syl20bnr/spacemacs/442d025779da2f62fc86c2082703697714db6514/assets/spacemacs-badge.svg + :target: http://spacemacs.org + **text2array** helps you process your NLP text dataset into Numpy ndarray objects that are ready to use for e.g. neural network inputs. **text2array** handles data shuffling, batching, padding, and converting into arrays. Say goodbye to these tedious works! From a99af40ca9d62da7af76fd3f4af998de6f5e4cea Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 10:03:09 +0700 Subject: [PATCH 155/162] Add .travis.yml --- .travis.yml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..95a3334 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,3 @@ +language: python +python: 3.6 +script: pytest From 72555d514497f0bf4f3afa047de0a47f74eee2bf Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 10:11:21 +0700 Subject: [PATCH 156/162] Add coveralls setup to travis --- .travis.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.travis.yml b/.travis.yml index 95a3334..720a41b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,6 @@ language: python python: 3.6 script: pytest +after_success: + - pip install python-coveralls + - coveralls From 51c8a130c308dda4e613975688993f42c9bf169e Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 10:43:48 +0700 Subject: [PATCH 157/162] Fix installing dependencies in travis --- .travis.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.travis.yml b/.travis.yml index 720a41b..9326068 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,8 @@ language: python python: 3.6 +install: + - pip install -r requirements.txt + - pip install . script: pytest after_success: - pip install python-coveralls From 031432a922cacfb8b1d5f925734a99bbe2df402f Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 10:48:59 +0700 Subject: [PATCH 158/162] Try to fix no coverage data found warning --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 9326068..9521f9d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,7 @@ language: python python: 3.6 install: - pip install -r requirements.txt - - pip install . + - pip install -e . script: pytest after_success: - pip install python-coveralls From 7bbb3208e7223405d924411e18f6ec3c178f4594 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 10:53:54 +0700 Subject: [PATCH 159/162] Try to fix KeyError when running coveralls --- .travis.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.travis.yml b/.travis.yml index 9521f9d..86994eb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,6 +2,8 @@ language: python python: 3.6 install: - pip install -r requirements.txt + # https://github.com/z4r/python-coveralls/issues/61 + - pip install coverage==4.4.2 - pip install -e . script: pytest after_success: From 8cc0f85081b92097680bae8d8a560ca9a36b65e1 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 11:00:40 +0700 Subject: [PATCH 160/162] Change to coveralls --- .travis.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 86994eb..9292bdc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,10 +2,8 @@ language: python python: 3.6 install: - pip install -r requirements.txt - # https://github.com/z4r/python-coveralls/issues/61 - - pip install coverage==4.4.2 - pip install -e . script: pytest after_success: - - pip install python-coveralls + - pip install coveralls - coveralls From b3c267c20f993f3c2455873be1a225f5ececf6af Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 11:08:56 +0700 Subject: [PATCH 161/162] Add travis and coveralls badges [skip ci] --- README.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.rst b/README.rst index aa35f80..59cb5a2 100644 --- a/README.rst +++ b/README.rst @@ -3,6 +3,12 @@ text2array *Convert your NLP text dataset to arrays!* +.. image:: https://travis-ci.org/kmkurn/text2array.svg?branch=master + :target: https://travis-ci.org/kmkurn/text2array + +.. image:: https://coveralls.io/repos/github/kmkurn/text2array/badge.svg?branch=master + :target: https://coveralls.io/github/kmkurn/text2array?branch=master + .. image:: https://cdn.rawgit.com/syl20bnr/spacemacs/442d025779da2f62fc86c2082703697714db6514/assets/spacemacs-badge.svg :target: http://spacemacs.org From 4396407e599e7d95c448eb9f2a82d4472571ad63 Mon Sep 17 00:00:00 2001 From: Kemal Kurniawan Date: Fri, 1 Feb 2019 11:12:38 +0700 Subject: [PATCH 162/162] Remove todos [skip ci] --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index 38a9287..106dc94 100644 --- a/setup.py +++ b/setup.py @@ -3,8 +3,6 @@ readme = Path(__file__).resolve().parent / 'README.rst' -# TODO setup travis -# TODO coveralls.io setup( name='text2array', version='0.0.1',