-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
datamodule.py
130 lines (120 loc) · 6.54 KB
/
datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import pytorch_lightning as pl
from horovod.spark.common import constants
from horovod.spark.data_loaders.pytorch_data_loaders import (
PytorchInfiniteAsyncDataLoader,
PytorchInmemAsyncDataLoader)
from petastorm import TransformSpec, make_reader, make_batch_reader
PETASTORM_HDFS_DRIVER = constants.PETASTORM_HDFS_DRIVER
class PetastormDataModule(pl.LightningDataModule):
"""Default DataModule for Lightning Estimator"""
def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_val: bool=True,
train_batch_size: int=32, val_batch_size: int=32, shuffle_size: int=1000,
num_reader_epochs=None, reader_pool_type: str="process", reader_worker_count: int=2,
transform_spec=None, inmemory_cache_all=False,
cur_shard: int=0, shard_count: int=1, schema_fields=None, storage_options=None,
steps_per_epoch_train: int=1, steps_per_epoch_val: int=1, verbose=True, **kwargs):
super().__init__()
self.train_dir = train_dir
self.val_dir = val_dir
self.num_train_epochs = num_train_epochs
self.has_val = has_val
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.shuffle_size = shuffle_size
self.num_reader_epochs = num_reader_epochs
self.reader_pool_type = reader_pool_type
self.reader_worker_count = reader_worker_count
self.transform_spec = transform_spec
self.inmemory_cache_all = inmemory_cache_all
self.cur_shard = cur_shard
self.shard_count = shard_count
self.schema_fields = schema_fields
self.storage_options = storage_options
self.steps_per_epoch_train = steps_per_epoch_train
self.steps_per_epoch_val = steps_per_epoch_val
self.verbose = verbose
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
transform_spec = TransformSpec(self.transform_spec) if self.transform_spec else None
# In general, make_batch_reader is faster than make_reader for reading the dataset.
# However, we found out that make_reader performs data transformations much faster than
# make_batch_reader with parallel worker processes. Therefore, the default reader
# we choose is make_batch_reader unless there are data transformations.
if transform_spec:
reader_factory = make_reader
else:
reader_factory = make_batch_reader
self.train_reader = reader_factory(self.train_dir, num_epochs=self.num_reader_epochs,
cur_shard=self.cur_shard, shard_count=self.shard_count,
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=self.schema_fields,
storage_options=self.storage_options)
if self.has_val:
self.val_reader = reader_factory(self.val_dir, num_epochs=self.num_reader_epochs,
cur_shard=self.cur_shard, shard_count=self.shard_count,
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=self.schema_fields,
storage_options=self.storage_options)
def teardown(self, stage=None):
if stage == "fit" or stage is None:
if self.verbose:
print("Tear down: closing async dataloaders")
self.train_dl.close_async_loader()
if self.has_val:
self.val_dl.close_async_loader()
if not self.inmemory_cache_all:
# Reader was loaded once and stopped for inmemory datalaoder.
if self.verbose:
print("Tear down: closing petastorm readers")
self.train_reader.stop()
self.train_reader.join()
if self.has_val:
self.val_reader.stop()
self.val_reader.join()
if self.verbose:
print("Tear down: async dataloaders closed.")
def train_dataloader(self):
if self.verbose:
print("Setup train dataloader")
kwargs = dict(reader=self.train_reader, batch_size=self.train_batch_size,
name="train dataloader",
limit_step_per_epoch=self.steps_per_epoch_train,
verbose=self.verbose)
if self.inmemory_cache_all:
# Use inmem dataloader
dataloader_class = PytorchInmemAsyncDataLoader
kwargs['shuffle'] = self.shuffle_size > 0
kwargs['num_epochs'] = self.num_train_epochs
else:
dataloader_class = PytorchInfiniteAsyncDataLoader
kwargs['shuffling_queue_capacity'] = self.shuffle_size
# To avoid loading too much data in memory, need to calculate the queue size
# dynamicaly, and limit the data loaded in queue.
# Add 1 in size for storing the None in the end of each epoch.
kwargs['async_loader_queue_size'] = max(1, min(100000 // kwargs['batch_size'], kwargs['limit_step_per_epoch'] // 4)) + 1
self.train_dl = dataloader_class(**kwargs)
return self.train_dl
def val_dataloader(self):
if not self.has_val:
return None
if self.verbose:
print("setup val dataloader")
kwargs = dict(reader=self.val_reader, batch_size=self.val_batch_size,
name="val dataloader",
limit_step_per_epoch=self.steps_per_epoch_val,
verbose=self.verbose)
if self.inmemory_cache_all:
# Use inmem dataloader
dataloader_class = PytorchInmemAsyncDataLoader
kwargs['shuffle'] = False
kwargs['num_epochs'] = self.num_train_epochs
else:
dataloader_class = PytorchInfiniteAsyncDataLoader
kwargs['shuffling_queue_capacity'] = 0
# To avoid loading too much data in memory, need to calculate the queue size
# dynamicaly, and limit the data loaded in queue.
# Add 1 in size for storing the None in the end of each epoch.
kwargs['async_loader_queue_size'] = max(1, min(10000 // kwargs['batch_size'], kwargs['limit_step_per_epoch'] // 4)) + 1
self.val_dl = dataloader_class(**kwargs)
return self.val_dl