Skip to content

Commit

Permalink
add customized data loader (#2923)
Browse files Browse the repository at this point in the history
* add customized data loader

Signed-off-by: Peng Zhang <pengz@uber.com>

* fix test

Signed-off-by: Peng Zhang <pengz@uber.com>

* add callbacks in integration test

Signed-off-by: Peng Zhang <pengz@uber.com>

* data_loader

Signed-off-by: Peng Zhang <pengz@uber.com>

* make the data loader interface more general.

Signed-off-by: Peng Zhang <pengz@uber.com>
  • Loading branch information
irasit committed May 30, 2021
1 parent 4d62420 commit 52d0b27
Show file tree
Hide file tree
Showing 11 changed files with 519 additions and 108 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased] - YYYY-MM-DD

### Added
Custom spark data loader interface. ([#2938](https://github.com/horovod/horovod/issues/2938))

- Added support to supply a logger and associated parameter to control the frequency of logging. ([#2926](https://github.com/horovod/horovod/pull/2926))

Expand Down
2 changes: 2 additions & 0 deletions docs/pytorch.rst
Expand Up @@ -139,3 +139,5 @@ Start the training job and specify the number of workers on the command line as
You can find an example of use pytorch lightning trainer with horovod backend in `pytorch_lightning_mnist.py script <../examples/pytorch/pytorch_lightning_mnist.py>`__

See the PyTorch Lightning `docs <https://pytorch-lightning.readthedocs.io/en/stable/multi_gpu.html#horovod>`_ for more details.

A Pytorch-Lightning based spark estimator is also added, example is in `pytorch_lightning_spark_mnist.py <../examples/spark/pytorch/pytorch_lightning_spark_mnist.py>`__
2 changes: 2 additions & 0 deletions docs/spark.rst
Expand Up @@ -96,6 +96,8 @@ logging (for Tensorboard) using the Estimator ``Store`` abstraction. Stores are
artifacts including intermediate representations of the training data. Horovod natively supports stores for HDFS
and local filesystems.

`Petastorm <https://github.com/uber/petastorm/blob/master/petastorm/pytorch.py#L258>` based data loader is used by default, but user can define a custom data loader by override the `BaseDataLoader` interface. A async data loader mixin can also added on top of the data loader.

End-to-end example
------------------
`keras_spark_rossmann_estimator.py script <../examples/spark/keras/keras_spark_rossmann_estimator.py>`__ provides
Expand Down
31 changes: 21 additions & 10 deletions examples/spark/pytorch/pytorch_lightning_spark_mnist.py
Expand Up @@ -95,8 +95,7 @@ def __init__(self):
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
#raise RuntimeError("x shape is {}".format(x.shape))
x = x.float()
x = x.float().reshape((-1, 1, 28, 28))
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
Expand All @@ -112,18 +111,18 @@ def training_step(self, batch, batch_nb):
x, y = batch['features'], batch['label']
y_hat = self(x)
loss = F.nll_loss(y_hat, y.long())
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
self.log('train_loss', loss)
return loss

def validation_step(self, batch, batch_nb):
x, y = batch['features'], batch['label']
y_hat = self(x)
return {'val_loss': F.nll_loss(y_hat, y.long())}
loss = F.nll_loss(y_hat, y.long())
self.log('val_loss', loss)

def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() if len(outputs) > 0 else float('inf')
self.log('avg_val_loss', avg_loss)

model = Net()

Expand Down Expand Up @@ -156,12 +155,24 @@ def on_train_epoch_end(self, trainer, model, unused=None):
self.train_epcoh_end_counter += 1

def on_train_end(self, trainer, model):
print('Training ends')
assert self.epcoh_end_counter == 2 * epochs
print("Training ends:"
f"self.epcoh_end_counter={self.epcoh_end_counter}, "
f"self.train_epcoh_end_counter={self.train_epcoh_end_counter}")
assert self.train_epcoh_end_counter == epochs

callbacks = [MyDummyCallback()]

# added EarlyStopping and ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
callbacks.append(ModelCheckpoint(dirpath=args.work_dir))

from pytorch_lightning.callbacks.early_stopping import EarlyStopping
callbacks.append(EarlyStopping(monitor='val_loss',
min_delta=0.00,
patience=3,
verbose=True,
mode='max'))

torch_estimator = hvd.TorchEstimator(backend=backend,
store=store,
model=model,
Expand Down
19 changes: 19 additions & 0 deletions horovod/data/__init__.py
@@ -0,0 +1,19 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from horovod.data.data_loader_base import (
BaseDataLoader,
AsyncDataLoaderMixin
)
132 changes: 132 additions & 0 deletions horovod/data/data_loader_base.py
@@ -0,0 +1,132 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from queue import Queue, Empty
from threading import Thread, Event


class BaseDataLoader(object):
def __len__(self):
"""
Length of the batches to be loaded.
"""
raise NotImplementedError()

def _iterate(self):
"""
Interface for the implimentation of iterate batches
"""
raise NotImplementedError()

def __iter__(self):
"""
Starting iteration and get batchs
"""
for batch in self._iterate():
yield self._process_batch(batch)

def _process_batch(self, batch):
"""
Hook to modify batch before output. Will be override by trainer to reshape the data
as needed. Please do not override it.
"""
return batch


class AsyncDataLoaderMixin(object):
"""
Async Mixin on top of implementation of BaseDataLoader. It contains a seperate thread
which reads batch from self._iterate() and push them in the queue. The self.__iter__() function
will pop the batch from the queue.
If async_loader_queue_size is set to 0, the data loader will not work in async mode.
For example:
class PytorchAsyncDataLoader(AsyncDataLoaderMixin, PytorchDataLoader):
"""

def __init__(self, async_loader_queue_size=64, *args, **kwargs):
"""
initialize the async data loader. Need to add this in the __init__() of the implementation
"""
self.async_loader_queue_size = async_loader_queue_size
super().__init__(*args, **kwargs)

print(f"Apply the AsyncDataLoaderMixin on top of the data loader, async_loader_queue_size={async_loader_queue_size}. ")

if self.async_loader_queue_size > 0:
self.finished_event = Event()
self.queue = Queue(self.async_loader_queue_size)
self.thread = Thread(target=self._async_worker)
self.thread.daemon = True
self.started = False

def __del__(self):
self._close_async_loader()
s = super()
if hasattr(s, "__del__"):
s.__del__(self)

def _close_async_loader(self):
"""
Close the async data loader.
"""
print("Closing the AsyncDataLoaderMixin.")
if self.async_loader_queue_size > 0 and self.started:
self.finished_event.set()
try:
# Free buffer to allow worker to retry
self.queue.get_nowait()
except Empty:
pass
self.thread.join()

def _async_worker(self):
"""
Start worker thread to load data asynchronously.
User need to implement self._iterate() to read the data.
"""
try:
while not self.finished_event.is_set():
for batch in self._iterate():
if self.finished_event.is_set():
break
self.queue.put(batch)
self.queue.put(None)
except Exception as ex:
self.queue.put(ex)
self.queue.put(None)
finally:
self.queue.put(None)

def __iter__(self):
"""
Override the __iter__() to iterate data asynchronously to produce batchs.
Will procude batchs from the queue which were generated by self._iterate().
"""

print("Start generating batches from async data loader.")
if self.async_loader_queue_size > 0:
if not self.started:
self.started = True
self.thread.start()
while True:
batch = self.queue.get()
if batch is None:
break
if isinstance(batch, Exception):
raise batch
yield self._process_batch(batch)
else:
for batch in self._iterate():
yield self._process_batch(batch)
19 changes: 19 additions & 0 deletions horovod/spark/data_loaders/__init__.py
@@ -0,0 +1,19 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from horovod.spark.data_loaders.pytorch_data_loaders import (
PytorchDataLoader,
PytorchAsyncDataLoader
)
53 changes: 53 additions & 0 deletions horovod/spark/data_loaders/pytorch_data_loaders.py
@@ -0,0 +1,53 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from petastorm.pytorch import BatchedDataLoader
from horovod.data import BaseDataLoader, AsyncDataLoaderMixin


class PytorchDataLoader(BaseDataLoader):
def __init__(self, reader, batch_size, shuffling_queue_capacity):
self.reader = reader
self.batch_size = batch_size
self.shuffling_queue_capacity = shuffling_queue_capacity
print(f"Initializing petastorm dataloader with batch_size {batch_size}"
f" and shuffling_queue_capacity {shuffling_queue_capacity}")

def __len__(self):
return len(self.reader)

def _iterate(self):
# Reset the reader if needed.
if self.reader.last_row_consumed:
print(f"Resetting Petastorm reader for {self.reader.dataset.paths}")
self.reader.reset()

# Re-create the data loader for each iteration. This is needed becasue there may be
# some left-over data from last epoch which can cause petastorm's BatchedDataLoader
# fail to start new iteration. To workaround the issue, we have to re-create the data
# loader at each new iterration starts.
data_loader = BatchedDataLoader(
self.reader,
batch_size=self.batch_size,
shuffling_queue_capacity=self.shuffling_queue_capacity,
)

for batch in data_loader:
yield batch


class PytorchAsyncDataLoader(AsyncDataLoaderMixin, PytorchDataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

0 comments on commit 52d0b27

Please sign in to comment.