## Stochastic Volatility Model

この例題は、以下のサイトで紹介されているコードを参考に作成しています。

- http://num.pyro.ai/en/stable/examples/stochastic_volatility.html
- https://docs.pymc.io/notebooks/stochastic_volatility.html



## Install Packages

In [None]:
!pip install numpyro

【重要】パッケージのインストール完了後に、ランタイムを再起動して下さい！

## Import Packages

In [None]:
import numpyro
import numpyro.distributions as dist

import jax
import jax.numpy as jnp

import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from numpyro.examples.datasets import SP500, load_dataset

In [None]:
plt.rcParams['font.size'] = 14

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

## Load Dataset

S&P 500 の daily log return のデータを読み込む。

In [None]:
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()

In [None]:
#data = pd.read_csv('SP500.csv', index_col='DATE')
#dates = data['DATE'].values
#returns = data['VALUE'].values

In [None]:
dates = pd.to_datetime(dates)

In [None]:
plt.figure(figsize=(10, 4))

plt.plot(dates, returns)
plt.title('S&P 500')
plt.ylabel('daily log returns');

## Define Model & Inference

In [None]:
returns = jnp.array(returns)

In [None]:
def model(returns):
    
    step_size = numpyro.sample('sigma', dist.Exponential(50))
    
    s = numpyro.sample('s', dist.GaussianRandomWalk(scale=step_size, num_steps=len(returns)))
        
    nu = numpyro.sample('nu', dist.Exponential(0.1))
    
    numpyro.sample('r', dist.StudentT(df=nu, loc=0.0, scale=jnp.exp(s)), obs=returns)
    

In [None]:
nuts = numpyro.infer.NUTS(model, target_accept_prob=0.95)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=4)

mcmc.run(jax.random.PRNGKey(1), returns)

trace = mcmc.get_samples()

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
az.summary(idata, var_names=['nu', 'sigma'])

In [None]:
az.plot_trace(idata, var_names=['nu', 'sigma'])
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

## Check Result

In [None]:
plt.figure(figsize=(10, 6))

plt.plot(dates, returns)
plt.plot(dates, jnp.exp(trace['s'][::20].T), 'r', alpha=0.5)

plt.legend(['returns', 'volatility']);