In [1]:
from torchkf import *
import torch
from torch import distributions as td
from torch import nn
import numpy as np
import plotly.graph_objs as go
import plotly.express as px
import pykeos as pk
# from easystate import EasyState

pyunicorn: Package netCDF4 could not be loaded. Some functionality in class Data might not be available!
pyunicorn: Package netCDF4 could not be loaded. Some functionality in class NetCDFDictionary might not be available!


In [2]:
class JansenRit(nn.Module): 
    
    def __init__(self, dt=0.001, prior_scale=np.exp(-4)): 
        super().__init__()
        
        self.params = nn.ParameterDict(
            {k:nn.Parameter(torch.tensor(np.log(v).astype(np.float32), requires_grad=True))
            for k, v in dict(
                A=3.25, 
                B=22., 
                a_inv=10., # aka Te
                b_inv=20., # ala Ti
                C=135., 
                C1rep=1., 
                C2rep=0.8, 
                C3rep=0.25, 
                C4rep=0.25, 
                vmax=5., 
                v0=6., 
                r=0.56,
            ).items()
        })

        
        self.priors = dict()
        for k, v in self.params.items(): 
            if len(v.shape) == 0: 
#                 self.priors[k] = td.LogNormal(torch.log(torch.tensor([v.detach().item()]) - prior_scale / 2.), torch.tensor([[prior_scale]]))
                self.priors[k] = Gaussian(torch.tensor([np.exp(v.detach().item())]), torch.tensor([[prior_scale]]))

                
        self.dt = dt
    
    def prior_log_prob(self): 
        prob = 0
        for k in self.params.keys(): 
            prob += self.priors[k].log_prob(torch.exp(self.params[k].expand(self.priors[k].event_shape)))
        return prob
    
    def ode(self, x): 
        x0, x1, x2, x3, x4, x5, p = [x[..., i] for i in range(x.shape[-1])]
        A,B,a_inv,b_inv,C,C1rep,C2rep,C3rep,C4rep,vmax,v0,r = (torch.exp(self.params[k]) for k in ('A','B','a_inv','b_inv','C','C1rep','C2rep','C3rep','C4rep','vmax','v0','r'))
        a, b = 1e3/a_inv, 1e3/b_inv # Convert to seconds
        C1,C2,C3,C4 = [Crep * C for Crep in [C1rep,C2rep,C3rep,C4rep]]

        def sigm(x): 
            return vmax / (1. + torch.exp(r * (v0 - x)))
                                            
        return torch.stack([
            x3, x4, x5, 
            A * a * sigm(x1 - x2) - 2 * a * x3 - a**2 * x0,
            A * a * (p + C2  * sigm(C1 * x0)) - 2 * a * x4 - a**2 * x1, 
            B * b * C4 * sigm(C3 * x0) - 2 * b * x5 - b ** 2 * x2
        ], dim=-1)
    
    def __call__(self, x): 
        return x[..., :-1] + self.dt * self.ode(x)
    
class Ueda: 
    def __init__(self, dt=0.001): 
        self.dt = dt
        
    def ode(self, x): 
        x,y,z = [x[...,i] for i in range(x.shape[-1])]
        return torch.stack([
            y,
            - x ** 3 - 0.05 * y + 7.5 * torch.sin(z),
            torch.ones(y.shape)
        ], dim=-1)

    def __call__(self, x): 
        return x + self.dt * self.ode(x)
    
class Rossler: 
    def __init__(self, dt=0.001): 
        self.dt = dt
        
    def ode(self, x): 
        x,y,z = [x[...,i] for i in range(x.shape[-1])]
        return torch.stack([
            - y - z,
            x + 0.1 * y,
            0.1 + z * (x - 14.)
        ], dim=-1)

    def __call__(self, x): 
        return x + self.dt * self.ode(x)

In [102]:
jansen_rit = JansenRit(0.002)
ueda = Ueda(0.008)

input_model = GaussianSystem(
    state_dim=3, 
    obs_dim=1, 
    fwd_transform=LinearizedTransform(ueda), 
    obs_transform=LinearTransform(torch.tensor([[40., 0., 0.]]), torch.tensor([240.])), 
    initial_state_mean=torch.tensor([2., 2., 0.]), 
    initial_state_cov=np.exp(-128) * torch.eye(3),
    process_noise_cov=np.exp(-16) * torch.eye(3), 
    obs_noise_cov=np.exp(3) * torch.eye(1),
)

cortical_model = GaussianSystem(
    fwd_transform=LinearizedTransform(jansen_rit), 
    obs_transform=LinearTransform(torch.FloatTensor([[0., 1., -1., 0., 0., 0.]])), 
    input_dim=1,
    state_dim=6, 
    obs_dim=1, 
    obs_noise_cov=np.exp(2) * torch.eye(1), 
    process_noise_cov=np.exp(2) * torch.eye(6), 
    initial_state_mean=torch.zeros(6), 
    initial_state_cov=np.exp(-6) * torch.eye(6),
)

In [103]:
state_space = HierarchicalDynamicalModel(systems=[input_model, cortical_model]) 

In [104]:
traj = state_space.blind_forecast(1500)
# plot_traj(traj[2]['y'][None])
plot_traj(traj[0]['y'][None]).show()
plot_traj(traj[1]['y'][None]).show()

In [76]:
y = state_space.sample(1800)
px.line(y=y[0]['y'][:, 0].detach()).show()
px.line(y=y[1]['y'][:, 0].detach()).show()

In [82]:
input_model_ = GaussianSystem(
    state_dim=3, 
    obs_dim=1, 
    fwd_transform=LinearizedTransform(ueda), 
    obs_transform=LinearTransform(torch.tensor([[40., 0., 0.]]), torch.tensor([240.])), 
    initial_state_mean=torch.tensor([0., 0., 0.]), 
    initial_state_cov=np.exp(-1) * torch.eye(3),
    process_noise_cov=np.exp(-8) * torch.eye(3), 
    obs_noise_cov=np.exp(2) * torch.eye(1),
)

cortical_model_ = GaussianSystem(
    fwd_transform=LinearizedTransform(jansen_rit), 
    obs_transform=LinearTransform(torch.FloatTensor([[0., 1., -1., 0., 0., 0.]])), 
    input_dim=1,
    state_dim=6, 
    obs_dim=1, 
    obs_noise_cov=np.exp(4) * torch.eye(1), 
    process_noise_cov=np.exp(2) * torch.eye(6), 
    initial_state_mean=torch.zeros(6), 
    initial_state_cov=np.exp(-6) * torch.eye(6),
)

dec_state_space = HierarchicalDynamicalModel(systems=[input_model_, cortical_model_]) 

In [91]:
filter_traj = dec_state_space.filter(y[1]['y'][None], backward_pass=False)

Filter:   2%|▏         | 40/1800 [00:00<00:20, 86.10it/s] 


ValueError: Expected parameter covariance_matrix (Tensor of shape (1, 6, 6)) of distribution Gaussian(loc: torch.Size([1, 6]), covariance_matrix: torch.Size([1, 6, 6])) to satisfy the constraint PositiveDefinite(), but found invalid values:
tensor([[[nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan]]])

In [109]:
fig = plot_traj(Gaussian(filter_traj[0]['x_post'].mean[..., :2], filter_traj[0]['x_post'].covariance_matrix[..., :2, :2]))
for i in range(2):     
    fig.add_scatter(y=y[0]['x'][:, i], line_color=px.colors.qualitative.T10[i], line_dash='dash', name=f'x[{i}]')
fig.update_layout(template='plotly_white', 
        title={
         'text':'hidden states (ueda)',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=600)
fig.update_yaxes(title_text='states (a.u.)', range=(-15, 15))
fig.update_xaxes(title_text='time')
fig.show()

In [107]:
fig = plot_traj(filter_traj[0]['y_prior'])
fig.add_scatter(y=y[0]['y'][:, 0].detach(), line_color=px.colors.qualitative.T10[0], line_dash='dash', name='x[0]')
fig.update_layout(template='plotly_white', 
        title={
         'text':'ueda output',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=600)
fig.update_yaxes(title_text='states (hz)', range=(0, 450))
fig.update_xaxes(title_text='time')
fig.show()

In [55]:

fig = plot_traj(Gaussian(filter_traj[1]['x_post'].mean[..., :3], filter_traj[1]['x_post'].covariance_matrix[..., :3, :3]))
for i in range(3): 
    fig.add_scatter(y=y[1]['x'][:, i].detach(), line_color=px.colors.qualitative.T10[i], line_dash='dash', name=f'x[{i}]')
fig.update_layout(template='plotly_white', 
        title={
         'text':'causal states',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=800)
fig.update_yaxes(title_text='states (hz)', range=(-50, 50))
fig.update_xaxes(title_text='time')
fig.show()

In [108]:

fig = plot_traj(filter_traj[1]['y_prior'])
fig.add_scatter(y=y[1]['y'][:, 0].detach(), line_color=px.colors.qualitative.T10[0], line_dash='dash', name=f'x[{0}]')
fig.update_layout(template='plotly_white', 
        title={
         'text':'jansen-rit output',
         'x':0.5,
         'xanchor': 'center'}, height=600, width=600)
fig.update_yaxes(title_text='states (hz)', range=(-70, 70))
fig.update_xaxes(title_text='time')
fig.show()

In [25]:
optim = torch.optim.SGD(params=[*jansen_rit.parameters()], lr=1e-3)

In [26]:
for i in range(10): 
    with torch.no_grad(): 
        filter_traj = state_space.filter(eeg_sig, backward_pass=False)
#     plot_traj(filter_traj['x_backward']).show()
    fig = plot_traj(filter_traj[-1]['y_prior'])
    fig.add_scatter(y=eeg_sig[0, :, 0], mode='lines')
    fig.show()

#     if i < 1: 
#         with torch.no_grad():
#             state_space.fit_params(eeg_sig, filter_traj)
#         print(state_space._parameters) 

    optim.zero_grad()

    fe = state_space.complete_data_likelihood(eeg_sig, filter_traj)
    loss = -fe['ll'].mean()-jansen_rit.prior_log_prob()
    loss.backward()
    optim.step()

    print(*(f'{k}: {np.exp(v.detach())}\n' for k, v in jansen_rit.params.items()))
    print(fe)

Filter: 100%|██████████| 1000/1000 [00:07<00:00, 135.04it/s]


AttributeError: 'list' object has no attribute 'keys'

In [548]:
state_space.blind_forecast(500)

{'ll': tensor([6824.3711, 6243.5498, 5818.5684, 6428.8584], grad_fn=<AddBackward0>),
 'll_x': tensor([2378.6331, 1907.2921, 1601.6952, 2107.6965], grad_fn=<AddBackward0>),
 'll_y': tensor([4445.7383, 4336.2578, 4216.8730, 4321.1616])}

In [477]:
[v.grad for v in jansen_rit.params.values()], output_scale.grad
[v.mean.grad if isinstance(v, Gaussian) else v.grad for v in filter_traj.values()]

([tensor(-3308.4021),
  tensor(-755.3165),
  tensor(-183204.4062),
  tensor(176.2482),
  tensor(389.9962),
  tensor(-48.0745),
  tensor(-1009.1649),
  tensor(-8064.1157),
  tensor(5933.9941),
  tensor(4848.9268),
  tensor(8120.3218),
  tensor(-2740.9648)],
 None)

In [375]:
print(*(f'{k}: {np.exp(v.detach())}\n' for k, v in jansen_rit.params.items()))

A: 2.4594669342041016
 B: 18.949174880981445
 C: 87.189697265625
 C1rep: 0.991181492805481
 C2rep: 0.6542659997940063
 C3rep: 0.23125238716602325
 C4rep: 0.2153315395116806
 a_inv: 10.571099281311035
 b_inv: 22.291934967041016
 r: 0.5475118160247803
 v0: 6.398193359375
 vmax: 3.522102117538452



In [278]:
jansen_rit.priors['a_inv'].log_prob

<bound method TransformedDistribution.log_prob of LogNormal()>

In [275]:
fe

{'ll': tensor([86054.9062], grad_fn=<AddBackward0>),
 'll_x': tensor([63631.9297], grad_fn=<AddBackward0>),
 'll_y': tensor([22422.9727])}

In [46]:
g.fit(filter_traj['x_backward'].mean, eeg_sig)

In [202]:
print(*(f'{k}: {v}\n' for k, v in jansen_rit.params.items()))

A: -44.28509521484375
 B: 20.21320915222168
 C: 133.59674072265625
 Crep: Parameter containing:
tensor([  -8.1617, -131.3714, -140.6812, -156.9876], requires_grad=True)
 a_inv: 5837.216796875
 b_inv: 1486.7198486328125
 r: -11.286907196044922
 v0: 12.293356895446777
 vmax: -24.009294509887695



Parameter containing:
tensor(276.1747, requires_grad=True)

In [42]:
plot_traj(filter_traj['x_backward'])