In [1]:
import numpy as np
import torch
from torchdiffeq import odeint

In [2]:
F = torch.rand(2, 2) # transition


# define noise as present in KF
Q = torch.eye(2) * torch.pow(torch.randn(2), 2)
# if we assume white noise PSD and covariance are the same
Qc = Q

interval = torch.Tensor([0, 0.3])
t = interval[1]

def discretize(F, Q, Qc, t):
    A = torch.matrix_exp(F * t)

    n = F.shape[0]

    # matrix fraction decomposition
    M = torch.zeros(2*n, 2*n)
    M[:n,:n] = F
    # Formula says L @ Qc @ L but L is identity
    M[:n,n:] = Qc
    M[n:,n:] = -F.T
    M = torch.matrix_exp(M * t) @ torch.cat([torch.zeros(n, n), torch.eye(n, n)])
    C, D = M[:n], M[n:]
    L = C @ torch.inverse(D)

    return A, L

A, L = discretize(F, Q, Qc, t)
print(L)

def discrete_step(m, P):
    m = A @ m
    P = A @ P @ A.T + L
    return m, P

def ode_func(t, state):
    m, P = state
    dm = F @ m
    dP = F @ P + P @ F.T + Q
    return dm, dP

tensor([[0.1150, 0.0023],
        [0.0023, 0.0329]])


In [3]:
m, P = torch.rand(2, 1), torch.rand(2, 2)
solution = odeint(ode_func, (m, P), interval)
m1, P1 = solution[0][1], solution[1][1]
print('ODE solution')
print('m:', m1.flatten())
print('P:', P1.flatten())

m2, P2 = discrete_step(m, P)
print('Discretized solution')
print('m:', m2.flatten())
print('P:', P2.flatten())

assert torch.isclose(m1, m2).all() 
assert torch.isclose(P1, P2).all()

ODE solution
m: tensor([0.7024, 0.6069])
P: tensor([1.1578, 0.3395, 1.0126, 0.0768])
Discretized solution
m: tensor([0.7024, 0.6069])
P: tensor([1.1578, 0.3395, 1.0126, 0.0768])
