Skip to content

Commit 1caf8be

Browse files
Datamodule (Lightning-AI#2668)
* ✨ Add copy of pl_bolts datamodule to lightning * ✨ add datamodule to necessary init files * 🚧 add datamodule property to LightningModule * 🚧 . * 🎨 Let DataModule do its own thing * 🚧 add back setup and run both hooks implicitly * 🚧 . * 🐛 fix add_argparse_args * 💄 apply black formatting and isort * 📝 docstrings * 📝 . * 📝 . * 🐛 overwrite cls prepare_data instead of instance * 📝 . * ✅ add some tests * Update datamodule.py * Update datamodule.py * Update datamodule.py Co-authored-by: William Falcon <waf2107@columbia.edu>
1 parent 938ec5a commit 1caf8be

File tree

8 files changed

+589
-145
lines changed

8 files changed

+589
-145
lines changed

pytorch_lightning/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
__copyright__ = 'Copyright (c) 2018-2020, %s.' % __author__
88
__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning'
99
# this has to be simple string, see: https://github.com/pypa/twine/issues/522
10-
__docs__ = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." \
11-
" Scale your models. Write less boilerplate."
10+
__docs__ = (
11+
"PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers."
12+
" Scale your models. Write less boilerplate."
13+
)
1214
__long_docs__ = """
1315
Lightning is a way to organize your PyTorch code to decouple the science code from the engineering.
1416
It's more of a style-guide than a framework.
@@ -47,10 +49,11 @@
4749

4850
if __LIGHTNING_SETUP__:
4951
import sys # pragma: no-cover
52+
5053
sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover
5154
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
5255
else:
53-
from pytorch_lightning.core import LightningModule, data_loader
56+
from pytorch_lightning.core import LightningDataModule, LightningModule, data_loader
5457
from pytorch_lightning.callbacks import Callback
5558
from pytorch_lightning.trainer import Trainer
5659
from pytorch_lightning.utilities.seed import seed_everything
@@ -59,13 +62,14 @@
5962

6063
__all__ = [
6164
'Trainer',
65+
'LightningDataModule',
6266
'LightningModule',
6367
'Callback',
6468
'data_loader',
6569
'seed_everything',
6670
'metrics',
6771
'EvalResult',
68-
'TrainResult'
72+
'TrainResult',
6973
]
7074

7175
# necessary for regular bolts imports. Skip exception since bolts is not always installed

pytorch_lightning/core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,9 @@ def training_step(self, batch, batch_idx):
336336
337337
"""
338338

339+
from pytorch_lightning.core.datamodule import LightningDataModule
339340
from pytorch_lightning.core.decorators import data_loader
340341
from pytorch_lightning.core.lightning import LightningModule
341342

342-
__all__ = ['LightningModule', 'data_loader']
343+
__all__ = ['LightningDataModule', 'LightningModule', 'data_loader']
343344
# __call__ = __all__
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
import inspect
2+
from abc import abstractmethod
3+
from argparse import ArgumentParser, Namespace
4+
from typing import Any, List, Tuple, Union
5+
6+
from torch.utils.data import DataLoader
7+
8+
from pytorch_lightning.utilities import parsing, rank_zero_only, rank_zero_warn
9+
10+
11+
class _DataModuleWrapper(type):
12+
def __call__(cls, *args, **kwargs):
13+
"""A wrapper for LightningDataModule that:
14+
15+
1. Runs user defined subclass's __init__
16+
2. Assures prepare_data() runs on rank 0
17+
"""
18+
19+
# Wrap cls's prepare_data function with rank_zero_only
20+
cls.prepare_data = rank_zero_only(cls.prepare_data)
21+
22+
# Get instance of LightningDataModule by mocking its __init__ via __call__
23+
obj = type.__call__(cls, *args, **kwargs)
24+
25+
return obj
26+
27+
28+
class LightningDataModule(object, metaclass=_DataModuleWrapper): # pragma: no cover
29+
"""
30+
A DataModule standardizes the training, val, test splits, data preparation and transforms.
31+
The main advantage is consistent data splits, data preparation and transforms across models.
32+
33+
Example::
34+
35+
class MyDataModule(LightningDataModule):
36+
def __init__(self):
37+
super().__init__()
38+
def prepare_data(self):
39+
# download, split, etc...
40+
# only called on 1 GPU/TPU in distributed
41+
def setup(self):
42+
# make assignments here (val/train/test split)
43+
# called on every process in DDP
44+
def train_dataloader(self):
45+
train_split = Dataset(...)
46+
return DataLoader(train_split)
47+
def val_dataloader(self):
48+
val_split = Dataset(...)
49+
return DataLoader(val_split)
50+
def test_dataloader(self):
51+
test_split = Dataset(...)
52+
return DataLoader(test_split)
53+
54+
A DataModule implements 5 key methods:
55+
56+
* **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
57+
* **setup** (things to do on every accelerator in distributed mode).
58+
* **train_dataloader** the training dataloader.
59+
* **val_dataloader** the val dataloader(s).
60+
* **test_dataloader** the test dataloader(s).
61+
62+
63+
This allows you to share a full dataset without explaining how to download,
64+
split transform and process the data
65+
66+
"""
67+
68+
name: str = ...
69+
70+
def __init__(
71+
self, train_transforms=None, val_transforms=None, test_transforms=None,
72+
):
73+
super().__init__()
74+
self._train_transforms = train_transforms
75+
self._val_transforms = val_transforms
76+
self._test_transforms = test_transforms
77+
self.dims = ()
78+
79+
@property
80+
def train_transforms(self):
81+
"""
82+
Optional transforms (or collection of transforms) you can apply to train dataset
83+
"""
84+
return self._train_transforms
85+
86+
@train_transforms.setter
87+
def train_transforms(self, t):
88+
self._train_transforms = t
89+
90+
@property
91+
def val_transforms(self):
92+
"""
93+
Optional transforms (or collection of transforms) you can apply to validation dataset
94+
"""
95+
return self._val_transforms
96+
97+
@val_transforms.setter
98+
def val_transforms(self, t):
99+
self._val_transforms = t
100+
101+
@property
102+
def test_transforms(self):
103+
"""
104+
Optional transforms (or collection of transforms) you can apply to test dataset
105+
"""
106+
return self._test_transforms
107+
108+
@test_transforms.setter
109+
def test_transforms(self, t):
110+
self._test_transforms = t
111+
112+
def size(self, dim=None) -> Union[Tuple, int]:
113+
"""
114+
Return the dimension of each input either as a tuple or list of tuples.
115+
"""
116+
117+
if dim is not None:
118+
return self.dims[dim]
119+
120+
return self.dims
121+
122+
@abstractmethod
123+
def prepare_data(self, *args, **kwargs):
124+
"""
125+
Use this to download and prepare data.
126+
In distributed (GPU, TPU), this will only be called once.
127+
128+
.. warning:: Do not assign anything to the datamodule in this step since this will only be called on 1 GPU.
129+
130+
Pseudocode::
131+
132+
dm.prepare_data()
133+
dm.setup()
134+
135+
Example::
136+
137+
def prepare_data(self):
138+
download_imagenet()
139+
clean_imagenet()
140+
cache_imagenet()
141+
"""
142+
143+
@abstractmethod
144+
def setup(self, *args, **kwargs):
145+
"""
146+
Use this to load your data from file, split it, etc. You are safe to make state assignments here.
147+
This hook is called on every process when using DDP.
148+
149+
Example::
150+
151+
def setup(self):
152+
data = load_data(...)
153+
self.train_ds, self.val_ds, self.test_ds = split_data(data)
154+
"""
155+
156+
@abstractmethod
157+
def train_dataloader(self, *args, **kwargs) -> DataLoader:
158+
"""
159+
Implement a PyTorch DataLoader for training.
160+
Return:
161+
Single PyTorch :class:`~torch.utils.data.DataLoader`.
162+
Note:
163+
Lightning adds the correct sampler for distributed and arbitrary hardware.
164+
There is no need to set it yourself.
165+
166+
Example::
167+
168+
def train_dataloader(self):
169+
dataset = MNIST(root=PATH, train=True, transform=transforms.ToTensor(), download=False)
170+
loader = torch.utils.data.DataLoader(dataset=dataset)
171+
return loader
172+
173+
"""
174+
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
175+
176+
@abstractmethod
177+
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
178+
r"""
179+
Implement a PyTorch DataLoader for training.
180+
Return:
181+
Single PyTorch :class:`~torch.utils.data.DataLoader`.
182+
Note:
183+
Lightning adds the correct sampler for distributed and arbitrary hardware.
184+
There is no need to set it yourself.
185+
Note:
186+
You can also return a list of DataLoaders
187+
188+
Example::
189+
190+
def val_dataloader(self):
191+
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
192+
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
193+
return loader
194+
"""
195+
196+
@abstractmethod
197+
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
198+
r"""
199+
Implement a PyTorch DataLoader for training.
200+
Return:
201+
Single PyTorch :class:`~torch.utils.data.DataLoader`.
202+
Note:
203+
Lightning adds the correct sampler for distributed and arbitrary hardware.
204+
There is no need to set it yourself.
205+
Note:
206+
You can also return a list of DataLoaders
207+
208+
Example::
209+
210+
def test_dataloader(self):
211+
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
212+
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
213+
return loader
214+
"""
215+
216+
@classmethod
217+
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
218+
r"""Extends existing argparse by default `LightningDataModule` attributes.
219+
"""
220+
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
221+
added_args = [x.dest for x in parser._actions]
222+
223+
blacklist = ['kwargs']
224+
depr_arg_names = blacklist + added_args
225+
depr_arg_names = set(depr_arg_names)
226+
227+
allowed_types = (str, float, int, bool)
228+
229+
# TODO: get "help" from docstring :)
230+
for arg, arg_types, arg_default in (
231+
at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names
232+
):
233+
arg_types = [at for at in allowed_types if at in arg_types]
234+
if not arg_types:
235+
# skip argument with not supported type
236+
continue
237+
arg_kwargs = {}
238+
if bool in arg_types:
239+
arg_kwargs.update(nargs="?")
240+
# if the only arg type is bool
241+
if len(arg_types) == 1:
242+
# redefine the type for ArgParser needed
243+
def use_type(x):
244+
return bool(parsing.str_to_bool(x))
245+
246+
else:
247+
# filter out the bool as we need to use more general
248+
use_type = [at for at in arg_types if at is not bool][0]
249+
else:
250+
use_type = arg_types[0]
251+
252+
if arg_default == inspect._empty:
253+
arg_default = None
254+
255+
parser.add_argument(
256+
f'--{arg}',
257+
dest=arg,
258+
default=arg_default,
259+
type=use_type,
260+
help=f'autogenerated by plb.{cls.__name__}',
261+
**arg_kwargs,
262+
)
263+
264+
return parser
265+
266+
@classmethod
267+
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
268+
"""
269+
Create an instance from CLI arguments.
270+
271+
Args:
272+
args: The parser or namespace to take arguments from. Only known arguments will be
273+
parsed and passed to the :class:`LightningDataModule`.
274+
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
275+
These must be valid DataModule arguments.
276+
277+
Example::
278+
279+
parser = ArgumentParser(add_help=False)
280+
parser = LightningDataModule.add_argparse_args(parser)
281+
module = LightningDataModule.from_argparse_args(args)
282+
283+
"""
284+
if isinstance(args, ArgumentParser):
285+
args = cls.parse_argparser(args)
286+
params = vars(args)
287+
288+
# we only want to pass in valid DataModule args, the rest may be user specific
289+
valid_kwargs = inspect.signature(cls.__init__).parameters
290+
datamodule_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
291+
datamodule_kwargs.update(**kwargs)
292+
293+
return cls(**datamodule_kwargs)
294+
295+
@classmethod
296+
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
297+
r"""Scans the DataModule signature and returns argument names, types and default values.
298+
Returns:
299+
List with tuples of 3 values:
300+
(argument name, set with argument types, argument default value).
301+
"""
302+
datamodule_default_params = inspect.signature(cls.__init__).parameters
303+
name_type_default = []
304+
for arg in datamodule_default_params:
305+
arg_type = datamodule_default_params[arg].annotation
306+
arg_default = datamodule_default_params[arg].default
307+
try:
308+
arg_types = tuple(arg_type.__args__)
309+
except AttributeError:
310+
arg_types = (arg_type,)
311+
312+
name_type_default.append((arg, arg_types, arg_default))
313+
314+
return name_type_default

0 commit comments

Comments
 (0)