# Layer Pytorch Lightning

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/layerai/examples/blob/main/tutorials/best_practices.ipynb) [![Layer Examples Github](https://badgen.net/badge/icon/github?icon=github&label)](https://github.com/layerai/examples/tree/main/tutorials/best-practices)

In [None]:
!pip install layer --upgrade
!pip install lightning --upgrade

In [None]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

dataset = MNIST(root="./MNIST", download=True, transform=transform)
training_set, test_set, validation_set = random_split(dataset, [55000,1000, 4000])

training_loader = DataLoader(training_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set)
validation_loader = DataLoader(validation_set, batch_size=64)

In [None]:
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from pytorch_lightning import LightningModule

class MNIST_LitModule(LightningModule):

    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
        '''method used to define our model parameters'''
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = Linear(28 * 28, n_layer_1)
        self.layer_2 = Linear(n_layer_1, n_layer_2)
        self.layer_3 = Linear(n_layer_2, n_classes)

        # loss
        self.loss = CrossEntropyLoss()

        # optimizer parameters
        self.lr = lr

        # save hyper-parameters to self.hparams (auto-logged by W&B)
        self.save_hyperparameters()

    def forward(self, x):
        '''method used for inference input -> output'''

        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # let's do 3 x (linear + relu)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)

        return x

    def training_step(self, batch, batch_idx):
        '''needs to return a loss from a single batch'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)

        return loss

    def validation_step(self, batch, batch_idx):
        '''used for logging metrics'''
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)

        # Let's return preds to use it in a custom callback
        return preds

    def test_step(self, batch, batch_idx):
        '''used for logging metrics'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)
    
    def configure_optimizers(self):
        '''defines model optimizer'''
        return Adam(self.parameters(), lr=self.lr)
    
    def _get_preds_loss_accuracy(self, batch):
        '''convenience function since train/valid/test steps are similar'''
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y)
        return preds, loss, acc

In [None]:
model = MNIST_LitModule(n_layer_1=64, n_layer_2=64)

In [None]:
from pytorch_lightning.loggers import LayerLogger
from pytorch_lightning import Trainer

# Get your API KEY here:
# https://app.layer.ai/me/settings/developer
layer_logger = LayerLogger(project_name='MNIST', api_key='[API_KEY]')

In [None]:
from pytorch_lightning.callbacks import Callback
 
class LogPredictionsCallback(Callback):
    
    def __init__(self):
        self.step = 0
    
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        """Called when the validation batch ends."""
 
        x, y = batch
        image = x[0]

        layer_logger.log_image(key='sample_image', image=image,step=self.step)
        self.step +=1

log_predictions_callback = LogPredictionsCallback()

In [None]:
def check_accuracy(model):
  correct = 0
  for i, (batch, targets) in enumerate(test_loader):
    out_predict = model(batch)
    pred = out_predict.max(1, keepdim=True)[1]
    correct += pred.eq(targets).float().sum().item()
  accuracy = correct/len(test_loader.dataset)
  return accuracy > 0.90

# Training without decorators

In [None]:
import layer
from layer.decorators.assertions import assert_true

def train(params):
  trainer = Trainer(
        logger=layer_logger,
        callbacks=[log_predictions_callback],
        max_epochs=params["epochs"],
        accelerator='gpu', devices=1
    )
  
  trainer.fit(model, training_loader, validation_loader)
  return model

params = {
    "epochs":1
}

asserted_func = assert_true(check_accuracy)(train)
train_func = layer.model("pl_model")(asserted_func)

model = train_func(params)

# Training with decorators

In [None]:
import layer

@layer.model("pl_model")
@assert_true(check_accuracy)
def train(params):
    trainer = Trainer(
        logger=layer_logger,
        callbacks=[log_predictions_callback],
        max_epochs=params["epochs"],
        accelerator='gpu', devices=1
    )
    
    trainer.fit(model, training_loader, validation_loader)
    return model

params = {
    "epochs":1
}
train(params)