# SAVING AND LOADING CHECKPOINTS (BASIC)

## What is a checkpoint?
When a model is training, the performance changes as it continues to see more data. It is a best practice to save the state of a model throughout the training process. This gives you a version of the model, a checkpoint, at each key point during the development of the model. Once training has completed, use the checkpoint that corresponds to the best performance you found during the training process.
Checkpoints also enable your training to resume from where it was in case the training process is interrupted.
PyTorch Lightning checkpoints are fully usable in plain PyTorch.

当一个模型正在训练时，性能随着它继续看到更多的数据而改变。在整个培训过程中保存模型的状态是一种最佳实践。这将在模型开发期间的每个关键点为您提供模型的一个版本，即一个检查点。一旦培训完成，使用与培训过程中发现的最佳性能相对应的检查点。检查点还可以使您的培训从原来的地方恢复，以防培训过程中断。火炬闪电检查点在普通火炬中是完全可用的。

In [18]:
# 将上一章节的内容导入：
from json import encoder

import torch.utils.data as data
import torch
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
import os
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning as L

# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
train_loader = DataLoader(train_set, num_workers=0, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, num_workers=0, batch_size=64, shuffle=True)

# PyTorch
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

# PyTorch-Lightning
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        # Example input array for logging and model tracing
        self.example_input_array = torch.rand(16, 1, 28, 28)
    
    def forward(self, x):
         # Define the forward pass
        x = x.view(x.size(0), -1)  # Flatten the input
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat
        

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
        x_hat = self(x)  # Use the forward method
        loss = F.mse_loss(x_hat, x.view(x.size(0), -1))
        return loss
    
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        x_hat = self(x)  # Use the forward method
        test_loss = F.mse_loss(x_hat, x.view(x.size(0), -1))
        self.log("test_loss", test_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, _ = batch  # 分离输入数据和标签，假设 batch 是 (inputs, targets)
        x = x.view(x.size(0), -1)  # 如果还没有在 DataLoader 中转换，这里要确保输入被平铺
        z = self.encoder(x)  # 通过编码器
        x_hat = self.decoder(z)  # 通过解码器
        return x_hat

In [14]:
from lightning.pytorch.callbacks import DeviceStatsMonitor
model = LitAutoEncoder.load_from_checkpoint("ckpts/lightning_logs/version_0/checkpoints/epoch=2-step=2250.ckpt",encoder=Encoder(), decoder=Decoder())

trainer = L.Trainer(accelerator='gpu', max_epochs=3,callbacks=[DeviceStatsMonitor()], default_root_dir='ckpts')
trainer.fit(model, train_loader, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params | In sizes  | Out sizes
------------------------------------------------------------
0 | encoder | Encoder | 50.4 K | [16, 784] | [16, 3]  
1 | decoder | Decoder | 51.2 K | [16, 3]   | [16, 784]
------------------------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Training: |                                                                                                   …

`Trainer.fit` stopped: `max_epochs=3` reached.


In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LitAutoEncoder.load_from_checkpoint("ckpts/lightning_logs/version_0/checkpoints/epoch=2-step=2250.ckpt",encoder=Encoder(), decoder=Decoder()).to(device)
model.eval()
x = torch.rand(32, 1, 28, 28).to(device)
y_hat = model(x)
print(y_hat.size())
# # predict with the model
# y_hat = model(x)

torch.Size([32, 784])


In [20]:
predictions = trainer.predict(model, test_loader)
# print(predictions)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |                                                                                                 …

In [22]:
# Use PyTorch as normal
checkpoint = torch.load("ckpts/lightning_logs/version_0/checkpoints/epoch=2-step=2250.ckpt")
model.load_state_dict(checkpoint["state_dict"])
model.eval()

LitAutoEncoder(
  (encoder): Encoder(
    (l1): Sequential(
      (0): Linear(in_features=784, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=3, bias=True)
    )
  )
  (decoder): Decoder(
    (l1): Sequential(
      (0): Linear(in_features=3, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=784, bias=True)
    )
  )
)

## Enable distributed inference

```python
import torch
from lightning.pytorch.callbacks import BasePredictionWriter


class CustomWriter(BasePredictionWriter):
    def __init__(self, output_dir, write_interval):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))

        # optionally, you can also save `batch_indices` to get the information about the data index
        # from your prediction data
        torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))


# or you can set `write_interval="batch"` and override `write_on_batch_end` to save
# predictions at batch level
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False)
```

# Save hyperparameters

The LightningModule allows you to automatically save all the hyperparameters passed to init simply by calling self.save_hyperparameters().

```python
class MyLightningModule(LightningModule):
    def __init__(self, learning_rate, another_parameter, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
```

The hyperparameters are saved to the “hyper_parameters” key in the checkpoint

```python
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value, "another_parameter": the_other_value}
```

The LightningModule also has access to the Hyperparameters

```python
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
print(model.learning_rate)
```

# Initialize with other parameters

If you used the self.save_hyperparameters() method in the __init__ method of the LightningModule, you can override these and initialize the model with different hyperparameters.

如果您在 LightningModule 的 _ _ init _ _ 方法中使用了 self. save _ hyperproperties ()方法，那么您可以覆盖这些方法并使用不同的 hyperproperties 初始化模型。

```python
# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)

# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)

# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
```

in some cases, we may also pass entire PyTorch modules to the __init__ method, which you don’t want to save as hyperparameters due to their large size. If you didn’t call self.save_hyperparameters() or ignore parameters via save_hyperparameters(ignore=...), then you must pass the missing positional arguments or keyword arguments when calling load_from_checkpoint method:

在某些情况下，我们还可以将整个 PyTorch 模块传递给 _ _ init _ _ 方法，由于它们的大小，您不希望将它们保存为超参数。如果您没有调用 self. save _ hyper旦()或者通过 save _ hyper旦忽略参数(忽略 = ...) ，那么在调用 load _ from _ check 方法时，您必须传递缺少的位置参数或关键字参数:

```python
class LitAutoencoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        ...

    ...


model = LitAutoEncoder.load_from_checkpoint(PATH, encoder=encoder, decoder=decoder)
```

# nn.Module from checkpoint

Lightning checkpoints are fully compatible with plain torch nn.Modules.

闪电检查点完全兼容普通火炬模块。

```python
checkpoint = torch.load(CKPT_PATH)
print(checkpoint.keys())
```

In [10]:
checkpoint = torch.load("ckpts/lightning_logs/version_0/checkpoints/epoch=2-step=2250.ckpt")
print(checkpoint.keys())

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers'])


# Resume training state 恢复训练

If you don’t just want to load weights, but instead restore the full training, do the following:

如果你不只是想负重，而是想恢复完整的训练，做以下几件事:

```python
model = LitModel()
trainer = Trainer()

# automatically restores model, epoch, step, LR schedulers, etc...
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
```