## Mixture of Regression

## Install Package

In [None]:
!pip install numpyro

【重要】パッケージのインストール完了後に、ランタイムを再起動して下さい！

## Import Package

In [None]:
import numpyro
import numpyro.distributions as dist

import arviz as az

import jax
import jax.numpy as jnp

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.rcParams['font.size'] = 14

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

## Generate Data

In [None]:
x1 = np.random.randn(40)
e1 = np.random.randn(40) * 0.1

y1 = (1 * x1 + 0) + e1

x2 = np.random.randn(60)
e2 = np.random.randn(60) * 0.1

y2 = (2 * x2 + 1) + e2

x = np.concatenate([x1, x2])
y = np.concatenate([y1, y2])

In [None]:
plt.figure(figsize=(8, 6))

plt.scatter(x, y)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Data Generated');

## Define Model & Inference

In [None]:
def model(x, y=None):
    
    num_data = len(x)
    
    a = numpyro.sample('a', dist.Normal(0, 10), sample_shape=(2,))
    b = numpyro.sample('b', dist.Normal(0, 10), sample_shape=(2,))
    
    p = numpyro.sample('p', dist.Dirichlet(jnp.ones(2)))
    s = numpyro.sample('s', dist.Categorical(p), sample_shape=(num_data,))
    
    # クラスターの大きさに関する制約
    numpyro.factor('cond_cluster_size', jnp.where(jnp.min(p) > 0.25, 0, -jnp.inf))
        
    mu = a[s] * x + b[s]
    
    sd = numpyro.sample('sd', dist.HalfNormal(5))

    # クラスターの順序に関する制約    
    numpyro.factor('cond_cluster_order', jnp.where(a[0] < a[1], 0, -jnp.inf))
    
    numpyro.sample('obs', dist.Normal(mu, sd), obs=y)

In [None]:
nuts = numpyro.infer.NUTS(model, target_accept_prob=0.95)
kernel = numpyro.infer.DiscreteHMCGibbs(nuts)

mcmc = numpyro.infer.MCMC(kernel, num_warmup=3000, num_samples=1000, num_chains=4)

mcmc.run(jax.random.PRNGKey(0), x, y=y)
trace =mcmc.get_samples()

idata = az.from_numpyro(mcmc)

In [None]:
az.plot_trace(idata, var_names=['a', 'b', 'p', 'sd'])
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata, var_names=['a', 'b', 'p', 'sd'])

## Check Latent Variable

In [None]:
fig = plt.figure(figsize=(10, 6))

sns.heatmap(trace['s'], cmap='jet')

plt.title('Latent Variables');

In [None]:
plt.figure(figsize=(8, 4))

plt.plot(trace['s'].mean(axis=0))

plt.title('Mean Value of Latent Variables');

## Check Estimated Models

In [None]:
a_estimated = trace['a'].mean(axis=0)
b_estimated = trace['b'].mean(axis=0)

In [None]:
x_new = np.linspace(-2, 2)

y0_new = a_estimated[0] * x_new + b_estimated[0]
y1_new = a_estimated[1] * x_new + b_estimated[1]

In [None]:
plt.figure(figsize=(8, 6))

plt.scatter(x, y, marker='x')

plt.plot(x_new, y0_new)
plt.plot(x_new, y1_new)

plt.title('Estimated Linear Models');