Skip to content

Commit

Permalink
Add attributes to stimuli, add category attribute to CAT2000
Browse files Browse the repository at this point in the history
this also makes sure external datasets are not left
in half finished states if setup fails.

Signed-off-by: Matthias Kümmerer <matthias@matthias-k.org>
  • Loading branch information
matthias-k committed Apr 22, 2020
1 parent 0009e68 commit c9b9f16
Show file tree
Hide file tree
Showing 9 changed files with 1,010 additions and 775 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,10 @@
# Changelog

* 0.2.20 (unpublished):
* Stimuli now support attributes, just like Fixations. The CAT2000 train and test
datasets now have the stimulus categories as attribute.
* failure to download and setup a dataset will no longer result in leftover
dataset files that keep pysaliency from trying again.
* 0.2.19:
* added pytorch implementation for optimization of similarity metric as alternative
to tensorflow implementation which still uses tensorflow 1.x
Expand Down
91 changes: 75 additions & 16 deletions pysaliency/datasets.py
Expand Up @@ -838,7 +838,8 @@ class Stimuli(Sequence):
A `Stimulus` instance for each stimulus. Mainly for caching.
"""
def __init__(self, stimuli):
__attributes__ = []
def __init__(self, stimuli, attributes=None):
self.stimuli = stimuli
self.shapes = [s.shape for s in self.stimuli]
self.sizes = LazyList(lambda n: (self.shapes[n][0], self.shapes[n][1]),
Expand All @@ -848,14 +849,23 @@ def __init__(self, stimuli):
pickle_cache=True)
self.stimulus_objects = [StimuliStimulus(self, n) for n in range(len(self.stimuli))]

if attributes is not None:
assert isinstance(attributes, dict)
self.attributes = attributes
self.__attributes__ = list(attributes.keys())
else:
self.attributes = {}

def __len__(self):
return len(self.stimuli)

def __getitem__(self, index):
if isinstance(index, slice):
return ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]])
attributes = {key: value[index] for key, value in self.attributes.items()}
return ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]], attributes=attributes)
elif isinstance(index, list):
return ObjectStimuli([self.stimulus_objects[i] for i in index])
attributes = {key: value[index] for key, value in self.attributes.items()}
return ObjectStimuli([self.stimulus_objects[i] for i in index], attributes=attributes)
else:
return self.stimulus_objects[index]

Expand All @@ -865,34 +875,49 @@ def to_hdf5(self, target, verbose=False, compression='gzip', compression_opts=9)
"""

target.attrs['type'] = np.string_('Stimuli')
target.attrs['version'] = np.string_('1.0')
target.attrs['version'] = np.string_('1.1')

for n, stimulus in enumerate(tqdm(self.stimuli, disable=not verbose)):
target.create_dataset(str(n), data=stimulus, compression=compression, compression_opts=compression_opts)

for attribute_name, attribute_value in self.attributes.items():
target.create_dataset(attribute_name, data=attribute_value)
target.attrs['__attributes__'] = np.string_(json.dumps(self.__attributes__))

target.attrs['size'] = len(self)

@classmethod
@hdf5_wrapper(mode='r')
def read_hdf5(cls, source):
""" Read train fixations from hdf5 file or hdf5 group """
""" Read stimuli from hdf5 file or hdf5 group """

data_type = decode_string(source.attrs['type'])
data_version = decode_string(source.attrs['version'])

if data_type != 'Stimuli':
raise ValueError("Invalid type! Expected 'Stimuli', got", data_type)

if data_version != '1.0':
raise ValueError("Invalid version! Expected '1.0', got", data_version)
if data_version not in ['1.0', '1.1']:
raise ValueError("Invalid version! Expected '1.0' or '1.1', got", data_version)

size = source.attrs['size']
stimuli = []

for n in range(size):
stimuli.append(source[str(n)][...])

stimuli = cls(stimuli=stimuli)
if data_version < '1.1':
__attributes__ = []
else:
json_attributes = source.attrs['__attributes__']
if not isinstance(json_attributes, string_types):
json_attributes = json_attributes.decode('utf8')
__attributes__ = json.loads(json_attributes)

attributes = {attribute: source[attribute][...] for attribute in __attributes__}

stimuli = cls(stimuli=stimuli, attributes=attributes)


return stimuli

Expand All @@ -901,7 +926,7 @@ class ObjectStimuli(Stimuli):
"""
This Stimuli class is mainly used for slicing of other stimuli objects.
"""
def __init__(self, stimulus_objects):
def __init__(self, stimulus_objects, attributes=None):
self.stimulus_objects = stimulus_objects
self.stimuli = LazyList(lambda n: self.stimulus_objects[n].stimulus_data,
length = len(self.stimulus_objects))
Expand All @@ -912,6 +937,14 @@ def __init__(self, stimulus_objects):
self.stimulus_ids = LazyList(lambda n: self.stimulus_objects[n].stimulus_id,
length = len(self.stimulus_objects))

if attributes is not None:
assert isinstance(attributes, dict)
self.attributes = attributes
self.__attributes__ = list(attributes.keys())
else:
self.attributes = {}


def read_hdf5(self, target):
raise NotImplementedError()

Expand All @@ -920,7 +953,7 @@ class FileStimuli(Stimuli):
"""
Manage a list of stimuli that are saved as files.
"""
def __init__(self, filenames, cache=True, shapes=None):
def __init__(self, filenames, cache=True, shapes=None, attributes=None):
"""
Create a stimuli object that reads it's stimuli from files.
Expand Down Expand Up @@ -967,6 +1000,13 @@ def __init__(self, filenames, cache=True, shapes=None):
self.sizes = LazyList(lambda n: (self.shapes[n][0], self.shapes[n][1]),
length = len(self.stimuli))

if attributes is not None:
assert isinstance(attributes, dict)
self.attributes = attributes
self.__attributes__ = list(attributes.keys())
else:
self.attributes = {}

def load_stimulus(self, n):
return imread(self.filenames[n])

Expand All @@ -977,7 +1017,8 @@ def __getitem__(self, index):
if isinstance(index, list):
filenames = [self.filenames[i] for i in index]
shapes = [self.shapes[i] for i in index]
return type(self)(filenames=filenames, shapes=shapes)
attributes = {key: [value[i] for i in index] for key, value in self.attributes.items()}
return type(self)(filenames=filenames, shapes=shapes, attributes=attributes)
else:
return self.stimulus_objects[index]

Expand All @@ -987,7 +1028,7 @@ def to_hdf5(self, target):
"""

target.attrs['type'] = np.string_('FileStimuli')
target.attrs['version'] = np.string_('2.0')
target.attrs['version'] = np.string_('2.1')

import h5py
# make sure everything is unicode
Expand All @@ -1014,6 +1055,10 @@ def to_hdf5(self, target):
for n, shape in enumerate(self.shapes):
shape_dataset[n] = np.array(shape)

for attribute_name, attribute_value in self.attributes.items():
target.create_dataset(attribute_name, data=attribute_value)
target.attrs['__attributes__'] = np.string_(json.dumps(self.__attributes__))

target.attrs['size'] = len(self)

@classmethod
Expand All @@ -1027,8 +1072,9 @@ def read_hdf5(cls, source, cache=True):
if data_type != 'FileStimuli':
raise ValueError("Invalid type! Expected 'Stimuli', got", data_type)

if data_version not in ['1.0', '2.0']:
raise ValueError("Invalid version! Expected '1.0' or '2.0', got", data_version)
valid_versions = ['1.0', '2.0', '2.1']
if data_version not in valid_versions:
raise ValueError("Invalid version! Expected one of {}, got {}".format(', '.join(valid_versions), data_version))

encoded_filenames = source['filenames'][...]

Expand All @@ -1041,7 +1087,17 @@ def read_hdf5(cls, source, cache=True):

shapes = [list(shape) for shape in source['shapes'][...]]

stimuli = cls(filenames=filenames, cache=cache, shapes=shapes)
if data_version < '2.1':
__attributes__ = []
else:
json_attributes = source.attrs['__attributes__']
if not isinstance(json_attributes, string_types):
json_attributes = json_attributes.decode('utf8')
__attributes__ = json.loads(json_attributes)

attributes = {attribute: source[attribute][...] for attribute in __attributes__}

stimuli = cls(filenames=filenames, cache=cache, shapes=shapes, attributes=attributes)

return stimuli

Expand All @@ -1063,7 +1119,10 @@ def create_subset(stimuli, fixations, stimuli_indices):


def concatenate_stimuli(stimuli):
return ObjectStimuli(sum([s.stimulus_objects for s in stimuli], []))
attributes = {}
for key in stimuli[0].attributes.keys():
attributes[key] = concatenate_attributes(s.attributes[key] for s in stimuli)
return ObjectStimuli(sum([s.stimulus_objects for s in stimuli], []), attributes=attributes)


def concatenate_attributes(attributes):
Expand Down

0 comments on commit c9b9f16

Please sign in to comment.