Skip to content

Commit

Permalink
Spark/Lightning: add reader_worker_count and reader_pool_type (horovo…
Browse files Browse the repository at this point in the history
…d#3612)

Signed-off-by: Lee Yang <leey@nvidia.com>
  • Loading branch information
chongxiaoc authored and leewyang committed Aug 5, 2022
1 parent fb4d6eb commit 5481ace
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions horovod/spark/lightning/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def setup(self, stage=None):
reader_factory = make_batch_reader

self.train_reader = reader_factory(self.train_dir, num_epochs=self.num_reader_epochs,
reader_pool_type=self.reader_pool_type,
workers_count=self.reader_worker_count,
cur_shard=self.cur_shard, shard_count=self.shard_count,
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=self.schema_fields,
Expand All @@ -72,6 +74,8 @@ def setup(self, stage=None):
**reader_factory_kwargs)
if self.has_val:
self.val_reader = reader_factory(self.val_dir, num_epochs=self.num_reader_epochs,
reader_pool_type=self.reader_pool_type,
workers_count=self.reader_worker_count,
cur_shard=self.cur_shard, shard_count=self.shard_count,
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=self.schema_fields,
Expand Down

0 comments on commit 5481ace

Please sign in to comment.