In [30]:
import numpy as np
from numpy import linalg
import scipy as sp
# from scipy import linalg
import matplotlib.pyplot as plt
%matplotlib inline

In [155]:
A = np.eye(2)
B = np.eye(2)[1]
Z = lambda x: np.array([x[1], x[0]**2])
u = lambda Z, x: np.array([np.random.randn()])

def sim_forward(A, B, Z, u, x_0, sigma_w, n_steps):
    xs = np.zeros((n_steps+1, len(x_0)))
    xs[0] = x_0
    Zs = np.zeros((n_steps, Z(x_0).shape[0]))
    us = np.zeros((n_steps, u(Z, x_0).shape[0]))
    for i in range(n_steps):
        us[i] = u(Z, xs[i])
        Zs[i] = Z(xs[i])
        xs[i+1] = A@Zs[i] + B*us[i] + np.random.randn()*sigma_w
    return np.stack(xs), np.stack(Zs), np.stack(us)

In [161]:
n_runs = 25
n_steps = 10
x_0 = np.zeros(2)
xs = np.zeros((n_runs, n_steps+1, len(x_0)))
zs = np.zeros((n_runs, n_steps, Z(x_0).shape[0]))
us = np.zeros((n_runs, n_steps, u(Z, x_0).shape[0]))
for i in range(n_runs):
    xs[i], zs[i], us[i] = sim_forward(A, B, Z, u, x_0, 
                                      0.1, n_steps)
xs_reshaped = xs[:, 1:].reshape(n_runs*n_steps, -1)
zs = zs.reshape(n_runs*n_steps, -1)
us = us.reshape(n_runs*n_steps, -1)

In [77]:
Zhats

array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  2.06567521e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00, -1.29558065e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  5.10770803e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.62314601e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00, -8.26241176e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  9.14260770e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.47438278e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00, -2.51777662e-01],
       [ 0.00000000e+00,  0.00000000e+00,  0.000

In [517]:
Zs = np.hstack((zs, us))
# try a different Z basis
Zhat = lambda x: np.array([x[0], x[1], x[0]*x[1], x[0]**2, x[1]**2, x[0]**3, x[0]**4, x[0]**5, x[0]**6])
zhats = np.array([[Zhat(xs[i, j, :]) for j in range(xs.shape[1]-1)] 
                   for i in range(xs.shape[0])]
                   )
zhats = zhats.reshape(n_runs*n_steps, -1)
Zhats = np.hstack((zhats, us))

In [179]:
(linalg.inv(Zs.T@Zs)@Zs.T@xs_reshaped).T


array([[ 1.00000022e+00, -7.12883225e-08, -8.27746845e-04],
       [ 2.15819261e-07,  9.99999929e-01,  9.99172253e-01]])

In [518]:
fake_parms = (linalg.inv(Zhats.T@Zhats)@Zhats.T@xs_reshaped)
fake_parms

array([[-2.76670014e-03, -8.45530337e-03],
       [ 1.00000520e+00,  8.02200706e-03],
       [ 1.98333562e-07,  1.56895660e-05],
       [ 8.94061496e-05,  1.00010187e+00],
       [-3.14737550e-11,  3.08067663e-09],
       [-6.63905644e-07, -6.77616811e-07],
       [ 1.77673756e-09,  1.77949367e-09],
       [-1.67643591e-12, -1.67784176e-12],
       [ 4.98963478e-16,  4.99267403e-16],
       [ 1.62776475e-03,  1.00162754e+00]])

In [141]:
xs

array([[[ 0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00, -3.93243171e-01],
        [-3.93243171e-01,  1.18857308e+00],
        [ 1.18857308e+00,  1.31976340e+00],
        [ 1.31976340e+00,  1.71569590e-01],
        [ 1.71569590e-01,  1.71741569e+00],
        [ 1.71741569e+00, -1.11175642e+00],
        [-1.11175642e+00,  1.52392777e+00],
        [ 1.52392777e+00,  1.21841198e+00],
        [ 1.21841198e+00,  3.06482545e+00],
        [ 3.06482545e+00,  1.40223573e+00]],

       [[ 0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00, -3.90318843e-01],
        [-3.90318843e-01,  1.36426001e+00],
        [ 1.36426001e+00, -1.91148603e+00],
        [-1.91148603e+00,  1.45473248e+00],
        [ 1.45473248e+00,  3.54907193e+00],
        [ 3.54907193e+00,  3.08185888e+00],
        [ 3.08185888e+00,  1.35645876e+01],
        [ 1.35645876e+01,  7.21081178e+00],
        [ 7.21081178e+00,  1.83762350e+02],
        [ 1.83762350e+02,  5.05944948e+01]],

       [[ 0.00000000e+00,  0

Likely no need to enforce sparsity, but maybe it'd be helpful for very large bases/not enough data

In [516]:
W = np.random.rand(100, 40)
np.max(np.sqrt(linalg.eigvals(2*W.T@W*2)))

62.712079119103414

In [471]:
np.sqrt(100)+np.sqrt(40)+np.sqrt(2*np.log(1/0.5))

17.501965342852234