In [1]:
from torchkf import *
import torch
import numpy as np

import plotly.graph_objs as go
import plotly.express as px
import pykeos as pk

pyunicorn: Package netCDF4 could not be loaded. Some functionality in class Data might not be available!
pyunicorn: Package netCDF4 could not be loaded. Some functionality in class NetCDFDictionary might not be available!


In [7]:
pE = np.array([18., 18., 46.92, 2., 1., 2., 4., 1., 1., 1.])

def lorenz(x, v, P):
    x0 = P[0] * x[1] - P[1] * x[0]
    x1 = P[2] * x[0] - P[3] * x[2] * x[0] - P[4] * x[1]
    x2 = P[5] * x[0] * x[1] - P[6] * x[2]
    
    return np.array([x0, x1, x2]) / 128

def obs(x, v, P): 
    return np.array(x.T @ P[-3:])

models = [GaussianModel(
    f=lorenz, 
    g=obs, 
    x=np.array([0.9,0.8,30])[:, None], pE=pE, sv=1/8., sw=1/8., m=1, 
    n=3, W=np.array([[np.exp(16)] * 3]), V=np.array([[np.exp(0)]]), 
)]

nT = 1024
hdm = HierarchicalGaussianModel(*models)
gen = DEMInversion(hdm, states_embedding_order=3).generate(nT)
y   = gen.v[:, 0, :1, 0]

Compiling derivatives, it might take some time... Done. 


  0%|          | 0/1024 [00:00<?, ?it/s]

In [8]:
px.line(y=[*gen.x[:, 0, :, 0].T] + [*y.T])

In [9]:
hdm[0]['x']  = np.array([[12,13,16]])
hdm[0]['pC'] = np.diag(np.ones(pE.shape)) * np.exp(-128)
hdm[0]['V'] *= np.exp(-4)

deminv = DEMInversion(hdm, states_embedding_order=12)
dec    = deminv.run(y, nD=1, nE=1, nM=1)

timestep:   0%|          | 0/1024 [00:00<?, ?it/s]

In [10]:
px.line(y=[*(dec.qU.x[:, 0] - gen.x[:, 0, :, 0]).T])

In [11]:
fig = pk.SysWrapper(dec.qU.x[128:, 0, :]).plot(show=False)
pk.SysWrapper(gen.x[128:, 0, :, 0]).plot(fig=fig, show=False)
fig.update_layout(template='simple_white')

In [28]:
px.line(y=[*dec.qU.x[:, 0, :].T] + [*gen.x[:, 0, :, 0].T]).show()
px.line(y=[*dec.qU.y[:, 0, :].T] + [*y[:, :].T]).show()

In [29]:
from plotly.subplots import make_subplots
fig   = make_subplots(cols=2, rows=2)#, column_titles=['$(x, y)$', '$(x,z)$'])
color = px.colors.qualitative.Pastel[2]

for i in range(2): 
    fig.add_scatter(x=dec.qU.x[:, 0, 0], y=dec.qU.x[:, 0, 1], line_color=color, 
                    line_width=1, opacity=1, legendgroup='estimated', name='Estimated', showlegend=i==0, row=i+1, col=1)
    fig.add_scatter(x=dec.qU.x[:, 0, 0], y=dec.qU.x[:, 0, 2], line_color=color, 
                    line_width=1, opacity=1, legendgroup='estimated', showlegend=False, row=i+1, col=2)

    fig.add_scatter(x=gen.x[:, 0, 0, 0], y=gen.x[:, 0, 1, 0], line_color='black', line_dash='dash', 
                    line_width=1, legendgroup='realized', name='Realized', showlegend=i==0, opacity=0.7, row=i+1, col=1)
    fig.add_scatter(x=gen.x[:, 0, 0, 0], y=gen.x[:, 0, 2, 0], line_color='black', line_dash='dash', 
                    line_width=1, legendgroup='realized', showlegend=False, row=i+1, col=2)
    
    fig.add_scatter(mode='markers', x=[.9], y=[.8], marker_color='black', showlegend=False, marker_size=10, row=i+1, col=1)
    fig.add_scatter(mode='markers', x=[.9], y=[30.], marker_color='black', showlegend=False, marker_size=10, row=i+1, col=2)

    fig.add_scatter(mode='markers', x=[hdm[0].x[0].item()], y=[hdm[0].x[1].item()], showlegend=False, marker_color=color, marker_size=10, row=i+1, col=1)
    fig.add_scatter(mode='markers', x=[hdm[0].x[0].item()], y=[hdm[0].x[2].item()], showlegend=False, marker_color=color, marker_size=10, row=i+1, col=2)

fig.update_xaxes(title_text='$x$', title_standoff=0, mirror=True)
fig.update_yaxes(title_standoff=0, mirror=True)

fig.update_xaxes(range=(-16, 16), row=1, col=1)
fig.update_yaxes(range=(-20, 20), title_text='$y$', row=1, col=1)

fig.update_xaxes(range=(-16, 16), row=1, col=2)
fig.update_yaxes(range=(5, 40), title_text='$z$', row=1, col=2)

fig.update_xaxes(range=(-5, 5), row=2, col=1)
fig.update_yaxes(range=(-5, 5), title_text='$y$', row=2, col=1)

fig.update_xaxes(range=(-5, 5), row=2, col=2)
fig.update_yaxes(range=(12, 27), title_text='$z$', row=2, col=2)
fig.update_layout(template='simple_white', width=800, height=800, 
                 legend=dict(orientation="h",
                    yanchor="bottom",
                    y=1.01,
                    xanchor="left",
                    x=0.01)
)
fig.show()
fig.write_image('/home/yop/Images/manuscript/DEM/lorenz_inversion_2d_zoom.png')
fig.write_image('/home/yop/Images/manuscript/DEM/lorenz_inversion_2d_zoom.pdf')


In [36]:
trajs = []
for j, initial_states in enumerate(([12,13,16], [8,0,30], [0, 10, 24], [5, 5,8], [17,2,35],  [5, 15,16])):
    hdm[0]['x'] = torch.DoubleTensor(initial_states)
    deminv = DEMInversion(hdm, states_embedding_order=4)
    dec    = deminv.run(y, nD=1, nE=1, nM=1)
    trajs.append(dec)

timestep:   0%|          | 0/1024 [00:00<?, ?it/s]

timestep:   0%|          | 0/1024 [00:00<?, ?it/s]

timestep:   0%|          | 0/1024 [00:00<?, ?it/s]

timestep:   0%|          | 0/1024 [00:00<?, ?it/s]

timestep:   0%|          | 0/1024 [00:00<?, ?it/s]

timestep:   0%|          | 0/1024 [00:00<?, ?it/s]

In [57]:
np.ones(6)

array([1., 1., 1., 1., 1., 1.])

In [56]:
from plotly.subplots import make_subplots
fig   = make_subplots(cols=2, rows=2)#, column_titles=['$(x, y)$', '$(x,z)$'])
for j, initial_states in enumerate(([12,13,16], [8,0,30], [0, 10, 24], [5, 5,8], [17,2,35],  [5, 15,16])):
    if j == 4: continue 
    dec = trajs[j]

    color = px.colors.qualitative.Pastel[j+1]

    for i in range(2): 
        fig.add_scatter(x=dec.qU.x[:, 0, 0], y=dec.qU.x[:, 0, 1], line_color=color, 
                        line_width=1, opacity=.8, legendgroup=f'{initial_states}', name=f'{initial_states}', showlegend=i==0, row=i+1, col=1)
        fig.add_scatter(x=dec.qU.x[:, 0, 0], y=dec.qU.x[:, 0, 2], line_color=color, 
                        line_width=1, opacity=.8, legendgroup=f'{initial_states}', showlegend=False, row=i+1, col=2)

for i in range(2): 
    fig.add_scatter(x=gen.x[:, 0, 0, 0], y=gen.x[:, 0, 1, 0], line_color='black', line_dash='dash', 
                    line_width=1, legendgroup='realized', name='Realized', showlegend=i==0, opacity=0.7, row=i+1, col=1)
    fig.add_scatter(x=gen.x[:, 0, 0, 0], y=gen.x[:, 0, 2, 0], line_color='black', line_dash='dash', 
                    line_width=1, legendgroup='realized', showlegend=False, row=i+1, col=2)

    fig.add_scatter(mode='markers', x=[.9], y=[.8], marker_color='black', showlegend=False, marker_size=10, row=i+1, col=1)
    fig.add_scatter(mode='markers', x=[.9], y=[30.], marker_color='black', showlegend=False, marker_size=10, row=i+1, col=2)

for j, initial_states in enumerate(([12,13,16], [8,0,30], [0, 10, 24], [5, 5,8], [17,2,35],  [5, 15,16])):
    if j == 4: continue 
    
    dec = trajs[j]
    color = px.colors.qualitative.Pastel[j+1]
    
    for i in range(2): 
    
        fig.add_scatter(mode='markers', x=[dec.qU.x[0,0,0].item()], y=[dec.qU.x[0,0,1].item()], legendgroup=f'{initial_states}',
                        showlegend=False, marker_color=color, marker_size=10, row=i+1, col=1)
        fig.add_scatter(mode='markers', x=[dec.qU.x[0,0,0].item()], y=[dec.qU.x[0,0,2].item()], legendgroup=f'{initial_states}',
                        showlegend=False, marker_color=color, marker_size=10, row=i+1, col=2)
    
fig.update_xaxes(title_text='$x$', title_standoff=0, mirror=True)
fig.update_yaxes(title_standoff=0, mirror=True)

fig.update_xaxes(range=(-16, 16), row=1, col=1)
fig.update_yaxes(range=(-20, 20), title_text='$y$', row=1, col=1)

fig.update_xaxes(range=(-16, 16), row=1, col=2)
fig.update_yaxes(range=(5, 40), title_text='$z$', row=1, col=2)

fig.update_xaxes(range=(-5, 5), row=2, col=1)
fig.update_yaxes(range=(-5, 5), title_text='$y$', row=2, col=1)

fig.update_xaxes(range=(-5, 5), row=2, col=2)
fig.update_yaxes(range=(12, 27), title_text='$z$', row=2, col=2)
fig.update_layout(template='simple_white', width=800, height=800, 
                 legend=dict(
                      orientation='h',
                    yanchor="bottom",
                    y=1.01,
                    xanchor="left",
                    x=0.01)
)
fig.show()
fig.write_image('/home/yop/Images/manuscript/DEM/lorenz_inversion_2d_zoom_multiple.png')
fig.write_image('/home/yop/Images/manuscript/DEM/lorenz_inversion_2d_zoom_multiple.pdf')


In [33]:
fig = go.Figure()
## TAKE VALUES FROM SPM TOOLBOXES
for i, initial_states in enumerate(([12,13,16], [8,0,30], [0, 10, 24], [5, 5,8], [17,2,35],  [5, 15,16])):

    fig.add_scatter(x=dec.qU.x[:, 0, 0], y=dec.qU.x[:, 0, 1], line_color=px.colors.qualitative.Pastel[i + 1], opacity=0.7)
    fig.add_scatter(mode='markers', x=[hdm[0].x[0].item()], y=[hdm[0].x[1].item()], marker_color=px.colors.qualitative.Pastel[i + 1], marker_size=10)
    
    
fig.add_scatter(x=gen.x[:, 0, 0, 0], y=gen.x[:, 0, 1, 0], line_color='black', line_dash='dash')
fig.add_scatter(mode='markers', x=[.9], y=[.8], marker_color='black', marker_size=10)
fig.update_xaxes(range=(-20, 20))
fig.update_yaxes(range=(-20, 20))
fig.update_layout(template='simple_white', width=600, height=600)

In [None]:
fig = go.Figure()
# y = traj['y'][None]

for i, initial_states in enumerate(([12,13,16], [8,0,30], [0, 10, 24], [5, 5,8], [17,2,35],  [5, 15,16])):
    y = state_space.sample(512)['y'][None]
    
    model_parameters['initial_state_mean'] = torch.tensor([float(v) for v in initial_states])
    fit_model = GaussianStateSpaceModel(**model_parameters)
    traj_filt = fit_model.filter(y, backward_pass=True)
    
    pk.SysWrapper(ts=traj_filt['x_prev'].mean[0, :, :2]).plot(fig=fig, show=False, 
                                                              line_color=px.colors.qualitative.Pastel[i+1], name=f'{initial_states}', legendgroup=f'{i+1}', opacity=0.7);
    fig.add_scatter(x=[initial_states[0]],y=[initial_states[1]], mode='markers', 
                    marker_color=px.colors.qualitative.Pastel[i+1], marker_size=10, legendgroup=f'{i+1}', showlegend=False)
    
# pk.SysWrapper(ts=traj_filt['x_prev'].mean[0, :, :2]).plot(fig=fig, show=False);

pk.SysWrapper(ts=traj['x'].mean[:, :2]).plot(fig=fig, show=False, line_color='black', line_dash='dash', name=f'Reference', opacity=0.8);
fig.add_scatter(x=[0.9],y=[0.8], mode='markers', marker_color='black', marker_size=10, showlegend=False)

fig.update_xaxes(range=(-20,24))
fig.update_yaxes(range=(-20,24))
fig.update_layout(height=600, width=600, template='simple_white', title={'text':'attractor reconstruction (EKF)',
         'x':0.5,
         'xanchor': 'center'})
fig.show()

In [None]:
for i in range(10): 
    trajectory = state_space.filter(measurements[None, ...], backward_pass=True)
    A.fit(trajectory['x_backward'].mean[:, :-1], trajectory['x_backward'].mean[:, 1:])
    H.fit(trajectory['x_backward'].mean, measurements[None])
    state_space.fit_params(measurements[None], trajectory)    
    plot_traj(trajectory['x_backward']).show()
    plot_traj(trajectory['y_prior']).show()
    
    ll.append(state_space.complete_data_likelihood(measurements[None], trajectory))
    print(ll[-1])