In [23]:
import os
import numpy as np
import matplotlib.pyplot as plt
import mne
import mne.time_frequency as m
from typing import Optional
import matplotlib as mpt
import sys
import os
current_dir = os.path.dirname(os.path.abspath('./'))
if not current_dir in sys.path:
    sys.path.append(current_dir)

from utils.visualize import plot_tfr

def display_spectra(
    data: np.array,
    tfa: mne.time_frequency.AverageTFR,
    fig: mpt.figure.Figure,
    ax: mpt.axes.Axes,
    title: Optional[str] = '',
    cmap: Optional[str | mpt.colors.Colormap] = 'RdBu_r',
    **kwargs
) -> None:
    imshow = ax.imshow(
        data.T,
        aspect='auto',
        origin='lower',
        extent=(tfa.times[0], tfa.times[-1], tfa.freqs[0], tfa.freqs[-1]),
        cmap=cmap,
        **kwargs
    )
    ax.set_title(title)
    ax.set_yticks(np.linspace(tfa.freqs[0], tfa.freqs[-1], tfa.freqs.shape[0]))
    ax.set_yticklabels(np.round(tfa.freqs, 2))
    cbar = fig.colorbar(
        imshow,
        orientation='horizontal',
        ax=ax,
        pad=.1,
        aspect=50
    )
    sig = data.mean(1)
    ax.plot(tfa.times, 87*((sig - sig.min())/(sig.max() - sig.min())) + 3, '#555555')
    ax_additional = ax.twinx()
    ax_additional.set_yticklabels([
        str(i)[:3] + str(i)[-4:]
        for i in np.linspace(sig.min(), sig.max(), len(ax_additional.get_yticks()))
    ])


In [4]:
mne.set_log_level('ERROR')

In [5]:
root = '../'
subjects_dir = os.path.join(root, 'data', 'subjects')
method = 'morlet'
# method = 'multitaper'
conds = ['hits', 'miss']
tfa_cond1 = []
tfa_cond2 = []
any_tfa = None
bl = (-.8, .2)
# bl = (3.25, 3.75)

for subject_name in os.listdir(subjects_dir):
    subject_dir = os.path.join(subjects_dir, subject_name)
    tfa_dir = os.path.join(subject_dir, 'tfa')
    for tfa_file in os.listdir(tfa_dir):
        if method in tfa_file:
            tfa = m.read_tfrs(os.path.join(tfa_dir, tfa_file))

            if isinstance(tfa, list) and len(tfa) == 1:
                tfa = tfa[0]
            elif isinstance(tfa, list):
                raise OSError(f'Several TFRAverage objects contained in {tfa_file}')

            tfa.apply_baseline(bl)

            if any_tfa is None:
                any_tfa = tfa.copy()

            if conds[0] in tfa_file:
                tfa_cond1.append(tfa.data)
            elif conds[1] in tfa_file:
                tfa_cond2.append(tfa.data)

assert len(tfa_cond1) == len(tfa_cond2), 'Data for conditions are inconsistent'

tfa_cond1 = np.array(tfa_cond1)
tfa_cond2 = np.array(tfa_cond2)

In [6]:
# As mne.stats.spatio_temporal_cluster_test requires the shape: n_observations x n_p x n_q x n_verticles
tfa_cond1 = np.transpose(tfa_cond1, (0, 3, 2, 1)) # shape: subs, chnls, freqs, time-points
tfa_cond2 = np.transpose(tfa_cond2, (0, 3, 2, 1))

In [7]:
tfa_cond1_ave = tfa_cond1.copy().mean(-1)
tfa_cond2_ave = tfa_cond2.copy().mean(-1)
tfa_diff_ave = tfa_cond1_ave - tfa_cond2_ave

In [105]:
from typing import Callable


def rowwise(
    data: np.ndarray,
    fun: Callable
):
    out = []
    for row in data:
        out.append(fun(row))

    return np.array(out)#, mean

row_min_max = lambda row: (row - row.min())/(row.max() - row.min())

In [112]:
# average TFR zscore normalized
#! Check

from scipy.stats import zscore

%matplotlib qt

fig, axs = plt.subplots(1, 3)

data = tfa_cond1_ave.mean(0).T
data[:, :15] = 0
data[:, -15:] = 0
data[:, 15:-15] = zscore(data[:, 15:-15], axis=1)
display_spectra(
    data.T,
    any_tfa,
    fig,
    axs[0],
    'HITS'
)
data = tfa_cond2_ave.mean(0).T
data[:, :15] = 0
data[:, -15:] = 0
data[:, 15:-15] = zscore(data[:, 15:-15], axis=1)
display_spectra(
    data.T,
    any_tfa,
    fig,
    axs[1],
    'MISS'
)
data = tfa_diff_ave.mean(0).T
data[:, :15] = 0
data[:, -15:] = 0
data[:, 15:-15] = zscore(data[:, 15:-15], axis=1)
display_spectra(
    data.T,
    any_tfa,
    fig,
    axs[2],
    'DIFF'
)
plt.subplots_adjust(
    left=.075,
    bottom=None,
    right=None,
    top=None,
    wspace=.4,
    hspace=None
)
fig.set_size_inches(17, 5)
plt.show()

  ax_additional.set_yticklabels([


In [113]:
def random_walk(n: int) -> np.ndarray:
    out = list()
    r = 0
    for i in range(n):
        r += np.random.random()*2 - 1
        out.append(r)
    return np.array(out)

fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4)
x = np.arange(100)
y = random_walk(100)
ax1.plot(x, y)
ax1.axhline(0, color='r', linestyle='--')
ax2.plot(x, y - y.mean())
ax2.axhline(0, color='r', linestyle='--')
print(y.std())
ax3.plot(x, y/y.std())
ax3.axhline(0, color='r', linestyle='--')
ax4.plot(x, (y - y.mean())/y.std())
ax4.axhline(0, color='r', linestyle='--')
fig.set_size_inches(10, 5)

2.8431238130713967


In [110]:
# average TFR min-max normalized

%matplotlib qt

fig, axs = plt.subplots(1, 3)

data = tfa_cond1_ave.mean(0).T
data[:, :15] = 0
data[:, -15:] = 0
data[:, 15:-15] = rowwise(data[:, 15:-15], row_min_max)
display_spectra(
    data.T,
    any_tfa,
    fig,
    axs[0],
    'HITS'
)
data = tfa_cond2_ave.mean(0).T
data[:, :15] = 0
data[:, -15:] = 0
data[:, 15:-15] = rowwise(data[:, 15:-15], row_min_max)
display_spectra(
    data.T,
    any_tfa,
    fig,
    axs[1],
    'MISS'
)
data = tfa_diff_ave.mean(0).T
data[:, :15] = 0
data[:, -15:] = 0
data[:, 15:-15] = rowwise(data[:, 15:-15], row_min_max)
display_spectra(
    data.T,
    any_tfa,
    fig,
    axs[2],
    'DIFF'
)
plt.subplots_adjust(
    left=.075,
    bottom=None,
    right=None,
    top=None,
    wspace=.4,
    hspace=None
)
fig.set_size_inches(17, 5)
plt.show()

  ax_additional.set_yticklabels([


In [243]:
# average TFR non-normalized

%matplotlib qt

fig, axs = plt.subplots(1, 3)

display_spectra(
    tfa_cond1_ave.mean(0),
    any_tfa,
    fig,
    axs[0],
    'HITS'
)
display_spectra(
    tfa_cond2_ave.mean(0),
    any_tfa,
    fig,
    axs[1],
    'MISS'
)
display_spectra(
    tfa_diff_ave.mean(0),
    any_tfa,
    fig,
    axs[2],
    'DIFF'
)
plt.subplots_adjust(
    left=.075,
    bottom=None,
    right=None,
    top=None,
    wspace=.4,
    hspace=None
)
fig.set_size_inches(17, 5)
plt.show()

  ax_additional.set_yticklabels([


In [89]:
# dirty code for average TFR

%matplotlib qt

fig, axs = plt.subplots(1, 3)



hits = axs[0].imshow(
    tfa_cond1_ave.mean(0).T,
    aspect='auto',
    origin='lower',
    extent=(any_tfa.times[0], any_tfa.times[-1], any_tfa.freqs[0], any_tfa.freqs[-1]),
    cmap='RdBu_r'
)
axs[0].set_title('HITS')
axs[0].set_yticks(np.linspace(any_tfa.freqs[0], any_tfa.freqs[-1], any_tfa.freqs.shape[0]))
axs[0].set_yticklabels(np.round(any_tfa.freqs, 2))
cbar1 = fig.colorbar(
    hits,
    orientation='horizontal',
    ax=axs[0],
    pad=.1,
    aspect=50
)
sig = tfa_cond1_ave.mean(0).mean(1)
axs[0].plot(any_tfa.times, 87*((sig - sig.min())/(sig.max() - sig.min())) + 3, '#555555')
ax_additional = axs[0].twinx()
sigmin = sig.min()
sigmax = sig.max()
ax_additional.set_yticklabels([
    str(i)[:3] + str(i)[-4:]
    for i in np.linspace(sigmin, sigmax, len(ax_additional.get_yticks()))
])

miss = axs[1].imshow(
    tfa_cond2_ave.mean(0).T,
    aspect='auto',
    origin='lower',
    extent=(any_tfa.times[0], any_tfa.times[-1], any_tfa.freqs[0], any_tfa.freqs[-1]),
    cmap='RdBu_r'
)
axs[1].set_title('MISS')
axs[1].set_yticks(np.linspace(any_tfa.freqs[0], any_tfa.freqs[-1], any_tfa.freqs.shape[0]))
axs[1].set_yticklabels(np.round(any_tfa.freqs, 2))
cbar2 = fig.colorbar(
    miss,
    orientation='horizontal',
    ax=axs[1],
    pad=.1,
    aspect=50
)
sig = tfa_cond2_ave.mean(0).mean(1)
axs[1].plot(any_tfa.times, 87*((sig - sig.min())/(sig.max() - sig.min())) + 3, '#555555')
ax_additional = axs[1].twinx()
sigmin = sig.min()
sigmax = sig.max()
ax_additional.set_yticklabels([
    str(i)[:3] + str(i)[-4:]
    for i in np.linspace(sigmin, sigmax, len(ax_additional.get_yticks()))
])

diff = axs[2].imshow(
    tfa_diff_ave.mean(0).T,
    aspect='auto',
    origin='lower',
    extent=(any_tfa.times[0], any_tfa.times[-1], any_tfa.freqs[0], any_tfa.freqs[-1]),
    cmap='RdBu_r'
)
axs[2].set_title('HITS - MISS')
axs[2].set_yticks(np.linspace(any_tfa.freqs[0], any_tfa.freqs[-1], any_tfa.freqs.shape[0]))
axs[2].set_yticklabels(np.round(any_tfa.freqs, 2))
cbar3 = fig.colorbar(
    diff,
    orientation='horizontal',
    ax=axs[2],
    pad=.1,
    aspect=50
)
sig = tfa_diff_ave.mean(0).mean(1)
axs[2].plot(any_tfa.times, 87*((sig - sig.min())/(sig.max() - sig.min())) + 3, '#555555')
ax_additional = axs[2].twinx()
sigmin = sig.min()
sigmax = sig.max()
ax_additional.set_yticklabels([
    str(i)[:3] + str(i)[-4:]
    for i in np.linspace(sigmin, sigmax, len(ax_additional.get_yticks()))
])
plt.subplots_adjust(
    left=.075,
    bottom=None,
    right=None,
    top=None,
    wspace=.4,
    hspace=None
)
fig.set_size_inches(17, 5)
plt.show()

  ax_additional.set_yticklabels([
  ax_additional.set_yticklabels([
  ax_additional.set_yticklabels([


In [193]:
# cluster statistics

import scipy

mne.stats.permutation_cluster_test #?
F_obs, clusters, cluster_p_values, H0 = \
    mne.stats.permutation_cluster_test(
        [tfa_cond1_ave, tfa_cond2_ave],
        out_type='mask', #? why mask
        n_permutations=2000,
        stat_fun=lambda *a: scipy.stats.ttest_rel(*a)[0], #? null hypothesis that two related or repeated samples have identical average (expected) values.
        n_jobs=10,
        threshold=3
    )

In [192]:
freqs = any_tfa.freqs
times = any_tfa.times

fig, ax = plt.subplots(1, 1, figsize=(6, 4))
fig.subplots_adjust(0.12, 0.08, 0.96, 0.94, 0.2, 0.43)
# F_obs = tfa_diff_ave.mean(0)

F_obs_plot = np.nan * np.ones_like(F_obs)
for c, p_val in zip(clusters, cluster_p_values):
    if p_val <= 0.5:
        F_obs_plot[c] = F_obs[c]# * signs[c]

ax.imshow(
    F_obs.T,
    extent=[times[0], times[-1], freqs[0], freqs[-1]],
    aspect='auto', origin='lower',
    cmap='gray'
)
max_F = np.nanmax(abs(F_obs_plot))
ax.imshow(
    F_obs_plot.T,
    extent=[times[0], times[-1], freqs[0], freqs[-1]],
    aspect='auto', origin='lower', cmap='RdBu_r',
    vmin=-max_F, vmax=max_F, interpolation = 'none'
)

ax.set_xlabel('Time (ms)')
ax.set_ylabel('Frequency (Hz)')


Text(0, 0.5, 'Frequency (Hz)')