# LIGHTNING DATAMODULE

A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:

A datamodule encapsulates the five steps involved in data processing in PyTorch:
1. Download / tokenize / process.
2. Clean and (maybe) save to disk.
3. Load inside Dataset.
4. Apply transforms (rotate, tokenize, etc…).
5. Wrap inside a DataLoader.

# Why do I need a DataModule?

在普通的 PyTorch 代码中，数据清理/准备通常分散在许多文件中。这使得跨项目共享和重用精确的拆分和转换成为不可能。

LightningDataModule 是管理 PyTorch 闪电中的数据的一种方便方法。它封装了训练、验证、测试和预测数据加载器，以及数据处理、下载和转换的任何必要步骤。通过使用 LightningDataModule，您可以轻松地开发与数据集无关的模型，热交换不同的数据集，并跨项目共享数据分割和转换。



In [4]:
import lightning as L
import torch
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms

##### 下面是一个简单的 PyTorch 例子:

In [5]:
# regular PyTorch
test_data = MNIST(root="MNIST",train=False, download=True)
predict_data = MNIST(root="MNIST", train=False, download=True)
train_data = MNIST(root="MNIST", train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])

train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)

##### LightningDataModule

In [6]:
class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "MNIST", batch_size: int = 32):
        super().__init__()
        self.mnist_val = self.mnist_test = self.mnist_train = None
        self.mnist_predict = None
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False)
        self.mnist_predict = MNIST(self.data_dir, train=False)
        mnist_full = MNIST(self.data_dir, train=True)
        self.mnist_train, self.mnist_val = torch.utils.data.random_split(
            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)

    # def teardown(self, stage: str):
    #     # Used to clean-up when the run is finished
    #     None

但是现在，随着处理复杂性的增长(转换，多 GPU 培训) ，你可以让闪电为你处理这些细节，同时让这个数据集可重用，这样你就可以与同事分享或在不同的项目中使用。

In [7]:
# mnist = MNISTDataModule(my_path)
# model = LitClassifier()
# 
# trainer = Trainer()
# trainer.fit(model, mnist)

下面是一个更现实、更复杂的 DataModule，它展示了数据模块的可重用性有多大。

In [8]:
import lightning as L
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms


class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

# LightningDataModule API

要定义 DataModule，可以使用以下方法来创建 train/val/test/predict dataloader:

- prepare_data (how to download, tokenize, etc…)
- setup (how to split, define dataset, etc…)
- train_dataloader
- val_dataloader
- test_dataloader
- predict_dataloader

## prepare_data

Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures the prepare_data() is called only within a single process on CPU, so you can safely add your downloading logic within. In case of multi-node training, the execution of this hook depends upon prepare_data_per_node. setup() is called after prepare_data and there is a barrier in between which ensures that all the processes proceed to setup once the data is prepared and available for use.

使用多个进程(分布式设置)下载和保存数据将导致数据损坏。闪电可以确保 ready _ data ()只在 CPU 上的单个进程中调用，因此您可以安全地在其中添加下载逻辑。在多节点训练的情况下，这个钩子的执行取决于 ready _ data _ per _ node。Setup ()在 ready _ data 之后调用，并且在这两者之间存在一个屏障，以确保在数据准备好并可以使用之后，所有进程都继续进行设置。

In [10]:
import os
class MNISTDataModule_example(L.LightningDataModule):
    def prepare_data(self):
        # download
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

##### WARNING:
**从主进程调用 ready _ data。不建议在这里分配 state (例如 self. x = y) ，因为它是在单个进程上调用的，如果在这里分配 state，那么它们对其他进程就不可用。**

## setup

还有一些数据操作可能需要在每个 GPU 上执行。使用 setup ()可以执行以下操作:
- count number of classes
- build vocabulary
- perform train/val/test splits
- create datasets
- apply transforms (defined explicitly in your datamodule)
- etc…

##### setup is called from every process across all the nodes. Setting state here is recommended.

##### teardown can be used to clean up the state. It is also called from every process across all the nodes.


# transfer_batch_to_device
https://lightning.ai/docs/pytorch/stable/data/datamodule.html



In [None]:
def transfer_batch_to_device(self, batch, device, dataloader_idx):
    if isinstance(batch, CustomBatch):
        # move all tensors in your custom data structure to the device
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
    elif dataloader_idx == 0:
        # skip device transfer for the first dataloader or anything you wish
        pass
    else:
        batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
    return batch

# on_before_batch_transfer

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

```python
def on_before_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = transforms(batch['x'])
    return batch
```



In [None]:
def on_before_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = transforms(batch['x'])
    return batch

# on_after_batch_transfer

在批处理转移到设备后，重写以更改或应用批处理增强。

In [None]:
def on_after_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = gpu_transforms(batch['x'])
    return batch

# load_state_dict

在加载检查点时调用，实现重新加载给定数据模块状态的数据模块状态。

# teardown

Called at the end of fit (train + validate), validate, test, or predict.

# prepare_data_per_node

If set to True will call prepare_data() on LOCAL_RANK=0 for every node. If set to False will only call from NODE_RANK=0, LOCAL_RANK=0.

In [None]:
class LitDataModule(L.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True

# Using a DataModule

The recommended way to use a DataModule is simply:

In [None]:
dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)

# 如果需要从数据集中获取信息来构建模型，那么手动运行 ready _ data 和 setup (Light 确保该方法在正确的设备上运行)。

In [None]:
dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage="fit")

model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)

dm.setup(stage="test")
trainer.test(datamodule=dm)

您可以访问当前使用的数据模块的培训师通过 traner.dataloader 和当前使用的数据加载器通过培训师属性 train _ dataloader () ，val _ dataloader () ，test _ dataloader () ，并预测 _ dataloader ()。

# DataModules without Lightning

In [None]:
# download, etc...
dm = MNISTDataModule()
dm.prepare_data()

# splits/transforms
dm.setup(stage="fit")

# use data
for batch in dm.train_dataloader():
    ...

for batch in dm.val_dataloader():
    ...

dm.teardown(stage="fit")

# lazy load test data
dm.setup(stage="test")
for batch in dm.test_dataloader():
    ...

dm.teardown(stage="test")

# Hyperparameters in DataModules

In [None]:
import lightning as L


class CustomDataModule(L.LightningDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

    def configure_optimizers(self):
        # access the saved hyperparameters
        opt = optim.Adam(self.parameters(), lr=self.hparams.lr)

# Save DataModule state

When a checkpoint is created, it asks every DataModule for their state. If your DataModule defines the state_dict and load_state_dict methods, the checkpoint will automatically track and restore your DataModules.

创建检查点时，它会询问每个 DataModule 的状态。如果 DataModule 定义 state _ dict 和 load _ state _ dict 方法，则检查点将自动跟踪和还原 DataModule。

In [10]:
import lightning as L


class LitDataModule(L.LightningDataModule):
    def state_dict(self):
        # track whatever you want here
        state = {"current_train_batch_index": self.current_train_batch_index}
        return state

    def load_state_dict(self, state_dict):
        # restore the state based on what you tracked in (def state_dict)
        self.current_train_batch_index = state_dict["current_train_batch_index"]