Skip to content

Commit

Permalink
Add flags from Lightning-Universe#388 to API
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-clem committed Dec 17, 2020
1 parent 716dedf commit 54114c5
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 47 deletions.
18 changes: 13 additions & 5 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ def __init__(
val_split: Union[int, float] = 0.2,
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -60,8 +63,12 @@ def __init__(
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
seed: Seed to fix the validation split
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""

if not _TORCHVISION_AVAILABLE:
Expand All @@ -70,14 +77,15 @@ def __init__(
)

super().__init__(
dataset_cls=BinaryMNIST,
dims=(1, 28, 28),
data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
normalize=normalize,
seed=seed,
batch_size=batch_size,
seed=seed,
shuffle=shuffle,
pin_memory=pin_memory,
drop_last=drop_last,
*args,
**kwargs,
)
Expand Down
35 changes: 18 additions & 17 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence, Union
from typing import Any, Optional, Sequence, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10
Expand Down Expand Up @@ -64,19 +64,26 @@ def __init__(
val_split: Union[int, float] = 0.2,
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
batch_size: int = 32,
*args,
**kwargs,
):
seed: int = 42,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
seed: Seed to fix the validation split
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""

if not _TORCHVISION_AVAILABLE:
Expand All @@ -89,24 +96,18 @@ def __init__(
val_split=val_split,
num_workers=num_workers,
normalize=normalize,
seed=seed,
batch_size=batch_size,
seed=seed,
shuffle=shuffle,
pin_memory=pin_memory,
drop_last=drop_last,
*args,
**kwargs,
)

@property
def num_samples(self) -> int:
len_dataset = 50_000
# todo: clean up with parent methods
if isinstance(self.val_split, int):
train_len = len_dataset - self.val_split
elif isinstance(self.val_split, float):
val_len = int(self.val_split * len_dataset)
train_len = len_dataset - val_len
else:
raise ValueError(f'Unsupported type {type(self.val_split)}')

train_len, _ = self._get_splits(len_dataset=50_000)
return train_len

@property
Expand Down
24 changes: 17 additions & 7 deletions pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Any, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.utils import _TORCHVISION_AVAILABLE
Expand Down Expand Up @@ -50,19 +50,26 @@ def __init__(
val_split: Union[int, float] = 0.2,
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
batch_size: int = 32,
*args,
**kwargs,
):
seed: int = 42,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
seed: Seed to fix the validation split
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""

if not _TORCHVISION_AVAILABLE:
Expand All @@ -75,8 +82,11 @@ def __init__(
val_split=val_split,
num_workers=num_workers,
normalize=normalize,
seed=seed,
batch_size=batch_size,
seed=seed,
shuffle=shuffle,
pin_memory=pin_memory,
drop_last=drop_last,
*args,
**kwargs,
)
Expand Down
24 changes: 17 additions & 7 deletions pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Any, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.utils import _TORCHVISION_AVAILABLE
Expand Down Expand Up @@ -49,19 +49,26 @@ def __init__(
val_split: Union[int, float] = 0.2,
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
batch_size: int = 32,
*args,
**kwargs,
):
seed: int = 42,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
seed: Seed to fix the validation split
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""

if not _TORCHVISION_AVAILABLE:
Expand All @@ -74,8 +81,11 @@ def __init__(
val_split=val_split,
num_workers=num_workers,
normalize=normalize,
seed=seed,
batch_size=batch_size,
seed=seed,
shuffle=shuffle,
pin_memory=pin_memory,
drop_last=drop_last,
*args,
**kwargs,
)
Expand Down
32 changes: 21 additions & 11 deletions pl_bolts/datamodules/vision_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from abc import abstractmethod
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Union

import torch
from pytorch_lightning import LightningDataModule
Expand All @@ -10,9 +10,9 @@
class VisionDataModule(LightningDataModule):

EXTRA_ARGS = {}
name: str = ""
#: Dataset class to use
DATASET_CLASS = ...
name: str = ""
#: A tuple describing the shape of the data
DIMS: tuple = ...

Expand All @@ -22,19 +22,26 @@ def __init__(
val_split: Union[int, float] = 0.2,
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
batch_size: int = 32,
*args,
**kwargs,
):
seed: int = 42,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
seed: Seed to fix the validation split
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""

super().__init__(*args, **kwargs)
Expand All @@ -43,8 +50,11 @@ def __init__(
self.val_split = val_split
self.num_workers = num_workers
self.normalize = normalize
self.seed = seed
self.batch_size = batch_size
self.seed = seed
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last

def prepare_data(self):
"""
Expand Down Expand Up @@ -110,7 +120,7 @@ def default_transforms(self):

def train_dataloader(self) -> DataLoader:
""" The train dataloader """
return self._data_loader(self.dataset_train, shuffle=True)
return self._data_loader(self.dataset_train, shuffle=self.shuffle)

def val_dataloader(self) -> DataLoader:
""" The val dataloader """
Expand All @@ -126,6 +136,6 @@ def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:
batch_size=self.batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True,
drop_last=self.drop_last,
pin_memory=self.pin_memory
)

0 comments on commit 54114c5

Please sign in to comment.