### Data

In [29]:
import pandas as pd
import numpy as np

data = pd.read_csv("./bml-component-data.csv")

# Scale time column
data["timestamp"] /= 10000
data = data.drop(columns=["index"])

In [30]:
def pad(array, target_length):
    """Pads 1D array to target length using -1 constant as padding."""

    n_pad = target_length - array.shape[0]
    return np.pad(array, (0, n_pad), mode="constant", constant_values=[-1])

In [31]:
# Collect all values from each CX into one row and put timestamps and efficiencies into np arrays
data_grouped = data.groupby(["ID"]).agg(list)
data_grouped = data_grouped.map(np.array)

# Each row is exactly one CX
x = list(data_grouped["timestamp"])
y = list(data_grouped["efficiency"])

# Rows are ragged (rows aren't all the same length) so padding is required
max_obs = max([len(row) for row in y])

x_padded = np.array([pad(row, max_obs) for row in x])
y_padded = np.array([pad(row, max_obs) for row in y])

# Generate mask (True for actual values, False for padded values)
mask = x_padded != -1

In [32]:
# Generate M1 to M5 arrays. Each M value is the same for all measurements within a CX component
M_df = data.drop_duplicates(subset=["ID"])

# M[0] is all M1 values, M[1] is all M2 values etc.
M = np.array([M_df["M1"], M_df["M2"], M_df["M3"], M_df["M4"], M_df["M5"]])

### Deployable Model

In [42]:
import numpyro as npr
import numpyro.distributions as dist
import jax.numpy as jnp
from jax import random
from numpyro.infer import MCMC, NUTS

class DeployableModel:
    def __init__(self, samples, m):
        self.samples = samples
        self.m = m

    @classmethod
    def from_data(cls, t, y, m, mask):
        def model(t, y, m, mask):
            sigma = npr.sample("sigma", dist.HalfNormal(1))

            # Generate u_i and v_i for all CX components
            with npr.plate("cx-component", t.shape[0]):
                u = npr.sample("u", dist.Normal(90, 10))
                v = npr.sample("v", dist.Normal(5, 5))

            # Generate a weight for each M component
            with npr.plate("m", m.shape[0]):
                w = npr.sample("w", dist.Normal(0, 1))

            # Make predictions of f_i(t) then sample from normal dist with variance sigma to account for noise
            with npr.plate("observations", t.shape[1]):
                with npr.handlers.mask(mask=mask):
                    m_sum = -(v + jnp.matmul(w, m))
                    f = u[:, jnp.newaxis] * jnp.exp(m_sum[:, jnp.newaxis] * t)
                    npr.sample("obs", dist.Normal(f, sigma), y)
        
        nuts_kernel = NUTS(model)
        mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
        rng_key = random.PRNGKey(0)
        mcmc.run(rng_key, t=t, y=y, m=m, mask=mask)

        return cls(mcmc.get_samples(), m)
    
    def inference(self, component: int, t: float):
        t /= 10_000
        u = self.samples["u"][:, component] # 1000, 
        v = self.samples["v"][:, component] # 1000, 
        w = self.samples["w"]               # 1000, 5

        m_sum = -(v + jnp.matmul(w, self.m[: ,component]))
        f = u * jnp.exp(m_sum * t)

        results = {
            "mean": jnp.mean(f, axis=0), 
            "5th_percentile": jnp.percentile(f, 5, axis=0), 
            "95th_percentile": jnp.percentile(f, 95, axis=0)
        }
        
        return results

In [43]:
model = DeployableModel.from_data(x_padded, y_padded, M, mask)

  mcmc.run(rng_key, t=t, y=y, m=m, mask=mask)
sample: 100%|██████████| 1500/1500 [00:10<00:00, 139.42it/s, 255 steps of size 1.98e-02. acc. prob=0.93]


In [44]:
model.inference(10, 4000)

{'mean': Array(21.530245, dtype=float32),
 '5th_percentile': Array(17.882442, dtype=float32),
 '95th_percentile': Array(25.311098, dtype=float32)}