In [1]:
from torchkf import *
import torch
from torch import distributions as td
from torch import nn
import numpy as np
import plotly.graph_objs as go
import plotly.express as px
import pykeos as pk
# from easystate import EasyState

pyunicorn: Package netCDF4 could not be loaded. Some functionality in class Data might not be available!
pyunicorn: Package netCDF4 could not be loaded. Some functionality in class NetCDFDictionary might not be available!


In [2]:
class JansenRit(nn.Module): 
    
    def __init__(self, dt=0.001, prior_scale=np.exp(-4)): 
        super().__init__()
        
        self.params = nn.ParameterDict(
            {k:nn.Parameter(torch.tensor(np.log(v).astype(np.float32), requires_grad=True))
            for k, v in dict(
                A=3.25, 
                B=22., 
                a_inv=10., # aka Te
                b_inv=20., # ala Ti
                C=135., 
                C1rep=1., 
                C2rep=0.8, 
                C3rep=0.25, 
                C4rep=0.25, 
                vmax=5., 
                v0=6., 
                r=0.56,
            ).items()
        })

        
        self.priors = dict()
        for k, v in self.params.items(): 
            if len(v.shape) == 0: 
#                 self.priors[k] = td.LogNormal(torch.log(torch.tensor([v.detach().item()]) - prior_scale / 2.), torch.tensor([[prior_scale]]))
                self.priors[k] = Gaussian(torch.tensor([np.exp(v.detach().item())]), torch.tensor([[prior_scale]]))

                
        self.dt = dt
    
    def prior_log_prob(self): 
        prob = 0
        for k in self.params.keys(): 
            prob += self.priors[k].log_prob(torch.exp(self.params[k].expand(self.priors[k].event_shape)))
        return prob
    
    def ode(self, x): 
        x0, x1, x2, x3, x4, x5, p = [x[..., i] for i in range(x.shape[-1])]
        A,B,a_inv,b_inv,C,C1rep,C2rep,C3rep,C4rep,vmax,v0,r = (torch.exp(self.params[k]) for k in ('A','B','a_inv','b_inv','C','C1rep','C2rep','C3rep','C4rep','vmax','v0','r'))
        a, b = 1e3/a_inv, 1e3/b_inv # Convert to seconds
        C1,C2,C3,C4 = [Crep * C for Crep in [C1rep,C2rep,C3rep,C4rep]]

        def sigm(x): 
            return vmax / (1. + torch.exp(r * (v0 - x)))
                                            
        return torch.stack([
            x3, x4, x5, 
            A * a * sigm(x1 - x2) - 2. * a * x3 - a**2. * x0,
            A * a * (p + C2  * sigm(C1 * x0)) - 2. * a * x4 - a**2. * x1, 
            B * b * C4 * sigm(C3 * x0) - 2. * b * x5 - b ** 2. * x2
        ], dim=-1)
    
    def __call__(self, x): 
        return x[..., :-1] + self.dt * self.ode(x)
    
class Ueda: 
    def __init__(self, dt=0.001): 
        self.dt = dt
        
    def ode(self, x): 
        x,y,z = [x[...,i] for i in range(x.shape[-1])]
        return torch.stack([
            y,
            - x ** 3 - 0.05 * y + 7.5 * torch.sin(z),
            torch.ones(y.shape)
        ], dim=-1)

    def __call__(self, x): 
        return x + self.dt * self.ode(x)
    
class Rossler: 
    def __init__(self, dt=0.001): 
        self.dt = dt
        
    def ode(self, x): 
        x,y,z = [x[...,i] for i in range(x.shape[-1])]
        return torch.stack([
            - y - z,
            x + 0.1 * y,
            0.1 + z * (x - 14.)
        ], dim=-1)

    def __call__(self, x): 
        return x + self.dt * self.ode(x)

In [6]:
jansen_rit = JansenRit(0.001)
ueda = Ueda(0.008)

input_model = GaussianSystem(
    state_dim=3, 
    obs_dim=1, 
    fwd_transform=LinearizedTransform(ueda), 
    obs_transform=LinearTransform(torch.tensor([[40., 0., 0.]]), torch.tensor([240.])), 
    initial_state_mean=torch.tensor([2., 2., 0.]), 
    initial_state_cov=np.exp(-128) * torch.eye(3),
    process_noise_cov=np.exp(-32) * torch.eye(3), 
    obs_noise_cov=np.exp(2) * torch.eye(1),
)

cortical_model = GaussianSystem(
    fwd_transform=LinearizedTransform(jansen_rit), 
    obs_transform=LinearTransform(torch.tensor([[0., 1., -1., 0., 0., 0.]])), 
    input_dim=1,
    state_dim=6, 
    obs_dim=1, 
    obs_noise_cov=np.exp(2) * torch.eye(1), 
    process_noise_cov=np.exp(2) * torch.eye(6), 
    initial_state_mean=torch.zeros(6), 
    initial_state_cov=np.exp(-6) * torch.eye(6),
)

In [7]:
state_space = HierarchicalDynamicalModel(systems=[input_model, cortical_model]) 

In [8]:
traj = state_space.blind_forecast(4500)
# plot_traj(traj[2]['y'][None])
plot_traj(traj[0]['y'][None]).show()
plot_traj(traj[1]['y'][None]).show()

AttributeError: module 'torch' has no attribute 'concat'

In [None]:
y = state_space.sample(4500)
px.line(y=y[0]['y'][:, 0].detach()).show()
px.line(y=y[1]['y'][:, 0].detach()).show()

In [None]:

input_model_ = GaussianSystem(
    state_dim=3, 
    obs_dim=1, 
    fwd_transform=LinearizedTransform(ueda), 
    obs_transform=LinearTransform(torch.tensor([[40., 0., 0.]]), torch.tensor([240.])), 
    initial_state_mean=torch.tensor([1., 3., 0.]), 
    initial_state_cov=np.exp(0) * torch.eye(3),
    process_noise_cov=np.exp(-32) * torch.eye(3), 
    obs_noise_cov=np.exp(2) * torch.eye(1),
)

cortical_model_ = GaussianSystem(
    fwd_transform=LinearizedTransform(jansen_rit), 
    obs_transform=LinearTransform(torch.FloatTensor([[0., 1., -1., 0., 0., 0.]])), 
    input_dim=1,
    state_dim=6, 
    obs_dim=1, 
    obs_noise_cov=np.exp(4) * torch.eye(1), 
    process_noise_cov=np.exp(2) * torch.eye(6), 
    initial_state_mean=torch.zeros(6), 
    initial_state_cov=np.exp(-6) * torch.eye(6),
)

dec_state_space = HierarchicalDynamicalModel(systems=[input_model_, cortical_model_]) 

In [None]:
filter_traj = dec_state_space.filter(y[1]['y'][None], backward_pass=True)

In [None]:
fig = plot_traj(Gaussian(filter_traj[0]['x_backward'].mean[..., :2], filter_traj[0]['x_backward'].covariance_matrix[..., :2, :2]))
for i in range(2):     
    fig.add_scatter(y=y[0]['x'][:, i], line_color=px.colors.qualitative.T10[i], line_dash='dash', name=f'x[{i}]')
fig.update_layout(template='plotly_white', 
        title={
         'text':'hidden states (ueda)',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=800)
fig.update_yaxes(title_text='states (a.u.)', range=(-15, 15))
fig.update_xaxes(title_text='time')
fig.show()

In [None]:
r=pk.Rossler()

In [None]:
fig = pk.SysWrapper(filter_traj[0]['x_backward'].mean[0, :, :2].detach()).plot(line_color=px.colors.qualitative.T10[1], opacity=0.7, show=False, )
pk.SysWrapper(y[0]['x'][:, :2]).plot(fig=fig, line_color=px.colors.qualitative.T10[0], opacity=0.7, show=False)
fig.update_layout(template='simple_white', width=800, height=600)
fig

In [None]:
fig = plot_traj(filter_traj[0]['y_prior'])
fig.add_scatter(y=y[0]['y'][:, 0].detach(), line_color=px.colors.qualitative.T10[0], line_dash='dash', name='x[0]')
fig.update_layout(template='plotly_white', 
        title={
         'text':'ueda output',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=600)
fig.update_yaxes(title_text='states (hz)', range=(0, 450))
fig.update_xaxes(title_text='time')
fig.show()

In [None]:

fig = plot_traj(Gaussian(filter_traj[1]['x_backward'].mean[..., :3], filter_traj[1]['x_backward'].covariance_matrix[..., :3, :3]))
for i in range(3): 
    fig.add_scatter(y=y[1]['x'][:, i].detach(), line_color=px.colors.qualitative.T10[i], line_dash='dash', name=f'x[{i}]')
fig.update_layout(template='plotly_white', 
        title={
         'text':'causal states',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=800)
fig.update_yaxes(title_text='states (hz)', range=(-75, 75))
fig.update_xaxes(title_text='time')
fig.show()

In [None]:

fig = plot_traj(filter_traj[1]['y_prior'])
fig.add_scatter(y=y[1]['y'][:, 0].detach(), line_color=px.colors.qualitative.T10[1], line_dash='dash', name=f'x[{0}]', opacity=0.7)
fig.update_layout(template='plotly_white', 
        title={
         'text':'jansen-rit output',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=800)
fig.update_yaxes(title_text='states (hz)', range=(-70, 70))
fig.update_xaxes(title_text='time')
fig.show()