Skip to content

Commit

Permalink
Fast loading of chunks from large datasets.
Browse files Browse the repository at this point in the history
  • Loading branch information
kboone committed May 14, 2019
1 parent 2172c79 commit 9f75f67
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 15 deletions.
48 changes: 36 additions & 12 deletions avocado/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,52 @@ def load(cls, dataset_name, metadata_only=False, chunk=None,
if not os.path.exists(data_path):
raise AvocadoException("Couldn't find dataset %s!" % dataset_name)

metadata = pd.read_hdf(data_path, 'metadata')

if chunk is not None:
if chunk is None:
# Load the full dataset
metadata = pd.read_hdf(data_path, 'metadata')
else:
# Load only part of the dataset.
if num_chunks is None:
raise AvocadoException(
"num_chunks must be specified to load the data in chunks!"
)

# Ensure that the metadata is sorted by the index or we will run
# into major issues.
metadata.sort_index(inplace=True)
if chunk < 0 or chunk > num_chunks:
raise AvocadoException(
"chunk must be in range [0, num_chunks)!"
)


# Use some pandas tricks to figure out which range of the indexes
# we want.
with pd.HDFStore(data_path) as store:
index = store.get_storer('metadata').table.colindexes['index']
num_rows = index.nelements

start_idx = chunk * len(metadata) // num_chunks
end_idx = (chunk + 1) * len(metadata) // num_chunks
# Inclusive indices
start_idx = chunk * num_rows // num_chunks
end_idx = (chunk + 1) * num_rows // num_chunks - 1

start_object_id = metadata.index[start_idx]
end_object_id = metadata.index[end_idx]
# Use the HDF5 index to figure out the object_ids of the rows
# that we are interested in.
start_object_id = index.read_sorted(start_idx, start_idx+1)[0]
end_object_id = index.read_sorted(end_idx, end_idx+1)[0]

metadata = metadata.iloc[start_idx:end_idx]
start_object_id = start_object_id.decode().strip()
end_object_id = end_object_id.decode().strip()

match_str = (
"(index >= '%s') & (index <= '%s')"
% (start_object_id, end_object_id)
)
metadata = pd.read_hdf(data_path, 'metadata', where=match_str)

if metadata_only:
observations = None
elif chunk is not None:
# Load only the observations for this chunk
match_str = (
"(object_id >= '%s') & (object_id < '%s')"
"(object_id >= '%s') & (object_id <= '%s')"
% (start_object_id, end_object_id)
)
observations = pd.read_hdf(data_path, 'observations',
Expand All @@ -126,6 +146,10 @@ def load(cls, dataset_name, metadata_only=False, chunk=None,
# Create a Dataset object
dataset = Dataset(dataset_name, metadata, observations)

# Label folds if we have a full dataset with fold information
if chunk is None and 'category' in dataset.metadata:
dataset.label_folds()

return dataset


Expand Down
6 changes: 3 additions & 3 deletions scripts/download_plasticc.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def preprocess_observations(input_path, output_path, chunk_size=10**6):
desc=" %s" % os.path.basename(input_path)):
chunk = update_plasticc_observations(chunk)
chunk.to_hdf(output_path, 'observations', mode='a', append=True,
data_columns=['object_id'])
format='table', data_columns=['object_id'])


if __name__ == "__main__":
Expand Down Expand Up @@ -174,15 +174,15 @@ def preprocess_observations(input_path, output_path, chunk_size=10**6):
rawdir, 'plasticc_train_metadata.csv.gz')
train_metadata = pd.read_csv(raw_train_metadata_path)
train_metadata = update_plasticc_metadata(train_metadata)
train_metadata.to_hdf(train_path, 'metadata', mode='a',
train_metadata.to_hdf(train_path, 'metadata', mode='a', format='table',
data_columns=['object_id'])

print("Preprocessing test metadata...")
raw_test_metadata_path = os.path.join(
rawdir, 'plasticc_test_metadata.csv.gz')
test_metadata = pd.read_csv(raw_test_metadata_path)
test_metadata = update_plasticc_metadata(test_metadata)
test_metadata.to_hdf(test_path, 'metadata', mode='a',
test_metadata.to_hdf(test_path, 'metadata', mode='a', format='table',
data_columns=['object_id'])

print("Preprocessing training observations...")
Expand Down

0 comments on commit 9f75f67

Please sign in to comment.