# PyTorch Lightning V1.2.0 - DeepSpeed, Pruning, Quantization, SWA
Reference:
* https://medium.com/pytorch/pytorch-lightning-v1-2-0-43a032ade82b

In [None]:
import torch
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import ModelPruning, QuantizationAwareTraining
from pytorch_lightning.profiler.profilers import PyTorchProfiler

## Pruning

In [None]:
trainer = Trainer(callbacks=[ModelPruning("l1_unstructured")])

## Quantization

In [None]:
class RegressionModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer_0 = torch.nn.Linear(16, 64)
        self.layer_0a = torch.nn.ReLU()
        self.layer_1 = torch.nn.Linear(64, 64)
        self.layer_1a = torch.nn.ReLU()
        self.layer_end = torch.nn.Linear(64, 1)

    def forward(self, x):
        x = self.layer_0(x)
        x = self.layer_0a(x)
        x = self.layer_1(x)
        x = self.layer_1a(x)
        x = self.layer_end(x)
        return x

In [None]:
qcb = QuantizationAwareTraining(
    # specification of quant estimation quaity
    observer_type='histogram',
    # specify which layers shall be merged together to increase efficiency
    modules_to_fuse=[(f'layer_{i}', f'layer_{i}a') for i in range(2)],
    input_compatible=False,
)

trainer = Trainer(callbacks=[qcb])
qmodel = RegressionModel()
trainer.fit(qmodel, ...)

# take sample data batch, for example from you test dataloader
batch = iter(my_dataloader()).next()
# using fully quantized model, you need to apply quantization layer
qmodel(qmodel.quant(batch[0]))

# converting model to torchscript
tsmodel = qmodel.to_torchscript()
# even converted model preserve created quantisation layer which you can/should use
tsmodel(tsmodel.quant(batch[0]))

## Stochastic Weight Averaging 

In [None]:
trainer = Trainer(stochastic_weight_avg=True)

In [None]:
from pytorch_lightning.callbacks import StochasticWeightAveraging
trainer = Trainer(callbacks=[StochasticWeightAveraging()])

## Finetuning

In [None]:
from pytorch_lightning.callbacks import BaseFinetuning

class MyBackboneFinetuning(BaseFinetuning):

    def __init__(self, unfreeze_backbone_at_epoch: int = 5, train_bn: bool = True, backbone_lr: float = 1e-5):
        self._unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch
        self._train_bn = train_bn
        self._backbone_lr = backbone_lr

    def freeze_before_training(self, pl_module: LightningModule):
        self.freeze(pl_module.backbone, train_bn=self._train_bn)

    def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
        """Called on every epoch starts."""
        if epoch == self.unfreeze_backbone_at_epoch:
            self.unfreeze_and_add_param_group(
                pl_module.backbone,
                optimizer,
                lr=self._backbone_lr,
                train_bn=self.train_bn,
            )

trainer = Trainer(callbacks=[MyBackboneFinetuning()])

## PyTorch Geometric integration

In [None]:
# ! pip install torch-sparse -f https://data.pyg.org/whl/torch-1.9.1+cu102.html
# ! pip install torch-geometric
import os.path as osp

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.nn import BatchNorm1d
from torchmetrics import Accuracy

from torch_geometric import seed_everything
from torch_geometric.data import LightningNodeData
from torch_geometric.datasets import Reddit
from torch_geometric.nn import GraphSAGE

In [None]:
datamodule = Reddit('data/Reddit')
model = GraphSAGE(datamodule.num_features, datamodule.num_classes)

trainer = Trainer(gpus=2, accelerator='ddp', max_epochs=10)
trainer.fit(model, datamodule=datamodule)

## New Accelerator/plugins API

In [None]:
trainer = Trainer(gpus=1, accelerator="ddp_spawn", precision=16)

In [None]:
# Pass in a plugin
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin

plugins=[NativeMixedPrecisionPlugin(precision=16, device="cuda[0]")]
trainer = Trainer(gpus=1, accelerator='ddp_spawn', plugins=plugins)