In [1]:
from torchkf import *
import torch
from torch import distributions as td
from torch import nn
import numpy as np
import plotly.graph_objs as go
import plotly.express as px
import pykeos as pk
from easystate import EasyState

pyunicorn: Package netCDF4 could not be loaded. Some functionality in class Data might not be available!
pyunicorn: Package netCDF4 could not be loaded. Some functionality in class NetCDFDictionary might not be available!


In [2]:
e = EasyState('eeg-joints-subj-4-run-1')
epochs_eeg = e['epochs_eeg_no_csd'].load_data()
eeg_sig = epochs_eeg.resample(500).get_data()[:, epochs_eeg.ch_names.index('C1')] * 1e3
eeg_sig = torch.FloatTensor(eeg_sig[:4, :1000, None])

EasyState:: Loading state from /home/yop/.easystate/eeg-joints-subj-4-run-1
Loading data for 22 events and 4701 original time points ...


In [3]:
class JansenRit(nn.Module): 
    
    def __init__(self, dt=0.001, prior_scale=np.exp(-4)): 
        super().__init__()
        
        self.params = nn.ParameterDict(
            {k:nn.Parameter(torch.tensor(np.log(v).astype(np.float32), requires_grad=True))
            for k, v in dict(
                A=3.25, 
                B=22., 
                a_inv=10., 
                b_inv=20., 
                C=135., 
                C1rep=1., 
                C2rep=0.8, 
                C3rep=0.25, 
                C4rep=0.25, 
                vmax=5., 
                v0=6., 
                r=0.56,
            ).items()
        })

        
        self.priors = dict()
        for k, v in self.params.items(): 
            if len(v.shape) == 0: 
#                 self.priors[k] = td.LogNormal(torch.log(torch.tensor([v.detach().item()]) - prior_scale / 2.), torch.tensor([[prior_scale]]))
                self.priors[k] = Gaussian(torch.tensor([np.exp(v.detach().item())]), torch.tensor([[prior_scale]]))

                
        self.dt = dt
    
    def prior_log_prob(self): 
        prob = 0
        for k in self.params.keys(): 
            prob += self.priors[k].log_prob(torch.exp(self.params[k].expand(self.priors[k].event_shape)))
        return prob
    
    def ode(self, x, p): 
        x0, x1, x2, x3, x4, x5 = [x[..., i] for i in range(x.shape[-1])]
        A,B,a_inv,b_inv,C,C1rep,C2rep,C3rep,C4rep,vmax,v0,r = (torch.exp(self.params[k]) for k in ('A','B','a_inv','b_inv','C','C1rep','C2rep','C3rep','C4rep','vmax','v0','r'))
        a, b = 1e3/a_inv, 1e3/b_inv # Convert to seconds
        C1,C2,C3,C4 = [Crep * C for Crep in [C1rep,C2rep,C3rep,C4rep]]

        def sigm(x): 
            return vmax / (1. + torch.exp(r * (v0 - x)))
                                            
        return torch.stack([
            x3, x4, x5, 
            A * a * sigm(x1 - x2) - 2 * a * x3 - a**2 * x0,
            A * a * (p + C2  * sigm(C1 * x0)) - 2 * a * x4 - a**2 * x1, 
            B * b * C4 * sigm(C3 * x0) - 2 * b * x5 - b ** 2 * x2
        ], dim=-1)
    
    def __call__(self, x, p=200): 
        return x + self.dt * self.ode(x, p)

In [15]:
jansen_rit = JansenRit(0.002)
f = LinearizedTransform(jansen_rit)
output_scale = torch.tensor(1e-3, requires_grad=False)
g = LinearTransform(torch.exp(output_scale) * torch.FloatTensor([[0., 1., -1., 0., 0., 0.]]))

In [16]:
state_space = GaussianStateSpaceModel(fwd_transform=f, obs_transform=g, state_dim=6, obs_dim=1, observation_noise_cov=0.1 * torch.eye(1), process_noise_cov=10 * torch.eye(6)) 

In [17]:
traj = state_space.blind_forecast(2000)
plot_traj(traj['y'][None])

In [18]:
optim = torch.optim.SGD(params=[*jansen_rit.parameters()], lr=1e-3)

In [13]:
jansen_rit.params['C']

Parameter containing:
tensor(-658.5491, requires_grad=True)

In [19]:
for i in range(10): 
    with torch.no_grad(): 
        filter_traj = state_space.filter(eeg_sig, backward_pass=True)
    plot_traj(filter_traj['x_backward']).show()
    fig = plot_traj(filter_traj['y_prior'])
    fig.add_scatter(y=eeg_sig[0, :, 0], mode='lines')
    fig.show()

    if i < 1: 
        with torch.no_grad():
            state_space.fit_params(eeg_sig, filter_traj)
#         print(state_space._parameters) 

    optim.zero_grad()

    fe = state_space.complete_data_likelihood(eeg_sig, filter_traj)
    loss = -fe['ll'].mean()-jansen_rit.prior_log_prob()
    loss.backward()
    optim.step()

    print(*(f'{k}: {np.exp(v.detach())}\n' for k, v in jansen_rit.params.items()))
    print(fe)

Filter: 100%|██████████| 1000/1000 [00:08<00:00, 113.85it/s]
Smooth: 100%|██████████| 1000/1000 [00:00<00:00, 4222.22it/s]


100%|██████████| 1000/1000 [00:01<00:00, 638.94it/s]
Filter:   1%|          | 11/1000 [00:00<00:09, 103.33it/s]

A: 4.451501846313477
 B: 21.89960289001465
 C: 134.36947631835938
 C1rep: 0.9995848536491394
 C2rep: 0.8064599633216858
 C3rep: 0.24807946383953094
 C4rep: 0.2488592118024826
 a_inv: 26279.06640625
 b_inv: 19.937780380249023
 r: 0.2094164937734604
 v0: 2.261995315551758
 vmax: 6.794973373413086

{'ll': tensor([13.8052, 13.2884, 10.1221, 13.3905], grad_fn=<AddBackward0>), 'll_x': tensor([7.9563, 7.6482, 4.5618, 7.7314], grad_fn=<DivBackward0>), 'll_y': tensor([5.8488, 5.6402, 5.5602, 5.6591])}


Filter:   2%|▏         | 22/1000 [00:00<00:10, 97.57it/s] 


ValueError: Expected parameter covariance_matrix (Tensor of shape (4, 6, 6)) of distribution MultivariateNormal(loc: torch.Size([4, 6]), covariance_matrix: torch.Size([4, 6, 6])) to satisfy the constraint PositiveDefinite(), but found invalid values:
tensor([[[3.2610e-02, 2.4853e-01, 2.5322e-01, 6.4573e-03, 8.3311e+00,
          1.2080e+01],
         [2.4853e-01, 3.5398e+00, 3.5356e+00, 4.8916e-02, 1.0618e+02,
          1.0235e+02],
         [2.5322e-01, 3.5356e+00, 3.5345e+00, 4.9349e-02, 1.0548e+02,
          1.0376e+02],
         [6.4573e-03, 4.8916e-02, 4.9349e-02, 2.4775e-01, 2.0017e+00,
          2.4092e+00],
         [8.3311e+00, 1.0618e+02, 1.0548e+02, 2.0017e+00, 5.1487e+03,
          4.5584e+03],
         [1.2080e+01, 1.0235e+02, 1.0376e+02, 2.4092e+00, 4.5584e+03,
          5.4822e+03]],

        [[3.2640e-02, 2.4886e-01, 2.5348e-01, 6.4620e-03, 8.3332e+00,
          1.2024e+01],
         [2.4886e-01, 3.5454e+00, 3.5405e+00, 4.8973e-02, 1.0626e+02,
          1.0192e+02],
         [2.5348e-01, 3.5405e+00, 3.5387e+00, 4.9393e-02, 1.0554e+02,
          1.0329e+02],
         [6.4620e-03, 4.8973e-02, 4.9393e-02, 2.4775e-01, 2.0018e+00,
          2.3984e+00],
         [8.3332e+00, 1.0626e+02, 1.0554e+02, 2.0018e+00, 5.1476e+03,
          4.5392e+03],
         [1.2024e+01, 1.0192e+02, 1.0329e+02, 2.3984e+00, 4.5392e+03,
          5.4342e+03]],

        [[3.2665e-02, 2.4892e-01, 2.5356e-01, 6.4655e-03, 8.3342e+00,
          1.2054e+01],
         [2.4892e-01, 3.5428e+00, 3.5381e+00, 4.8960e-02, 1.0613e+02,
          1.0207e+02],
         [2.5356e-01, 3.5381e+00, 3.5366e+00, 4.9385e-02, 1.0542e+02,
          1.0346e+02],
         [6.4655e-03, 4.8960e-02, 4.9385e-02, 2.4775e-01, 2.0014e+00,
          2.4035e+00],
         [8.3342e+00, 1.0613e+02, 1.0542e+02, 2.0014e+00, 5.1436e+03,
          4.5448e+03],
         [1.2054e+01, 1.0207e+02, 1.0346e+02, 2.4035e+00, 4.5448e+03,
          5.4535e+03]],

        [[3.2665e-02, 2.4887e-01, 2.5348e-01, 6.4655e-03, 8.3332e+00,
          1.2019e+01],
         [2.4887e-01, 3.5421e+00, 3.5371e+00, 4.8954e-02, 1.0613e+02,
          1.0178e+02],
         [2.5348e-01, 3.5371e+00, 3.5352e+00, 4.9372e-02, 1.0540e+02,
          1.0314e+02],
         [6.4655e-03, 4.8954e-02, 4.9372e-02, 2.4775e-01, 2.0013e+00,
          2.3969e+00],
         [8.3332e+00, 1.0613e+02, 1.0540e+02, 2.0013e+00, 5.1435e+03,
          4.5344e+03],
         [1.2019e+01, 1.0178e+02, 1.0314e+02, 2.3969e+00, 4.5344e+03,
          5.4263e+03]]])

In [548]:
state_space.blind_forecast(500)

{'ll': tensor([6824.3711, 6243.5498, 5818.5684, 6428.8584], grad_fn=<AddBackward0>),
 'll_x': tensor([2378.6331, 1907.2921, 1601.6952, 2107.6965], grad_fn=<AddBackward0>),
 'll_y': tensor([4445.7383, 4336.2578, 4216.8730, 4321.1616])}

In [477]:
[v.grad for v in jansen_rit.params.values()], output_scale.grad
[v.mean.grad if isinstance(v, Gaussian) else v.grad for v in filter_traj.values()]

([tensor(-3308.4021),
  tensor(-755.3165),
  tensor(-183204.4062),
  tensor(176.2482),
  tensor(389.9962),
  tensor(-48.0745),
  tensor(-1009.1649),
  tensor(-8064.1157),
  tensor(5933.9941),
  tensor(4848.9268),
  tensor(8120.3218),
  tensor(-2740.9648)],
 None)

In [375]:
print(*(f'{k}: {np.exp(v.detach())}\n' for k, v in jansen_rit.params.items()))

A: 2.4594669342041016
 B: 18.949174880981445
 C: 87.189697265625
 C1rep: 0.991181492805481
 C2rep: 0.6542659997940063
 C3rep: 0.23125238716602325
 C4rep: 0.2153315395116806
 a_inv: 10.571099281311035
 b_inv: 22.291934967041016
 r: 0.5475118160247803
 v0: 6.398193359375
 vmax: 3.522102117538452



In [278]:
jansen_rit.priors['a_inv'].log_prob

<bound method TransformedDistribution.log_prob of LogNormal()>

In [275]:
fe

{'ll': tensor([86054.9062], grad_fn=<AddBackward0>),
 'll_x': tensor([63631.9297], grad_fn=<AddBackward0>),
 'll_y': tensor([22422.9727])}

In [46]:
g.fit(filter_traj['x_backward'].mean, eeg_sig)

In [202]:
print(*(f'{k}: {v}\n' for k, v in jansen_rit.params.items()))

A: -44.28509521484375
 B: 20.21320915222168
 C: 133.59674072265625
 Crep: Parameter containing:
tensor([  -8.1617, -131.3714, -140.6812, -156.9876], requires_grad=True)
 a_inv: 5837.216796875
 b_inv: 1486.7198486328125
 r: -11.286907196044922
 v0: 12.293356895446777
 vmax: -24.009294509887695



Parameter containing:
tensor(276.1747, requires_grad=True)

In [42]:
plot_traj(filter_traj['x_backward'])