In [1]:
import os
from torch import optim, nn, utils, Tensor
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
import lightning.pytorch as pl

from lightning.pytorch.tuner import Tuner


In [2]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False)
        self.mnist_predict = MNIST(self.data_dir, train=False)
        mnist_full = MNIST(self.data_dir, train=True)
        self.mnist_train, self.mnist_val = utils.data.random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)

    def teardown(self, stage: str):
        # Used to clean-up when the run is finished
        ...

In [3]:
# setup data
my_path = os.path.join("..", "..", "datasets")
print('my_path: ', my_path)

mnist = MNISTDataModule(my_path)
# model = LitClassifier()

# trainer = Trainer()
# trainer.fit(model, mnist)

my_path:  ..\..\datasets


In [4]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
        
        # For eg., if you are working with NLP task where you need to tokenize the text and use it,
        # then you can do something like as follows:
        # tokenize
        # save it to disk

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = utils.data.random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

In [5]:
dm = MNISTDataModule(my_path)
# model = Model()
# trainer.fit(model, datamodule=dm)
# trainer.test(datamodule=dm)
# trainer.validate(datamodule=dm)
# trainer.predict(datamodule=dm)

In [6]:
dm = MNISTDataModule(my_path)
dm.prepare_data()
dm.setup(stage="fit")

# model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
# trainer.fit(model, dm)

# dm.setup(stage="test")
# trainer.test(datamodule=dm)