In [1]:
import torch 
import numpy as np
from torchkf import *
from pprint import pprint
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import logging

np.set_printoptions(linewidth=160, precision=2)

#### Generate some data
The assumed model is: 
\begin{align} 
    y &= \theta_1 x \\ \dot{x} &= \theta_2 x + \theta_3 v 
\end{align}
where 
\begin{align} 
    \theta_1 = \begin{bmatrix} 
        0.1250 & 0.1633 \\
        0.1250 & 0.0676 \\ 
        0.1250 & -0.0676 \\ 
        0.1250 & -0.1633 
     \end{bmatrix} &&
     \theta_2 = \begin{bmatrix} 
         -0.25 & 1.00 \\
         -0.50 & -0.25 
     \end{bmatrix} && 
     \theta_3 = \begin{bmatrix} 
         1 \\ 0
     \end{bmatrix} 
\end{align}

We generate the data with $v = \exp\left(\frac{1}{4} (t - 12)^2\right)$. 

In [11]:
theta1 = np.array([[0.125,  0.1633], 
                       [0.125,  0.0676], 
                       [0.125, -0.0676], 
                       [0.125, -0.1633]])
theta2 = np.array([[-0.25,  1.00],
                       [-0.50, -0.25]])
theta3 = np.array([[1.], [0.]])
pE = np.concatenate([theta1.reshape((-1,)), theta2.reshape((-1,)), theta3.reshape((-1,))])

In [12]:
nps = (theta1.size,theta2.size,theta3.size)
models = [
    GaussianModel(
        g=lambda x, v, P: P[:nps[0]].reshape(theta1.shape) @ x, 
        f=lambda x, v, P: P[nps[0]:nps[0] + nps[1]].reshape(theta2.shape) @ x \
                        + P[-nps[2]:].reshape(theta3.shape) @ v,
        n=2, sv=1./2,sw=1./2,
        V=np.array([np.exp(8.)]), 
        W=np.array([np.exp(16.)]), 
        pE=pE, pC=np.ones_like(pE) * np.exp(-64)
    ), 
    GaussianModel(l=1, V=np.array([np.exp(32.)]))
]
genmodel = HierarchicalGaussianModel(*models)

Compiling derivatives, it might take some time... f() ok ... g() ok ... Done. 


In [13]:
nT = 32
t  = np.arange(1, nT+1)  
u  = (np.exp(-(t - 12)**2/4))[:, None]
gen = DEMInversion(genmodel, states_embedding_order=4).generate(nT, u)
y   = gen.v[:,0,:4,0]
px.line(y=[*y.T] + [*gen.x[:, 0, :, 0].T])

  0%|          | 0/32 [00:00<?, ?it/s]

In [14]:
decmodel = genmodel
decmodel[1].V = np.ones((1,1))

In [15]:
deminv  = DEMInversion(decmodel, states_embedding_order=4)
results = deminv.run(y, nD=1, nE=1, nM=1, K=1, td=1)

[array([[1.6e-28, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00],
       [0.0e+00, 1.6e-28, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00],
       [0.0e+00, 0.0e+00, 1.6e-28, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00],
       [0.0e+00, 0.0e+00, 0.0e+00, 1.6e-28, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00],
       [0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 1.6e-28, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00],
       [0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 1.6e-28, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00],
       [0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 1.6e-28, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00],
       [0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0

timestep:   0%|          | 0/32 [00:00<?, ?it/s]

In [16]:
fig = make_subplots(rows=2, cols=2) 

fig.add_scatter(y=results.qU.v[:, 0, 0], row=2, col=1, showlegend=True, legendgroup='estimated', name='Estimated', line_color=px.colors.qualitative.T10[0])
fig.add_scatter(y=gen.v[:, 0, -1, 0], row=2, col=1, showlegend=True, legendgroup='realized', name='Realized', line_dash='dash',line_color=px.colors.qualitative.T10[0])

for i in range(4): 
    fig.add_scatter(y=results.qU.y[:, 0, i], row=1, col=1,legendgroup='estimated', showlegend=False, line_color=px.colors.qualitative.T10[i])
    fig.add_scatter(y=y[:, i], row=1, col=1, legendgroup='realized',showlegend=False, line_dash='dash', line_color=px.colors.qualitative.T10[i])

for i in range(2): 
    fig.add_scatter(y=results.qU.x[:, 0, i], row=1, col=2, legendgroup='estimated',showlegend=False, line_color=px.colors.qualitative.T10[i])
    fig.add_scatter(y=gen.x[:, 0, i, 0], row=1, col=2, legendgroup='realized',showlegend=False, line_color=px.colors.qualitative.T10[i], line_dash='dash')
    
fig.update_layout(height=800, width=800, template='simple_white')

In [54]:
nps   = (theta1.size,theta2.size,theta3.size)
ip    = [0, 10] 
P     = pE.copy()
P[ip] = 0
pC    = np.zeros_like(pE)
pC[ip]= np.exp(8)

models = [
    GaussianModel(
        g=lambda x, v, P: P[:nps[0]].reshape(theta1.shape) @ x, 
        f=lambda x, v, P: P[nps[0]:nps[0] + nps[1]].reshape(theta2.shape) @ x \
                        + P[-nps[2]:].reshape(theta3.shape) @ v,
        n=2, sv=1./2,sw=1./2,
#         V=np.array([np.exp(8.)]), 
#         W=np.array([np.exp(16.)]), 
        Q=[np.eye(4)], R=[np.eye(2)],
#         hE=np.array([.]), hC=np.array([[np.exp(-8)]]),
#         gE=np.array([0.]), gC=np.array([[np.exp(-16)]]),
        pE=P, pC=pC, 
    ), 
    GaussianModel(l=1, V=np.array([np.exp(0.)]))
]
decdualmodel = HierarchicalGaussianModel(*models)

Compiling derivatives, it might take some time... f() ok ... g() ok ... Done. 


In [55]:
deminv  = DEMInversion(decdualmodel, states_embedding_order=4)
# deminv.logger.setLevel('INFO')
results = deminv.run(y, u, nD=1, nE=128, nM=8, K=1., Emin=0, tol=np.finfo(np.float64).eps)
px.line(y=[results.F]).update_layout(template='simple_white', height=400, width=800)

[array([[2980.96,    0.  ],
       [   0.  , 2980.96]])]


E-step (F = -inf):   0%|          | 0/128 [00:00<?, ?it/s]

  M-step:   0%|          | 0/8 [00:00<?, ?it/s]

timestep:   0%|          | 0/32 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [36]:
pE[0], results.qP.P[0], pE[10], results.qP.P[10]

(0.125, array([0.11]), -0.5, array([-0.45]))

In [37]:
fig = make_subplots(rows=2, cols=2) 

fig.add_scatter(y=results.qU.v[:, 0, 0], row=2, col=1, showlegend=True, legendgroup='estimated', name='Estimated', line_color=px.colors.qualitative.T10[0])
fig.add_scatter(y=gen.v[:, 0, -1, 0], row=2, col=1, showlegend=True, legendgroup='realized', name='Realized', line_dash='dash',line_color=px.colors.qualitative.T10[0])

for i in range(4): 
    fig.add_scatter(y=results.qU.y[:, 0, i], row=1, col=1,legendgroup='estimated', showlegend=False, line_color=px.colors.qualitative.T10[i])
    fig.add_scatter(y=y[:, i], row=1, col=1, legendgroup='realized',showlegend=False, line_dash='dash', line_color=px.colors.qualitative.T10[i])

for i in range(2): 
    fig.add_scatter(y=results.qU.x[:, 0, i], row=1, col=2, legendgroup='estimated',showlegend=False, line_color=px.colors.qualitative.T10[i])
    fig.add_scatter(y=gen.x[:, 0, i, 0], row=1, col=2, legendgroup='realized',showlegend=False, line_color=px.colors.qualitative.T10[i], line_dash='dash')
    
fig.update_layout(height=800, width=800, template='simple_white')

In [38]:
import pstats
p = pstats.Stats('../profresults')
p.strip_dirs().sort_stats("cumulative").print_stats()

Fri Jul  8 01:27:29 2022    ../profresults

         24654184 function calls (23711294 primitive calls) in 30.688 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   1980/1    0.009    0.000   30.711   30.711 {built-in method builtins.exec}
        1    0.006    0.006   30.710   30.710 dem_lorenz.py:1(<module>)
        1    5.260    5.260   22.195   22.195 dem.py:127(run)
     4096    4.992    0.001   13.423    0.003 dem_de.py:17(dem_eval_err_diff)
        1    0.357    0.357    5.736    5.736 dem.py:732(generate)
      187    0.004    0.000    4.099    0.022 __init__.py:1(<module>)
        1    3.254    3.254    3.318    3.318 dem_z.py:8(dem_z)
   638993    2.305    0.000    2.532    0.000 dem_structs.py:17(kron)
     8193    0.085    0.000    2.389    0.000 dem_dx.py:74(compute_dx)
815285/766106    1.546    0.000    2.386    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
     8193    2.107 

        1    0.000    0.000    0.001    0.001 dis.py:1(<module>)
        1    0.000    0.000    0.001    0.001 line.py:1(<module>)
     8678    0.001    0.000    0.001    0.000 numbers.py:2289(__index__)
     1827    0.001    0.000    0.001    0.000 __init__.py:2825(__getattr__)
      481    0.001    0.000    0.001    0.000 enum.py:75(__setitem__)
        1    0.000    0.000    0.001    0.001 _basic.py:7(<module>)
      748    0.000    0.000    0.001    0.000 numbers.py:1222(_as_mpf_val)
    22272    0.001    0.000    0.001    0.000 ndim_array.py:270(shape)
      184    0.001    0.000    0.001    0.000 <frozen importlib._bootstrap_external>:1459(__init__)
     1648    0.001    0.000    0.001    0.000 operations.py:429(make_args)
       61    0.000    0.000    0.001    0.000 pyparsing.py:3292(<listcomp>)
        1    0.000    0.000    0.001    0.001 _pocketfft.py:1(<module>)
        1    0.000    0.000    0.001    0.001 signal.py:1(<module>)
       70    0.000    0.000    0.001    0.000

      270    0.000    0.000    0.000    0.000 weakref.py:382(__getitem__)
        2    0.000    0.000    0.000    0.000 std.py:1262(close)
       69    0.000    0.000    0.000    0.000 add.py:680(<lambda>)
     1100    0.000    0.000    0.000    0.000 version.py:344(local)
     2464    0.000    0.000    0.000    0.000 inspect.py:2551(kind)
     1532    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:834(create_module)
        1    0.000    0.000    0.000    0.000 selecting.py:1(<module>)
        1    0.000    0.000    0.000    0.000 conventions.py:1(<module>)
      136    0.000    0.000    0.000    0.000 numbers.py:2222(__mod__)
      228    0.000    0.000    0.000    0.000 function.py:144(<listcomp>)
        1    0.000    0.000    0.000    0.000 gmpyfinitefield.py:1(<module>)
     1019    0.000    0.000    0.000    0.000 permutations.py:1057(array_form)
        1    0.000    0.000    0.000    0.000 integerring.py:1(<module>)
     1734    0.000    0.000    0.000

        1    0.000    0.000    0.000    0.000 unparser.py:1(<module>)
        1    0.000    0.000    0.000    0.000 extensions.py:1(<module>)
      257    0.000    0.000    0.000    0.000 hmac.py:18(<genexpr>)
        1    0.000    0.000    0.000    0.000 conv.py:301(Conv2d)
        2    0.000    0.000    0.000    0.000 __init__.py:559(__init__)
       21    0.000    0.000    0.000    0.000 __init__.py:218(_acquireLock)
      150    0.000    0.000    0.000    0.000 traceback.py:243(__init__)
        1    0.000    0.000    0.000    0.000 _polybase.py:18(ABCPolyBase)
       46    0.000    0.000    0.000    0.000 __init__.py:2381(_is_unpacked_egg)
       92    0.000    0.000    0.000    0.000 {built-in method math.log}
        1    0.000    0.000    0.000    0.000 sparse_adam.py:1(<module>)
        1    0.000    0.000    0.000    0.000 identification.py:1(<module>)
       22    0.000    0.000    0.000    0.000 _typing.py:292(<genexpr>)
        3    0.000    0.000    0.000    0.000 numba_.

        1    0.000    0.000    0.000    0.000 offsets.py:1(<module>)
       51    0.000    0.000    0.000    0.000 decorators.py:72(call_highest_priority)
       45    0.000    0.000    0.000    0.000 ctx_iv.py:394(<lambda>)
        1    0.000    0.000    0.000    0.000 conv.py:591(_ConvTransposeNd)
        1    0.000    0.000    0.000    0.000 matmul.py:22(MatMul)
        1    0.000    0.000    0.000    0.000 embedding_ops.py:67(Embedding)
        1    0.000    0.000    0.000    0.000 polyutils.py:170(<listcomp>)
        1    0.000    0.000    0.000    0.000 sqfreetools.py:1(<module>)
        2    0.000    0.000    0.000    0.000 {method 'newbyteorder' of 'numpy.generic' objects}
       10    0.000    0.000    0.000    0.000 decorator.py:127(doctest_depends_on)
        2    0.000    0.000    0.000    0.000 _ops.py:75(__getattr__)
       13    0.000    0.000    0.000    0.000 __init__.py:1224(__init__)
        1    0.000    0.000    0.000    0.000 loss.py:1429(TripletMarginWithDistance

        1    0.000    0.000    0.000    0.000 free_groups.py:114(FreeGroup)
        1    0.000    0.000    0.000    0.000 indexed.py:124(Indexed)
        7    0.000    0.000    0.000    0.000 pyparsing.py:4197(__str__)
        1    0.000    0.000    0.000    0.000 homomorphisms.py:18(ModuleHomomorphism)
        1    0.000    0.000    0.000    0.000 rnn.py:896(RNNCell)
        1    0.000    0.000    0.000    0.000 typing.py:1561(SupportsIndex)
        1    0.000    0.000    0.000    0.000 _datasource.py:536(Repository)
        1    0.000    0.000    0.000    0.000 _xlwt.py:21(XlwtWriter)
        1    0.000    0.000    0.000    0.000 __init__.py:1619(_register)
        1    0.000    0.000    0.000    0.000 orderings.py:64(ProductOrder)
        1    0.000    0.000    0.000    0.000 blocks.py:1681(NDArrayBackedExtensionBlock)
        1    0.000    0.000    0.000    0.000 flags.py:4(Flags)
        5    0.000    0.000    0.000    0.000 enum.py:676(_generate_next_value_)
       22    0.000   

        1    0.000    0.000    0.000    0.000 boolalg.py:377(BooleanFalse)
        1    0.000    0.000    0.000    0.000 integer.py:520(Int32Dtype)
        1    0.000    0.000    0.000    0.000 _datasource.py:99(__init__)
        1    0.000    0.000    0.000    0.000 ast.py:1711(Print)
        1    0.000    0.000    0.000    0.000 version.py:349(_cmp)
        1    0.000    0.000    0.000    0.000 __init__.py:162(remove_shim)
        3    0.000    0.000    0.000    0.000 six.py:880(add_metaclass)
        4    0.000    0.000    0.000    0.000 utils.py:135(disable_on_exception)
        1    0.000    0.000    0.000    0.000 formal.py:1697(FormalPowerSeriesInverse)
        1    0.000    0.000    0.000    0.000 __init__.py:40(__init__)
        1    0.000    0.000    0.000    0.000 loss.py:1345(TripletMarginLoss)
        1    0.000    0.000    0.000    0.000 util.py:182(Finalize)
        1    0.000    0.000    0.000    0.000 swa_utils.py:170(SWALR)
        1    0.000    0.000    0.000    0.00

        1    0.000    0.000    0.000    0.000 queue.py:255(_PySimpleQueue)
        1    0.000    0.000    0.000    0.000 plistlib.py:178(UID)
        1    0.000    0.000    0.000    0.000 hooks.py:64(BackwardHook)
        1    0.000    0.000    0.000    0.000 sympy_parser.py:1231(_T)
        1    0.000    0.000    0.000    0.000 python_parser.py:1087(FixedWidthReader)
        1    0.000    0.000    0.000    0.000 _optional.py:53(get_version)
        1    0.000    0.000    0.000    0.000 factor_.py:2203(udivisor_sigma)
        1    0.000    0.000    0.000    0.000 context.py:197(get_start_method)
        4    0.000    0.000    0.000    0.000 function.py:295(_iter_filter)
        2    0.000    0.000    0.000    0.000 index_tricks.py:754(__init__)
        1    0.000    0.000    0.000    0.000 locks.py:515(BoundedSemaphore)
        2    0.000    0.000    0.000    0.000 core.py:878(__init__)
        1    0.000    0.000    0.000    0.000 boolalg.py:1465(Exclusive)
        1    0.000    0.000

        1    0.000    0.000    0.000    0.000 argparse.py:749(ArgumentTypeError)
        1    0.000    0.000    0.000    0.000 pyparsing.py:4364(Suppress)
        1    0.000    0.000    0.000    0.000 polyerrors.py:130(ComputationFailed)
        1    0.000    0.000    0.000    0.000 _arrays.py:68(csc_array)
        1    0.000    0.000    0.000    0.000 stata.py:497(ValueLabelTypeMismatch)
        1    0.000    0.000    0.000    0.000 context.py:286(ForkServerProcess)
        1    0.000    0.000    0.000    0.000 expressions.py:486(VariableNode)
        1    0.000    0.000    0.000    0.000 socket.py:210(_GiveupOnSendfile)
        1    0.000    0.000    0.000    0.000 indexed.py:120(IndexException)
        1    0.000    0.000    0.000    0.000 function.py:100(PoleError)
        1    0.000    0.000    0.000    0.000 _special_inputs.py:32(Range)
        1    0.000    0.000    0.000    0.000 _script.py:302(_CachedForward)
        1    0.000    0.000    0.000    0.000 dispatcher.py:10(MDNot

<pstats.Stats at 0x7fefd2f847f0>

In [None]:
import sympy
from numba import jit

In [None]:
p = np.empty_like(theta1, dtype=sympy.Symbol)
p.flat = sympy.symbols(f'p0:{p.size}')
p = sympy.Matrix(p)

In [None]:
np.asarray(dfdp(np.ones_like(x).flat, theta1.flatten())).shape

In [None]:
p.flat()

In [None]:
x  = sympy.MatrixSymbol('x', models[0].n, 1)
v  = sympy.MatrixSymbol('v', models[0].m, 1)
pE = sympy.MatrixSymbol('p', models[0].p, 1)

In [None]:
from sympy.utilities.autowrap import autowrap
import sympy

In [None]:
x  = np.array(sympy.symbols(f'x0:{models[0].n}')).reshape((models[0].n, 1))
v  = np.array(sympy.symbols(f'v0:{models[0].m}')).reshape((models[0].m, 1))
p  = np.array(sympy.symbols(f'p0:{models[0].p}')).reshape((models[0].p, 1))
q  = np.array(sympy.symbols(f'q0:{models[0].p}')).reshape((models[0].p, 1))
u  = np.array(sympy.symbols(f'u0:{models[0].p**2}')).reshape((models[0].p, models[0].p))

In [None]:
o = autowrap(op)

In [None]:
p = sympy.MatrixSymbol('p', 14, 1)
u = sympy.MatrixSymbol('u', 14, 14)
q = sympy.MatrixSymbol('q', 14, 1)

In [None]:
sympy.Matrix(u @ q) + 

In [None]:
op = sympy.Matrix(p + u @ q)

In [None]:
fc = autowrap(op, args=(p, u, q))

In [None]:
op = sympy.MutableDenseMatrix(p + u @ q)

In [None]:
np.asarray(p)

In [None]:
func = sympy.lambdify((p, u, q), op, 'numpy', cse=True)

In [None]:
import itertools

class autowrapnd: 
    def __init__(self, expr, *args, **kwargs): 
        self._func = autowrap(expr, *args, **kwargs)
    
    def __call__(self, *args): 
        return self._func(*itertools.chain.from_iterable(np.array(_).flat for _ in args))

In [None]:
np.ndarray.flat

In [None]:
func = autowrapnd(op)

In [None]:
args = [*(torch.randn(*_.shape) for _ in (p, u, q))]

In [None]:
func(*args)

In [None]:
_inn = autowrap(op)
func = lambda p, u, q: _inn(*p.flat, *u.flat, *q.flat)

In [None]:
%timeit (lambda p, u, q: p + u @ q)(*(np.random.randn(*_.shape) for _ in (p, u, q))) 

In [None]:
ins = [*(np.random.randn(*_.shape) for _ in (p, u, q))]
inp = [*(torch.randn(*_.shape) for _ in (p, u, q))]

In [None]:
%timeit func(*ins)

In [None]:
%timeit func(*inp)

In [None]:
sympy.Matrix(lambda x, v, p: models[0].f(x, v, p + u @ q))

In [None]:
np.prod((3,4,5))

In [None]:
import math

class cdotdict(dotdict):
    """callable dot dict""" 
    def __call__(self, *args, **kwargs):
        return dotdict({
            k: v(*args, **kwargs) if callable(v) else v
            for k, v in self.items()
        })

def compute_sym_df_d2f(func, *dims, input_keys=None, wrt=None, cast_to=np.ndarray):
    """ 
    Use symbolic differentiation to compute jacobian and hessian of a function of 3 vectors. 
     - func: if the function to differentiate (must return a vector, ie a tensor (l, ...) where ... are empty or 1's)
     - dims: a list of tuple containing the dimensions of each argument
    Returns: (df, d2f) where: 
     - df.dx, df.dv, and df.dp contains the jacobians wrt each argument
     - d2f.dx.dx, ... contains the ndim-hessians wrt each pair of arguments
    """
    if cast_to == np.ndarray: 
        cast = lambda x: np.array(x, dtype=np.float64)
    elif cast_to == torch.tensor: 
        cast = lambda x: torch.from_numpy(np.array(x, dtype=np.float64))
    elif callable(cast_to):
        cast = cast_to
    else: raise NotImplementedError()
        
    if input_keys is None: 
        import string
        input_keys = string.ascii_lowercase[:len(dims)]
    else: 
        assert(len(dims) == len(input_keys))
    
    if wrt is None:
        wrt = input_keys
    else: 
        assert(all(_ in input_keys for _ in wrt))
        
    wrt = [f'd{k}' for k in wrt]
    dims = [(dim,) if isinstance(dim, int) else dim for dim in dims]
    
    # Squeeze column vectors
    squeezedims = [dim if len(dim) == 1 or dim[1] != 1 else (dim[0],) for dim in dims]
    
    # compute flat dimension 
    flatdims = [math.prod(dim) for dim in dims]
    
    # create flat variables
    flatvar = [
        (f'd{k}', np.array(sympy.symbols(f'{k}0:{n}'))) 
        for k, n in zip(input_keys, flatdims)
    ]
    
    var = [
        (f'd{k}', np.array(sympy.symbols(f'{k}0:{n}')).reshape(dim))
        for k, n, dim in zip(input_keys, flatdims, dims)
    ]
    
    args = [v[1] for v in var]

    fxvp = sympy.Matrix(func(*args))
    l = fxvp.shape[0]

    dfsymb  = dotdict({
        d: fxvp.jacobian(sym)
        for d, sym in flatvar
        if d in wrt
    })

    df  = cdotdict()
    d2f = cdotdict()

    for i, (d1, sym1) in enumerate(flatvar):
        if d1 not in wrt: continue 
            
        if d1 not in d2f.keys():
            d2f[d1] = cdotdict()
            
        for j, (d2, sym2) in enumerate(flatvar): 
            if j < i: continue
            if d2 not in wrt: continue 
                
            if d2 not in d2f.keys(): 
                d2f[d2] = cdotdict()

            h  = sympy.MutableDenseNDimArray((dfsymb[d1].reshape(l * sym1.shape[0], 1).jacobian(sym2)))
            h  = h.reshape(l, sym1.shape[0], sym2.shape[0])
            ht = sympy.permutedims(h, (0, 2, 1))
            
            h  = sympy.Matrix(h.reshape(l, sym1.shape[0]*sym2.shape[0]))
            ht = sympy.Matrix(h.reshape(l, sym1.shape[0]*sym2.shape[0]))
            
            if len(h.free_symbols) > 0: 

                d2f[d1][d2] = lambda *_args, _func=autowrapnd(h), _target_shape=(l, *squeezedims[i], *squeezedims[j]):\
                    cast(_func(*_args)).reshape(_target_shape)
                d2f[d2][d1] = lambda *_args, _func=autowrapnd(ht), _target_shape=(l, *squeezedims[j], *squeezedims[i]):\
                    cast(_func(*_args)).reshape(_target_shape)
            else:
                d2f[d1][d2] = lambda *_args, _symb=h, _target_shape=(l, *squeezedims[i], *squeezedims[j]):\
                    cast(_symb).reshape(_target_shape)
                d2f[d2][d1] = lambda *_args, _symb=ht, _target_shape=(l, *squeezedims[j], *squeezedims[i]):\
                    cast(_symb).reshape(_target_shape)
                
        J = dfsymb[d1]
        if len(J.free_symbols) > 0:
            df[d1] = lambda *_args, _func=autowrapnd(J), _target_shape=(l, *squeezedims[i]): \
                cast(_func(*_args)).reshape(_target_shape)
        else: 
            df[d1] = lambda *_args, _symb=J, _target_shape=(l, *squeezedims[i]): \
                cast(_symb).reshape(_target_shape)
        
    return df, d2f
    

In [None]:
import sympy

In [None]:
autowrapnd(

In [None]:
df, d2f = compute_sym_df_d2f(
    lambda x, v, p, u, q: models[0].f(x, v, p + u @ q), models[0].n, models[0].m, models[0].p, (models[0].p, models[0].p), models[0].p, 
    input_keys='xvpur', wrt='xvp', cast_to=lambda x:x)

In [None]:
print(df.dx(x, v, p, u, q))

In [None]:
df  = [sympy.Matrix(f(x,v,r)).jacobian(_) for _ in (x, v, r)]
df[2] = df[2] @ u
d2f = [_.jacobian(p) @ u for _ in df]

dF  = [sympy.Matrix(F(x,v,p,u,q)).jacobian(_) for _ in (x, v, q)]
d2F = [_.jacobian(q) for _ in dF]
df, d2f, dF, d2F

In [None]:
sympy.Matrix(p + u@q).jacobian(q) == sympy.Matrix(u)

In [None]:
x, v, p, u, q = genmodel[0].x, genmodel[1].v, genmodel[0].pE, torch.diag(genmodel[0].pE[:, 0]), genmodel[0].pE, 
d2finst = d2f(x, v, p, u, q)
dfinst  = df(x, v, p, u, q)

In [None]:
{k: v.shape for k, v in dfinst.items()}, {k: {k_: v_.shape for k_, v_ in v.items()} for k, v in d2finst.items()},

In [None]:
d2f.dv.dp.shape

In [None]:
dfdx.free_symbols

In [None]:
sympy.lambdify((p,), sympy.ImmutableDenseNDimArray(sympy.diff(dfdv, p)), 'numpy')(p)

In [None]:
dfdx = sympy.Matrix(models[0].f(x, v, p)).jacobian(x)
dfdx = sympy.diff(sympy.Matrix(models[0].f(x, v, p)), x.reshape(2))
dfdx

In [None]:
dfdx = sympy.Matrix(models[0].f(x, v, p)).jacobian(x)
dfdv = sympy.Matrix(models[0].f(x, v, p)).jacobian(v)
dfdp = sympy.Matrix(models[0].f(x, v, p)).jacobian(p)
dfdx, dfdv, dfdp

In [None]:
J = sympy.lambdify((x, p), (p @ x).jacobian(x.flat()))
J(np.ones_like(x.flat()), theta1.numpy().flat)

In [None]:
dfdp = sympy.utilities.lambdify((x, p), dp, 'numpy')

In [None]:
dx = sympy.diff(p @ x, x)
dp = sympy.ImmutableDenseNDimArray(sympy.diff(p @ x, p))
dx, dp

In [None]:
p = sympy.Matrix(theta1) 
x = sympy.Matrix(sympy.symbols('x0:2'))
dx = sympy.diff(p @ x, x)
dx.shape

In [None]:
x = results.qU.x.reshape((results.qU.x.shape[0], -1))
v = results.qU.v.reshape((results.qU.v.shape[0], -1))
xv = torch.cat([x, v], dim=1)

In [None]:
traj = Gaussian(xv, results.qU.c)[None, ...]
plot_traj(traj,n_states=9)

In [None]:
results.qP