# Pytorch Lightning
一个基于pytorch的深度学习训练框架，对训练流程进行了封装，可以简化代码。
核心组件：`LightningModule`、`Trainer`

## LightningModule
`LightningModule`将Pytorch代码组织为6个部分：
- 计算（`init`)
- 训练（`training_step`)
- 验证循环（`validation_step`）
- 测试循环（`test_step`）
- 预测循环（`predict_step`）
- 优化器和学习率策略（`configure_optimizers`)

In [None]:
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

以上的模型继承了`pl.LightningModule`，可以通过`Trainer`对象进行调用。

train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer(max_epochs=1)
model = LitModel()

trainer.fit(model, train_dataloaders=train_loader)

`LightningModule`封装了很多有用的方法，其中需要掌握的核心方法有：
|Name|Description|
|----|-----------|
|init|定义计算|
|forward|仅用于前向推理|
|training_step|完整的训练循环|
|validation_step|完整的验证循环|
|test_step|完整的测试循环|
|predict_step|完整的预测循环|
|configure_optimizers|定义优化器和学习率策略|

### 训练
通过覆写`training_step()`方法激活训练循环：

In [None]:
class LitClassifier(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        return loss # 需要返回loss，或包含key为loss的字典

使用`self.log()`方法记录训练指标

In [None]:
def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)

    # logs metrics for each training_step,
    # and the average across the epoch, to the progress bar and logger
    self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss

程序会将每个step的输出收集到列表中，并传入`training_epoch_end(self, training_step_outputs)`方法，方便进行epoch层级的计算。

In [None]:
def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    preds = ...
    return {"loss": loss, "other_stuff": preds}


def training_epoch_end(self, training_step_outputs):
    all_preds = torch.stack(training_step_outputs)
    ...

多设备运算可以通过覆写`training_step_end(self, batch_parts)`方法定义每个step后的运算。

In [None]:

def fit(self):
    if global_rank == 0:
        # prepare data is called on GLOBAL_ZERO only
        prepare_data()

    configure_callbacks()

    with parallel(devices):
        # devices can be GPUs, TPUs, ...
        train_on_device(model)


def train_on_device(model):
    # called PER DEVICE
    setup("fit")
    configure_optimizers()
    on_fit_start()

    # the sanity check runs here

    on_train_start()
    for epoch in epochs:
        fit_loop()
    on_train_end()

    on_fit_end()
    teardown("fit")


def fit_loop():
    on_train_epoch_start()

    for batch in train_dataloader():
        on_train_batch_start()

        on_before_batch_transfer()
        transfer_batch_to_device()
        on_after_batch_transfer()

        training_step()

        on_before_zero_grad()
        optimizer_zero_grad()

        on_before_backward()
        backward()
        on_after_backward()

        on_before_optimizer_step()
        configure_gradient_clipping()
        optimizer_step()

        on_train_batch_end()

        if should_check_val:
            val_loop()
    # end training epoch
    training_epoch_end()

    on_train_epoch_end()


def val_loop():
    on_validation_model_eval()  # calls `model.eval()`
    torch.set_grad_enabled(False)

    on_validation_start()
    on_validation_epoch_start()

    val_outs = []
    for batch_idx, batch in enumerate(val_dataloader()):
        on_validation_batch_start(batch, batch_idx)

        batch = on_before_batch_transfer(batch)
        batch = transfer_batch_to_device(batch)
        batch = on_after_batch_transfer(batch)

        out = validation_step(batch, batch_idx)

        on_validation_batch_end(batch, batch_idx)
        val_outs.append(out)

    validation_epoch_end(val_outs)

    on_validation_epoch_end()
    on_validation_end()

    # set up for train
    on_validation_model_train()  # calls `model.train()`
    torch.set_grad_enabled(True)
    

### Validation
通过覆写`validation_step()`方法激活验证循环。

In [None]:
class LitModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss)
        # 如果不需要处理validation_step的输出，可以不返回任何值

如果要处理返回值，类似于`training_step()`。

In [1]:
def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    return pred


def validation_epoch_end(self, validation_step_outputs):
    all_preds = torch.stack(validation_step_outputs)
    ...

### Testing
测试（test）类似于验证（validation），需要覆写`test_step()`方法。

### Inference
默认`predict_step()`会运行`forward()`方法。可以通过简单覆写`predict_step()`定义其行为。

In [None]:
class LitMCdropoutModel(pl.LightningModule):
    def __init__(self, model, mc_iteration):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration

    def predict_step(self, batch, batch_idx):
        # enable Monte Carlo Dropout
        self.dropout.train()

        # take average of `self.mc_iteration` iterations
        pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0)
        return pred

可以使用两种方式调用`predict()`

In [2]:
# call after training
trainer = Trainer()
trainer.fit(model)

# automatically auto-loads the best weights from the previous run
predictions = trainer.predict(dataloaders=predict_dataloader)

# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
predictions = trainer.predict(model, dataloaders=test_dataloader)

NameError: name 'Trainer' is not defined

如果想要执行推理，可以为`LightningModule`添加`forward()`方法。
需要注意的是，这种情况下需要手动调用`eval()`方法和`no_grad()`上下文管理器。

In [None]:
class Autoencoder(pl.LightningModule):
    def forward(self, x):
        return self.decoder(x)


model = Autoencoder()
model.eval()
with torch.no_grad():
    reconstruction = model(embedding)

在生产中，可能存在多模型

In [None]:
from torchmetrics.functional import accuracy


class ClassificationTask(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model # 传入的模型

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"val_acc": acc, "val_loss": loss}
        self.log_dict(metrics)
        return metrics

    def test_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"test_acc": acc, "test_loss": loss}
        self.log_dict(metrics)
        return metrics

    def _shared_eval_step(self, batch, batch_idx): # 用于validation_step和test_step的评价步骤
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)
        return loss, acc

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        y_hat = self.model(x)
        return y_hat

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.02)

### 保存超参数
在`LightningModule`的`__init__`方法中使用`save_hyperparameters()`，可以存储`self.hparams`当中的所有属性。这些参数同时会存储在模型的`checkpoint`当中，大大方便了训练之后模型的恢复。如果`loggers`支持，将会自动记录`self.hparams`中的内容。

In [None]:
class LitMNIST(LightningModule):
    def __init__(self, layer_1_dim=128, learning_rate=1e-2):
        super().__init__()
        # call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint
        self.save_hyperparameters()

        # equivalent
        self.save_hyperparameters("layer_1_dim", "learning_rate")

        # Now possible to access layer_1_dim from hparams
        self.hparams.layer_1_dim

通常默认情况下任何传给`__init__`构造方法的参数都被认为是`LightningModule`的超参数，但是有时候有些参数不需要存储（例如不能够序列化存储的对象），可以明确指出排除这些参数。

In [None]:
class LitMNIST(LightningModule):
    def __init__(self, loss_fx, generator_network, layer_1_dim=128):
        super().__init__()
        self.layer_1_dim = layer_1_dim
        self.loss_fx = loss_fx

        # call this to save only (layer_1_dim=128) to the checkpoint
        self.save_hyperparameters("layer_1_dim")

        # equivalent
        self.save_hyperparameters(ignore=["loss_fx", "generator_network"])

load_from_checkpoint
使用`save_hyperparameters()`自动存储了超参数的`LightningModule`可以通过`load_from_checkpoint()`方便的恢复。如果有参数被排除，那么需要在恢复时单独提供。

In [None]:
# to load specify the other args
model = LitMNIST.load_from_checkpoint(PATH, loss_fx=torch.nn.SomeOtherLoss, generator_network=MyGenerator())