## Structural Time Series Model

構造時系列モデルを使って、以下のサイトで紹介されている電子機器の生産高のデータをモデル化します。

https://www.statsmodels.org/stable/examples/notebooks/generated/stl_decomposition.html


## Install Packages

In [None]:
!pip install numpyro

【重要】パッケージのインストール完了後に、ランタイムを再起動して下さい！

## Import Packages

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import jax.numpy as jnp

import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from numpyro.contrib.control_flow import scan

from statsmodels.datasets import elec_equip
from dateutil.relativedelta import relativedelta

In [None]:
plt.rcParams['font.size'] = 14

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

## Load Data

In [None]:
data = elec_equip.load_pandas().data

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(data)

plt.title('Production of Electrical Equipment in EU');

## Define Model & Inference

In [None]:
t_obs = data.index
y_obs = data.values.flatten()

y_obs = jax.numpy.array(y_obs)

In [None]:
def fourier_basis(num_steps, num_basis, period):
    
    t = jnp.arange(num_steps)
    
    x = 2 * jnp.pi * jnp.arange(1, num_basis + 1) / period
    
    x = x * t[:, None]
    
    x = jnp.concatenate((jnp.cos(x), jnp.sin(x)), axis=1)
    
    return x

In [None]:
def seasonality_comp(name, num_steps, num_basis, period):
    
    #
    # Generate Cos/Sin Waves
    #
    
    basis = fourier_basis(num_steps, num_basis, period)    
    
    #
    # Calculate Weighted Sum
    #
        
    r = numpyro.sample('r', dist.Normal(0, 10), sample_shape=(2 * num_basis,))
    
    ss = numpyro.deterministic(name, jnp.dot(basis, r))
    
    return ss

In [None]:
def gaussian_random_walk(name, num_steps, scale=1.0):
            
    def transition_fn(u_prev, t):
        
        u_curr = numpyro.sample(name, dist.Normal(u_prev, scale))
        
        return u_curr, u_curr
    
    _, u = scan(transition_fn, 0.0, jnp.arange(num_steps))
    
    return u

In [None]:
def model(y_obs, future=0):
    
    num_steps = len(y_obs)
    
    #
    # System Model
    #
    
    c = numpyro.sample('c', dist.HalfNormal(10))
    
    sd_b = numpyro.sample('sd_b', dist.HalfNormal(10))
    
    b = gaussian_random_walk('b', num_steps + future) * sd_b
    
    u = numpyro.deterministic('smooth_trend', c + jnp.cumsum(b))
    
    s = seasonality_comp('seasonality', num_steps + future, period=12, num_basis=5)
        
    v = u + s
    
    #
    # Observer
    #
    
    sd_y = numpyro.sample('sd_y', dist.HalfNormal(10))

    with numpyro.handlers.condition(data={'y':y_obs}):
                
        def observer_fn(carry, x):
            
            y = numpyro.sample('y', dist.Normal(x, sd_y))
            
            return carry, y
        
        _, y = scan(observer_fn, None, v)
        

    if future > 0:
        
        numpyro.deterministic('y_pred', y[-future:])

In [None]:
nuts = numpyro.infer.NUTS(model, target_accept_prob=0.95)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=4)

mcmc.run(jax.random.PRNGKey(1), y_obs)

trace = mcmc.get_samples()

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
az.summary(idata, var_names=['c', 'sd_b', 'sd_y'])

In [None]:
az.plot_trace(idata, var_names=['c', 'sd_b', 'sd_y'])
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

## Check Estimated Trend

In [None]:
smooth_trend = trace['smooth_trend']

In [None]:
mu = jnp.mean(smooth_trend, 0)
pi = jnp.percentile(smooth_trend, jnp.array([5, 95]), 0)

In [None]:
plt.figure(figsize=(10, 4))

plt.plot(t_obs, y_obs, '.', color='C1', label='observed')
plt.plot(t_obs, mu, '-.', label='estimated')
plt.fill_between(t_obs, pi[0, :], pi[1, :], alpha=0.3)

plt.title('Smooth Trend')
plt.legend();

In [None]:
seasonality = trace['seasonality']

In [None]:
mu = jnp.mean(seasonality, 0)
pi = jnp.percentile(seasonality, jnp.array([5, 95]), 0)

In [None]:
plt.figure(figsize=(10, 4))

plt.plot(t_obs, mu, '-.')
plt.fill_between(t_obs, pi[0, :], pi[1, :], alpha=0.3)

plt.title('Seasonal Component');

## Check Prediction

In [None]:
future = 24

In [None]:
predictive = numpyro.infer.Predictive(model, trace)
ppc_samples = predictive(jax.random.PRNGKey(2), y_obs, future=future)

In [None]:
y_pred = ppc_samples['y_pred']
t_pred = pd.date_range(t_obs[-1] + relativedelta(months=1), periods=future, freq='MS')

In [None]:
mu_pred = jnp.mean(y_pred, 0)
pi_pred = jnp.percentile(y_pred, jnp.array([5, 95]), 0)

In [None]:
plt.figure(figsize=(10, 4))

plt.plot(t_obs, y_obs, '-')
plt.plot(t_pred, mu_pred, '-.')
plt.fill_between(t_pred, pi_pred[0, :], pi_pred[1, :], alpha=0.3)

plt.title('Observed Data / Prediction');

In [None]:
plt.figure(figsize=(10, 4))

plt.fill_between(t_pred, pi_pred[0, :], pi_pred[1, :], alpha=0.2);
plt.plot(t_pred, y_pred[:10,].T)

plt.title('Prediction');