The pattern of our brain waves can change drastically depending on our mental state or task that is being executed. One example of that is during sleep phases, were our brain enters in different rhythmic patterns, related to its level of activity, depending on the sleep stage. These different patterns present distinct frequencies and, therefore, can be distingushed by spectral analysis, such as by looking at their periodogram or spectrogram and detecting the frequencies with the most power.

However, this approach fails to identify between Mu and Alpha waves, since their peak frequency lies on the same band, between 8 and 12 Hz. Howerver, they represent very distinct mental states.

*Alpha waves* are generated on the thalamus and on the occipital lobe, and are in general found in different sleep stages, as well as might having some relationship with visual memory and perception of mistakes.

*Mu waves* are encountered on motor cortex regions and generated by pyramidal neurons. During the performing and perception of motor actions, the decrease or suppression of mu waves can be detected, what is called desynchronization.

In this notebook, we focus on trying to detect other brain patterns in the somatossensorial dataset used by the paper, and use the location information given by the channel atoms to map the patterns found to the location they were generated, what can give more information about the classification of the wave and the activity being performed.

In [None]:
import mne
import numpy as np
import scipy
from matplotlib import pyplot as plt

from src.mne_data import load_data

In [None]:
# Let us first define the parameters of our model.

sfreq = 150.

# Define the shape of the dictionary
n_atoms = 25
n_times_atom = int(round(sfreq * 1.0))  # 1000. ms


In [None]:
# Next, we define the parameters for multivariate CSC

from alphacsc import BatchCDL
cdl = BatchCDL(
    # Shape of the dictionary
    n_atoms=n_atoms,
    n_times_atom=n_times_atom,
    # Request a rank1 dictionary with unit norm temporal and spatial maps
    rank1=True, uv_constraint='separate',
    # Initialize the dictionary with random chunk from the data
    D_init='chunk',
    # rescale the regularization parameter to be 20% of lambda_max
    lmbd_max="scaled", reg=.2,
    # Number of iteration for the alternate minimization and cvg threshold
    n_iter=100, eps=1e-4,
    # solver for the z-step
    solver_z="lgcd", solver_z_kwargs={'tol': 1e-2, 'max_iter': 1000},
    # solver for the d-step
    solver_d='alternate_adaptive', solver_d_kwargs={'max_iter': 300},
    # Technical parameters
    verbose=1, random_state=0, n_jobs=6)



In [None]:
# Loading the somatossesorial dataset

t_lim = (-2, 4)
X, info = load_data(dataset='somato', epoch=t_lim, sfreq=sfreq)


In [None]:
# Fit the model and learn rank1 atoms
cdl.fit(X)


In [None]:
rhythms = {4:'Delta',
           8:'Theta',
           12:'Alpha-Mu',
           30:'Beta'}

def find_peaks(model, n_atoms, n=5, figure=False, rows=1, columns=1):

    if figure:
        figsize = (columns * 5, rows * 5.5)
        fig, axes = plt.subplots(rows, columns, figsize=figsize, squeeze=False)

    main_rhythm = {}

    for i_atom in range(n_atoms):

        print(f"Atom {i_atom+1}")

        v_hat = model.v_hat_[i_atom]
        u_hat = model.u_hat_[i_atom]
        psd = np.abs(np.fft.rfft(v_hat)) ** 2
        frequencies = np.linspace(0, sfreq / 2.0, len(psd))

        mask = frequencies<=30
        frequencies = frequencies[mask]
        psd = psd[mask]
        print(frequencies)

        peaks_idx = np.argsort(psd)[-n:][::-1]
        peaks_freq = frequencies[peaks_idx]

        print(peaks_freq)

        for v in rhythms.keys():
            if peaks_freq[0]<v:

                print(f"    {rhythms[v]} wave")

                # n most relevant channels
                idx_sorted = np.argpartition(u_hat, -n)[-n:]
                #idx_sorted = idx_sorted[np.argsort(u_hat[idx_sorted])[::-1]]

                # most relevant channels
                channels = np.array(info.ch_names)[idx_sorted]

                print(f"    {n} most relevant channels:")
                print(f'    {channels}')

                main_rhythm[i_atom] = {"rhythm": rhythms[v],
                                       "channels": channels}

                if figure:
                    row = i_atom // columns
                    col = i_atom % columns
                    ax = axes[row, col]

                    u_hat = model.u_hat_[i_atom]
                    mne.viz.plot_topomap(u_hat, info, axes=ax, show=False)
                    ax.set(title=f'Atom {i_atom + 1} - Rhythm {rhythms[v]}',)

                break

    if figure:
        for i in range(n_atoms, rows * columns):
            row = i // columns
            col = i % columns
            axes[row, col].axis('off')
        plt.tight_layout()
        plt.savefig("../figures/waves_per_region.pdf", dpi=300)
        plt.show()

    return main_rhythm

main_rhythm = find_peaks(cdl,n_atoms, figure=True, rows=5, columns=5)
