In [None]:
from dempy import *
import sympy
import torch
from torch import distributions as td
from torch import nn
import numpy as np
import plotly.graph_objs as go
import plotly.express as px
import pykeos as pk
# from easystate import EasyState

In [None]:
class KeyAccess: 
    def __init__(self, keys):
        self._keys = {k:keys.index(k) for k in keys}
        self._it   = None
        
    def __call__(self, it): 
        self._it = it
        return self
        
    def __getitem__(self, key):
        try:
            return self._it[self._keys[key]]
        except KeyError: 
            return self._it[key]
        
    def __setitem__(self, key, value): 
        try: 
            self._it[self._keys[key]] = value 
        except KeyError:
            self._it[key] = value 

In [None]:
from collections import OrderedDict
from functools import partial

class JansenRit:
    def __init__(self): 
        self._params_init = OrderedDict(A=3.25, B=22., a_inv=10., b_inv=20., C=135., C1rep=1., C2rep=0.8, 
                        C3rep=0.25, C4rep=0.25, vmax=5., v0=6., r=0.56)
        self._params_keys = [*self._params_init.keys()]
        self._dt = 1e-3

        self.params = np.log(np.array([*self._params_init.values()]).reshape((-1, 1)))
        
        self.m = 1
        self.p = len(self._params_init)
        self.l = 1
        self.n = 6
        
        self.pE = self.params
        self.pC = np.ones(len(self._params_init)) * np.exp(-32)
        self.V  = np.exp(8) * np.ones(self.m)
        self.W  = np.exp(16) * np.ones(self.n)

    @staticmethod
    def sigmoid(x, vmax, v0, r): 
        return vmax / (1. + np.exp(r * (v0 - x)))
    
    def fsymb(self, x, v, p): 
        x, v, p = x[:, 0], v[:, 0], p[:, 0]
        A,B,a_inv,b_inv,C,C1rep,C2rep,C3rep,C4rep,vmax,v0,r = [sympy.exp(_) for _ in p]

        def sigm(x): 
            return vmax / (1. + sympy.exp(r * (v0 - x)))
        
        a, b = 1e3/a_inv, 1e3/b_inv
        C1, C2, C3, C4 = C*C1rep, C*C2rep, C*C3rep, C*C4rep

        x0 = x[3]
        x1 = x[4]
        x2 = x[5]
        x3 = A * a * sigm(x[1] - x[2]) - 2 * a * x[3] - x[0] * a**2
        x4 = A * a * (v[0] + C2 * sigm(C1 * x[0])) - 2 * a * x[4] - x[1] * a ** 2
        x5 = B * b * C4 * sigm(C3 * x[0]) - 2 * b * x[5] -  x[2] * b ** 2

        return sympy.Matrix([x0, x1, x2, x3, x4, x5]) * self._dt
    
    
#     @staticmethod
    def g(self, x, v, p): 
        return (x[1] - x[2]).reshape((1, 1))

In [None]:
x = sympy.MatrixSymbol('x', 6, 1)
v = sympy.MatrixSymbol('v', 1, 1)
p = sympy.MatrixSymbol('p', 12, 1)

In [None]:
jrit = JansenRit()
func = sympy.lambdify((x, v, p), jrit.fsymb(x, v, p), 'numpy')

In [None]:
class Ueda: 
    def __init__(self): 
        self._dt = 1e-3
        
    def f(self, x, v, p): 
        x,y,z = [x[...,i] for i in range(x.shape[-1])]
        return np.stack([
            x[1],
            - x[0] ** 3 - 0.05 * x[1] + 7.5 * np.sin(x[2]),
            1
        ], dim=-1) * self._dt

    def g(self, x, v, p): 
        return np.array([40 * x[0] + 240.])
    
# class Rossler: 
#     def __init__(self, dt=0.001): 
#         self.dt = dt
        
#     def ode(self, x): 
#         x,y,z = [x[...,i] for i in range(x.shape[-1])]
#         return torch.stack([
#             - y - z,
#             x + 0.1 * y,
#             0.1 + z * (x - 14.)
#         ], dim=-1)

#     def __call__(self, x): 
#         return x + self.dt * self.ode(x)

In [None]:
jrit = JansenRit()
jrit = GaussianModel(
    fsymb=jrit.fsymb, 
    g=jrit.g,
    m=jrit.m, 
    n=jrit.n, 
    l=jrit.l,
    p=jrit.p,
    pE=jrit.pE,
    pC=jrit.pC,
    V=jrit.V, 
    W=jrit.W
)
ueda = Ueda()
ueda = GaussianModel(
    f=jrit.fsymb, 
    g=jrit.g,
    m=jrit.m, 
    n=jrit.n, 
    l=jrit.l,
    p=jrit.p,
    pE=jrit.pE,
    pC=jrit.pC,
    V=jrit.V, 
    W=jrit.W
)

hdm = HierarchicalGaussianModel(jrit)

In [None]:
gen = DEMInversion(hdm).generate(1000)

In [None]:
px.line(y=gen.v[:, 0, 0, 0]);

In [None]:
DEMInversion(hdm).run(gen.v[:, 0, 0, 0, None])

In [None]:
jansen_rit = JansenRit(0.001)
ueda = Ueda(0.008)

input_model = GaussianSystem(
    state_dim=3, 
    obs_dim=1, 
    fwd_transform=LinearizedTransform(ueda), 
    obs_transform=LinearTransform(torch.tensor([[40., 0., 0.]]), torch.tensor([240.])), 
    initial_state_mean=torch.tensor([2., 2., 0.]), 
    initial_state_cov=np.exp(-128) * torch.eye(3),
    process_noise_cov=np.exp(-32) * torch.eye(3), 
    obs_noise_cov=np.exp(2) * torch.eye(1),
)

cortical_model = GaussianSystem(
    fwd_transform=LinearizedTransform(jansen_rit), 
    obs_transform=LinearTransform(torch.tensor([[0., 1., -1., 0., 0., 0.]])), 
    input_dim=1,
    state_dim=6, 
    obs_dim=1, 
    obs_noise_cov=np.exp(2) * torch.eye(1), 
    process_noise_cov=np.exp(2) * torch.eye(6), 
    initial_state_mean=torch.zeros(6), 
    initial_state_cov=np.exp(-6) * torch.eye(6),
)

In [None]:
state_space = HierarchicalDynamicalModel(systems=[input_model, cortical_model]) 

In [None]:
traj = state_space.blind_forecast(4500)
# plot_traj(traj[2]['y'][None])
plot_traj(traj[0]['y'][None]).show()
plot_traj(traj[1]['y'][None]).show()

In [None]:
y = state_space.sample(4500)
px.line(y=y[0]['y'][:, 0].detach()).show()
px.line(y=y[1]['y'][:, 0].detach()).show()

In [None]:

input_model_ = GaussianSystem(
    state_dim=3, 
    obs_dim=1, 
    fwd_transform=LinearizedTransform(ueda), 
    obs_transform=LinearTransform(torch.tensor([[40., 0., 0.]]), torch.tensor([240.])), 
    initial_state_mean=torch.tensor([1., 3., 0.]), 
    initial_state_cov=np.exp(0) * torch.eye(3),
    process_noise_cov=np.exp(-32) * torch.eye(3), 
    obs_noise_cov=np.exp(2) * torch.eye(1),
)

cortical_model_ = GaussianSystem(
    fwd_transform=LinearizedTransform(jansen_rit), 
    obs_transform=LinearTransform(torch.FloatTensor([[0., 1., -1., 0., 0., 0.]])), 
    input_dim=1,
    state_dim=6, 
    obs_dim=1, 
    obs_noise_cov=np.exp(4) * torch.eye(1), 
    process_noise_cov=np.exp(2) * torch.eye(6), 
    initial_state_mean=torch.zeros(6), 
    initial_state_cov=np.exp(-6) * torch.eye(6),
)

dec_state_space = HierarchicalDynamicalModel(systems=[input_model_, cortical_model_]) 

In [None]:
filter_traj = dec_state_space.filter(y[1]['y'][None], backward_pass=True)

In [None]:
fig = plot_traj(Gaussian(filter_traj[0]['x_backward'].mean[..., :2], filter_traj[0]['x_backward'].covariance_matrix[..., :2, :2]))
for i in range(2):     
    fig.add_scatter(y=y[0]['x'][:, i], line_color=px.colors.qualitative.T10[i], line_dash='dash', name=f'x[{i}]')
fig.update_layout(template='plotly_white', 
        title={
         'text':'hidden states (ueda)',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=800)
fig.update_yaxes(title_text='states (a.u.)', range=(-15, 15))
fig.update_xaxes(title_text='time')
fig.show()

In [None]:
r=pk.Rossler()

In [None]:
fig = pk.SysWrapper(filter_traj[0]['x_backward'].mean[0, :, :2].detach()).plot(line_color=px.colors.qualitative.T10[1], opacity=0.7, show=False, )
pk.SysWrapper(y[0]['x'][:, :2]).plot(fig=fig, line_color=px.colors.qualitative.T10[0], opacity=0.7, show=False)
fig.update_layout(template='simple_white', width=800, height=600)
fig

In [None]:
fig = plot_traj(filter_traj[0]['y_prior'])
fig.add_scatter(y=y[0]['y'][:, 0].detach(), line_color=px.colors.qualitative.T10[0], line_dash='dash', name='x[0]')
fig.update_layout(template='plotly_white', 
        title={
         'text':'ueda output',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=600)
fig.update_yaxes(title_text='states (hz)', range=(0, 450))
fig.update_xaxes(title_text='time')
fig.show()

In [None]:

fig = plot_traj(Gaussian(filter_traj[1]['x_backward'].mean[..., :3], filter_traj[1]['x_backward'].covariance_matrix[..., :3, :3]))
for i in range(3): 
    fig.add_scatter(y=y[1]['x'][:, i].detach(), line_color=px.colors.qualitative.T10[i], line_dash='dash', name=f'x[{i}]')
fig.update_layout(template='plotly_white', 
        title={
         'text':'causal states',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=800)
fig.update_yaxes(title_text='states (hz)', range=(-75, 75))
fig.update_xaxes(title_text='time')
fig.show()

In [None]:

fig = plot_traj(filter_traj[1]['y_prior'])
fig.add_scatter(y=y[1]['y'][:, 0].detach(), line_color=px.colors.qualitative.T10[1], line_dash='dash', name=f'x[{0}]', opacity=0.7)
fig.update_layout(template='plotly_white', 
        title={
         'text':'jansen-rit output',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=800)
fig.update_yaxes(title_text='states (hz)', range=(-70, 70))
fig.update_xaxes(title_text='time')
fig.show()