[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/juansensio/blog/blob/master/074_pytorch_lightning_optim/074_pytorch_lightning_optim.ipynb)

# Pytorch Lightning - Optimización

Seguimos hablando sobre optimizar nuestro código en Pytorch. Hemos visto ya muchas técnicas que podemos usar, por suerte la mayoría de ellas ya están implementadas en `Pytorch Lightning`, por lo que no tenemos que comernos mucho la cabeza.

In [1]:
import os
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
import torch
from skimage import io 
from torch.utils.data import DataLoader

class Dataset(torch.utils.data.Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, ix):
        img = io.imread(self.images[ix])[...,(3,2,1)]
        img = torch.tensor(img / 4000, dtype=torch.float).clip(0,1).permute(2,0,1)  
        label = torch.tensor(self.labels[ix], dtype=torch.long)        
        return img, label
    
class DataModule(pl.LightningDataModule):

    def __init__(self, path='./data', batch_size=1024, num_workers=20, test_size=0.2, random_state=42):
        super().__init__()
        self.path = path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.test_size = test_size 
        self.random_state = random_state
        
        
    def setup(self, stage=None):

        self.classes = sorted(os.listdir(self.path))

        print("Generating images and labels ...")
        images, encoded = [], []
        for ix, label in enumerate(self.classes):
            _images = os.listdir(f'{self.path}/{label}')
            images += [f'{self.path}/{label}/{img}' for img in _images]
            encoded += [ix]*len(_images)
        print(f'Number of images: {len(images)}')

         # train / val split
        print("Generating train / val splits ...")
        train_images, val_images, train_labels, val_labels = train_test_split(
            images,
            encoded,
            stratify=encoded,
            test_size=self.test_size,
            random_state=self.random_state
        )

        print("Training samples: ", len(train_labels))
        print("Validation samples: ", len(val_labels))
        
        self.train_ds = Dataset(train_images, train_labels)
        self.val_ds = Dataset(val_images, val_labels)

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True
        )

In [2]:
dm = DataModule()
dm.setup()

imgs, labels = next(iter(dm.train_dataloader()))
imgs.shape, labels.shape

Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples:  21600
Validation samples:  5400


(torch.Size([1024, 3, 64, 64]), torch.Size([1024]))

In [3]:
import torch.nn.functional as F
import timm

class Model(pl.LightningModule):

    def __init__(self, n_outputs=10, prof=None):
        super().__init__()
        self.model = timm.create_model('tf_efficientnet_b5', pretrained=True, num_classes=n_outputs)
        self.prof = prof

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        self.log('loss', loss)
        self.log('acc', acc, prog_bar=True)
        if self.prof is not None:
            self.prof.step()
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def shared_step(self, batch):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = (torch.argmax(y_hat, axis=1) == y).sum().item() / y.size(0)
        return loss, acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

In [4]:
model = Model()
dm = DataModule()

trainer = pl.Trainer(gpus=1, precision=16, max_epochs=3)
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type         | Params
---------------------------------------
0 | model | EfficientNet | 28.4 M
---------------------------------------
28.4 M    Trainable params
0         Non-trainable params
28.4 M    Total params
113.445   Total estimated model params size (MB)


Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples:  21600
Validation samples:  5400


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




Podemos usar una estrategia distribuida a través del parámetro `accelerator`. En este caso usaremos el valor `dp` para una estrategia `Data Parallel`. Puedes ver el resto de estrategias [aquí](https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html).

In [5]:
model = Model()
dm = DataModule(batch_size=2048)
trainer = pl.Trainer(gpus=2, accelerator='dp', precision=16, max_epochs=3)
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type         | Params
---------------------------------------
0 | model | EfficientNet | 28.4 M
---------------------------------------
28.4 M    Trainable params
0         Non-trainable params
28.4 M    Total params
113.445   Total estimated model params size (MB)


Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples:  21600
Validation samples:  5400


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




Si bien el entrenamiento es ligeramente más lento que comparado con el código en `Pytorch` puro, la flexibilidad y funcionalidad que nos aporta `Pytorch Lightning` puede valer la pena en la mayoría de casos. Puedes ver un ejemplo usando `Distributed Data Parallel`[aquí](https://github.com/juansensio/blog/blob/master/071_pytorch_lightning_optim/ddp.py).

## Profiling

`Pytorch Lightning` también nos ofrece alternativas a la hora de *tracker* nuestro código en la búsqueda de cuellos de botella.

In [6]:
model = Model()
dm = DataModule()

trainer = pl.Trainer(gpus=1, precision=16, max_epochs=1, profiler='simple')
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type         | Params
---------------------------------------
0 | model | EfficientNet | 28.4 M
---------------------------------------
28.4 M    Trainable params
0         Non-trainable params
28.4 M    Total params
113.445   Total estimated model params size (MB)


Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples:  21600
Validation samples:  5400


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

FIT Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  15.719         	|  100 %          	|
--------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  13.826         	|1              	|  13.826         	|  87.956         	|
run_training_batch                 	|  0.41584        	|22             	|  9.1485         	|  58.2           	|
optimizer_step_and_closure_0       	|  0.41555        	|22             	|  9.1421         	|  58.16          	|
training_step_and_backward         	|  0.24678        	|22             	|  5.4291         	|  34.538         	|
backward                           




Puedes ver más opciones [aquí](https://pytorch-lightning.readthedocs.io/en/stable/advanced/profiler.html).

## Resumen

Optimizar nuestro código en `Pytorch` es muy importante, y para ello tenemos muchas herramientas y técnicas a nuestro alcance para exprimir al máximo nuestras redes. `Pytorch Lightning` nos facilita mucho la vida a la hora de utilizar estas técnicas de manera transparente sin necesidad de hacer grandes cambios en nuestro código, mientras que en `Pytorch` tendremos que bucear en la documentación y ejemplos para poder aprovechar todo lo que hemos ido viendo en los últimos posts (dando como resultado un código muy largo y difícil de entender). A través del objeto `Trainer` podermos definir diferentes estrategias de entrenamiento distribuido de manera sencilla, y las opciones de `profiling` nos ayudarán a encontrar los puntos débiles de nuestro código para poder corregirlos.