# Introduction
In this notebook we will demonstrates how to use BigDL-Nano to accelerate PyTorch or PyTorch-Lightning applications on training workloads.

### Prepare Environment
Before you start with Apis delivered by bigdl-nano, you have to make sure BigDL-Nano is correctly installed for PyTorch. If not, please follow [this](../../../../../docs/readthedocs/source/doc/Nano/Overview/nano.md) to set up your environment.<br>

We used pre-built cifar10 datamodule from lightning-bolts for demo. You are required to install lightnig-bolts as follows:
```python
pip install lightning-bolts
```

### Load Cifar10 DataModule
Import the existing data module from bolts and modify the train and test transforms.
You could access [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) for a view of the whole dataset.
Leveraging OpenCV and libjpeg-turbo, BigDL-Nano can accelerate computer vision data pipelines by providing a drop-in replacement of torch_vision's `datasets` and `transforms`.

In [None]:
import os
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from bigdl.nano.pytorch.vision import transforms
DATA_PATH = os.environ.get('DATA_PATH', '.')
BATCH_SIZE = 64
DEV_RUN = bool(os.environ.get('DEV_RUN', False))
train_transforms = transforms.Compose(
    [
        transforms.RandomCrop(32, 4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        cifar10_normalization()
    ]
)
test_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        cifar10_normalization()
    ]
)
cifar10_dm = CIFAR10DataModule(
    data_dir = DATA_PATH,
    batch_size = BATCH_SIZE,
    train_transforms = train_transforms,
    val_transforms = test_transforms,
    test_transforms = test_transforms
)


###  Custom Model
Modify the pre-existing Resnet architecture from TorchVision. The pre-existing architecture is based on ImageNet images (224x224) as input. So we need to modify it for CIFAR10 images (32x32).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
from torchvision.models import resnet18
from pytorch_lightning import LightningModule, seed_everything
from torchmetrics.functional import accuracy
seed_everything(7)
def create_model():
    model = resnet18(pretrained=False, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    return model

class LitResnet(LightningModule):
    def __init__(self, learning_rate=0.05, num_processes=1):
        super().__init__()

        self.save_hyperparameters()
        self.model = create_model()

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams.learning_rate,
            momentum=0.9,
            weight_decay=5e-4,
        )
        steps_per_epoch = 45000 // BATCH_SIZE // self.hparams.num_processes
        scheduler_dict = {
            "scheduler": OneCycleLR(
                optimizer,
                0.1,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=steps_per_epoch,
            ),
            "interval": "step",
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}

### Train with Nano Apis
The PyTorch Trainer (`bigdl.nano.pytorch.Trainer`) is the place where we integrate most optimizations. It extends PyTorch Lightning's Trainer and has a few more parameters and methods specific to BigDL-Nano. The Trainer can be directly used to train a `LightningModule`.

In [None]:
from bigdl.nano.pytorch import Trainer
model = LitResnet()
model.datamodule = cifar10_dm
trainer = Trainer(max_epochs=30,
                  fast_dev_run=DEV_RUN) # run model once quickly in test
fit_time_basic = %timeit -n 1 -r 1 -o \
trainer.fit(model, datamodule=cifar10_dm)
metric_basic = trainer.test(model, datamodule=cifar10_dm)


Intel Extension for Pytorch (a.k.a. IPEX) link extends PyTorch with optimizations for an extra performance boost on Intel hardware. BigDL-Nano integrates IPEX through the Trainer. Users can turn on IPEX by setting use_ipex=True.

In [None]:
model = LitResnet()
model.datamodule = cifar10_dm
trainer = Trainer(max_epochs=30, 
                  use_ipex=True,
                  fast_dev_run=DEV_RUN)
fit_time_ipex = %timeit -n 1 -r 1 -o \
trainer.fit(model, datamodule=cifar10_dm)
metric_ipex = trainer.test(model, datamodule=cifar10_dm)

Setting use_ipex=True will Apply optimizations at Python frontend to the given model (nn.Module), as well as the given optimizer (optional). Optimizations include conv+bn folding (for inference only), weight prepacking and so on.

Increase the number of processes on distributed training to accelerate training.

In [None]:
model = LitResnet(learning_rate=0.1, num_processes=4)
model.datamodule = cifar10_dm
trainer = Trainer(max_epochs=30, 
                  num_processes=4,
                  fast_dev_run=DEV_RUN)
fit_time_dit = %timeit -n 1 -r 1 -o \
trainer.fit(model, datamodule=cifar10_dm)
metric_dit = trainer.test(model, datamodule=cifar10_dm)

Enable both distributed training and ipex

In [None]:
model = LitResnet(learning_rate=0.1, num_processes=4)
model.datamodule = cifar10_dm
trainer = Trainer(max_epochs=30, 
                  num_processes=4,
                  use_ipex=True,
                  fast_dev_run=DEV_RUN)
fit_time_dit_ipex = %timeit -n 1 -r 1 -o \
trainer.fit(model, datamodule=cifar10_dm)
metric_dit_ipex = trainer.test(model, datamodule=cifar10_dm)