Skip to content

Commit

Permalink
Lots of changes to make reading and writing by chunks work properly.
Browse files Browse the repository at this point in the history
  • Loading branch information
kboone committed May 17, 2019
1 parent a0d1ed9 commit 8e3433a
Show file tree
Hide file tree
Showing 6 changed files with 406 additions and 146 deletions.
5 changes: 4 additions & 1 deletion avocado/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .settings import settings
from .utils import logger, AvocadoException
from .utils import *

from .astronomical_object import *
from .dataset import *
Expand All @@ -9,4 +9,7 @@

from . import plasticc

# Expose the load method of Dataset
load = Dataset.load

__all__ = ['Dataset', 'AstronomicalObject']
23 changes: 13 additions & 10 deletions avocado/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,18 +334,19 @@ def _resample_light_curve(self, reference_object, augmented_metadata):
observations['flux'] = new_fluxes
observations['flux_error'] = np.sqrt(new_fluxvars)

# Update the brightness of the new observations.
if reference_redshift == 0:
# Adjust brightness for galactic objects.
adjust_mag = np.random.normal(0, 0.5)
# adjust_mag = np.random.lognormal(-1, 0.5)
adjust_scale = 10**(-0.4*adjust_mag)
else:
# Update the brightness of the new observations. If the
# 'augment_brightness' key is in the metadata, we add that in
# magnitudes to the augmented object.
augment_brightness = augmented_metadata.get(
'augment_brightness', 0)
adjust_scale = 10**(-0.4*augment_brightness)

if reference_redshift != 0:
# Adjust brightness for extragalactic objects. We simply follow
# the Hubble diagram.
delta_distmod = (self.cosmology.distmod(reference_redshift) -
self.cosmology.distmod(new_redshift)).value
adjust_scale = 10**(0.4*delta_distmod)
adjust_scale *= 10**(0.4*delta_distmod)

observations['flux'] *= adjust_scale
observations['flux_error'] *= adjust_scale
Expand Down Expand Up @@ -457,7 +458,9 @@ def augment_dataset(self, augment_name, dataset, num_augments,
if augmented_object is not None:
augmented_objects.append(augmented_object)

augmented_dataset = Dataset.from_objects(augment_name,
augmented_objects)
augmented_dataset = Dataset.from_objects(
augment_name, augmented_objects, chunk=dataset.chunk,
num_chunks=dataset.num_chunks
)

return augmented_dataset
129 changes: 50 additions & 79 deletions avocado/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.model_selection import StratifiedKFold

from .astronomical_object import AstronomicalObject
from .utils import logger, AvocadoException, write_dataframes
from .utils import logger, AvocadoException, write_dataframe, read_dataframes
from .settings import settings

class Dataset():
Expand All @@ -26,8 +26,14 @@ class Dataset():
objects : list
A list of :class:`AstronomicalObject` instances. Either this or
observations can be specified but not both.
chunk : int (optional)
If the dataset was loaded in chunks, this indicates the chunk number.
num_chunks : int (optional)
If the dataset was loaded in chunks, this is the total number of chunks
used.
"""
def __init__(self, name, metadata, observations=None, objects=None):
def __init__(self, name, metadata, observations=None, objects=None,
chunk=None, num_chunks=None):
"""Create a new Dataset from a set of metadata and observations"""
# Make copies of everything so that we don't mess anything up.
metadata = metadata.copy()
Expand All @@ -36,6 +42,8 @@ def __init__(self, name, metadata, observations=None, objects=None):

self.name = name
self.metadata = metadata
self.chunk = chunk
self.num_chunks = num_chunks

if observations is None:
self.objects = objects
Expand Down Expand Up @@ -97,15 +105,16 @@ def get_raw_features_path(self, tag=None):
return features_path

@classmethod
def load(cls, name, metadata_only=False, chunk=None, num_chunks=None):
def load(cls, name, metadata_only=False, chunk=None, num_chunks=None,
**kwargs):
"""Load a dataset that has been saved in HDF5 format in the data
directory.
For an example of how to create such a dataset, see
`scripts/download_plasticc.py`.
The dataset can optionally be loaded in chunks. To do this, pass chunk
and num_chunks to this method.
and num_chunks to this method. See `read_dataframes` for details.
Parameters
----------
Expand All @@ -119,6 +128,8 @@ def load(cls, name, metadata_only=False, chunk=None, num_chunks=None):
number to load. This is a zero-based index.
num_chunks : int (optional)
The total number of chunks to use.
**kwargs
Additional arguments to `read_dataframes`
Returns
-------
Expand All @@ -131,68 +142,21 @@ def load(cls, name, metadata_only=False, chunk=None, num_chunks=None):
if not os.path.exists(data_path):
raise AvocadoException("Couldn't find dataset %s!" % name)

if chunk is None:
# Load the full dataset
metadata = pd.read_hdf(data_path, 'metadata')
else:
# Load only part of the dataset.
if num_chunks is None:
raise AvocadoException(
"num_chunks must be specified to load the data in chunks!"
)

if chunk < 0 or chunk >= num_chunks:
raise AvocadoException(
"chunk must be in range [0, num_chunks)!"
)


# Use some pandas tricks to figure out which range of the indexes
# we want.
with pd.HDFStore(data_path, 'r') as store:
index = store.get_storer('metadata').table.colindexes['index']
num_rows = index.nelements

# Inclusive indices
start_idx = chunk * num_rows // num_chunks
end_idx = (chunk + 1) * num_rows // num_chunks - 1

# Use the HDF5 index to figure out the object_ids of the rows
# that we are interested in.
start_object_id = index.read_sorted(start_idx, start_idx+1)[0]
end_object_id = index.read_sorted(end_idx, end_idx+1)[0]

start_object_id = start_object_id.decode().strip()
end_object_id = end_object_id.decode().strip()

match_str = (
"(index >= '%s') & (index <= '%s')"
% (start_object_id, end_object_id)
)
metadata = pd.read_hdf(data_path, 'metadata', mode='r',
where=match_str)

if metadata_only:
observations = None
elif chunk is not None:
# Load only the observations for this chunk
match_str = (
"(object_id >= '%s') & (object_id <= '%s')"
% (start_object_id, end_object_id)
)
observations = pd.read_hdf(data_path, 'observations',
where=match_str)
keys = ['metadata']
else:
# Load all observations
observations = pd.read_hdf(data_path, 'observations')
keys = ['metadata', 'observations']

dataframes = read_dataframes(data_path, keys, chunk=chunk,
num_chunks=num_chunks, **kwargs)

# Create a Dataset object
dataset = cls(name, metadata, observations)
dataset = cls(name, *dataframes, chunk=chunk, num_chunks=num_chunks)

return dataset

@classmethod
def from_objects(cls, name, objects):
def from_objects(cls, name, objects, **kwargs):
"""Load a dataset from a list of AstronomicalObject instances.
Parameters
Expand All @@ -201,6 +165,8 @@ def from_objects(cls, name, objects):
A list of AstronomicalObject instances.
name : str
The name of the dataset.
**kwargs
Additional arguments to pass to Dataset()
Returns
-------
Expand All @@ -212,7 +178,7 @@ def from_objects(cls, name, objects):
metadata.set_index('object_id', inplace=True)

# Load the new dataset.
dataset = cls(name, metadata, objects=objects)
dataset = cls(name, metadata, objects=objects, **kwargs)

return dataset

Expand Down Expand Up @@ -318,7 +284,7 @@ def _get_object(self, index=None, category=None, object_id=None, **kwargs):
=======
astronomical_object : AstronomicalObject
The object that was retrieved.
kwargs : dict
**kwargs
Additional arguments passed to the function that weren't used.
"""
return self.get_object(index, category, object_id), kwargs
Expand Down Expand Up @@ -371,8 +337,8 @@ def write(self, **kwargs):
Parameters
----------
kwargs : kwargs
Additional arguments to be passed to `utils.write_dataframes`
**kwargs
Additional arguments to be passed to `utils.write_dataframe`
"""
# Pull out the observations from every object
observations = []
Expand All @@ -382,11 +348,14 @@ def write(self, **kwargs):
observations.append(object_observations)
observations = pd.concat(observations, ignore_index=True, sort=False)

write_dataframes(
self.path,
[self.metadata, observations],
['metadata', 'observations'],
**kwargs
write_dataframe(
self.path, self.metadata, 'metadata', chunk=self.chunk,
num_chunks=self.num_chunks, **kwargs
)

write_dataframe(
self.path, observations, 'observations', index_chunk_column=False,
chunk=self.chunk, num_chunks=self.num_chunks, **kwargs
)

def extract_raw_features(self, featurizer):
Expand Down Expand Up @@ -466,32 +435,28 @@ def write_raw_features(self, tag=None, **kwargs):
tag : str (optional)
The tag for this version of the features. By default, this will use
settings['features_tag'].
kwargs : kwargs
Additional arguments to be passed to `utils.write_dataframes`
**kwargs
Additional arguments to be passed to `utils.write_dataframe`
"""
raw_features_path = self.get_raw_features_path(tag=tag)

write_dataframes(
write_dataframe(
raw_features_path,
[self.raw_features],
['raw_features'],
self.raw_features,
'raw_features',
chunk=self.chunk,
num_chunks=self.num_chunks,
**kwargs
)

def load_raw_features(self, tag=None):
"""Load the raw features from disk.
Note: This method does not currently support reading by chunks. If
this is called on a chunked dataset, all of the raw features will be
loaded!
Parameters
----------
tag : str (optional)
The version of the raw features to use. By default, this will use
settings['features_tag'].
kwargs : kwargs
Additional arguments to be passed to `utils.write_dataframes`
Returns
-------
Expand All @@ -500,6 +465,12 @@ def load_raw_features(self, tag=None):
"""
raw_features_path = self.get_raw_features_path(tag=tag)

self.raw_features = pd.read_hdf(raw_features_path)
self.raw_features = read_dataframe(
raw_features_path,
'raw_features',
chunk=self.chunk,
num_chunks=self.num_chunks,
**kwargs
)

return self.raw_features
2 changes: 1 addition & 1 deletion avocado/plasticc.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _augment_redshift(self, reference_object, augmented_metadata):
augmented_metadata['host_specz'] = aug_redshift
augmented_metadata['host_photoz'] = aug_photoz
augmented_metadata['host_photoz_error'] = aug_photoz_error
augmented_metadata['distmod'] = aug_distmod
augmented_metadata['augment_brightness'] = 0.

def _augment_metadata(self, reference_object):
"""Generate new metadata for the augmented object.
Expand Down

0 comments on commit 8e3433a

Please sign in to comment.