In [5]:
from json import encoder

import torch.utils.data as data
import torch
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
import os
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning as L

In [6]:
# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

In [7]:
# PyTorch
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

# Define the test loop

In [35]:
# PyTorch-Lightning
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss
    
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)
        
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    

In [9]:
from torch.utils.data import DataLoader

# initialize the Trainer
trainer = L.Trainer()
model = LitAutoEncoder(Encoder(), Decoder())

# test the model
trainer.test(model, dataloaders=DataLoader(test_set))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
C:\Users\alpha\.conda\envs\vp\Lib\site-packages\lightning\pytorch\trainer\connectors\logger_connector\logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeForce RTX 4070 Ti SUPER') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_

Testing: |                                                                                                    …

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter serve

# Add a validation loop

In [10]:
# use a0% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

In [16]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_set,num_workers=0, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_set,num_workers=0, batch_size=64, shuffle=True)
model = LitAutoEncoder(Encoder(), Decoder())

# train with both splits
trainer = L.Trainer(accelerator='gpu', max_epochs=3)
trainer.fit(model, train_loader, valid_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 50.4 K
1 | decoder | Decoder | 51.2 K
------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=3` reached.


In [12]:
torch.cuda.is_available()


True

# What is a checkpoint?
When a model is training, the performance changes as it continues to see more data. It is a best practice to save the state of a model throughout the training process. This gives you a version of the model, a checkpoint, at each key point during the development of the model. Once training has completed, use the checkpoint that corresponds to the best performance you found during the training process.

Checkpoints also enable your training to resume from where it was in case the training process is interrupted.

PyTorch Lightning checkpoints are fully usable in plain PyTorch.

> 当一个模型正在训练时，性能随着它继续看到更多的数据而改变。在整个培训过程中保存模型的状态是一种最佳实践。这将在模型开发期间的每个关键点为您提供模型的一个版本，即一个检查点。一旦培训完成，使用与培训过程中发现的最佳性能相对应的检查点。检查点还可以使您的培训从原来的地方恢复，以防培训过程中断。火炬闪电检查点在普通火炬中是完全可用的。

# Contents of a checkpoint

A Lightning checkpoint contains a dump of the model’s entire internal state. Unlike plain PyTorch, Lightning saves everything you need to restore a model even in the most complex distributed training environments.

> “闪电”检查点包含模型整个内部状态的转储。与普通的 PyTorch 不同，即使在最复杂的分布式培训环境中，“闪电”也会保存恢复模型所需的所有内容。

Inside a Lightning checkpoint you’ll find:
- 16-bit scaling factor (if using 16-bit precision training) / 16位缩放因子
- Current epoch
- Global step
- LightningModule’s state_dict
- State of all optimizers
- State of all learning rate schedulers
- State of all callbacks (for stateful callbacks)
- State of datamodule (for stateful datamodules)
- The hyperparameters (init arguments) with which the model was created
- The hyperparameters (init arguments) with which the datamodule was created
- State of Loops

# Save a checkpoint
# simply by using the Trainer you get automatic checkpointing
`trainer = Trainer()`
# saves checkpoints to 'some/path/' at every epoch end
`trainer = Trainer(default_root_dir="some/path/")`

In [18]:
trainer = L.Trainer(accelerator='gpu', max_epochs=3, default_root_dir='./ckpts')
trainer.fit(model, train_loader, valid_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: ckpts\lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 50.4 K
1 | decoder | Decoder | 51.2 K
------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=3` reached.


In [23]:

model = LitAutoEncoder(Encoder(), Decoder()).load_from_checkpoint("ckpts/lightning_logs/version_0/checkpoints/epoch=2-step=2250.ckpt")
# disable randomness, dropout, etc...
model.eval()


TypeError: The classmethod `LitAutoEncoder.load_from_checkpoint` cannot be called on an instance. Please call it on the class type and make sure the return value is used.

In [24]:
checkpoint = torch.load("ckpts/lightning_logs/version_0/checkpoints/epoch=2-step=2250.ckpt", map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])

KeyError: 'hyper_parameters'

# Save hyperparameters

In [26]:
class MyLightningModule(L.LightningModule):
    def __init__(self, learning_rate, another_parameter, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

In [27]:
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value, "another_parameter": the_other_value}

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

In [28]:
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
print(model.learning_rate)

FileNotFoundError: [Errno 2] No such file or directory: 'C:/path/to/checkpoint.ckpt'

# Initialize with other parameters
> 在某些情况下，我们还可以将整个 PyTorch 模块传递给 _ _ init _ _ 方法，由于它们的大小，您不希望将它们保存为超参数。如果您没有调用 self. save _ hyper旦()或者通过 save _ hyper旦忽略参数(忽略 = ...) ，那么在调用 load _ from _ check 方法时，您必须传递缺少的位置参数或关键字参数:

In [29]:
# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)

# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)

# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

NameError: name 'LitModel' is not defined

In [30]:
class LitAutoencoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        ...

    ...


model = LitAutoEncoder.load_/from_checkpoint(PATH, encoder=encoder, decoder=decoder)

AttributeError: type object 'LitAutoEncoder' has no attribute 'load_'

# nn.Module from checkpoint

In [31]:
checkpoint = torch.load("ckpts/lightning_logs/version_0/checkpoints/epoch=2-step=2250.ckpt")
print(checkpoint.keys())

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers'])


In [32]:
autoencoder = LitAutoEncoder(Encoder(), Decoder())

In [33]:
checkpoint = torch.load("ckpts/lightning_logs/version_0/checkpoints/epoch=2-step=2250.ckpt")
encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("encoder.")}
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("decoder.")}

# EarlyStopping Callback
To enable it:

- Import EarlyStopping callback.
- Log the metric you want to monitor using log() method.
- Init the callback, and set monitor to the logged metric of your choice.
- Set the mode based on the metric needs to be monitored.
- Pass the EarlyStopping callback to the Trainer callbacks flag.

In [38]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

model = LitAutoEncoder.load_from_checkpoint("ckpts/lightning_logs/version_0/checkpoints/epoch=2-step=2250.ckpt",encoder=Encoder(), decoder=Decoder())

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=3,
    verbose=False,
    mode="min")
trainer = L.Trainer(callbacks=early_stop_callback)
trainer.fit(model, train_loader, valid_loader)



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 50.4 K
1 | decoder | Decoder | 51.2 K
------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

# Additional parameters that stop training at extreme points:
- **stopping_threshold**: Stops training immediately once the monitored quantity reaches this threshold. It is useful when we know that going beyond a certain optimal value does not further benefit us.
- **divergence_threshold**: Stops training as soon as the monitored quantity becomes worse than this threshold. When reaching a value this bad, we believes the model cannot recover anymore and it is better to stop early and run with different initial conditions.
- **check_finite**: When turned on, it stops training if the monitored metric becomes NaN or infinite.
- **check_on_train_epoch_end**: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within training-specific hooks on epoch-level.
> - **stopping_threshold**: 一旦监控的量达到这个阈值，立即停止训练。当我们知道超过某个最优值不会再带来任何好处时，这非常有用。
> - **divergence_threshold**: 一旦监控的量变得比这个阈值更糟糕，就停止训练。当达到这样糟糕的值时，我们认为模型无法再恢复，最好提前停止并用不同的初始条件重新运行。
> - **check_finite**: 打开时，如果监控的指标变为 NaN 或无穷大，则停止训练。
> - **check_on_train_epoch_end**: 打开时，它会在每个训练 epoch 结束时检查指标。仅当您监控在 epoch 级别的训练特定钩子中记录的任何指标时，才使用此选项。

In case you need early stopping in a different part of training, subclass EarlyStopping and change where it is called:
如果你需要在训练的不同部分提前停止，子类 EarlyStops 可以改变它的名称:

In [None]:
class MyEarlyStopping(EarlyStopping):
    def on_validation_end(self, trainer, pl_module):
        # override this to disable early stopping at the end of val loop
        pass

    def on_train_end(self, trainer, pl_module):
        # instead, do it at the end of training loop
        self._run_early_stopping_check(trainer)