In [1]:
import os
from sklearn import datasets
import numpy as np

import torch
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torchmetrics import Accuracy
from torch import nn
from torch.nn import functional as F

from torch.utils.data import DataLoader, random_split
from torchvision import transforms

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import mlflow

from omegaconf import DictConfig, ListConfig, OmegaConf
import hydra

モデルの読み込み

In [2]:
# 1. 従来通りのModelを作成する

class IrisNet(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.x1   = nn.Linear(in_features=4, out_features=cfg.model.hidden_size )
        self.act1 = nn.ReLU()
        self.x2   = nn.Linear(in_features=cfg.model.hidden_size, out_features=3)
        self.act2 = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.x1(x)
        x = self.act1(x)
        x = self.x2(x)
        x = self.act2(x)
        return x

In [3]:
#2. train/valid stepを設定する、plmoduleを作成する

class PLIrisModel(pl.LightningModule):
    def __init__(self, cfg: DictConfig, experiment_name="test1"):
        super().__init__()
        self.cfg     = cfg

        self.net     = IrisNet(cfg=cfg)
        self.mtrics  = Accuracy()

        ### MLFlow ###
        #self.writer = MlflowWriter(experiment_name=experiment_name)
        #self.writer.create_new_run()
        #self.writer.log_params_from_omegaconf_dict(cfg)

    def forward(self, x):
        return self.net(x.float())

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = F.nll_loss(pred, y)
        batch_loss = loss * x.size(0)
        return {"loss": loss, "y": y, "pred": pred.detach(), "batch_loss": batch_loss.detach()}
    
    def training_epoch_end(self, train_step_outputs):
        preds      = torch.cat( [trn["pred"] for trn in train_step_outputs], dim=0 )
        ys         = torch.cat( [trn["y"] for trn in train_step_outputs], dim=0 )
        epoch_loss = sum( [trn["batch_loss"] for trn in train_step_outputs] ) / ys.size(0)

        acc = self.mtrics(preds, ys)
        print('-------- Current Epoch {} --------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loss, acc))

        ### MLFlow ###
        #self.writer.log_metric("trn_loss", float(epoch_loss) )
        #self.writer.log_metric("trn_acc",  float(acc))

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = F.nll_loss(pred, y)
        batch_loss = loss * x.size(0)
        return {"y": y, "pred": pred.detach(), "batch_loss": batch_loss.detach()}
    
    def validation_epoch_end(self, valid_step_outputs):
        preds      = torch.cat( [val["pred"] for val in valid_step_outputs], dim=0 )
        ys         = torch.cat( [val["y"] for val in valid_step_outputs], dim=0 )
        epoch_loss = sum( [val["batch_loss"] for val in valid_step_outputs] ) / ys.size(0)

        acc = self.mtrics(preds, ys)
        print('-------- Current Epoch {} --------'.format(self.current_epoch + 1))
        print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))

        ### for CallBacks ###
        self.log("val_loss", epoch_loss)
        self.log("val_acc", acc)
        
        ### MLFlow ###
        #self.writer.log_metric("val_loss", float(epoch_loss) )
        #self.writer.log_metric("val_acc",  float(acc))
    
    def configure_optimizers(self):
        lr         = self.cfg.optim.lr
        optim_name = self.cfg.optim.optim_name
        optimizer  = getattr(torch.optim, optim_name)(self.parameters(), lr=lr)
        return optimizer

In [4]:
# check_pointからのモデルの読み込み
model_path = "models/best-checkpoint.ckpt"
cfg = OmegaConf.load("configs/config.yaml")

model = PLIrisModel.load_from_checkpoint(model_path, cfg=cfg)

In [5]:
sample_input = torch.randn((1, 4))

torch_pred = model(sample_input)
print("torch model prediction is", torch_pred)


torch model prediction is tensor([[0.3982, 0.5937, 0.0080]], grad_fn=<SoftmaxBackward0>)


ONNX への変換

In [6]:
import torch.onnx

def ExportONNX(model, dummy_input):
    model.eval()

    torch.onnx.export(
        model,
        dummy_input,
        "./models/SampleModel.onnx",
        export_params=True,
        opset_version=10,
        do_constant_folding=True,
        input_names = ['input'],  
        output_names = ['output'], 
        dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                      'output' : {0 : 'batch_size'}}
        )   
    print(" ") 
    print('Model has been converted to ONNX') 

In [7]:
ExportONNX(model, dummy_input=sample_input)

 
Model has been converted to ONNX


ONNXモデルの読み込み、実行テスト

In [8]:
onnx_model_path = "models/SampleModel.onnx"

import onnxruntime as ort 
ort_session = ort.InferenceSession( onnx_model_path )

# Inputの名前を、モデルの変換時のものと一致させる
ort_inputs = { "input": list(sample_input.numpy().tolist()) } 
#ort_inputs = { "input": [[0.1, 0.2, 0, 1]] } #普通の数字のリストのInputでも推論可能

onnx_pred = ort_session.run( None, ort_inputs )[0]
# Noneで、モデルのすべてのOutputを取得する
# Noneではなく、ONNX変換時に指定したoutput名を入れればそれを取得可能

print("ONNX_model_prediction is", onnx_pred)

ONNX_model_prediction is [[0.3982159  0.5937354  0.00804876]]
