In [None]:
import os
import torch
import mlflow
import numpy as np
import pandas as pd
from torch import nn
from pprint import pprint
import onnxruntime as ort
import plotly.express as px
from hydra import compose, initialize

os.chdir('..')
from src.utils import train_setup, evaluate

In [None]:
# Specify metrics and device
metrics = {'mse': nn.MSELoss(), 'mae': nn.L1Loss()}
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

## Model Inference with MLFlow

In [None]:
with initialize(version_base=None, config_path='../config'):
    # Load config/inference.yaml
    config = compose(config_name='inference.yaml')
    
    # Load model from MLFlow
    mlflow.set_tracking_uri(uri=config.mlflow.uri)
    model = mlflow.pytorch.load_model(config.mlflow.model_uri).to(device)
    
    # Set-up datasets and loaders
    datasets, loaders, _ = train_setup(config, device)
    train_dataset, val_dataset, test_dataset = datasets
    train_loader, val_loader, test_loader = loaders

# Evaluate model performance on test data
test_metrics, test_outputs = evaluate(model, test_loader, device, metrics, train_dataset.inverse_transform_targets, config)
test_df = pd.DataFrame({k: np.squeeze(v) for k, v in test_outputs.items()} | {'dt': test_dataset.index}).set_index('dt')
pprint(test_metrics)

# Plot model predictions
fig = px.line(test_df.reset_index(), x='dt', y=['targets', 'predictions'], title='Model Prediction on Test Dataset', template='plotly_white')
fig.update_traces(mode='lines')
fig.update_layout(xaxis_title='Date', yaxis_title='', legend_title_text='')
fig.update_traces(selector=dict(name='targets'), name='Actual', line=dict(color='black', dash='solid'))
fig.update_traces(selector=dict(name='predictions'), name='Prediction', line=dict(color='blue', dash='solid'))
fig.show()

## Model Inference with ONNX

In [None]:
with initialize(version_base=None, config_path='../config'):
    # Load config/inference.yaml
    config = compose(config_name='inference.yaml')
    
    # Set-up datasets and loaders
    datasets, loaders, _ = train_setup(config, device)
    train_dataset, val_dataset, test_dataset = datasets
    train_loader, val_loader, test_loader = loaders
    
# Load model from ONNX
model_path = os.path.join(config.logging.model_dir, config.logging._exp_name, config.logging.model_file)
ort_session = ort.InferenceSession(model_path)

# Evaluate model performance on test data
test_features = torch.stack([test_dataset[idx][0] for idx in range(len(test_dataset))], dim=0)
test_targets = torch.stack([test_dataset[idx][1] for idx in range(len(test_dataset))], dim=0)
test_predictions = torch.from_numpy(ort_session.run(None, {'features': test_features.cpu().numpy()})[0])
test_targets = train_dataset.inverse_transform_targets(test_targets)
test_predictions = train_dataset.inverse_transform_targets(test_predictions)
test_metrics = {metric: metric_fn(test_targets, test_predictions).item() 
                for metric, metric_fn in metrics.items()}
test_df = pd.DataFrame({'targets': test_targets.squeeze().cpu().numpy(), 
                        'predictions': test_predictions.squeeze().cpu().numpy(),
                        'dt': test_dataset.index}).set_index('dt')
pprint(test_metrics)

# Plot model predictions
fig = px.line(test_df.reset_index(), x='dt', y=['targets', 'predictions'], title='Model Prediction on Test Dataset', template='plotly_white')
fig.update_traces(mode='lines')
fig.update_layout(xaxis_title='Date', yaxis_title='', legend_title_text='')
fig.update_traces(selector=dict(name='targets'), name='Actual', line=dict(color='black', dash='solid'))
fig.update_traces(selector=dict(name='predictions'), name='Prediction', line=dict(color='blue', dash='solid'))
fig.show()