In [None]:
import pymc as pm
import arviz as az
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

n_groups = 3
group_intercept = 0.0 + np.random.normal(0, 0.1, n_groups)
group_trend = 1.0 + np.random.normal(0, 0.1, n_groups)


x = np.linspace(-1, 1, 11)

df = pd.DataFrame()

for i in np.arange(n_groups):
    y_mu = group_intercept[i] - group_trend[i]*x
    y = np.random.normal(y_mu, 0.01)

    new_df = pd.DataFrame({'x': x, 'y': y, 'group': i})

    df = pd.concat([df, new_df], ignore_index = True)

df['observation'] = np.arange(len(df))

In [53]:
with pm.Model() as model:
    model.add_coord('group', df['group'].unique(), mutable = True)
    model.add_coord('observation', df['observation'], mutable = True)

    x = pm.MutableData('x', df['x'], dims = 'observation')
    y = pm.MutableData('y', df['y'], dims = 'observation')
    group_idx = pm.MutableData('group_idx', df['group'], dims = 'observation')

    intercept = pm.Normal('intercept', 0.0, 1.0)
    trend = pm.HalfNormal('trend', 1.0)
    error = pm.HalfNormal('error', 1.0)
    
    group_intercept = pm.Normal('group_intercept', intercept, 1.0, dims = 'group')
    group_trend = pm.HalfNormal('group_trend', trend, dims = 'group')
    
    mu = pm.Deterministic('mu', group_intercept[group_idx] - group_trend[group_idx]*x, dims = 'observation')

    likelihood = pm.Normal('likelihood', mu, error, observed = y, dims = 'observation')

    print('Sample posterior...')
    inference_data = pm.sample()

    print('Sample prior predictive...')
    inference_data.extend(pm.sample_prior_predictive())

    print('Sample posterior predictive...')
    inference_data.extend(pm.sample_posterior_predictive(inference_data))

Sample posterior...


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, trend, error, group_intercept, group_trend]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 8 seconds.
Sampling: [error, group_intercept, group_trend, intercept, likelihood, trend]


Sample prior predictive...
Sample posterior predictive...


Sampling: [likelihood]


In [44]:
new_x = np.array([-2.0, 2.0])
new_group_idx = np.full_like(new_x, df['group'].max()).astype(int)
new_observation = df['observation'].max() + np.arange(len(new_x)) + 1

with model:
    pm.set_data(new_data = {'x': new_x,
                            'group_idx': new_group_idx},
                coords = {'observation': new_observation})
    
    pred_inference_data = pm.sample_posterior_predictive(inference_data, return_inferencedata = False, predictions = True)

Sampling: [likelihood]


In [51]:
pred_inference_data['likelihood'].mean(axis = 0)

array([ 2.31704926, -2.23234269])