### Load model data

In [29]:
import pickle

# Load samples and M values
with open("deployable_model.pkl", "rb") as f:
    data = pickle.load(f)

samples = data["samples"]
m = data["m"]

### Deployable Model

In [30]:
import jax.numpy as jnp

class DeployableModel:
    def __init__(self, samples, m):
        self.samples = samples
        self.m = m
    
    def predict(self, component: int, t: float):
        """
        params:
            component: Integer component index.
            t: Time in hours.
        returns: 
            mean prediction, standard deviation of predictions
        """
        u = self.samples["u"][:, component] # Shape (n_samples,) 
        v = self.samples["v"][:, component] # Shape (n_samples,) 
        w = self.samples["w"]               # Shape (n_samples, 5)

        # Calculate f_i(t) as usual from the enhanced model 
        # Don't sample from normal dist as we are making predictions
        m_sum = -(v + jnp.matmul(w, self.m[: ,component]))
        f = u * jnp.exp(m_sum * t)

        # Calculate mean and standard deviation
        mean = jnp.mean(f, axis=0)
        std = jnp.std(f, axis=0)
        
        return mean, std

In [31]:
model = DeployableModel(samples, m)

In [32]:
import pandas as pd

data = pd.read_csv("./bml-component-data.csv")

# Scale time column
data["timestamp"] /= 10_000

# Sample 5 random rows for testing
test_data = data.sample(5)

for i, row in test_data.iterrows():
    # Extract timestamp, efficiency value and component index from row
    t = float(row["timestamp"])
    y = float(row["efficiency"])
    component = int(row["index"])

    # Make prediction from component index and timestamp t
    mean, std = model.predict(component, t)

    print(f"component: {component}, t: {t * 10_000:.0f}")
    print(f"actual: {y:.2f}, predicted_mean: {mean:.2f}, predicted_std: {std:.2f}")
    print()

component: 8, t: 2500
actual: 53.03, predicted_mean: 55.63, predicted_std: 2.29

component: 37, t: 3900
actual: 24.45, predicted_mean: 27.84, predicted_std: 1.14

component: 40, t: 1500
actual: 63.51, predicted_mean: 65.29, predicted_std: 0.95

component: 40, t: 100
actual: 80.65, predicted_mean: 85.44, predicted_std: 1.99

component: 19, t: 400
actual: 76.61, predicted_mean: 81.74, predicted_std: 2.44

