### Load model data

In [1]:
import pickle

with open("deployable_model.pkl", "rb") as f:
    data = pickle.load(f)

samples = data["samples"]
m = data["m"]

### Deployable Model

In [2]:
import jax.numpy as jnp

class DeployableModel:
    def __init__(self, samples, m):
        self.samples = samples
        self.m = m
    
    def inference(self, component: int, t: float):
        t /= 10_000
        u = self.samples["u"][:, component] # 1000, 
        v = self.samples["v"][:, component] # 1000, 
        w = self.samples["w"]               # 1000, 5

        m_sum = -(v + jnp.matmul(w, self.m[: ,component]))
        f = u * jnp.exp(m_sum * t)

        results = {
            "mean": jnp.mean(f, axis=0), 
            "5th_percentile": jnp.percentile(f, 5, axis=0), 
            "95th_percentile": jnp.percentile(f, 95, axis=0)
        }
        
        return results

In [3]:
model = DeployableModel(samples, m)

In [4]:
model.inference(10, 4000)

{'mean': Array(21.634851, dtype=float32),
 '5th_percentile': Array(18.315655, dtype=float32),
 '95th_percentile': Array(25.371998, dtype=float32)}