In [1]:
from torchkf import *
import torch
import numpy as np
import plotly.graph_objs as go
import plotly.express as px
import pykeos as pk

pyunicorn: Package netCDF4 could not be loaded. Some functionality in class Data might not be available!
pyunicorn: Package netCDF4 could not be loaded. Some functionality in class NetCDFDictionary might not be available!


In [162]:
dt = 1.
input_gen = GaussianSystem(
    state_dim=2, 
    obs_dim=1, 
    fwd_transform=LinearTransform(torch.tensor([[1., dt], 
                                                [0., 1.]])), 
    initial_state_mean=torch.tensor([0.,1.]), 
    initial_state_cov=np.exp(-128) * torch.eye(2),
    obs_transform=LinearizedTransform(lambda x: torch.exp(-0.25 * (x[...,0].unsqueeze(-1) - 12.)**2)), 
    process_noise_cov=np.exp(-128) * torch.eye(2), 
    obs_noise_cov=np.exp(-6) * torch.eye(1),
)
conv_model = GaussianSystem(
    state_dim=2, 
    obs_dim=4,
    input_dim=1,
    fwd_transform=LinearTransform(torch.tensor([[-0.25,  1.00, 1.00], 
                                                [-0.5, -0.25, 0.00]])), 
    obs_transform=LinearTransform(torch.tensor([[0.1250,  0.1633],
                                                [0.1250,  0.0676], 
                                                [0.1250, -0.0676], 
                                                [0.1250, -0.1633]])), 
    initial_state_mean=torch.zeros(2), 
    initial_state_cov=np.exp(-64) * torch.eye(2),
    process_noise_cov=np.exp(-16) * torch.eye(2), 
    obs_noise_cov=np.exp(-8) * torch.eye(4),
)

In [163]:
systems = [input_gen, conv_model]
hdm = HierarchicalDynamicalModel(systems)

In [164]:
# traj = hdm.blind_forecast(400)
traj = hdm.sample(32)
measurements = traj[1]['y']
px.line(y=traj[0]['y'][:,0]).show();
px.line(y=[traj[1]['x'][:,0],traj[1]['x'][:,1]]).show();
px.line(y=[traj[1]['y'][:,i] for i in range(4)]).show();



In [171]:
naive_input_sys = GaussianSystem(
    state_dim=1, 
    obs_dim=1, 
    fwd_transform=LinearTransform(torch.eye(1)), 
    initial_state_mean=torch.tensor([0.]), 
    initial_state_cov=np.exp(-1) * torch.eye(1),
    obs_transform=LinearTransform(torch.eye(1)), 
    process_noise_cov=np.exp(-128) * torch.eye(1), 
    obs_noise_cov=np.exp(-6) * torch.eye(1),
)
deconv_model = GaussianSystem(
    state_dim=2, 
    obs_dim=4,
    input_dim=1,
    fwd_transform=LinearTransform(torch.tensor([[-0.25,  1.00, 1.00], 
                                                [-0.50, -0.25, 0.00]])), 
    obs_transform=LinearTransform(torch.tensor([[0.1250,  0.1633],
                                                [0.1250,  0.0676], 
                                                [0.1250, -0.0676], 
                                                [0.1250, -0.1633]])), 
    initial_state_mean=torch.zeros(2), 
    initial_state_cov=np.exp(-4) * torch.eye(2),
    process_noise_cov=np.exp(-16) * torch.eye(2), 
    obs_noise_cov=np.exp(-8) * torch.eye(1),
)
deconv_systems = [naive_input_sys, deconv_model]
hdm_deconv = HierarchicalDynamicalModel(systems)

In [172]:
deconv_traj = hdm_deconv.filter(measurements[None])

Filter: 100%|██████████| 32/32 [00:00<00:00, 215.39it/s]


In [173]:
fig = plot_traj(deconv_traj[0]['y_prior'])
fig.add_scatter(y=traj[0]['y'][:,0], line_color=px.colors.qualitative.T10[0], line_dash='dash', name='realized', opacity=0.7)
fig.update_layout(template='plotly_white', 
        title={
         'text':'causal states',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=600)
fig.update_yaxes(title_text='states (a.u.)')
fig.update_xaxes(title_text='time')
fig.show()

fig = plot_traj(deconv_traj[1]['x_post'])
fig.add_scatter(y=traj[1]['x'][:,0], line_color=px.colors.qualitative.T10[0], line_dash='dash', name='realized', legendgroup='x[0]', opacity=0.7)
fig.add_scatter(y=traj[1]['x'][:,1], line_color=px.colors.qualitative.T10[1], line_dash='dash', name='realized', legendgroup='x[1]', opacity=0.7)
fig.update_layout(template='plotly_white', 
        title={
         'text':'hidden states',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=600)
fig.update_yaxes(title_text='states (a.u.)')
fig.update_xaxes(title_text='time')
fig.show()

fig = plot_traj(deconv_traj[1]['y_prior'])
fig.add_scatter(y=traj[1]['y'][:,0], line_color=px.colors.qualitative.T10[0], line_dash='dash', name='realized', legendgroup='x[0]', opacity=0.7)
fig.add_scatter(y=traj[1]['y'][:,1], line_color=px.colors.qualitative.T10[1], line_dash='dash', name='realized', legendgroup='x[1]', opacity=0.7)
fig.add_scatter(y=traj[1]['y'][:,2], line_color=px.colors.qualitative.T10[2], line_dash='dash', name='realized', legendgroup='x[2]', opacity=0.7)
fig.add_scatter(y=traj[1]['y'][:,3], line_color=px.colors.qualitative.T10[3], line_dash='dash', name='realized', legendgroup='x[3]', opacity=0.7)
fig.update_layout(template='plotly_white', 
        title={
         'text':'response',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=600)
fig.update_yaxes(title_text='states (a.u.)')
fig.update_xaxes(title_text='time')
fig.show()