<a href="https://colab.research.google.com/github/chenyu313/Colaboratory_note/blob/main/%E4%BF%9D%E5%AD%98%E5%92%8C%E5%8A%A0%E8%BD%BD%E6%A3%80%E6%9F%A5%E7%82%B9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 什么是检查点（checkpoint）
当一个模型在训练时，随着它继续看到更多的数据，性能会发生变化。在整个训练过程中保存模型的状态是一种最佳实践。这为您提供了模型的一个版本，一个检查点，在模型开发期间的每个关键点上。训练完成后，使用与你在训练过程中发现的最佳表现相对应的检查点。  

检查点还允许您在训练过程中断的情况下从原来的位置恢复训练。

## 检查点内容
Lightning检查点包含模型整个内部状态的转储。与普通的PyTorch不同，即使在最复杂的分布式训练环境中，Lightning也会保存您恢复模型所需的所有内容。
* 16位比例因子(如果使用16位精确训练)
* 当前的epoch
* 全局步骤
* LightningModule的状态字典
* 所有优化器的状态
* 所有学习率调度器的状态
* 所有回调的状态(用于有状态回调)
* 数据模块的状态(用于有状态数据模块)
* 用于创建模型的超参数(初始参数)
* 用于创建数据模块的超参数(init参数)
* 循环状态

## 保存检查点
Lightning会自动在当前工作目录中为您保存一个检查点，其中包含您上一个训练阶段的状态。这可以确保你在训练中断的情况下可以恢复训练。

In [None]:
! pip install --quiet "seaborn" "pytorch-lightning>=1.4, <2.0.0" "torchvision" "setuptools==67.4.0" "lightning>=2.0.0rc0" "ipython[notebook]==7.9.0" "pandas" "torchmetrics >=0.11.0" "torch>=1.8.1, <1.14.0" "torchmetrics>=0.7, <0.12"

In [2]:
import os

import lightning as L
import pandas as pd
import seaborn as sn
import torch
from IPython.display import display
from lightning.pytorch.loggers import CSVLogger
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

In [28]:
class LitMNIST(L.LightningModule):
  #⚡闪电模型
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        # 将初始化参数设置为类属性
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # 硬编码一些数据集特定的属性
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # 定义PyTorch模型
        self.model = nn.Sequential(
            nn.Flatten(), #展开
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),  #激活函数
            nn.Dropout(0.1), #暂退法，防止过拟合
            nn.Linear(hidden_size, hidden_size), #隐藏层
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes),
        )

        self.val_accuracy = Accuracy(task="multiclass", num_classes=10)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)

        # 调用self.log将为你在TensorBoard中显示标量
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(preds, y)

        # 调用self.log将为你在TensorBoard中显示标量
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # 下载
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # 为数据加载器分配train/val数据集
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # 为数据加载器分配test数据集
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

In [5]:
# 只需使用训练器，您就可以获得自动检查点
model = LitMNIST()
trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=3,
    logger=CSVLogger(save_dir="logs/"),
)

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [8]:
# 使用default_root_dir参数更改检查点路径
trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=3,
    logger=CSVLogger(save_dir="logs/"),
    default_root_dir="some/path/",
)

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [9]:
# 开始训练
trainer.fit(model)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO: 
  | Name          | Type               | Params
-----------------------------------------------------
0 | model         | Sequential         | 55.1 K
1 | val_accuracy  | MulticlassAccuracy | 0     
2 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name          | Type               | Params
-----------------------------------------------------
0 | model         | Sequential         | 55.1 K
1 | val_accuracy  | MulticlassAccuracy | 0     
2 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
  rank_z

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=3` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.


## 加载检查点
使用以下方法加载LightningModule及其权重和超参数

In [14]:
model = LitMNIST.load_from_checkpoint("/content/logs/lightning_logs/version_0/checkpoints/epoch=2-step=2580-v1.ckpt")
model.eval()
trainer.fit(model)
trainer.test()
# 用该模型预测
# y_hat = model(x)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO: 
  | Name          | Type               | Params
-----------------------------------------------------
0 | model         | Sequential         | 55.1 K
1 | val_accuracy  | MulticlassAccuracy | 0     
2 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name          | Type               | Params
-----------------------------------------------------
0 | model         | Sequential         | 55.1 K
1 | val_accuracy  | MulticlassAccuracy | 0     
2 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=3` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.
  rank_zero_warn(
INFO: Restoring states from the checkpoint path at logs/lightning_logs/version_0/checkpoints/epoch=2-step=2580-v1.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at logs/lightning_logs/version_0/checkpoints/epoch=2-step=2580-v1.ckpt
INFO: Loaded model weights from the checkpoint at logs/lightning_logs/version_0/checkpoints/epoch=2-step=2580-v1.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at logs/lightning_logs/version_0/checkpoints/epoch=2-step=2580-v1.ckpt


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.11838669329881668, 'test_acc': 0.9635999798774719}]

## 保存超参数
LightningModule允许你通过调用self.save_hyperparameters()来自动保存所有传递给init的超参数

In [15]:
class LitMNIST(L.LightningModule):
    def __init__(self, learning_rate, another_parameter, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

超参数被保存到检查点的“hyper_parameters”键中

In [22]:
#checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
#print(checkpoint["hyper_parameters"])

In [25]:
CKPT_PATH="/content/logs/lightning_logs/version_0/checkpoints/epoch=2-step=2580-v1.ckpt"
checkpoint = torch.load(CKPT_PATH)
print(checkpoint.keys())

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers'])


## 恢复训练状态
如果你不只是想增加重量，而是想恢复完整的训练，请执行以下操作:

In [36]:
model = LitMNIST()
trainer = L.Trainer(max_epochs=30)

# 自动恢复 model, epoch, step, LR schedulers, etc...
trainer.fit(model, ckpt_path="/content/lightning_logs/version_0/checkpoints/epoch=25-step=22360.ckpt")

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: Restoring states from the checkpoint path at /content/lightning_logs/version_0/checkpoints/epoch=25-step=22360.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_0/checkpoints/epoch=25-step=22360.ckpt
INFO: 
  | Name          | Type               | Params
-----------------------------------------------------
0 | model         | Sequential         | 55.1 K
1 | val_accuracy  | MulticlassAccuracy | 0     
2 | test_accuracy | Mult

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=30` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


In [37]:
trainer.test()

  rank_zero_warn(
INFO: Restoring states from the checkpoint path at /content/lightning_logs/version_1/checkpoints/epoch=29-step=25800.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_1/checkpoints/epoch=29-step=25800.ckpt
INFO: Loaded model weights from the checkpoint at /content/lightning_logs/version_1/checkpoints/epoch=29-step=25800.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/lightning_logs/version_1/checkpoints/epoch=29-step=25800.ckpt


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.0828084647655487, 'test_acc': 0.9764999747276306}]

## 参考
https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-a-checkpoint