In [1]:
%matplotlib inline
%reload_ext autoreload
%reload_ext line_profiler
%load_ext snakeviz
%autoreload 2
%qtconsole

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from src.spectral.transforms import Multitaper
from src.spectral.connectivity import Connectivity

In [3]:
def simulate_MVAR(coefficients, noise_covariance=None, n_time_samples=100,
                  n_trials=1, n_burnin_samples=100):
    '''
    Parameters
    ----------
    coefficients : array, shape (n_time_samples, n_lags, n_signals, n_signals)
    noise_covariance : array, shape (n_signals, n_signals)
    
    Returns
    -------
    time_series : array, shape (n_time_samples - n_burnin_samples,
                                n_trials, n_signals)

    '''
    n_lags, n_signals, _ = coefficients.shape
    if noise_covariance is None:
        noise_covariance = np.eye(n_signals)
    time_series = np.random.multivariate_normal(
        np.zeros((n_signals,)), noise_covariance,
        size=(n_time_samples + n_burnin_samples, n_trials))
    
    for time_ind in np.arange(n_lags, n_time_samples + n_burnin_samples):
        for lag_ind in np.arange(n_lags):
            time_series[time_ind, ...] += np.matmul(
                coefficients[np.newaxis, np.newaxis, lag_ind, ...],
                time_series[time_ind - (lag_ind + 1), ..., np.newaxis]).squeeze()
    return time_series[n_burnin_samples:, ...]

In [8]:
def baccala_example4():
    '''Baccalá, L.A., and Sameshima, K. (2001). Partial directed coherence:
    a new concept in neural structure determination. Biological
    Cybernetics 84, 463–474.
    '''
    sampling_frequency = 200
    n_time_samples, n_lags, n_signals = 3000, 2, 5
    coefficients = np.zeros((n_lags, n_signals, n_signals))

    coefficients[0, 0, 0] = 0.95 * np.sqrt(2)
    coefficients[1, 0, 0] = -0.9025
    coefficients[0, 1, 0] = -0.50
    coefficients[1, 2, 1] = 0.40
    coefficients[0, 3, 2] = -0.50
    coefficients[0, 3, 3] = 0.25 * np.sqrt(2)
    coefficients[0, 3, 4] = 0.25 * np.sqrt(2)
    coefficients[0, 4, 3] = -0.25 * np.sqrt(2)
    coefficients[0, 4, 4] = 0.25 * np.sqrt(2)

    noise_covariance = None

    return simulate_MVAR(
        coefficients, noise_covariance=noise_covariance, n_time_samples=n_time_samples,
        n_trials=500, n_burnin_samples=500), sampling_frequency

In [9]:
time_series, sampling_frequency = baccala_example4()
time_halfbandwidth_product = 2

m = Multitaper(time_series,
               sampling_frequency=sampling_frequency,
               time_halfbandwidth_product=time_halfbandwidth_product,
               start_time=0)
c = Connectivity.from_multitaper(m)

In [10]:
from src.spectral.minimum_phase_decomposition import minimum_phase_decomposition
fourier_coefficients = c.fourier_coefficients
cross_spectral_matrix = c._cross_spectral_matrix
expected_cross_spectral_matrix = c._expectation(cross_spectral_matrix)

print(fourier_coefficients.shape)
print(cross_spectral_matrix.shape)
print(expected_cross_spectral_matrix.shape)

(1, 500, 3, 3000, 5)
(1, 500, 3, 3000, 5, 5)
(1, 3000, 5, 5)


In [11]:
%%snakeviz
minimum_phase_factor = minimum_phase_decomposition(expected_cross_spectral_matrix)

 
*** Profile stats marshalled to file '/var/folders/rt/nhwr2l2937n0f8g854zq3s6w0000gn/T/tmpkc1vshzr'. 


In [20]:
from src.spectral.minimum_phase_decomposition import _get_intial_conditions, _get_linear_predictor, _get_causal_signal

n_time_points = expected_cross_spectral_matrix.shape[0]
n_signals = expected_cross_spectral_matrix.shape[-1]
I = np.eye(n_signals)
is_converged = np.zeros(n_time_points, dtype=bool)
minimum_phase_factor = np.zeros(expected_cross_spectral_matrix.shape)
minimum_phase_factor[..., :, :, :] = _get_intial_conditions(
    expected_cross_spectral_matrix)

In [21]:
minimum_phase_factor.shape

(1, 3000, 5, 5)

In [22]:
iteration = 0

In [24]:
old_minimum_phase_factor = minimum_phase_factor.copy()
linear_predictor = _get_linear_predictor(
    minimum_phase_factor, expected_cross_spectral_matrix, I)

In [25]:
linear_predictor.shape

(1, 3000, 5, 5)

In [27]:
linear_predictor.squeeze()

array([[[ 1.16524030 +0.00000000e+00j, -0.07732229 +0.00000000e+00j,
         -0.15809898 +0.00000000e+00j,  0.16823778 +0.00000000e+00j,
         -0.03468379 +0.00000000e+00j],
        [-0.07732229 +0.00000000e+00j,  1.45478822 +0.00000000e+00j,
          0.17067192 +0.00000000e+00j, -0.11268733 +0.00000000e+00j,
          0.11759439 +0.00000000e+00j],
        [-0.15809898 +0.00000000e+00j,  0.17067192 +0.00000000e+00j,
          1.64488043 +0.00000000e+00j, -0.29674803 +0.00000000e+00j,
          0.18882200 +0.00000000e+00j],
        [ 0.16823778 +0.00000000e+00j, -0.11268733 +0.00000000e+00j,
         -0.29674803 +0.00000000e+00j,  1.90260400 +0.00000000e+00j,
         -0.07378509 +0.00000000e+00j],
        [-0.03468379 +0.00000000e+00j,  0.11759439 +0.00000000e+00j,
          0.18882200 +0.00000000e+00j, -0.07378509 +0.00000000e+00j,
          1.90131507 +0.00000000e+00j]],

       [[ 1.17656884 +9.35758708e-19j, -0.08048302 -1.45945151e-02j,
         -0.17868577 -9.83297215e-03j, 

In [28]:
causal_signal = _get_causal_signal(linear_predictor)

In [29]:
causal_signal.shape

(1, 3000, 5, 5)

In [39]:
%timeit np.matmul(minimum_phase_factor, _get_causal_signal(linear_predictor))

10.3 ms ± 1.31 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [40]:
%timeit np.matmul(minimum_phase_factor.squeeze(), _get_causal_signal(linear_predictor).squeeze())

8.73 ms ± 242 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [38]:
%timeit np.matmul(minimum_phase_factor, causal_signal)

5.19 ms ± 122 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [37]:
%timeit np.matmul(minimum_phase_factor.squeeze(), causal_signal.squeeze())

5.06 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [42]:
np.array(minimum_phase_factor.shape) == 1

array([ True, False, False, False], dtype=bool)

In [55]:
import dask.array as da

e = da.from_array(expected_cross_spectral_matrix, chunks=(1000))

In [50]:
%timeit minimum_phase_decomposition(expected_cross_spectral_matrix)

742 ms ± 79.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [56]:
%timeit minimum_phase_decomposition(e)

918 ms ± 67.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [53]:
f = minimum_phase_decomposition(e)
f.shape

(1, 3000, 5, 5)