Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added sort stream #336

Merged
merged 11 commits into from
Feb 25, 2015
27 changes: 27 additions & 0 deletions blocks/datasets/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,33 @@ def _cache(self):
cache.extend(data)


class SortMapping(object):
"""Callable class for creating sorting mappings.

This class can be used to create a callable that can be used by the
:class:`DataStreamMapping` constructor.

Parameters
----------
key : callable
The mapping that returns the value to sort on. Its input will be
a tuple that contains a single data point for each source.
reverse : boolean value that indicates whether the sort order should
be reversed.

"""
def __init__(self, key, reverse=False):
self.key = key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still need these two assignments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it might be handy for debugging to see which arguments were used to construct the stream but they are not necessary. I could also make the SortMapping instance an attribute instead because it contains the same information.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A user can always access it through e.g. stream.mapping.key right? Because stream.mapping is a class instance in this case, so I'd just get rid of it, but your call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I think about it, all this class now does is apply the mapper. Perhaps it's cleaner to use the DataStreamMapping directly and have a couple of common mappings available in the module. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Less SLOCs sounds good to me. We'll just need to make sure to add the mappings to the See also section of DataStreamMapping or something so that people will find them, but that can go on the lengthy doc to-do list if you want.

self.reverse = reverse

def __call__(self, x):
values = [self.key(i) for i in zip(*x)]
indices = [i for (v, i) in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry of being picky, but sorting indices is not really Pythonic. In Python copying a reference to an object, i.e. a=b, is always a lightweight operation regardless whether the object is an int or a huge matrix. Therefore, I would prefer something like this:

def __call__(self, batch):
    with_keys = [(example, self.key(example)) for example in zip(*batch)]
    with_keys.sort(key=lambda _, key : key)
    return [example for example, _ in with_keys]

Another thing is precomputing keys. In general it is a nice thing to do, but I quickly checked and seems like sorted is smart enough to do it on its own. I could not find it in the docs explicitly, but I would trust that such a function must be optimized. That simplifies our job to the following:

def __call__(self, batch):
    return list(sorted(zip(*batch), key=self.key, reverse=self.reverse))

Did I miss anything?

sorted(((v, i) for (i, v) in enumerate(values)),
reverse=self.reverse)]
return tuple([[i[j] for j in indices] for i in x])


class BatchDataStream(DataStreamWrapper):
"""Creates minibatches from data streams providing single examples.

Expand Down
35 changes: 34 additions & 1 deletion tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from collections import OrderedDict

import numpy
import operator
import theano
from six.moves import zip
from nose.tools import assert_raises

from blocks.datasets import ContainerDataset
from blocks.datasets.streams import (
CachedDataStream, DataStream, DataStreamMapping, BatchDataStream,
PaddingDataStream, DataStreamFilter, ForceFloatX)
PaddingDataStream, DataStreamFilter, ForceFloatX, SortMapping)
from blocks.datasets.schemes import BatchSizeScheme, ConstantScheme

floatX = theano.config.floatX
Expand Down Expand Up @@ -42,6 +43,38 @@ def test_data_stream_mapping():
assert list(wrapper2.get_epoch_iterator()) == list(zip(data, data_doubled))


def test_data_stream_mapping_sort():
data = [[1, 2, 3],
[2, 3, 1],
[3, 2, 1]]
data_sorted = [[1, 2, 3]] * 3
data_sorted_rev = [[3, 2, 1]] * 3
stream = ContainerDataset(data).get_default_stream()
wrapper1 = DataStreamMapping(stream,
mapping=SortMapping(operator.itemgetter(0)))
assert list(wrapper1.get_epoch_iterator()) == list(zip(data_sorted))
wrapper2 = DataStreamMapping(stream, SortMapping(lambda x: -x[0]))
assert list(wrapper2.get_epoch_iterator()) == list(zip(data_sorted_rev))
wrapper3 = DataStreamMapping(stream, SortMapping(operator.itemgetter(0),
reverse=True))
assert list(wrapper3.get_epoch_iterator()) == list(zip(data_sorted_rev))


def test_data_stream_mapping_multisource():
data_dict = {'x': [[1, 2, 3], [2, 3, 1], [3, 2, 1]],
'y': [[6, 5, 4], [6, 5, 4], [6, 5, 4]]}
data = OrderedDict()
data['x'] = data_dict['x']
data['y'] = data_dict['y']
data_sorted = [([1, 2, 3], [6, 5, 4]),
([1, 2, 3], [4, 6, 5]),
([1, 2, 3], [4, 5, 6])]
stream = ContainerDataset(data).get_default_stream()
wrapper = DataStreamMapping(stream,
mapping=SortMapping(operator.itemgetter(0)))
assert list(wrapper.get_epoch_iterator()) == data_sorted


def test_data_stream_filter():
data = [1, 2, 3]
data_filtered = [1, 3]
Expand Down