diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..f77eabf --- /dev/null +++ b/.coveragerc @@ -0,0 +1,11 @@ +[run] +branch = True + +[report] +exclude_lines = + # Re-enable standard pragma + pragma: no cover + # abstract method + abstractmethod + # debugging stuff + __repr__ diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..dfec9d6 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +# Config file for flake8 + +[flake8] +ignore = E,W # let yapf handle stylistic issues +show-source = True diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000..13b6369 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,8 @@ +# Config file for YAPF Python formatter + +[style] +based_on_style = pep8 +coalesce_brackets = True +column_limit = 96 +split_before_first_argument = True +split_complex_comprehension = True diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..9292bdc --- /dev/null +++ b/.travis.yml @@ -0,0 +1,9 @@ +language: python +python: 3.6 +install: + - pip install -r requirements.txt + - pip install -e . +script: pytest +after_success: + - pip install coveralls + - coveralls diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..4bc1693 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Kemal Kurniawan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..652948b --- /dev/null +++ b/Makefile @@ -0,0 +1,11 @@ +.PHONY: build upload test-upload + +build: + python setup.py bdist_wheel + python setup.py sdist + +upload: build + twine upload --skip-existing dist/* + +test-upload: build + twine upload -r testpypi --skip-existing dist/* diff --git a/README.rst b/README.rst index 908da3f..59cb5a2 100644 --- a/README.rst +++ b/README.rst @@ -1,4 +1,370 @@ -text2tensor -^^^^^^^^^^^ +text2array +========== -Convert your NLP text dataset to (batched) tensors! +*Convert your NLP text dataset to arrays!* + +.. image:: https://travis-ci.org/kmkurn/text2array.svg?branch=master + :target: https://travis-ci.org/kmkurn/text2array + +.. image:: https://coveralls.io/repos/github/kmkurn/text2array/badge.svg?branch=master + :target: https://coveralls.io/github/kmkurn/text2array?branch=master + +.. image:: https://cdn.rawgit.com/syl20bnr/spacemacs/442d025779da2f62fc86c2082703697714db6514/assets/spacemacs-badge.svg + :target: http://spacemacs.org + +**text2array** helps you process your NLP text dataset into Numpy ndarray objects that are +ready to use for e.g. neural network inputs. **text2array** handles data shuffling, +batching, padding, and converting into arrays. Say goodbye to these tedious works! + +Installation +------------ + +**text2array** requires at least Python 3.6 and can be installed via pip:: + + $ pip install text2array + +Overview +-------- + +.. code-block:: python + + >>> from text2array import Dataset, Vocab + >>> + >>> samples = [ + ... {'ws': ['john', 'talks']}, + ... {'ws': ['john', 'loves', 'mary']}, + ... {'ws': ['mary']} + ... ] + >>> + >>> # Create a Dataset + >>> dataset = Dataset(samples) + >>> len(dataset) + 3 + >>> dataset[1] + {'ws': ['john', 'loves', 'mary']} + >>> + >>> # Create a Vocab + >>> vocab = Vocab.from_samples(dataset) + >>> list(vocab['ws']) + ['', '', 'john', 'mary'] + >>> # 'talks' and 'loves' are out-of-vocabulary because they occur only once + >>> 'john' in vocab['ws'] + True + >>> vocab['ws']['john'] + 2 + >>> 'talks' in vocab['ws'] + False + >>> vocab['ws']['talks'] # unknown word is mapped to '' + 1 + >>> + >>> # Applying vocab to the dataset + >>> list(dataset) + [{'ws': ['john', 'talks']}, {'ws': ['john', 'loves', 'mary']}, {'ws': ['mary']}] + >>> dataset.apply_vocab(vocab) + >>> list(dataset) + [{'ws': [2, 1]}, {'ws': [2, 1, 3]}, {'ws': [3]}] + >>> + >>> # Shuffle, create batches of size 2, convert to array + >>> batches = dataset.shuffle().batch(2) + >>> batch = next(batches) + >>> arr = batch.to_array() + >>> arr['ws'] + array([[3, 0, 0], + [2, 1, 3]]) + >>> batch = next(batches) + >>> arr = batch.to_array() + >>> arr['ws'] + array([[2, 1]]) + +Tutorial +-------- + +Sample +++++++ + +``Sample`` is just a ``Mapping[FieldName, FieldValue]``, where ``FieldName = str`` and +``FieldValue = Union[float, int, str, Sequence['FieldValue']``. It is easiest to use a +``dict`` to represent a sample, but you can essentially use any object you like as long +as it implements ``Mapping[FieldName, FieldValue]`` (which can be ensured by subclassing +from this type). + +Dataset ++++++++ + +There are actually 2 classes for a dataset. ``Dataset`` is what you would normally use. It +implements ``Sequence[Sample]`` so it requires all the samples to fit in the memory. To +create a ``Dataset`` object, pass an object of type ``Sequence[Sample]`` as an argument. + +.. code-block:: python + + >>> from text2array import Dataset + >>> samples = [ + ... {'ws': ['john', 'talks']}, + ... {'ws': ['john', 'loves', 'mary']}, + ... {'ws': ['mary']} + ... ] + >>> + >>> # Create a Dataset + >>> dataset = Dataset(samples) + >>> len(dataset) + 3 + >>> dataset[1] + {'ws': ['john', 'loves', 'mary']} + +If the samples can't fit in the memory, use ``StreamDataset`` instead. It implements +``Iterable[Sample]`` and streams the samples one by one, only when iterated over. To +instantiate, pass an ``Iterable[Sample]`` object. + +.. code-block:: python + + >>> from text2array import StreamDataset + >>> samples = [ + ... {'ws': ['john', 'talks']}, + ... {'ws': ['john', 'loves', 'mary']}, + ... {'ws': ['mary']} + ... ] + >>> class Stream: + ... def __init__(self, seq): + ... self.seq = seq + ... def __iter__(self): + ... return iter(self.seq) + ... + >>> dataset = StreamDataset(Stream(samples)) # simulate a stream of samples + >>> list(dataset) + [{'ws': ['john', 'talks']}, {'ws': ['john', 'loves', 'mary']}, {'ws': ['mary']}] + +Because ``StreamDataset`` is an iterable, you can't ask for its length nor access +by index, but it can be iterated over. + +Shuffling dataset +^^^^^^^^^^^^^^^^^ + +``StreamDataset`` cannot be shuffled because shuffling requires all the elements to be +accessible by index. So, only ``Dataset`` can be shuffled. There are 2 ways to shuffle. +First, using ``shuffle`` method, which shuffles the dataset randomly without any +constraints. Second, using ``shuffle_by`` which accepts a ``Callable[[Sample], int]`` +and use that to shuffle by performing a noisy sorting. + +.. code-block:: python + + >>> from text2array import Dataset + >>> samples = [ + ... {'ws': ['john', 'talks']}, + ... {'ws': ['john', 'loves', 'mary']}, + ... {'ws': ['mary']} + ... ] + >>> dataset = Dataset(samples) + >>> dataset.shuffle_by(lambda s: len(s['ws'])) + +The example above shuffles the dataset but also tries to keep samples with similar lengths +closer. This is useful for NLP where we want to shuffle but also minimize padding in each +batch. If a very short sample ends up in the same batch as a very long one, there would be +a lot of wasted entries for padding. Sorting noisily by length can help mitigate this issue. +This approach is inspired by `AllenNLP `_. Both +``shuffle`` and ``shuffle_by`` returns the dataset object itself so method chaining +is possible. See the docstring for more details. + +Batching dataset +^^^^^^^^^^^^^^^^ + +To split up a dataset into batches, use the ``batch`` method, which takes the batch size +as an argument. + +.. code-block:: python + + >>> from text2array import Dataset + >>> samples = [ + ... {'ws': ['john', 'talks']}, + ... {'ws': ['john', 'loves', 'mary']}, + ... {'ws': ['mary']} + ... ] + >>> dataset = Dataset(samples) + >>> for batch in dataset.batch(2): + ... print('batch:', list(batch)) + ... + batch: [{'ws': ['john', 'talks']}, {'ws': ['john', 'loves', 'mary']}] + batch: [{'ws': ['mary']}] + +The method returns an ``Iterator[Batch]`` object so it can be iterated only once. If you want +the batches to have exactly the same size, i.e. dropping the last one if it's smaller than +the batch size, use ``batch_exactly`` instead. The two methods are also available for +``StreamDataset``. Before batching, you might want to map all those strings +into integer IDs first, which is explained in the next section. + +Applying vocabulary +^^^^^^^^^^^^^^^^^^^ + +A vocabulary should implement ``Mapping[FieldName, Mapping[FieldValue, FieldValue]]``. +Then, call ``apply_vocab`` method with the vocabulary as an argument. This is best +explained with an example. + +.. code-block:: python + + >>> from pprint import pprint + >>> from text2array import Dataset + >>> samples = [ + ... {'ws': ['john', 'talks'], 'i': 10, 'label': 'pos'}, + ... {'ws': ['john', 'loves', 'mary'], 'i': 20, 'label': 'pos'}, + ... {'ws': ['mary'], 'i': 30, 'label': 'neg'} + ... ] + >>> dataset = Dataset(samples) + >>> vocab = { + ... 'ws': {'john': 0, 'talks': 1, 'loves': 2, 'mary': 3}, + ... 'i': {10: 5, 20: 10, 30: 15} + ... } + >>> dataset.apply_vocab(vocab) + >>> pprint(list(dataset)) + [{'i': 5, 'label': 'pos', 'ws': [0, 1]}, + {'i': 10, 'label': 'pos', 'ws': [0, 2, 3]}, + {'i': 15, 'label': 'neg', 'ws': [3]}] + +Note that the vocabulary is only applied to fields whose name is contained in the +vocabulary. Although not shown above, the vocabulary application still works even if +the field value is a deeply nested sequence. Method ``apply_vocab`` is available +for ``StreamDataset`` as well. + +Vocabulary +++++++++++ + +Creating a vocabulary object from scratch is tedious. So, it's common to learn the vocabulary +from a dataset. The ``Vocab`` class can be used for this purpose. + +.. code-block:: python + + >>> from text2array import Vocab + >>> samples = [ + ... {'ws': ['john', 'talks'], 'i': 10, 'label': 'pos'}, + ... {'ws': ['john', 'loves', 'mary'], 'i': 20, 'label': 'pos'}, + ... {'ws': ['mary'], 'i': 30, 'label': 'neg'} + ... ] + >>> vocab = Vocab.from_samples(samples) + >>> vocab.keys() + dict_keys(['ws', 'label']) + >>> dict(vocab['ws']) + {'': 0, '': 1, 'john': 2, 'mary': 3} + >>> dict(vocab['label']) + {'': 0, 'pos': 1} + >>> 'john' in vocab['ws'], 'talks' in vocab['ws'] + (True, False) + >>> vocab['ws']['john'], vocab['ws']['talks'] + (2, 1) + +There are several things to note: + +#. Vocabularies are only created for fields which contain ``str`` values. +#. Words that occur only once are not included in the vocabulary. +#. Non-sequence fields do not have a padding token in the vocabulary. +#. Out-of-vocabulary words are assigned a single ID for unknown words. + +``Vocab.from_samples`` actually accepts an ``Iterable[Sample]``, which means a ``Dataset`` +or a ``StreamDataset`` can be passed as well. See the docstring to see other arguments +that it accepts to customize vocabulary creation. + +Batch ++++++ + +Both ``batch`` and ``batch_exactly`` methods return ``Iterator[Batch]`` where ``Batch`` +implements ``Sequence[Sample]``. This is true even for ``StreamDataset``. So, although +samples may not all fit in the memory, a batch of them should. Given a ``Batch`` +object, it can be converted into Numpy's ndarray by ``to_array`` method. Normally, +you'd want to apply the vocabulary beforehand to ensure all values contain only ints or floats. + +.. code-block:: python + + >>> from text2array import Dataset, Vocab + >>> samples = [ + ... {'ws': ['john', 'talks'], 'i': 10, 'label': 'pos'}, + ... {'ws': ['john', 'loves', 'mary'], 'i': 20, 'label': 'pos'}, + ... {'ws': ['mary'], 'i': 30, 'label': 'neg'} + ... ] + >>> dataset = Dataset(samples) + >>> vocab = Vocab.from_samples(dataset) + >>> dict(vocab['ws']) + {'': 0, '': 1, 'john': 2, 'mary': 3} + >>> dict(vocab['label']) + {'': 0, 'pos': 1} + >>> dataset.apply_vocab(vocab) + >>> batches = dataset.batch(2) + >>> batch = next(batches) + >>> arr = batch.to_array() + >>> arr.keys() + dict_keys(['ws', 'i', 'label']) + >>> arr['ws'] + array([[2, 1, 0], + [2, 1, 3]]) + >>> arr['i'] + array([10, 20]) + >>> arr['label'] + array([1, 1]) + +Note that ``to_array`` returns a ``Mapping[FieldName, np.ndarray]`` object, and sequential +fields are automatically padded. One of the nice things is that the field can be deeply +nested and the padding just works! + +.. code-block:: python + + >>> from pprint import pprint + >>> from text2array import Dataset, Vocab + >>> samples = [ + ... {'ws': ['john', 'talks'], 'cs': [list('john'), list('talks')]}, + ... {'ws': ['john', 'loves', 'mary'], 'cs': [list('john'), list('loves'), list('mary')]}, + ... {'ws': ['mary'], 'cs': [list('mary')]} + ... ] + >>> dataset = Dataset(samples) + >>> vocab = Vocab.from_samples(dataset) + >>> dataset.apply_vocab(vocab) + >>> dict(vocab['ws']) + {'': 0, '': 1, 'john': 2, 'mary': 3} + >>> pprint(dict(vocab['cs'])) + {'': 0, + '': 1, + 'a': 3, + 'h': 5, + 'j': 4, + 'l': 7, + 'm': 9, + 'n': 6, + 'o': 2, + 'r': 10, + 's': 8, + 'y': 11} + >>> batches = dat.batch(2) + >>> batch = next(batches) + >>> arr = batch.to_array() + >>> arr['ws'] + array([[2, 1, 0], + [2, 1, 3]]) + >>> arr['cs'] + array([[[ 4, 2, 5, 6, 0], + [ 1, 3, 7, 1, 8], + [ 0, 0, 0, 0, 0]], + + [[ 4, 2, 5, 6, 0], + [ 7, 2, 1, 1, 8], + [ 9, 3, 10, 11, 0]]]) + +So, you can go crazy and have a field representing a document hierarchically as paragraphs, +sentences, words, and characters, and it will be padded correctly. + +Contributing +------------ + +Pull requests are welcome! To start contributing, make sure to install all the dependencies. + +:: + + $ pip install -r requirements.txt + +Next, setup the pre-commit hook. + +:: + + $ ln -s ../../pre-commit.sh .git/hooks/pre-commit + +Tests and the linter can be run with ``pytest`` and ``flake8`` respectively. The latter also +runs ``mypy`` for type checking. + +License +------- + +MIT diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..43e269b --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --cov text2array --cov-report term-missing --cov-report html diff --git a/requirements.txt b/requirements.txt index f0fecb0..48581a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,20 @@ +atomicwrites==1.2.1 attrs==18.2.0 certifi==2018.11.29 +coverage==4.5.2 flake8==3.6.0 flake8-mypy==17.8.0 mccabe==0.6.1 +more-itertools==5.0.0 mypy==0.660 mypy-extensions==0.4.1 +numpy==1.16.0 +pluggy==0.8.1 +py==1.7.0 pycodestyle==2.4.0 pyflakes==2.0.0 +pytest==4.1.1 +pytest-cov==2.6.1 +six==1.12.0 typed-ast==1.2.0 yapf==0.25.0 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..106dc94 --- /dev/null +++ b/setup.py @@ -0,0 +1,27 @@ +from pathlib import Path +from setuptools import setup, find_packages + +readme = Path(__file__).resolve().parent / 'README.rst' + +setup( + name='text2array', + version='0.0.1', + description='Convert your NLP text data to arrays!', + long_description=readme.read_text(), + url='https://github.com/kmkurn/text2array', + author='Kemal Kurniawan', + author_email='kemal@kkurniawan.com', + license='MIT', + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.6', + ], + packages=find_packages(), + install_requires=[ + 'numpy ~=1.16.0', + ], + python_requires='>=3.6, <4', +) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9465082 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,42 @@ +import pytest + +from text2array import Dataset, StreamDataset + + +@pytest.fixture +def setup_rng(): + import random + random.seed(42) + + +@pytest.fixture +def samples(): + return [{'i': i, 'f': (i + 1) / 3} for i in range(5)] + + +@pytest.fixture +def dataset(samples): + return Dataset(samples) + + +@pytest.fixture +def stream(samples): + return Stream(samples) + + +@pytest.fixture +def stream_dataset(stream): + return StreamDataset(stream) + + +@pytest.fixture +def stream_cls(): + return Stream + + +class Stream: + def __init__(self, samples): + self.samples = samples + + def __iter__(self): + yield from self.samples diff --git a/tests/test_batch.py b/tests/test_batch.py new file mode 100644 index 0000000..f7e333c --- /dev/null +++ b/tests/test_batch.py @@ -0,0 +1,232 @@ +from typing import Mapping, Sequence + +import numpy as np +import pytest + +from text2array import Batch + + +def test_init(samples): + b = Batch(samples) + assert isinstance(b, Sequence) + assert len(b) == len(samples) + for i in range(len(b)): + assert b[i] == samples[i] + + +class TestToArray: + def test_ok(self): + ss = [{ + 'i': 4, + 'f': 0.67 + }, { + 'i': 2, + 'f': 0.89 + }, { + 'i': 3, + 'f': 0.23 + }, { + 'i': 5, + 'f': 0.11 + }, { + 'i': 3, + 'f': 0.22 + }] + b = Batch(ss) + arr = b.to_array() + assert isinstance(arr, Mapping) + assert len(arr) == 2 + assert set(arr) == set(['i', 'f']) + + assert isinstance(arr['i'], np.ndarray) + assert arr['i'].shape == (len(b), ) + assert arr['i'].tolist() == [s['i'] for s in b] + + assert isinstance(arr['f'], np.ndarray) + assert arr['f'].shape == (len(b), ) + assert arr['f'].tolist() == [pytest.approx(s['f']) for s in b] + + def test_empty(self): + b = Batch([]) + assert not b.to_array() + + def test_seq(self): + ss = [{'is': [1, 2]}, {'is': [1]}, {'is': [1, 2, 3]}, {'is': [1, 2]}] + b = Batch(ss) + arr = b.to_array() + + assert isinstance(arr['is'], np.ndarray) + assert arr['is'].tolist() == [[1, 2, 0], [1, 0, 0], [1, 2, 3], [1, 2, 0]] + + def test_seq_of_seq(self): + ss = [ + { + 'iss': [ + [1], + ] + }, + { + 'iss': [ + [1], + [1, 2], + ] + }, + { + 'iss': [ + [1], + [1, 2, 3], + [1, 2], + ] + }, + { + 'iss': [ + [1], + [1, 2], + [1, 2, 3], + [1], + ] + }, + { + 'iss': [ + [1], + [1, 2], + [1, 2, 3], + ] + }, + ] + b = Batch(ss) + arr = b.to_array() + + assert isinstance(arr['iss'], np.ndarray) + assert arr['iss'].shape == (5, 4, 3) + assert arr['iss'][0].tolist() == [ + [1, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + ] + assert arr['iss'][1].tolist() == [ + [1, 0, 0], + [1, 2, 0], + [0, 0, 0], + [0, 0, 0], + ] + assert arr['iss'][2].tolist() == [ + [1, 0, 0], + [1, 2, 3], + [1, 2, 0], + [0, 0, 0], + ] + assert arr['iss'][3].tolist() == [ + [1, 0, 0], + [1, 2, 0], + [1, 2, 3], + [1, 0, 0], + ] + assert arr['iss'][4].tolist() == [ + [1, 0, 0], + [1, 2, 0], + [1, 2, 3], + [0, 0, 0], + ] + + def test_seq_of_seq_of_seq(self): + ss = [ + { + 'isss': [ + [[1], [1, 2]], + [[1], [1, 2], [1, 2]], + ] + }, + { + 'isss': [ + [[1, 2], [1]], + ] + }, + { + 'isss': [ + [[1, 2], [1, 2]], + [[1], [1, 2], [1]], + [[1, 2], [1]], + ] + }, + { + 'isss': [ + [[1]], + [[1], [1, 2]], + [[1], [1, 2]], + ] + }, + { + 'isss': [ + [[1]], + [[1], [1, 2]], + [[1], [1, 2], [1, 2]], + [[1], [1, 2], [1, 2]], + ] + }, + ] + b = Batch(ss) + arr = b.to_array() + + assert isinstance(arr['isss'], np.ndarray) + assert arr['isss'].shape == (5, 4, 3, 2) + assert arr['isss'][0].tolist() == [ + [[1, 0], [1, 2], [0, 0]], + [[1, 0], [1, 2], [1, 2]], + [[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + ] + assert arr['isss'][1].tolist() == [ + [[1, 2], [1, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + ] + assert arr['isss'][2].tolist() == [ + [[1, 2], [1, 2], [0, 0]], + [[1, 0], [1, 2], [1, 0]], + [[1, 2], [1, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + ] + assert arr['isss'][3].tolist() == [ + [[1, 0], [0, 0], [0, 0]], + [[1, 0], [1, 2], [0, 0]], + [[1, 0], [1, 2], [0, 0]], + [[0, 0], [0, 0], [0, 0]], + ] + assert arr['isss'][4].tolist() == [ + [[1, 0], [0, 0], [0, 0]], + [[1, 0], [1, 2], [0, 0]], + [[1, 0], [1, 2], [1, 2]], + [[1, 0], [1, 2], [1, 2]], + ] + + def test_custom_padding(self): + ss = [{'is': [1]}, {'is': [1, 2]}] + b = Batch(ss) + arr = b.to_array(pad_with=9) + assert arr['is'].tolist() == [[1, 9], [1, 2]] + + ss = [{'iss': [[1, 2], [1]]}, {'iss': [[1]]}] + b = Batch(ss) + arr = b.to_array(pad_with=9) + assert arr['iss'].tolist() == [[[1, 2], [1, 9]], [[1, 9], [9, 9]]] + + def test_missing_field(self): + b = Batch([{'a': 10}, {'b': 20}]) + with pytest.raises(KeyError) as exc: + b.to_array() + assert "some samples have no field 'a'" in str(exc.value) + + def test_inconsistent_depth(self): + b = Batch([{'ws': [1, 2]}, {'ws': [[1, 2], [3, 4]]}]) + with pytest.raises(ValueError) as exc: + b.to_array() + assert "field 'ws' has inconsistent nesting depth" in str(exc.value) + + def test_str(self): + b = Batch([{'w': 'a'}, {'w': 'b'}]) + arr = b.to_array() + assert isinstance(arr['w'], np.ndarray) + assert arr['w'].tolist() == ['a', 'b'] diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..c3ecb62 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,197 @@ +from typing import Iterator, Sequence + +import pytest + +from text2array import Batch, Dataset, Vocab + + +def test_init(samples): + dat = Dataset(samples) + assert isinstance(dat, Sequence) + assert len(dat) == len(samples) + for i in range(len(dat)): + assert dat[i] == samples[i] + + +def test_init_samples_non_sequence(): + with pytest.raises(TypeError) as exc: + Dataset(10) + assert '"samples" is not a sequence' in str(exc.value) + + +class TestShuffle: + def test_mutable_seq(self, setup_rng, dataset): + before = list(dataset) + retval = dataset.shuffle() + after = list(dataset) + assert retval is dataset + assert_shuffled(before, after) + + def test_immutable_seq(self, setup_rng, samples): + dat = Dataset(tuple(samples)) + before = list(dat) + retval = dat.shuffle() + after = list(dat) + assert retval is dat + assert_shuffled(before, after) + + +class TestShuffleBy: + dataset = Dataset([{'i': 3}, {'i': 1}, {'i': 2}, {'i': 5}, {'i': 4}]) + + @staticmethod + def key(sample): + return sample['i'] + + def test_ok(self, setup_rng): + dat = self.dataset + before = list(dat) + retval = dat.shuffle_by(self.key) + after = list(dat) + assert retval is dat + assert_shuffled(before, after) + + def test_zero_scale(self, setup_rng): + dat = self.dataset + before = list(dat) + dat.shuffle_by(self.key, scale=0.) + after = list(dat) + assert sorted(before, key=self.key) == after + + def test_negative_scale(self, setup_rng): + with pytest.raises(ValueError) as exc: + self.dataset.shuffle_by(self.key, scale=-1) + assert 'scale cannot be less than 0' in str(exc.value) + + +def assert_shuffled(before, after): + assert before != after and len(before) == len(after) and all(x in after for x in before) + + +def test_batch(): + dat = Dataset([{'i': 3}, {'i': 1}, {'i': 2}, {'i': 5}, {'i': 4}]) + bs = dat.batch(2) + assert isinstance(bs, Iterator) + bs_lst = list(bs) + assert len(bs_lst) == 3 + assert all(isinstance(b, Batch) for b in bs_lst) + assert list(bs_lst[0]) == [dat[0], dat[1]] + assert list(bs_lst[1]) == [dat[2], dat[3]] + assert list(bs_lst[2]) == [dat[4]] + + +def test_batch_size_evenly_divides(dataset): + bs = dataset.batch(1) + bs_lst = list(bs) + assert len(bs_lst) == len(dataset) + for i in range(len(bs_lst)): + assert list(bs_lst[i]) == [dataset[i]] + + +def test_batch_exactly(): + dat = Dataset([{'i': 3}, {'i': 1}, {'i': 2}, {'i': 5}, {'i': 4}]) + bs = dat.batch_exactly(2) + assert isinstance(bs, Iterator) + bs_lst = list(bs) + assert len(bs_lst) == 2 + assert all(isinstance(b, Batch) for b in bs_lst) + assert list(bs_lst[0]) == [dat[0], dat[1]] + assert list(bs_lst[1]) == [dat[2], dat[3]] + + +def test_batch_nonpositive_batch_size(dataset): + with pytest.raises(ValueError) as exc: + next(dataset.batch(0)) + assert 'batch size must be greater than 0' in str(exc.value) + + with pytest.raises(ValueError) as exc: + next(dataset.batch_exactly(0)) + assert 'batch size must be greater than 0' in str(exc.value) + + +class TestApplyVocab: + def test_ok(self): + dat = Dataset([{ + 'w': 'a', + 'ws': ['a', 'b'], + 'cs': [['a', 'b'], ['b', 'a']], + 'i': 10, + 'j': 20 + }, { + 'w': 'b', + 'ws': ['a', 'a'], + 'cs': [['b', 'b'], ['b', 'a']], + 'i': 10, + 'j': 20 + }]) + vocab = { + 'w': { + 'a': 0, + 'b': 1 + }, + 'ws': { + 'a': 2, + 'b': 3 + }, + 'cs': { + 'a': 4, + 'b': 5 + }, + 'j': { + 20: 2 + } + } + dat.apply_vocab(vocab) + assert list(dat) == [{ + 'w': 0, + 'ws': [2, 3], + 'cs': [[4, 5], [5, 4]], + 'i': 10, + 'j': 2 + }, { + 'w': 1, + 'ws': [2, 2], + 'cs': [[5, 5], [5, 4]], + 'i': 10, + 'j': 2 + }] + + def test_value_not_in_vocab(self): + dat = Dataset([{'w': 'a'}]) + vocab = {'w': {'b': 0}} + with pytest.raises(KeyError) as exc: + dat.apply_vocab(vocab) + assert "value 'a' not found in vocab" in str(exc.value) + + dat = Dataset([{'w': 10}]) + vocab = {'w': {11: 0}} + with pytest.raises(KeyError) as exc: + dat.apply_vocab(vocab) + assert "value 10 not found in vocab" in str(exc.value) + + def test_with_vocab_object(self): + dat = Dataset([{ + 'ws': ['a', 'b'], + 'cs': [['a', 'c'], ['c', 'b', 'c']] + }, { + 'ws': ['b'], + 'cs': [['b']] + }]) + v = Vocab.from_samples(dat) + dat.apply_vocab(v) + assert list(dat) == [{ + 'ws': [v['ws']['a'], v['ws']['b']], + 'cs': [[v['cs']['a'], v['cs']['c']], [v['cs']['c'], v['cs']['b'], v['cs']['c']]] + }, { + 'ws': [v['ws']['b']], + 'cs': [[v['cs']['b']]] + }] + + def test_immutable_seq(self, samples): + ss = samples + lstdat = Dataset(ss) + tpldat = Dataset(tuple(ss)) + v = Vocab.from_samples(ss) + lstdat.apply_vocab(v) + tpldat.apply_vocab(v) + assert list(lstdat) == list(tpldat) diff --git a/tests/test_stream_dataset.py b/tests/test_stream_dataset.py new file mode 100644 index 0000000..5da076b --- /dev/null +++ b/tests/test_stream_dataset.py @@ -0,0 +1,151 @@ +from typing import Iterable, Iterator + +import pytest + +from text2array import Batch, StreamDataset, Vocab + + +def test_init(stream): + dat = StreamDataset(stream) + assert isinstance(dat, Iterable) + assert list(dat) == list(stream) + + +def test_init_stream_non_iterable(): + with pytest.raises(TypeError) as exc: + StreamDataset(5) + assert '"stream" is not iterable' in str(exc.value) + + +def test_can_be_iterated_twice(stream_dataset): + dat_lst1 = list(stream_dataset) + dat_lst2 = list(stream_dataset) + assert len(dat_lst1) == len(dat_lst2) + assert len(dat_lst2) > 0 + + +def test_batch(stream_cls): + stream_dat = StreamDataset(stream_cls([{'i': 3}, {'i': 1}, {'i': 2}, {'i': 5}, {'i': 4}])) + bs = stream_dat.batch(2) + assert isinstance(bs, Iterator) + bs_lst = list(bs) + assert len(bs_lst) == 3 + assert all(isinstance(b, Batch) for b in bs_lst) + dat = list(stream_dat) + assert list(bs_lst[0]) == [dat[0], dat[1]] + assert list(bs_lst[1]) == [dat[2], dat[3]] + assert list(bs_lst[2]) == [dat[4]] + + +def test_batch_size_evenly_divides(stream_dataset): + bs = stream_dataset.batch(1) + dat = list(stream_dataset) + bs_lst = list(bs) + assert len(bs_lst) == len(dat) + for i in range(len(bs_lst)): + assert list(bs_lst[i]) == [dat[i]] + + +def test_batch_exactly(stream_cls): + stream_dat = StreamDataset(stream_cls([{'i': 3}, {'i': 1}, {'i': 2}, {'i': 5}, {'i': 4}])) + bs = stream_dat.batch_exactly(2) + assert isinstance(bs, Iterator) + bs_lst = list(bs) + assert len(bs_lst) == 2 + assert all(isinstance(b, Batch) for b in bs_lst) + dat = list(stream_dat) + assert list(bs_lst[0]) == [dat[0], dat[1]] + assert list(bs_lst[1]) == [dat[2], dat[3]] + + +def test_batch_nonpositive_batch_size(stream_dataset): + with pytest.raises(ValueError) as exc: + next(stream_dataset.batch(0)) + assert 'batch size must be greater than 0' in str(exc.value) + + with pytest.raises(ValueError) as exc: + next(stream_dataset.batch_exactly(0)) + assert 'batch size must be greater than 0' in str(exc.value) + + +class TestApplyVocab: + def test_ok(self, stream_cls): + dat = StreamDataset( + stream_cls([{ + 'w': 'a', + 'ws': ['a', 'b'], + 'cs': [['a', 'b'], ['b', 'a']], + 'i': 10, + 'j': 20 + }, { + 'w': 'b', + 'ws': ['a', 'a'], + 'cs': [['b', 'b'], ['b', 'a']], + 'i': 10, + 'j': 20 + }])) + vocab = { + 'w': { + 'a': 0, + 'b': 1 + }, + 'ws': { + 'a': 2, + 'b': 3 + }, + 'cs': { + 'a': 4, + 'b': 5 + }, + 'j': { + 20: 2 + } + } + dat.apply_vocab(vocab) + assert list(dat) == [{ + 'w': 0, + 'ws': [2, 3], + 'cs': [[4, 5], [5, 4]], + 'i': 10, + 'j': 2 + }, { + 'w': 1, + 'ws': [2, 2], + 'cs': [[5, 5], [5, 4]], + 'i': 10, + 'j': 2 + }] + + def test_key_error(self, stream_cls): + dat = StreamDataset(stream_cls([{'w': 'a'}])) + vocab = {'w': {'b': 0}} + dat.apply_vocab(vocab) + with pytest.raises(KeyError) as exc: + list(dat) + assert "value 'a' not found in vocab" in str(exc.value) + + dat = StreamDataset(stream_cls([{'w': 10}])) + vocab = {'w': {11: 0}} + dat.apply_vocab(vocab) + with pytest.raises(KeyError) as exc: + list(dat) + assert "value 10 not found in vocab" in str(exc.value) + + def test_with_vocab_object(self, stream_cls): + dat = StreamDataset( + stream_cls([{ + 'ws': ['a', 'b'], + 'cs': [['a', 'c'], ['c', 'b', 'c']] + }, { + 'ws': ['b'], + 'cs': [['b']] + }])) + v = Vocab.from_samples(dat) + dat.apply_vocab(v) + assert list(dat) == [{ + 'ws': [v['ws']['a'], v['ws']['b']], + 'cs': [[v['cs']['a'], v['cs']['c']], [v['cs']['c'], v['cs']['b'], v['cs']['c']]] + }, { + 'ws': [v['ws']['b']], + 'cs': [[v['cs']['b']]] + }] diff --git a/tests/test_vocab.py b/tests/test_vocab.py new file mode 100644 index 0000000..71310bb --- /dev/null +++ b/tests/test_vocab.py @@ -0,0 +1,144 @@ +from typing import Mapping + +import pytest + +from text2array import Vocab + + +class TestFromSamples(): + def test_ok(self): + ss = [{'w': 'c'}, {'w': 'b'}, {'w': 'a'}, {'w': 'b'}, {'w': 'c'}, {'w': 'c'}] + vocab = Vocab.from_samples(ss) + + assert isinstance(vocab, Mapping) + assert len(vocab) == 1 + assert list(vocab) == ['w'] + with pytest.raises(KeyError): + vocab['ws'] + + itos = ' c b'.split() + assert isinstance(vocab['w'], Mapping) + assert len(vocab['w']) == len(itos) + assert list(vocab['w']) == itos + for i, w in enumerate(itos): + assert w in vocab['w'] + assert vocab['w'][w] == i + + assert 'foo' not in vocab['w'] + assert vocab['w']['foo'] == vocab['w'][''] + assert 'bar' not in vocab['w'] + assert vocab['w']['bar'] == vocab['w'][''] + + def test_has_vocab_for_all_str_fields(self): + ss = [{'w': 'b', 't': 'b'}, {'w': 'b', 't': 'b'}] + vocab = Vocab.from_samples(ss) + assert vocab.get('w') is not None + assert vocab.get('t') is not None + + def test_no_vocab_for_non_str(self): + vocab = Vocab.from_samples([{'i': 10}, {'i': 20}]) + with pytest.raises(KeyError) as exc: + vocab['i'] + assert "no vocabulary found for field name 'i'" in str(exc.value) + + def test_seq(self): + ss = [{'ws': ['a', 'c', 'c']}, {'ws': ['b', 'c']}, {'ws': ['b']}] + vocab = Vocab.from_samples(ss) + assert list(vocab['ws']) == ' c b'.split() + + def test_seq_of_seq(self): + ss = [{ + 'cs': [['c', 'd'], ['a', 'd']] + }, { + 'cs': [['c'], ['b'], ['b', 'd']] + }, { + 'cs': [['d', 'c']] + }] + vocab = Vocab.from_samples(ss) + assert list(vocab['cs']) == ' d c b'.split() + + def test_empty_samples(self): + vocab = Vocab.from_samples([]) + assert not vocab + + def test_empty_field_values(self): + with pytest.raises(ValueError) as exc: + Vocab.from_samples([{'w': []}]) + assert 'field values must not be an empty sequence' in str(exc.value) + + def test_min_count(self): + ss = [{ + 'w': 'c', + 't': 'c' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'a', + 't': 'a' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'c', + 't': 'c' + }, { + 'w': 'c', + 't': 'c' + }] + vocab = Vocab.from_samples(ss, options={'w': dict(min_count=3)}) + assert 'b' not in vocab['w'] + assert 'b' in vocab['t'] + + def test_no_unk(self): + vocab = Vocab.from_samples([{'w': 'a', 't': 'a'}], options={'w': dict(unk=None)}) + assert '' not in vocab['w'] + assert '' in vocab['t'] + with pytest.raises(KeyError) as exc: + vocab['w']['foo'] + assert "'foo' not found in vocabulary" in str(exc.value) + + def test_no_pad(self): + vocab = Vocab.from_samples([{'w': ['a'], 't': ['a']}], options={'w': dict(pad=None)}) + assert '' not in vocab['w'] + assert '' in vocab['t'] + + def test_max_size(self): + ss = [{ + 'w': 'c', + 't': 'c' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'a', + 't': 'a' + }, { + 'w': 'b', + 't': 'b' + }, { + 'w': 'c', + 't': 'c' + }, { + 'w': 'c', + 't': 'c' + }] + vocab = Vocab.from_samples(ss, options={'w': dict(max_size=1)}) + assert 'b' not in vocab['w'] + assert 'b' in vocab['t'] + + def test_iterator_is_passed(self): + ss = [{ + 'ws': ['b', 'c'], + 'w': 'c' + }, { + 'ws': ['c', 'b'], + 'w': 'c' + }, { + 'ws': ['c'], + 'w': 'c' + }] + vocab = Vocab.from_samples(iter(ss)) + assert 'b' in vocab['ws'] + assert 'c' in vocab['ws'] + assert 'c' in vocab['w'] diff --git a/text2array/__init__.py b/text2array/__init__.py new file mode 100644 index 0000000..0fefa6d --- /dev/null +++ b/text2array/__init__.py @@ -0,0 +1,12 @@ +from .batches import Batch +from .datasets import Dataset, StreamDataset +from .samples import Sample +from .vocab import Vocab + +__all__ = [ + Sample, + Batch, + Dataset, + StreamDataset, + Vocab, +] diff --git a/text2array/batches.py b/text2array/batches.py new file mode 100644 index 0000000..7eba416 --- /dev/null +++ b/text2array/batches.py @@ -0,0 +1,112 @@ +from functools import reduce +from typing import List, Mapping, Sequence, Union + +import numpy as np + +from .samples import FieldName, FieldValue, Sample + + +class Batch(Sequence[Sample]): + """A class to represent a single batch. + + Args: + samples: Sequence of samples this batch should contain. + """ + + def __init__(self, samples: Sequence[Sample]) -> None: + self._samples = samples + + def __getitem__(self, index) -> Sample: + return self._samples[index] + + def __len__(self) -> int: + return len(self._samples) + + def to_array(self, pad_with: int = 0) -> Mapping[FieldName, np.ndarray]: + """Convert the batch into :class:`np.ndarray`. + + Args: + pad_with: Pad sequential field values with this number. + + Returns: + A mapping from field names to :class:`np.ndarray`s whose first + dimension corresponds to the batch size as returned by ``__len__``. + """ + if not self._samples: + return {} + + arr = {} + for name in self._samples[0].keys(): + values = self._get_values(name) + + # Get max length for all depths, 1st elem is batch size + try: + maxlens = self._get_maxlens(values) + except self._InconsistentDepthError: + raise ValueError(f"field '{name}' has inconsistent nesting depth") + + # Get padding for all depths + paddings = self._get_paddings(maxlens, pad_with) + # Pad the values + data = self._pad(values, maxlens, paddings, 0) + + arr[name] = np.array(data) + + return arr + + def _get_values(self, name: str) -> Sequence[FieldValue]: + try: + return [s[name] for s in self._samples] + except KeyError: + raise KeyError(f"some samples have no field '{name}'") + + @classmethod + def _get_maxlens(cls, values: Sequence[FieldValue]) -> List[int]: + assert values + + # Base case + if isinstance(values[0], str) or not isinstance(values[0], Sequence): + return [len(values)] + + # Recursive case + maxlenss = [cls._get_maxlens(x) for x in values] + if not all(len(x) == len(maxlenss[0]) for x in maxlenss): + raise cls._InconsistentDepthError + + maxlens = reduce(lambda ml1, ml2: [max(l1, l2) for l1, l2 in zip(ml1, ml2)], maxlenss) + maxlens.insert(0, len(values)) + return maxlens + + @classmethod + def _get_paddings(cls, maxlens: List[int], with_: int) -> List[Union[int, List[int]]]: + res: list = [with_] + for maxlen in reversed(maxlens[1:]): + res.append([res[-1] for _ in range(maxlen)]) + res.reverse() + return res + + @classmethod + def _pad( + cls, + values: Sequence[FieldValue], + maxlens: List[int], + paddings: List[Union[int, List[int]]], + depth: int, + ) -> Sequence[FieldValue]: + assert values + assert len(maxlens) == len(paddings) + assert depth < len(maxlens) + + # Base case + if isinstance(values[0], str) or not isinstance(values[0], Sequence): + values_ = list(values) + # Recursive case + else: + values_ = [cls._pad(x, maxlens, paddings, depth + 1) for x in values] + + for _ in range(maxlens[depth] - len(values)): + values_.append(paddings[depth]) + return values_ + + class _InconsistentDepthError(Exception): + pass diff --git a/text2array/datasets.py b/text2array/datasets.py new file mode 100644 index 0000000..c47f89c --- /dev/null +++ b/text2array/datasets.py @@ -0,0 +1,236 @@ +from typing import Callable, Iterable, Iterator, Mapping, MutableSequence, Sequence +import abc +import random +import statistics as stat + +from .batches import Batch +from .samples import FieldName, FieldValue, Sample + + +class DatasetABC(Iterable[Sample], metaclass=abc.ABCMeta): + @abc.abstractmethod + def batch(self, batch_size: int) -> Iterator[Batch]: + pass + + def batch_exactly(self, batch_size: int) -> Iterator[Batch]: + """Group the samples in the dataset into batches of exact size. + + If the number of samples is not divisible by ``batch_size``, the last + batch (whose length is less than ``batch_size``) is dropped. + + Args: + batch_size: Number of samples in each batch. + + Returns: + The iterator of batches. + """ + return (b for b in self.batch(batch_size) if len(b) == batch_size) + + @abc.abstractmethod + def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]]) -> None: + pass + + @classmethod + def _apply_vocab_to_sample( + cls, + vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]], + sample: Sample, + ) -> Sample: + s = {} + for name, val in sample.items(): + try: + vb = vocab[name] + except KeyError: + s[name] = val + else: + s[name] = cls._apply_vb_to_val(vb, val) + return s + + @classmethod + def _apply_vb_to_val( + cls, + vb: Mapping[FieldValue, FieldValue], + val: FieldValue, + ) -> FieldValue: + if isinstance(val, str) or not isinstance(val, Sequence): + try: + return vb[val] + except KeyError: + raise KeyError(f'value {val!r} not found in vocab') + + return [cls._apply_vb_to_val(vb, v) for v in val] + + +class Dataset(DatasetABC, Sequence[Sample]): + """Dataset that fits all in memory (no streaming). + + Args: + samples: Sequence of samples the dataset should contain. This sequence should + support indexing by a positive/negative index of type :obj:`int` or a + :obj:`slice` object. + """ + + def __init__(self, samples: Sequence[Sample]) -> None: + if not isinstance(samples, Sequence): + raise TypeError('"samples" is not a sequence') + + self._samples = samples + + def __getitem__(self, index) -> Sample: + return self._samples[index] + + def __len__(self) -> int: + return len(self._samples) + + def shuffle(self) -> 'Dataset': + """Shuffle the dataset. + + This method shuffles in-place if ``samples`` is a mutable sequence. + Otherwise, a copy is made and then shuffled. This copy is a mutable + sequence, so subsequent shuffling will be done in-place. + + Returns: + This dataset object (useful for chaining). + """ + if not isinstance(self._samples, MutableSequence): + self._samples = list(self._samples) + self._shuffle_inplace() + return self + + def shuffle_by(self, key: Callable[[Sample], int], scale: float = 1.) -> 'Dataset': + """Shuffle the dataset by the given key. + + This method essentially performs noisy sorting. The samples in the dataset are + sorted ascending by the value of the given key, plus some random noise. This random + noise is drawn from ``Uniform(-z, z)``, where ``z`` equals ``scale`` times the + standard deviation of key values. This formulation means that ``scale`` regulates + how noisy the sorting is. The larger it is, the more noisy the sorting becomes, i.e. + it resembles random shuffling more closely. In an extreme case where ``scale=0``, + this method just sorts the samples by ``key``. This method is useful when working + with text data, where we want to shuffle the dataset and also minimize padding by + ensuring that sentences of similar lengths are not too far apart. + + Args: + by: Callable to get the key value of a given sample. + scale: Value to regulate the noise of the sorting. Must not be negative. + + Returns: + This dataset object (useful for chaining). + """ + if scale < 0: + raise ValueError('scale cannot be less than 0') + + std = stat.stdev(key(s) for s in self._samples) + z = scale * std + self._samples = sorted(self._samples, key=lambda s: key(s) + random.uniform(-z, z)) + return self + + def batch(self, batch_size: int) -> Iterator[Batch]: + """Group the samples in the dataset into batches. + + Args: + batch_size: Maximum number of samples in each batch. + + Returns: + The iterator of batches. + """ + if batch_size <= 0: + raise ValueError('batch size must be greater than 0') + + for begin in range(0, len(self._samples), batch_size): + end = begin + batch_size + yield Batch(self._samples[begin:end]) + + def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]]) -> None: + """Apply a vocabulary to this dataset. + + Applying a vocabulary means mapping all the (nested) field values to the corresponding + values according to the mapping specified by the vocabulary. Field names that have + no entry in the vocabulary are ignored. This method applies the vocabulary in-place + when the dataset holds a mutable sequence of samples. Otherwise, a mutable copy of + samples is made and the vocabulary is applied on it. + + Args: + vocab: The vocabulary to apply. + """ + if not isinstance(self._samples, MutableSequence): + self._samples = list(self._samples) + self._apply_vocab_inplace(vocab) + + def _shuffle_inplace(self) -> None: + assert isinstance(self._samples, MutableSequence) + n = len(self._samples) + for i in range(n): + j = random.randrange(n) + temp = self._samples[i] + self._samples[i] = self._samples[j] + self._samples[j] = temp + + def _apply_vocab_inplace( + self, + vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]], + ) -> None: + assert isinstance(self._samples, MutableSequence) + for i in range(len(self._samples)): + self._samples[i] = self._apply_vocab_to_sample(vocab, self._samples[i]) + + +class StreamDataset(DatasetABC): + """Dataset that streams its samples. + + Args: + stream: Stream of samples the dataset should stream from. + """ + + def __init__(self, stream: Iterable[Sample]) -> None: + if not isinstance(stream, Iterable): + raise TypeError('"stream" is not iterable') + + self._stream = stream + + def __iter__(self) -> Iterator[Sample]: + try: + vocab = self._vocab + except AttributeError: + yield from self._stream + return + + for s in self._stream: + yield self._apply_vocab_to_sample(vocab, s) + + def batch(self, batch_size: int) -> Iterator[Batch]: + """Group the samples in the dataset into batches. + + Args: + batch_size: Maximum number of samples in each batch. + + Returns: + The iterator of batches. + """ + if batch_size <= 0: + raise ValueError('batch size must be greater than 0') + + it, exhausted = iter(self._stream), False + while not exhausted: + batch: list = [] + while not exhausted and len(batch) < batch_size: + try: + batch.append(next(it)) + except StopIteration: + exhausted = True + if batch: + yield Batch(batch) + + def apply_vocab(self, vocab: Mapping[FieldName, Mapping[FieldValue, FieldValue]]) -> None: + """Apply a vocabulary to this dataset. + + Applying a vocabulary means mapping all the (nested) field values to the corresponding + values according to the mapping specified by the vocabulary. Field names that have + no entry in the vocabulary are ignored. Note that since the dataset holds a stream of + samples, the actual application is delayed until the dataset is iterated. Therefore, + ``vocab`` must still exist when that happens. + + Args: + vocab: The vocabulary to apply. + """ + self._vocab = vocab diff --git a/text2array/samples.py b/text2array/samples.py new file mode 100644 index 0000000..7e6f010 --- /dev/null +++ b/text2array/samples.py @@ -0,0 +1,7 @@ +from typing import Mapping, Sequence, Union + +# TODO remove these "type ignore" once mypy supports recursive types +# see: https://github.com/python/mypy/issues/731 +FieldName = str +FieldValue = Union[float, int, str, Sequence['FieldValue']] # type: ignore +Sample = Mapping[FieldName, FieldValue] # type: ignore diff --git a/text2array/vocab.py b/text2array/vocab.py new file mode 100644 index 0000000..755f3a4 --- /dev/null +++ b/text2array/vocab.py @@ -0,0 +1,154 @@ +from collections import Counter, OrderedDict, defaultdict +from typing import Counter as CounterT, Dict, Iterable, Iterator, Mapping, \ + Optional, Sequence, Set + +from .samples import FieldName, FieldValue, Sample + + +class Vocab(Mapping[FieldName, Mapping[str, int]]): + """Namespaced vocabulary storing the mapping from field names to their actual vocabulary. + + A vocabulary does not hold the str-to-int mapping directly, but rather it stores a mapping + from field names to the corresponding str-to-int mappings. These mappings are the actual + vocabulary for that particular field name. In other words, the actual vocabulary for each + field name is namespaced by the field name and all of them are handled this :class:`Vocab` + object. + + Args: + m: Mapping from :obj:`FieldName` to its str-to-int mapping. + """ + + def __init__(self, m: Mapping[FieldName, Mapping[str, int]]) -> None: + self._m = m + + def __len__(self) -> int: + return len(self._m) + + def __iter__(self) -> Iterator[FieldName]: + return iter(self._m) + + def __getitem__(self, name: FieldName) -> Mapping[str, int]: + try: + return self._m[name] + except KeyError: + raise KeyError(f"no vocabulary found for field name '{name}'") + + @classmethod + def from_samples( + cls, + samples: Iterable[Sample], + options: Optional[Mapping[FieldName, dict]] = None, + ) -> 'Vocab': + """Make an instance of this class from an iterable of samples. + + A vocabulary is only made for fields whose value is a string token or a (nested) + sequence of string tokens. It is important that ``samples`` be a true iterable, i.e. + it can be iterated more than once. No exception is raised when this is violated. + + Args: + samples: Iterable of samples. + options: Mapping from field names to dictionaries to control the creation of + the str-to-int mapping. Recognized dictionary keys are: + + * ``min_count``(:obj:`int`): Exclude tokens occurring fewer than this number + of times from the vocabulary (default: 2). + * ``pad``(:obj:`str`): String token to represent padding tokens. If ``None``, + no padding token is added to the vocabulary. Otherwise, it is the + first entry in the vocabulary (index is 0). Note that if the field has no + sequential values, no padding is added. String field values are *not* + considered sequential (default: ````). + * ``unk``(:obj:`str`): String token to represent unknown tokens with. If + ``None``, no unknown token is added to the vocabulary. This means when + querying the vocabulary with such token, an error is raised. Otherwise, + it is the first entry in the vocabulary *after* ``pad``, if any (index is + either 0 or 1) (default: ````). + * ``max_size``(:obj:`int`): Maximum size of the vocabulary, excluding ``pad`` + and ``unk``. If ``None``, no limit on the vocabulary size. Otherwise, at + most, only this number of most frequent tokens are included in the + vocabulary. Note that ``min_count`` also sets the maximum size implicitly. + So, the size is limited by whichever is smaller. (default: ``None``). + + Returns: + Vocabulary instance. + """ + if options is None: + options = {} + + counter: Dict[FieldName, CounterT[str]] = defaultdict(Counter) + seqfield: Set[FieldName] = set() + for s in samples: + for name, value in s.items(): + if cls._needs_vocab(value): + counter[name].update(cls._flatten(value)) + if isinstance(value, Sequence) and not isinstance(value, str): + seqfield.add(name) + + m = {} + for name, c in counter.items(): + store: dict = OrderedDict() + opts = options.get(name, {}) + + # Padding and unknown tokens + pad = opts.get('pad', '') + unk = opts.get('unk', '') + if name in seqfield and pad is not None: + store[pad] = len(store) + if unk is not None: + store[unk] = len(store) + + min_count = opts.get('min_count', 2) + max_size = opts.get('max_size') + n = len(store) + for tok, freq in c.most_common(): + if freq < min_count or (max_size is not None and len(store) - n >= max_size): + break + store[tok] = len(store) + + unk_id = None if unk is None else store[unk] + m[name] = _StringStore(store, unk_id=unk_id) + + return cls(m) + + @classmethod + def _needs_vocab(cls, val: FieldValue) -> bool: + if isinstance(val, str): + return True + if isinstance(val, Sequence): + if not val: + raise ValueError('field values must not be an empty sequence') + return cls._needs_vocab(val[0]) + return False + + @classmethod + def _flatten(cls, xs) -> Iterator[str]: + if isinstance(xs, str): + yield xs + return + + # must be an iterable, due to how we use this function + for x in xs: + yield from cls._flatten(x) + + +class _StringStore(Mapping[str, int]): + def __init__(self, m: Mapping[str, int], unk_id: Optional[int] = None) -> None: + assert unk_id is None or unk_id >= 0 + self._m = m + self._unk_id = unk_id + + def __len__(self) -> int: + return len(self._m) + + def __iter__(self) -> Iterator[str]: + return iter(self._m) + + def __getitem__(self, s: str) -> int: + try: + return self._m[s] + except KeyError: + if self._unk_id is not None: + return self._unk_id + raise KeyError(f"'{s}' not found in vocabulary") + + def __contains__(self, s) -> bool: + return s in self._m