(interactive_hgf)=
# Interactive visualization of the Hierarchical Gaussian Filter

In [1]:
from ghgf.model import HGF
import jax.numpy as jnp
from ghgf import load_data
import matplotlib.pyplot as plt
import ipywidgets as widgets

%matplotlib widget

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Binary HGF

In [2]:
# import the binary time series example
binaryseries = load_data("binary")

### 2-levels

In [3]:
_, axs = plt.subplots(nrows=3, figsize=(16, 7), sharex=True)

@widgets.interact(omega_2=(-3.5, 0.0, .1))
def update(omega_2 = -2.0):
    for ax in axs:
        ax.cla()

    # create, fit and plot a binary HGF
    HGF(
        n_levels=2,
        model_type="binary",
        initial_mu={"1": .0, "2": .5},
        initial_pi={"1": .0, "2": 1e4},
        omega={"1": None, "2": omega_2},
        rho={"1": None, "2": 0.0},
        kappas={"1": None},
        eta0=0.0,
        eta1=1.0,
        pihat = jnp.inf,
        verbose=False
    ).input_data(input_data=binaryseries).plot_trajectories(axs=axs)
    plt.show()

interactive(children=(FloatSlider(value=-2.0, description='omega_2', max=0.0, min=-3.5), Output()), _dom_class…

## 3-levels

In [4]:
_, axs = plt.subplots(nrows=4, figsize=(16, 7), sharex=True)

@widgets.interact(omega_2=(-10.0, 5.0, .1), omega_3=(-10.0, 5.0, .1))
def update(omega_2 = -3.0, omega_3 = -3.0):
    for ax in axs:
        ax.cla()

    # create, fit and plot a binary HGF
    HGF(
        n_levels=3,
        model_type="binary",
        initial_mu={"1": .0, "2": .5, "3": 0.},
        initial_pi={"1": .0, "2": 1e4, "3": 1e1},
        omega={"1": None, "2": omega_2, "3": omega_3},
        rho={"1": None, "2": 0.0, "3": 0.0},
        kappas={"1": None, "2": 1.0},
        eta0=0.0,
        eta1=1.0,
        pihat = jnp.inf,
        verbose=False
    ).input_data(input_data=binaryseries).plot_trajectories(axs=axs)
    plt.show()

interactive(children=(FloatSlider(value=-3.0, description='omega_2', max=5.0, min=-10.0), FloatSlider(value=-3…

## Continuous HGF

In [5]:
# import the exchange rate time series
timeserie = load_data("continuous")

### 2-levels

In [6]:
_, axs = plt.subplots(nrows=3, figsize=(16, 7), sharex=True)

@widgets.interact(omega_1=(-15.0, 5.0, .1), omega_2=(-15.0, 5.0, .1))
def update(omega_1 = -11.0, omega_2 = -3.0):
    for ax in axs:
        ax.cla()

    # create, fit and plot a binary HGF
    HGF(
        n_levels=2,
        model_type="continuous",
        initial_mu={"1": 1.04, "2": 0.0},
        initial_pi={"1": 1e4, "2": 1e1},
        omega={"1": omega_1, "2": omega_2},
        rho={"1": 0.0, "2": 0.0},
        kappas={"1": 1.0},
        verbose=False
    ).input_data(input_data=timeserie).plot_trajectories(axs=axs)
    plt.show()

interactive(children=(FloatSlider(value=-11.0, description='omega_1', max=5.0, min=-15.0), FloatSlider(value=-…

### 3-levels

In [7]:
_, axs = plt.subplots(nrows=4, figsize=(16, 7), sharex=True)

@widgets.interact(omega_1=(-15.0, 5.0, .1), omega_2=(-15.0, 5.0, .1), omega_3=(-15.0, 5.0, .1))
def update(omega_1 = -11.0, omega_2 = -3.0, omega_3 = -3.0):
    for ax in axs:
        ax.cla()

    # create, fit and plot a binary HGF
    HGF(
        n_levels=3,
        model_type="continuous",
        initial_mu={"1": 1.04, "2": 0.0, "3": 0.0},
        initial_pi={"1": 1e4, "2": 1e1, "3": 1e1},
        omega={"1": omega_1, "2": omega_2, "3": omega_3},
        rho={"1": 0.0, "2": 0.0, "3": 0.0},
        kappas={"1": 1.0, "2": 1.0},
        verbose=False
    ).input_data(input_data=timeserie).plot_trajectories(axs=axs)
    plt.show()

interactive(children=(FloatSlider(value=-11.0, description='omega_1', max=5.0, min=-15.0), FloatSlider(value=-…