In [1]:
from torchkf import *
import torch
import numpy as np

import plotly.graph_objs as go
import plotly.express as px
import pykeos as pk

pyunicorn: Package netCDF4 could not be loaded. Some functionality in class Data might not be available!
pyunicorn: Package netCDF4 could not be loaded. Some functionality in class NetCDFDictionary might not be available!


In [10]:
pE = torch.tensor([18., 18., 46.92, 2., 1., 2., 4., 1., 1., 1.])
def lorenz(x: torch.Tensor, v, P):
    P  = P.squeeze()
    x1 = (P[0] * x[1] - P[1] * x[0])
    x2 = (P[2] * x[0] - P[3] * x[2] * x[0] - P[4] * x[1])
    x3 = (P[5] * x[0] * x[1] - P[6] * x[2])
    return torch.stack([x1, x2, x3], dim=0) / 64.

models = [GaussianModel(
    f=lorenz, 
    g=lambda x, v, P: (P[-3:].T @ x).unsqueeze(-1),
    x=torch.tensor([0.9,0.8,30]).unsqueeze(-1), pE=pE, sv=1/8., sw=1/8., 
    n=3, W=torch.tensor([np.exp(16)] * 3), V=torch.tensor([np.exp(0)]), 
)]

nT = 512
hdm = HierarchicalGaussianModel(*models)
gen = DEMInversion(hdm, states_embedding_order=3).generate(nT)
y   = gen.v[:, 0, :, 0]

In [11]:
px.line(y=[*gen.x[:, 0, :, 0].T] + [*y.T])

In [31]:
models = [GaussianModel(
    f=lorenz, 
    g=lambda x, v, P: (P[-3:].T @ x).unsqueeze(-1),
    x=torch.tensor([1.,1.,16.]).unsqueeze(-1), pE=pE, sv=1/8., sw=1/8., 
    n=3, W=torch.tensor([np.exp(16)] * 3), V=torch.tensor([np.exp(2)]), 
    pC=torch.ones(pE.shape) * np.exp(-128), 
)]
hdm    = HierarchicalGaussianModel(*models)
deminv = DEMInversion(hdm, states_embedding_order=4)
dec    = deminv.run(y, nD=1, nE=1, nM=1)

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

In [36]:
px.line(y=[*(dec.qU.x[:, 0] - gen.x[:, 0, :, 0]).T])

In [32]:
px.line(y=[*dec.qU.x[:, 0, :].T] + [*gen.x[:, 0, :, 0].T]).show()
px.line(y=[*dec.qU.y[:, 0, :].T] + [*y[:, :].T]).show()

In [14]:
fig = go.Figure()
## TAKE VALUES FROM SPM TOOLBOXES
for i, initial_states in enumerate(([12,13,16], [8,0,30], [0, 10, 24], [5, 5,8], [17,2,35],  [5, 15,16])):
    models = [GaussianModel(
        f=lorenz, 
        g=lambda x, v, P: (P[-3:].T @ x).unsqueeze(-1),
        x=torch.DoubleTensor(initial_states).unsqueeze(-1), pE=pE, sv=1/8., sw=1/8., 
        n=3, W=torch.tensor([np.exp(16)] * 3), V=torch.tensor([np.exp(0)]), 
        pC=torch.ones(pE.shape) * np.exp(-128), 
    )]
    hdm    = HierarchicalGaussianModel(*models)
    deminv = DEMInversion(hdm, states_embedding_order=6)
    dec    = deminv.run(y, nD=4, nE=1, nM=1)
    

    fig.add_scatter(x=dec.qU.x[:, 0, 0], y=dec.qU.x[:, 0, 1], line_color=px.colors.qualitative.Pastel[i + 1], opacity=0.7)
    fig.add_scatter(mode='markers', x=[hdm[0].x[0].item()], y=[hdm[0].x[1].item()], marker_color=px.colors.qualitative.Pastel[i + 1], marker_size=10)
    
    
fig.add_scatter(x=gen.x[:, 0, 0, 0], y=gen.x[:, 0, 1, 0], line_color='black', line_dash='dash')
fig.add_scatter(mode='markers', x=[.9], y=[.8], marker_color='black', marker_size=10)
fig.update_layout(template='simple_white', width=600, height=600)

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

In [15]:
fig.update_xaxes(range=(-20, 20))
fig.update_yaxes(range=(-20, 20))

In [None]:
fig = go.Figure()
# y = traj['y'][None]

for i, initial_states in enumerate(([12,13,16], [8,0,30], [0, 10, 24], [5, 5,8], [17,2,35],  [5, 15,16])):
    y = state_space.sample(512)['y'][None]
    
    model_parameters['initial_state_mean'] = torch.tensor([float(v) for v in initial_states])
    fit_model = GaussianStateSpaceModel(**model_parameters)
    traj_filt = fit_model.filter(y, backward_pass=True)
    
    pk.SysWrapper(ts=traj_filt['x_prev'].mean[0, :, :2]).plot(fig=fig, show=False, 
                                                              line_color=px.colors.qualitative.Pastel[i+1], name=f'{initial_states}', legendgroup=f'{i+1}', opacity=0.7);
    fig.add_scatter(x=[initial_states[0]],y=[initial_states[1]], mode='markers', 
                    marker_color=px.colors.qualitative.Pastel[i+1], marker_size=10, legendgroup=f'{i+1}', showlegend=False)
    
# pk.SysWrapper(ts=traj_filt['x_prev'].mean[0, :, :2]).plot(fig=fig, show=False);

pk.SysWrapper(ts=traj['x'].mean[:, :2]).plot(fig=fig, show=False, line_color='black', line_dash='dash', name=f'Reference', opacity=0.8);
fig.add_scatter(x=[0.9],y=[0.8], mode='markers', marker_color='black', marker_size=10, showlegend=False)

fig.update_xaxes(range=(-20,24))
fig.update_yaxes(range=(-20,24))
fig.update_layout(height=600, width=600, template='simple_white', title={'text':'attractor reconstruction (EKF)',
         'x':0.5,
         'xanchor': 'center'})
fig.show()

In [None]:
for i in range(10): 
    trajectory = state_space.filter(measurements[None, ...], backward_pass=True)
    A.fit(trajectory['x_backward'].mean[:, :-1], trajectory['x_backward'].mean[:, 1:])
    H.fit(trajectory['x_backward'].mean, measurements[None])
    state_space.fit_params(measurements[None], trajectory)    
    plot_traj(trajectory['x_backward']).show()
    plot_traj(trajectory['y_prior']).show()
    
    ll.append(state_space.complete_data_likelihood(measurements[None], trajectory))
    print(ll[-1])