# PyTorch Lightning DataModules ⚡

With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`.

This notebook will walk you through how to start using Datamodules.

The most up to date documentation on datamodules can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html).

---

  - Give us a ⭐  [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)
  - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)
  - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)

### Setup
Lightning is easy to install. Simply ```pip install pytorch-lightning```

In [None]:
! pip install pytorch-lightning

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/ed/af/2f10c8ee22d7a05fe8c9be58ad5c55b71ab4dd895b44f0156bfd5535a708/pytorch_lightning-0.9.0-py3-none-any.whl (408kB)
[K     |████████████████████████████████| 409kB 9.0MB/s 
[?25hCollecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 28.6MB/s 
Collecting tensorboard==2.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/54/f5/d75a6f7935e4a4870d85770bc9976b12e7024fbceb83a1a6bc50e6deb7c4/tensorboard-2.2.0-py3-none-any.whl (2.8MB)
[K     |████████████████████████████████| 2.8MB 47.9MB/s 
Collecting PyYAML>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 58.0

# Introduction

First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms

## Defining the LitMNIST Model

Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.

Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢

This is fine if you don't plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets.

In [None]:
class LitMNIST(pl.LightningModule):

    def __init__(self, hidden_size=64, learning_rate=2e-4, data_dir='./', batch_size=32):

        super().__init__()

        self.hidden_size = hidden_size
        self.learning_rate = learning_rate
        self.data_dir = data_dir
        self.batch_size = batch_size

        # We hardcode dataset specific stuff here.
        num_classes = 10
        channels, height, width = (1, 28, 28)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # Build model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes)
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return pl.TrainResult(loss)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('val_loss', loss, prog_bar=True)
        result.log('val_acc', acc, prog_bar=True)
        return result

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

In [None]:
model = LitMNIST()
trainer = pl.Trainer(max_epochs=1, gpus=1, progress_bar_refresh_rate=20)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





1

# Using DataModules

DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models.

## Defining The MNISTDataModule

Here, we'll separate out the dataset-specific parts from the LitMNIST Model into a LightningDataModule

In [None]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./', batch_size=32):

        super().__init__()

        self.data_dir = data_dir
        self.batch_size = batch_size

        # We hardcode dataset specific stuff here.
        self.num_classes = 10
        self.dims = (1, 28, 28)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

## Defining the dataset agnostic `LitModel`

Below, we define the same model as the `LitMNIST` model we made earlier.

However, this time our model has the freedom to use any input data that we'd like 🔥.

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self, channels, height, width, num_classes, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Build model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes)
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return pl.TrainResult(loss)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('val_loss', loss, prog_bar=True)
        result.log('val_acc', acc, prog_bar=True)
        return result

## Training the `LitModel` using the `MNISTDataModule`

Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders.

In [None]:
dm = MNISTDataModule()
model = LitModel(*dm.size(), dm.num_classes)
trainer = pl.Trainer(max_epochs=1, gpus=1, progress_bar_refresh_rate=20)
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





1

### Defining the CIFAR10 DataModule

To really prove our model can train on multiple datasets, lets define a datamodule for CIFAR10.

In [None]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./', batch_size=32):

        super().__init__()

        self.data_dir = data_dir
        self.batch_size = batch_size

        # We hardcode dataset specific stuff here.
        self.num_classes = 10
        self.dims = (3, 32, 32)
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            CIFAR10_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.CIFAR10_train, self.CIFAR10_val = random_split(CIFAR10_full, [45000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.CIFAR10_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.CIFAR10_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.CIFAR10_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.CIFAR10_test, batch_size=self.batch_size)

In [None]:
dm = CIFAR10DataModule()
model = LitModel(*dm.size(), dm.num_classes)
trainer = pl.Trainer(max_epochs=1, gpus=1, progress_bar_refresh_rate=20)
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified



  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 201 K 


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





1