In [None]:
# Loading the desired already trained checkpoint

from src import *

module = MNISTModule.load_from_checkpoint('checkpoints/006-val_loss=0.14715-epoch=7.ckpt')
module.mlp

In [None]:
# Evaluate the model before exporting it

import torch 

dm = MNISTDataModule(**module.hparams['datamodule'])
dm.setup()

def torch_eval():
    module.eval()
    with torch.no_grad():
        preds, labels = torch.tensor([]), torch.tensor([])
        for imgs, _labels in dm.val_dataloader():
            outputs = module.predict(imgs) > 0.5
            preds = torch.cat([preds, outputs.cpu().long()])
            labels = torch.cat([labels, _labels])

    acc = (preds == labels).float().mean()
    return acc.item()

torch_eval()

In [None]:
# Pytorch Lightning makes it easy to export the model to ONNX format

input_sample = torch.randint(0, 255, (1, 28, 28), dtype=torch.uint8)
module.to_onnx(
    'models/binary_classifier_3.onnx', # nombre del modelo a guardar
    input_sample, # ejemplo de entrada
    export_params=True, # exportar los parametros del modelo
    opset_version=11, # en función de las ops en el modelo, se puede cambiar el opset
    input_names = ['input'], # nombre de la entrada	para usar en producción
    output_names = ['output'],  # nombre de la salida para usar en producción
    dynamic_axes={  # para poder tener diferentes batch sizes
        'input' : {0 : 'batch_size'},
        'output' : {0 : 'batch_size'},
    },
)

In [None]:
# ONNXRuntime allows to train the model on Python and then make it run in a different environment
# like JS, Android, iOS, etc.
# It can be installed with $ pip install onnxruntime
# It's also lighter than Pytorch or TF so it's easier to deploy within an image for example

import onnxruntime as ort 
import numpy as np

ort_session = ort.InferenceSession('models/binary_classifier_3.onnx')

ort_inputs = {
    "input": np.random.randint(0, 255, (10, 28, 28), dtype=np.uint8),
}

ort_output = ort_session.run(['output'], ort_inputs)
ort_output[0].shape

In [None]:
# Validating the onnx model

def sigmoid(x):
    return 1 / (1 + np.exp(-x))
    
def onnx_eval():
    with torch.no_grad():
        preds, labels = [], torch.tensor([])
        for imgs, _labels in dm.val_dataloader():
            ort_inputs = {
                "input": imgs.numpy(),
            }
            ort_output = ort_session.run(['output'], ort_inputs)[0]
            outputs = sigmoid(ort_output) > 0.5
            preds += outputs.astype(int).tolist()
            labels = torch.cat([labels, _labels])
    acc = (np.array(preds) == labels.numpy()).astype(float).mean()
    return acc 

onnx_eval()

In [None]:
# Versioning models

# $ dvc add models
# $ dvc push
# $ dvc pull models.dvc

# $ git add .
# $ git commit -m "new model"
# $ git push
# $ git tag -a v3 -m "version 3"
# $ git push origin --tags