diff --git a/docs/source/introduction_guide.rst b/docs/source/introduction_guide.rst index a16ba08818..e3e91bcb1f 100644 --- a/docs/source/introduction_guide.rst +++ b/docs/source/introduction_guide.rst @@ -425,7 +425,7 @@ Here's an example for logistic regression # use any numpy or sklearn dataset X, y = load_iris(return_X_y=True) - dm = SklearnDataModule(X, y) + dm = SklearnDataModule(X, y, batch_size=12) # build model model = LogisticRegression(input_dim=4, num_classes=3) @@ -434,7 +434,7 @@ Here's an example for logistic regression trainer = pl.Trainer(tpu_cores=8, precision=16) trainer.fit(model, dm.train_dataloader(), dm.val_dataloader()) - trainer.test(test_dataloaders=dm.test_dataloader(batch_size=12)) + trainer.test(test_dataloaders=dm.test_dataloader()) Any input will be flattened across all dimensions except the first one (batch). This means images, sound, etc... work out of the box. diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index a0d623253e..4789268888 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -111,8 +111,8 @@ def train_dataloader(self): """ Cityscapes train set """ - transforms = self.train_transforms or self.default_transforms() - target_transforms = self.target_transforms or self.default_target_transforms() + transforms = self.train_transforms or self._default_transforms() + target_transforms = self.target_transforms or self._default_target_transforms() dataset = Cityscapes(self.data_dir, split='train', @@ -136,8 +136,8 @@ def val_dataloader(self): """ Cityscapes val set """ - transforms = self.val_transforms or self.default_transforms() - target_transforms = self.target_transforms or self.default_target_transforms() + transforms = self.val_transforms or self._default_transforms() + target_transforms = self.target_transforms or self._default_target_transforms() dataset = Cityscapes(self.data_dir, split='val', @@ -161,8 +161,8 @@ def test_dataloader(self): """ Cityscapes test set """ - transforms = self.test_transforms or self.default_transforms() - target_transforms = self.target_transforms or self.default_target_transforms() + transforms = self.test_transforms or self._default_transforms() + target_transforms = self.target_transforms or self._default_target_transforms() dataset = Cityscapes(self.data_dir, split='test', @@ -181,7 +181,7 @@ def test_dataloader(self): ) return loader - def default_transforms(self): + def _default_transforms(self): cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -191,7 +191,7 @@ def default_transforms(self): ]) return cityscapes_transforms - def default_target_transforms(self): + def _default_target_transforms(self): cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 54ba6a0bcf..c07acbd5eb 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -74,14 +74,8 @@ def __init__( self.num_workers = num_workers self.seed = seed - self.default_transforms = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], - std=[0.32064945, 0.32098866, 0.32325324]) - ]) - # split into train, val, test - kitti_dataset = KittiDataset(self.data_dir, transform=self.default_transforms) + kitti_dataset = KittiDataset(self.data_dir, transform=self._default_transforms()) val_len = round(val_split * len(kitti_dataset)) test_len = round(test_split * len(kitti_dataset)) @@ -111,3 +105,11 @@ def test_dataloader(self): shuffle=False, num_workers=self.num_workers) return loader + + def _default_transforms(self): + kitti_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], + std=[0.32064945, 0.32098866, 0.32325324]) + ]) + return kitti_transforms diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index bd05d81c90..d64652ecd5 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -118,22 +118,22 @@ class SklearnDataModule(LightningDataModule): >>> from pl_bolts.datamodules import SklearnDataModule ... >>> X, y = load_boston(return_X_y=True) - >>> loaders = SklearnDataModule(X, y) + >>> loaders = SklearnDataModule(X, y, batch_size=32) ... >>> # train set - >>> train_loader = loaders.train_dataloader(batch_size=32) + >>> train_loader = loaders.train_dataloader() >>> len(train_loader.dataset) 355 >>> len(train_loader) 11 >>> # validation set - >>> val_loader = loaders.val_dataloader(batch_size=32) + >>> val_loader = loaders.val_dataloader() >>> len(val_loader.dataset) 100 >>> len(val_loader) 3 >>> # test set - >>> test_loader = loaders.test_dataloader(batch_size=32) + >>> test_loader = loaders.test_dataloader() >>> len(test_loader.dataset) 51 >>> len(test_loader) @@ -150,12 +150,14 @@ def __init__( num_workers=2, random_state=1234, shuffle=True, + batch_size: int = 16, *args, **kwargs, ): super().__init__(*args, **kwargs) self.num_workers = num_workers + self.batch_size = batch_size # shuffle x and y if shuffle and _SKLEARN_AVAILABLE: @@ -193,10 +195,10 @@ def _init_datasets(self, X, y, x_val, y_val, x_test, y_test): self.val_dataset = SklearnDataset(x_val, y_val) self.test_dataset = SklearnDataset(x_test, y_test) - def train_dataloader(self, batch_size: int = 16): + def train_dataloader(self): loader = DataLoader( self.train_dataset, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, @@ -204,10 +206,10 @@ def train_dataloader(self, batch_size: int = 16): ) return loader - def val_dataloader(self, batch_size: int = 16): + def val_dataloader(self): loader = DataLoader( self.val_dataset, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, @@ -215,10 +217,10 @@ def val_dataloader(self, batch_size: int = 16): ) return loader - def test_dataloader(self, batch_size: int = 16): + def test_dataloader(self): loader = DataLoader( self.test_dataset, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 06bcf77ce1..ee50f4b091 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -26,6 +26,7 @@ def __init__( data_dir, meta_dir=None, num_workers=16, + batch_size: int = 32, *args, **kwargs, ): @@ -39,6 +40,7 @@ def __init__( self.data_dir = data_dir self.num_workers = num_workers self.meta_dir = meta_dir + self.batch_size = batch_size @property def num_classes(self): @@ -74,7 +76,7 @@ def prepare_data(self): UnlabeledImagenet.generate_meta_bins(path) """) - def train_dataloader(self, batch_size, num_images_per_class=-1, add_normalize=False): + def train_dataloader(self, num_images_per_class=-1, add_normalize=False): transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -84,7 +86,7 @@ def train_dataloader(self, batch_size, num_images_per_class=-1, add_normalize=Fa transform=transforms) loader = DataLoader( dataset, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, @@ -92,7 +94,7 @@ def train_dataloader(self, batch_size, num_images_per_class=-1, add_normalize=Fa ) return loader - def val_dataloader(self, batch_size, num_images_per_class=50, add_normalize=False): + def val_dataloader(self, num_images_per_class=50, add_normalize=False): transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -102,14 +104,14 @@ def val_dataloader(self, batch_size, num_images_per_class=50, add_normalize=Fals transform=transforms) loader = DataLoader( dataset, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True ) return loader - def test_dataloader(self, batch_size, num_images_per_class, add_normalize=False): + def test_dataloader(self, num_images_per_class, add_normalize=False): transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -119,7 +121,7 @@ def test_dataloader(self, batch_size, num_images_per_class, add_normalize=False) transform=transforms) loader = DataLoader( dataset, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index b1ee3058a8..3842725d30 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -105,7 +105,7 @@ def train_dataloader(self): """ Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`. """ - transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms + transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms dataset = STL10(self.data_dir, split='unlabeled', download=False, transform=transforms) train_length = len(dataset) @@ -132,7 +132,7 @@ def train_dataloader_mixed(self): batch_size: the batch size transforms: a sequence of transforms """ - transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms + transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms unlabeled_dataset = STL10(self.data_dir, split='unlabeled', @@ -170,7 +170,7 @@ def val_dataloader(self): batch_size: the batch size transforms: a sequence of transforms """ - transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms + transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='unlabeled', download=False, transform=transforms) train_length = len(dataset) @@ -202,7 +202,7 @@ def val_dataloader_mixed(self): batch_size: the batch size transforms: a sequence of transforms """ - transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms + transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms unlabeled_dataset = STL10(self.data_dir, split='unlabeled', download=False, @@ -237,7 +237,7 @@ def test_dataloader(self): batch_size: the batch size transforms: the transforms """ - transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms + transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms dataset = STL10(self.data_dir, split='test', download=False, transform=transforms) loader = DataLoader( @@ -251,7 +251,7 @@ def test_dataloader(self): return loader def train_dataloader_labeled(self): - transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms + transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) train_length = len(dataset) @@ -268,7 +268,7 @@ def train_dataloader_labeled(self): return loader def val_dataloader_labeled(self): - transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms + transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', download=False, @@ -288,7 +288,7 @@ def val_dataloader_labeled(self): ) return loader - def default_transforms(self): + def _default_transforms(self): data_transforms = transform_lib.Compose([ transform_lib.ToTensor(), stl10_normalization() diff --git a/pl_bolts/models/regression/linear_regression.py b/pl_bolts/models/regression/linear_regression.py index 33012d56ef..8e576d7793 100644 --- a/pl_bolts/models/regression/linear_regression.py +++ b/pl_bolts/models/regression/linear_regression.py @@ -131,9 +131,6 @@ def cli_main(): 'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.' ) from err - X, y = load_boston(return_X_y=True) # these are numpy arrays - loaders = SklearnDataModule(X, y) - # args parser = ArgumentParser() parser = LinearRegression.add_model_specific_args(parser) @@ -144,9 +141,13 @@ def cli_main(): model = LinearRegression(input_dim=13, l1_strength=1, l2_strength=1) # model = LinearRegression(**vars(args)) + # data + X, y = load_boston(return_X_y=True) # these are numpy arrays + loaders = SklearnDataModule(X, y, batch_size=args.batch_size) + # train trainer = pl.Trainer.from_argparse_args(args) - trainer.fit(model, loaders.train_dataloader(args.batch_size), loaders.val_dataloader(args.batch_size)) + trainer.fit(model, loaders.train_dataloader(), loaders.val_dataloader()) if __name__ == '__main__': diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index e4ee2cd3bd..ea9f1dcc24 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -137,9 +137,6 @@ def cli_main(): 'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.' ) from err - X, y = load_iris(return_X_y=True) - loaders = SklearnDataModule(X, y) - # args parser = ArgumentParser() parser = LogisticRegression.add_model_specific_args(parser) @@ -150,9 +147,13 @@ def cli_main(): # model = LogisticRegression(**vars(args)) model = LogisticRegression(input_dim=4, num_classes=3, l1_strength=0.01, learning_rate=0.01) + # data + X, y = load_iris(return_X_y=True) + loaders = SklearnDataModule(X, y, batch_size=args.batch_size) + # train trainer = pl.Trainer.from_argparse_args(args) - trainer.fit(model, loaders.train_dataloader(args.batch_size), loaders.val_dataloader(args.batch_size)) + trainer.fit(model, loaders.train_dataloader(), loaders.val_dataloader()) if __name__ == '__main__':