In [1]:
%matplotlib notebook
from kalman_experiments.model_selection import fit_kf_parameters
from kalman_experiments import SSPE
from kalman_experiments.kalman import PerturbedP1DMatsudaKF
from kalman_experiments.models import MatsudaParams, SingleRhythmModel, collect, gen_ar_noise_coefficients, ArNoise
import numpy as np
# Setup oscillatioins model and generate oscillatory signal
# sim = SSPE.gen_sine_w_pink(1, 1000)
# a = gen_ar_noise_coefficients(alpha=1.5, order=20)
# kf = PerturbedP1DMatsudaKF(MatsudaParams(A=0.9, freq=1, sr=1000), q_s=2, psi=a, r_s=1)
# kf = fit_kf_parameters(sim.data, kf)


In [2]:
SRATE = 1000
DURATION = 100
FREQ_GT = 6
SIGNAL_SIGMA_GT = np.sqrt(10)
NOISE_SIGMA_GT = 1
A_GT = 0.99
ALPHA = 1.5
NOISE_AR_ORDER = 1000

mp = MatsudaParams(A_GT, FREQ_GT, SRATE)
oscillation_model = SingleRhythmModel(mp, sigma=SIGNAL_SIGMA_GT)
gt_states = collect(oscillation_model, DURATION * SRATE)

noise_model = ArNoise(x0=np.random.rand(NOISE_AR_ORDER), alpha=ALPHA, order=NOISE_AR_ORDER, s=NOISE_SIGMA_GT)
noise_sim = collect(noise_model, DURATION * SRATE)

data = np.real(gt_states) + noise_sim

In [5]:
a = gen_ar_noise_coefficients(alpha=ALPHA, order=30)
kf = PerturbedP1DMatsudaKF(MatsudaParams(A=0.99, freq=6, sr=SRATE), q_s=np.sqrt(10), psi=a, r_s=1, lambda_=0)

kf = fit_kf_parameters(data[:10000], kf, tol=1e-4)

Fitting KF parameters:   0%|                                    | 0/800 [00:00<?, ?it/s]

nll =  12347.756944608347


Fitting KF parameters:   0%|                            | 1/800 [00:01<25:03,  1.88s/it]

9013019.62513106028 8906748.684192098908 320335.12724912330037 9013751.430019968619
Amp=0.98884809864154812897, f=5.7240774108912404, q_s=3.1672978481018098803, r_s=1.0005425960768228537
nll =  12362.080527407423


Fitting KF parameters:   0%|                            | 2/800 [00:03<24:26,  1.84s/it]

9017933.396095070708 8911480.869730648806 320359.50769262318508 9018665.996711789691
Amp=0.9888337973977258694, f=5.721473227749647, q_s=3.1701762615976376578, r_s=1.0009417817585703836
nll =  12370.363747142432


Fitting KF parameters:   0%|                            | 3/800 [00:05<24:00,  1.81s/it]

9020722.541482900888 8914166.625159754476 320372.7028675776863 9021455.607625303109
Amp=0.9888256489231590132, f=5.719984989866871, q_s=3.1718143450690762589, r_s=1.0012583600752248923
nll =  12375.155011511099


Fitting KF parameters:   0%|▏                           | 4/800 [00:07<27:28,  2.07s/it]

9022280.14822387504 8915666.180335946536 320379.441430305757 9023013.48734586844
Amp=0.9888210634200840503, f=5.7191432179364305, q_s=3.1727340919267646587, r_s=1.0015272124402628282
nll =  12377.925952962487


Fitting KF parameters:   1%|▏                           | 5/800 [00:09<26:15,  1.98s/it]

9023124.232222657874 8916478.487235404895 320382.44647436005877 9023857.731738570106
Amp=0.98881854132951629464, f=5.718675832534735, q_s=3.173237814983617613, r_s=1.0017684367335656993
nll =  12379.52751006887


Fitting KF parameters:   1%|▏                           | 6/800 [00:11<25:44,  1.95s/it]

9023555.077921664197 8916892.7769691378235 320383.2903680031708 9024288.671744697199
Amp=0.98881721420862602, f=5.718425198272429, q_s=3.173500613624486551, r_s=1.0019936421469136988
nll =  12380.452090096847


Fitting KF parameters:   1%|▏                           | 7/800 [00:13<25:30,  1.93s/it]

9023746.689720963915 8917076.647846417844 320382.88246539512386 9024480.338844628077
Amp=0.98881657906422798347, f=5.718300003478403, q_s=3.173623908992911783, r_s=1.0022095442599583309
nll =  12380.984751961247


Fitting KF parameters:   1%|▏                           | 7/800 [00:15<29:09,  2.21s/it]

9023799.870579449167 8917127.191092524968 320381.7491711483142 9024533.551799468613
Amp=0.98881634455640554204, f=5.718247364235327, q_s=3.1736664460742780383, r_s=1.0024200309980686851





In [4]:
from collections import deque
ar_states = deque(np.random.randn(len(kf.psi)))
gen_noise = []
for i in range(100000):
    next_state = np.array(ar_states) @ kf.psi + np.random.randn()
    gen_noise.append(next_state)
    ar_states.pop()
    ar_states.appendleft(next_state)
    
from scipy.signal import welch
import matplotlib.pyplot as plt

noise_model = ArNoise(x0=np.random.rand(30), alpha=1.5, order=30, s=1)
noise_sim = collect(noise_model, 10 * 1000)
freqs, psd = welch(gen_noise, nperseg=1000)
_, psd_data = welch(sim.data, nperseg=1000)
_, psd_ar_noise = welch(noise_sim, nperseg=1000)
plt.loglog(freqs, psd, label="fitted ar")
plt.loglog(freqs, psd_data, label="data")
plt.loglog(freqs, psd_ar_noise, label="true ar")
plt.loglog(freqs, [1/f**1.5 if f else 1000000 for f in freqs ], label="1/f")
plt.legend()
plt.show()
    
    

NameError: name 'sim' is not defined

In [None]:
kf.M

In [None]:
kf.q_s

In [None]:
kf.r_s