In [None]:
!pip install git+https://github.com/johmedr/dempy.git
!pip install git+https://github.com/johmedr/torchkf.git

In [139]:
from dempy import *
import torch
import numpy as np
import corrts as ct

import plotly.graph_objs as go
import plotly.express as px
from plotly.subplots import make_subplots

import torchkf as tk

In [147]:
GLOBAL = dotdict()
GLOBAL.dt = 1/64.
GLOBAL.nT = 512
GLOBAL.sv = 1/32.
GLOBAL.sw = 1/32.
GLOBAL.V  = np.exp(4) 
GLOBAL.W  = np.exp(8) 
GLOBAL.x0 = [0.9, 0.8, 30]
GLOBAL.x1 = [1,1,16]
GLOBAL.xs = [[12,13,16], [8,0,30], [0, 10, 24], [5, 5,8], [17,2,35], [5, 15,16]]

GLOBAL.states_order = 6

In [233]:
def lorenz(x, v, P):
    x0 = 18 * (x[1] - x[0])
    x1 = 46.92 * x[0] - 2 * x[2] * x[0] - x[1]
    x2 = 2 * x[0] * x[1] - 4 * x[2]
    
    return np.array([x0, x1, x2]) * GLOBAL.dt


def obs(x, v, P): 
    return np.array([x[0] + x[1] + x[2]])


models = [GaussianModel(
    n=3, 
    f=lorenz, 
    g=obs, 
    x=np.array(GLOBAL.x0), 
    sv=GLOBAL.sv, 
    sw=GLOBAL.sw, 
    pE=np.array([1]),
    pC=np.array([0]),
    W=np.array([GLOBAL.W] * 3), 
    V=np.array([GLOBAL.V]),
    xP = np.eye(3),
    
)]

hdm = HierarchicalGaussianModel(*models, use_numerical_derivatives=False)

Compiling derivatives, it might take some time... 
  Compiling f... f() ok. (compiled in 0.01s)
  Compiling g... g() ok. (compiled in 0.00s)
Done. 


In [234]:
gen = DEMInversion(hdm, states_embedding_order=3).generate(GLOBAL.nT)
y   = gen.v[:, 0, :1]

plot_dem_generate(hdm, gen);

  0%|          | 0/512 [00:00<?, ?it/s]

In [235]:
decmodel = hdm.copy()

decmodel[0]['x']  = np.array(GLOBAL.x1)

deminv = DEMInversion(decmodel, states_embedding_order=GLOBAL.states_order)
dec    = deminv.run(y, nE=1, nD=1, nM=1)

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

In [236]:
plot_dem_states(decmodel, dec, gen);

In [237]:
from plotly.subplots import make_subplots
fig   = make_subplots(cols=2, rows=2)#, column_titles=['$(x, y)$', '$(x,z)$'])
color = px.colors.qualitative.Pastel[2]

for i in range(2): 
    fig.add_scatter(x=dec.qU.x[:, 0,0], y=dec.qU.x[:, 0,1], line_color=color, 
                    line_width=1, opacity=1, legendgroup='estimated', name='Estimated', showlegend=i==0, row=i+1, col=1)
    fig.add_scatter(x=dec.qU.x[:, 0,0], y=dec.qU.x[:, 0, 2], line_color=color, 
                    line_width=1, opacity=1, legendgroup='estimated', showlegend=False, row=i+1, col=2)

    fig.add_scatter(x=gen.x[:, 0, 0], y=gen.x[:, 0, 1], line_color='black', line_dash='dash', 
                    line_width=1, legendgroup='realized', name='Realized', showlegend=i==0, opacity=0.7, row=i+1, col=1)
    fig.add_scatter(x=gen.x[:, 0, 0], y=gen.x[:, 0, 2], line_color='black', line_dash='dash', 
                    line_width=1, legendgroup='realized', showlegend=False, row=i+1, col=2)
    
    fig.add_scatter(mode='markers', x=[.9], y=[.8], marker_color='black', showlegend=False, marker_size=10, row=i+1, col=1)
    fig.add_scatter(mode='markers', x=[.9], y=[30.], marker_color='black', showlegend=False, marker_size=10, row=i+1, col=2)

    fig.add_scatter(mode='markers', x=[hdm[0].x[0]], y=[hdm[0].x[1]], showlegend=False, marker_color=color, marker_size=10, row=i+1, col=1)
    fig.add_scatter(mode='markers', x=[hdm[0].x[0]], y=[hdm[0].x[2]], showlegend=False, marker_color=color, marker_size=10, row=i+1, col=2)

fig.update_xaxes(title_text='$x$', title_standoff=0, mirror=True)
fig.update_yaxes(title_standoff=0, mirror=True)

fig.update_xaxes(range=(-16, 16), row=1, col=1)
fig.update_yaxes(range=(-20, 20), title_text='$y$', row=1, col=1)

fig.update_xaxes(range=(-16, 16), row=1, col=2)
fig.update_yaxes(range=(5, 40), title_text='$z$', row=1, col=2)

fig.update_xaxes(range=(-5, 5), row=2, col=1)
fig.update_yaxes(range=(-5, 5), title_text='$y$', row=2, col=1)

fig.update_xaxes(range=(-5, 5), row=2, col=2)
fig.update_yaxes(range=(12, 27), title_text='$z$', row=2, col=2)
fig.update_layout(template='simple_white', width=800, height=800, 
                 legend=dict(orientation="h",
                    yanchor="bottom",
                    y=1.01,
                    xanchor="left",
                    x=0.01)
)
fig.show()
# fig.write_image('/home/yop/Images/manuscript/DEM/lorenz_inversion_2d_zoom.png')
# fig.write_image('/home/yop/Images/manuscript/DEM/lorenz_inversion_2d_zoom.pdf')


In [238]:
trajs_dem = []
for j, initial_states in enumerate(GLOBAL.xs):
    hdm[0]['x'] = np.array(initial_states)
    deminv = DEMInversion(hdm, states_embedding_order=GLOBAL.states_order)
    dec    = deminv.run(y, nD=1, nE=1, nM=1)
    trajs_dem.append(dec)

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

timestep:   0%|          | 0/512 [00:00<?, ?it/s]

In [239]:

def plot_trajs(trajs, get_x, show=False):
    fig   = make_subplots(cols=2, rows=2)#, column_titles=['$(x, y)$', '$(x,z)$'])
    for j, initial_states in enumerate(GLOBAL.xs):
        if j == 4: continue 
        dec = trajs[j]

        x = get_x(dec)

        color = px.colors.qualitative.Pastel[j+1]

        for i in range(2): 
            fig.add_scatter(x=x[:, 0], y=x[:, 1], line_color=color, 
                            line_width=1, opacity=.8, legendgroup=f'{initial_states}', name=f'{initial_states}', showlegend=i==0, row=i+1, col=1)
            fig.add_scatter(x=x[:, 0], y=x[:, 2], line_color=color, 
                            line_width=1, opacity=.8, legendgroup=f'{initial_states}', showlegend=False, row=i+1, col=2)

    for i in range(2): 
        fig.add_scatter(x=gen.x[:, 0, 0], y=gen.x[:, 0, 1], line_color='black', line_dash='dash', 
                        line_width=1, legendgroup='realized', name='Realized', showlegend=i==0, opacity=0.7, row=i+1, col=1)
        fig.add_scatter(x=gen.x[:, 0, 0], y=gen.x[:, 0, 2], line_color='black', line_dash='dash', 
                        line_width=1, legendgroup='realized', showlegend=False, row=i+1, col=2)

        fig.add_scatter(mode='markers', x=[.9], y=[.8], marker_color='black', showlegend=False, marker_size=10, row=i+1, col=1)
        fig.add_scatter(mode='markers', x=[.9], y=[30.], marker_color='black', showlegend=False, marker_size=10, row=i+1, col=2)

    for j, initial_states in enumerate(GLOBAL.xs):
        if j == 4: continue 

        dec = trajs[j]
        x = get_x(dec)
        
        color = px.colors.qualitative.Pastel[j+1]

        for i in range(2): 

            fig.add_scatter(mode='markers', x=[x[0,0].item()], y=[x[0,1].item()], legendgroup=f'{initial_states}',
                            showlegend=False, marker_color=color, marker_size=10, row=i+1, col=1)
            fig.add_scatter(mode='markers', x=[x[0,0].item()], y=[x[0,2].item()], legendgroup=f'{initial_states}',
                            showlegend=False, marker_color=color, marker_size=10, row=i+1, col=2)

    fig.update_xaxes(title_text='$x$', title_standoff=0, mirror=True)
    fig.update_yaxes(title_standoff=0, mirror=True)

    fig.update_xaxes(range=(-16, 16), row=1, col=1)
    fig.update_yaxes(range=(-20, 20), title_text='$y$', row=1, col=1)

    fig.update_xaxes(range=(-16, 16), row=1, col=2)
    fig.update_yaxes(range=(5, 40), title_text='$z$', row=1, col=2)

    fig.update_xaxes(range=(-5, 5), row=2, col=1)
    fig.update_yaxes(range=(-5, 5), title_text='$y$', row=2, col=1)

    fig.update_xaxes(range=(-5, 5), row=2, col=2)
    fig.update_yaxes(range=(12, 27), title_text='$z$', row=2, col=2)
    fig.update_layout(template='simple_white', width=800, height=800, 
                     legend=dict(
                          orientation='h',
                        yanchor="bottom",
                        y=1.01,
                        xanchor="left",
                        x=0.01)
    )
    if show:
        fig.show()
    return fig

In [240]:
fig = plot_trajs(trajs_dem, lambda dec: dec.qU.x[:, 0]);
fig.update_layout(title_text='Generalised Filtering')

In [241]:
def florenz(x: torch.Tensor):
    x1 = (18 * x[..., 1] - 18 * x[..., 0])
    x2 = (46.92 * x[..., 0] - 2 * x[..., 2] * x[..., 0] - x[..., 1])
    x3 = (2 * x[..., 0] * x[..., 1] - 4 * x[..., 2])
    return torch.stack([x1, x2, x3], dim=-1)

def Jlorenz(x):
    return torch.autograd.functional.jacobian(florenz, x, create_graph=False)

def lorenz_discrete(x, dt=GLOBAL.dt):
    shape = x.shape
    
    dfdx  = Jlorenz(x).reshape((3,3))
    f     = florenz(x).reshape((3,1))
    
    A = torch.cat([f*dt , dfdx*dt], dim=-1)
    B = torch.zeros((1,4))
    J = torch.cat([B, A], 0)
    dx = torch.matrix_exp(J)
    
    dx = dx[1:, 0]
    
    return x + dx.reshape(shape)


g = tk.LinearTransform(torch.ones(1,3))
f = tk.LinearizedTransform(lorenz_discrete)

In [242]:
model_parameters = dict(
    fwd_transform=f, 
    obs_transform=g, 
    state_dim=3, 
    obs_dim=1, 
    observation_noise_cov=(1./GLOBAL.V) * torch.eye(1), 
    process_noise_cov=(1./GLOBAL.W) * torch.eye(3), 
    initial_state_mean=torch.tensor(GLOBAL.x0)
)

In [243]:
y   = torch.from_numpy(gen.v[:, 0, :1].astype(np.float32))

trajs_ekf = []
for j, initial_states in enumerate(GLOBAL.xs):
    model_parameters['initial_state_mean'] = torch.tensor([float(v) for v in initial_states])
    fit_model = tk.GaussianStateSpaceModel(**model_parameters)
    dec = fit_model.filter(y, backward_pass=False)
    trajs_ekf.append(dec)

Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 191.36it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 192.99it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 193.60it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 195.04it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 194.15it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 192.94it/s]


In [244]:
fig = plot_trajs(trajs_ekf, lambda dec: dec['x_prev'].mean.squeeze().numpy());
fig.update_layout(title_text='EKF')

In [245]:
y   = torch.from_numpy(gen.v[:, 0, :1].astype(np.float32))

trajs_ekf_rts = []
for j, initial_states in enumerate(GLOBAL.xs):
    model_parameters['initial_state_mean'] = torch.tensor([float(v) for v in initial_states])
    fit_model = tk.GaussianStateSpaceModel(**model_parameters)
    dec = fit_model.filter(y, backward_pass=True)
    trajs_ekf_rts.append(dec)

Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 171.68it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 4063.26it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 193.14it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 3883.28it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 192.49it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 4384.65it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 191.71it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 4177.67it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 192.47it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 3240.24it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:02<00:00, 193.87it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 4164.78it/s]


In [246]:
fig = plot_trajs(trajs_ekf_rts, lambda dec: dec['x_prev'].mean.squeeze().numpy());
fig.update_layout(title_text='EKF+RTS')

In [272]:
x = tk.Gaussian(torch.tensor(GLOBAL.x0), torch.eye(3))
tk.UnscentedTransform(lorenz_discrete)(x).mean

tensor([[0., 0., 0.]])

In [247]:
model_parameters['fwd_transform'] = tk.UnscentedTransform(lorenz_discrete)

In [248]:
y   = torch.from_numpy(gen.v[:, 0, :1].astype(np.float32))

trajs_ukf = []
for j, initial_states in enumerate(GLOBAL.xs):
    model_parameters['initial_state_mean'] = torch.tensor([float(v) for v in initial_states])
    fit_model = tk.GaussianStateSpaceModel(**model_parameters)
    dec = fit_model.filter(y, backward_pass=False)
    trajs_ukf.append(dec)

Filter: 100%|█████████████████████████████████| 512/512 [00:05<00:00, 99.25it/s]
Filter: 100%|█████████████████████████████████| 512/512 [00:05<00:00, 99.80it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:05<00:00, 100.08it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:05<00:00, 101.40it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:05<00:00, 101.10it/s]
Filter: 100%|█████████████████████████████████| 512/512 [00:05<00:00, 98.06it/s]


In [249]:
fig = plot_trajs(trajs_ukf, lambda dec: dec['x_prev'].mean.squeeze().numpy());
fig.update_layout(title_text='UKF')

In [231]:
y   = torch.from_numpy(gen.v[:, 0, :1].astype(np.float32))

trajs_ukf_rts = []
for j, initial_states in enumerate(GLOBAL.xs):
    model_parameters['initial_state_mean'] = torch.tensor([float(v) for v in initial_states])
    fit_model = tk.GaussianStateSpaceModel(**model_parameters)
    dec = fit_model.filter(y, backward_pass=True)
    trajs_ukf_rts.append(dec)

Filter: 100%|████████████████████████████████| 512/512 [00:05<00:00, 100.52it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 3667.30it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:05<00:00, 101.57it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 4080.00it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:05<00:00, 100.21it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 3982.48it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:05<00:00, 101.94it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 3756.00it/s]
Filter: 100%|█████████████████████████████████| 512/512 [00:05<00:00, 98.56it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 3821.79it/s]
Filter: 100%|████████████████████████████████| 512/512 [00:05<00:00, 101.69it/s]
Smooth: 100%|███████████████████████████████| 511/511 [00:00<00:00, 3839.98it/s]


In [232]:
fig = plot_trajs(trajs_ukf_rts, lambda dec: dec['x_prev'].mean.squeeze().numpy());
fig.update_layout(title_text='UKF')