In [1]:
import torch 
import numpy as np
from torchkf import *
from pprint import pprint
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import logging

np.set_printoptions(linewidth=160, precision=2)

#### Generate some data
The assumed model is: 
\begin{align} 
    y &= \theta_1 x \\ \dot{x} &= \theta_2 x + \theta_3 v 
\end{align}
where 
\begin{align} 
    \theta_1 = \begin{bmatrix} 
        0.1250 & 0.1633 \\
        0.1250 & 0.0676 \\ 
        0.1250 & -0.0676 \\ 
        0.1250 & -0.1633 
     \end{bmatrix} &&
     \theta_2 = \begin{bmatrix} 
         -0.25 & 1.00 \\
         -0.50 & -0.25 
     \end{bmatrix} && 
     \theta_3 = \begin{bmatrix} 
         1 \\ 0
     \end{bmatrix} 
\end{align}

We generate the data with $v = \exp\left(\frac{1}{4} (t - 12)^2\right)$. 

In [2]:
theta1 = torch.tensor([[0.125,  0.1633], 
                       [0.125,  0.0676], 
                       [0.125, -0.0676], 
                       [0.125, -0.1633]])
theta2 = torch.tensor([[-0.25,  1.00],
                       [-0.50, -0.25]])
theta3 = torch.tensor([[1.], 
                       [0.]])
pE = torch.cat([theta1.reshape((-1,)), theta2.reshape((-1,)), theta3.reshape((-1,))])

In [41]:
nps = (torch.numel(theta1),torch.numel(theta2),torch.numel(theta3))
models = [
    GaussianModel(
        g=lambda x, v, P: P[:nps[0]].reshape(theta1.shape) @ x, 
        f=lambda x, v, P: P[nps[0]:nps[0] + nps[1]].reshape(theta2.shape) @ x \
                        + P[-nps[2]:].reshape(theta3.shape) @ v,
        n=2, sv=1./2,sw=1./2,
        V=torch.tensor([np.exp(32.)]), 
        W=torch.tensor([np.exp(32.)]), 
        pE=pE, pC=torch.ones_like(pE) * np.exp(-32)
    ), 
    GaussianModel(l=1, V=torch.tensor([np.exp(32.)]))
]
genmodel = HierarchicalGaussianModel(*models)

In [56]:
nT = 32
t  = np.arange(1, nT+1)  
u  = torch.tensor(np.exp(-(t - 12)**2/4)).unsqueeze(-1)
gen = DEMInversion(genmodel, states_embedding_order=7).generate(nT, u)
y   = gen.v[:,0,:4,0]
px.line(y=[*y.T] + [*gen.x[:, 0, :, 0].T])

In [57]:
nps = (torch.numel(theta1),torch.numel(theta2),torch.numel(theta3))
models = [
    GaussianModel(
        g=lambda x, v, P: P[:nps[0]].reshape(theta1.shape) @ x, 
        f=lambda x, v, P: P[nps[0]:nps[0] + nps[1]].reshape(theta2.shape) @ x \
                        + P[-nps[2]:].reshape(theta3.shape) @ v,
        n=2, sv=1./2,sw=1./2,
        V=torch.tensor([np.exp(8.)]), 
        W=torch.tensor([np.exp(16.)]), 
        pE=pE, pC=torch.ones_like(pE) * np.exp(-32)
    ), 
    GaussianModel(l=1, V=torch.tensor([np.exp(0.)]))
]
decmodel = HierarchicalGaussianModel(*models)

In [58]:
deminv  = DEMInversion(decmodel, states_embedding_order=12)
results = deminv.run(y, nD=1, nE=2, nM=2, K=1, td=1)

Li:  tensor(-inf)
Ai:  tensor(-inf)
Fi:  tensor(-inf)
reg:  tensor(2.0880e-14)
mh:  tensor([])
dp:  tensor(2.8074e-11)
qp:  tensor(2.8074e-11)
Li:  tensor(-inf)
Ai:  tensor(-inf)
mh:  tensor([])
dp:  tensor(0.)
qp:  tensor(0.)


In [59]:
fig = make_subplots(rows=2, cols=2) 
for i in range(4): 
    fig.add_scatter(y=y[:, i], row=1, col=1, showlegend=False, line_color=px.colors.qualitative.T10[i])
    fig.add_scatter(y=results.qU.y[:, 0, i], row=1, col=1, showlegend=False, line_dash='dash', line_color=px.colors.qualitative.T10[i])
fig.add_scatter(y=results.qU.v[:, 0, 0], row=2, col=1, showlegend=False, line_color=px.colors.qualitative.T10[0])
fig.add_scatter(y=gen.v[:, 0, -1, 0], row=2, col=1, showlegend=False, line_dash='dash',line_color=px.colors.qualitative.T10[0])

for i in range(2): 
    fig.add_scatter(y=results.qU.x[:, 0, i], row=1, col=2, showlegend=False, line_color=px.colors.qualitative.T10[i])
    fig.add_scatter(y=gen.x[:, 0, i, 0], row=1, col=2, showlegend=False, line_color=px.colors.qualitative.T10[i], line_dash='dash')
    
fig.update_layout(height=800, width=800, template='simple_white')

In [9]:
x = results.qU.x.reshape((results.qU.x.shape[0], -1))
v = results.qU.v.reshape((results.qU.v.shape[0], -1))
xv = torch.cat([x, v], dim=1)

In [10]:
traj = Gaussian(xv, results.qU.c)[None, ...]
plot_traj(traj,n_states=9)

ValueError: Expected parameter covariance_matrix (Tensor of shape (32, 18, 18)) of distribution Gaussian(loc: torch.Size([32, 18]), covariance_matrix: torch.Size([32, 18, 18])) to satisfy the constraint PositiveDefinite(), but found invalid values:
tensor([[[ 2.7842e-03,  1.8979e-04, -3.9823e-04,  ...,  3.4155e-04,
           1.7678e-04,  0.0000e+00],
         [ 1.8979e-04,  1.6529e-03,  1.4166e-03,  ..., -6.5880e-05,
           1.1614e-04,  0.0000e+00],
         [-3.9823e-04,  1.4166e-03,  1.9212e-03,  ..., -1.4897e-04,
           3.8533e-03,  0.0000e+00],
         ...,
         [ 3.4155e-04, -6.5880e-05, -1.4897e-04,  ...,  8.8476e-03,
           2.4655e-05,  0.0000e+00],
         [ 1.7678e-04,  1.1614e-04,  3.8533e-03,  ...,  2.4655e-05,
           5.5157e-01,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 2.7842e-03,  1.8979e-04, -3.9823e-04,  ...,  3.4155e-04,
           1.7678e-04,  0.0000e+00],
         [ 1.8979e-04,  1.6529e-03,  1.4166e-03,  ..., -6.5880e-05,
           1.1614e-04,  0.0000e+00],
         [-3.9823e-04,  1.4166e-03,  1.9212e-03,  ..., -1.4897e-04,
           3.8533e-03,  0.0000e+00],
         ...,
         [ 3.4155e-04, -6.5880e-05, -1.4897e-04,  ...,  8.8476e-03,
           2.4655e-05,  0.0000e+00],
         [ 1.7678e-04,  1.1614e-04,  3.8533e-03,  ...,  2.4655e-05,
           5.5157e-01,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 2.7842e-03,  1.8979e-04, -3.9823e-04,  ...,  3.4155e-04,
           1.7678e-04,  0.0000e+00],
         [ 1.8979e-04,  1.6529e-03,  1.4166e-03,  ..., -6.5880e-05,
           1.1614e-04,  0.0000e+00],
         [-3.9823e-04,  1.4166e-03,  1.9212e-03,  ..., -1.4897e-04,
           3.8533e-03,  0.0000e+00],
         ...,
         [ 3.4155e-04, -6.5880e-05, -1.4897e-04,  ...,  8.8476e-03,
           2.4655e-05,  0.0000e+00],
         [ 1.7678e-04,  1.1614e-04,  3.8533e-03,  ...,  2.4655e-05,
           5.5157e-01,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        ...,

        [[ 2.7842e-03,  1.8979e-04, -3.9823e-04,  ...,  3.4155e-04,
           1.7678e-04,  0.0000e+00],
         [ 1.8979e-04,  1.6529e-03,  1.4166e-03,  ..., -6.5880e-05,
           1.1614e-04,  0.0000e+00],
         [-3.9823e-04,  1.4166e-03,  1.9212e-03,  ..., -1.4897e-04,
           3.8533e-03,  0.0000e+00],
         ...,
         [ 3.4155e-04, -6.5880e-05, -1.4897e-04,  ...,  8.8476e-03,
           2.4655e-05,  0.0000e+00],
         [ 1.7678e-04,  1.1614e-04,  3.8533e-03,  ...,  2.4655e-05,
           5.5157e-01,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 2.7842e-03,  1.8979e-04, -3.9823e-04,  ...,  3.4155e-04,
           1.7678e-04,  0.0000e+00],
         [ 1.8979e-04,  1.6529e-03,  1.4166e-03,  ..., -6.5880e-05,
           1.1614e-04,  0.0000e+00],
         [-3.9823e-04,  1.4166e-03,  1.9212e-03,  ..., -1.4897e-04,
           3.8533e-03,  0.0000e+00],
         ...,
         [ 3.4155e-04, -6.5880e-05, -1.4897e-04,  ...,  8.8476e-03,
           2.4655e-05,  0.0000e+00],
         [ 1.7678e-04,  1.1614e-04,  3.8533e-03,  ...,  2.4655e-05,
           5.5157e-01,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 2.7842e-03,  1.8979e-04, -3.9823e-04,  ...,  3.4155e-04,
           1.7678e-04,  0.0000e+00],
         [ 1.8979e-04,  1.6529e-03,  1.4166e-03,  ..., -6.5880e-05,
           1.1614e-04,  0.0000e+00],
         [-3.9823e-04,  1.4166e-03,  1.9212e-03,  ..., -1.4897e-04,
           3.8533e-03,  0.0000e+00],
         ...,
         [ 3.4155e-04, -6.5880e-05, -1.4897e-04,  ...,  8.8476e-03,
           2.4655e-05,  0.0000e+00],
         [ 1.7678e-04,  1.1614e-04,  3.8533e-03,  ...,  2.4655e-05,
           5.5157e-01,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]])

In [None]:
results.qP