Notebook and code created by Bruna Lopes and Bruno Amorim, inspired by alphascs examples.

All the source code can be found in (github name).

The functions load_data, separate_sleep_stages, find_peaks, display_topomap, display_ffts and display_atoms are all implemented in the repository fo this project. To be able to import them, you can clone our repository.

git clone https://github.com/brunaafl/TS_CSC_MEG

cd TS_CSC_MEG/

# Functions

In [None]:
import mne
import numpy as np
import scipy
from matplotlib import pyplot as plt


rhythms = {4:'Delta',
           8:'Theta',
           12:'Alpha-Mu',
           30:'Beta',
           100: 'Spindle'}


def display_atom(model, i_atom, info, sfreq=150):

    n_plots = 3
    figsize = (n_plots * 5, 5.5)
    fig, axes = plt.subplots(1, n_plots, figsize=figsize, squeeze=False)

    # Plot the spatial map of the learn atom using mne topomap
    ax = axes[0, 0]
    u_hat = model.u_hat_[i_atom]
    mne.viz.plot_topomap(u_hat, info, axes=ax, show=False)
    ax.set(title='Learned spatial pattern')

    # Plot the temporal pattern of the learn atom
    ax = axes[0, 1]
    v_hat = model.v_hat_[i_atom]
    t = np.arange(v_hat.size) / sfreq
    ax.plot(t, v_hat)
    ax.set(xlabel='Time (sec)', title='Learned temporal waveform')
    ax.grid(True)

    # Plot the psd of the time atom
    ax = axes[0, 2]
    psd = np.abs(np.fft.rfft(v_hat)) ** 2
    frequencies = np.linspace(0, sfreq / 2.0, len(psd))
    ax.semilogy(frequencies, psd)
    ax.set(xlabel='Frequencies (Hz)', title='Power Spectral Density')
    ax.grid(True)
    ax.set_xlim(0, 30)

    plt.tight_layout()
    plt.show()


def display_atoms(model, n_atoms, rows, columns, sfreq, savefig="atoms_somato"):
    if rows * columns < n_atoms:
        raise ValueError("The grid size (rows x columns) must be at least equal to n_atoms")

    figsize = (columns * 5, rows * 5.5)
    fig, axes = plt.subplots(rows, columns, figsize=figsize, squeeze=False)

    for i_atom in range(n_atoms):
        row = i_atom // columns
        col = i_atom % columns
        ax = axes[row, col]

        v_hat = model.v_hat_[i_atom]
        t = np.arange(v_hat.size) / sfreq

        ax.plot(t, v_hat)
        ax.set(xlabel='Time (sec)', title=f'Atom {i_atom + 1}')
        ax.grid(True)

    for i in range(n_atoms, rows * columns):
        row = i // columns
        col = i % columns
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.savefig(f"../figures/{savefig}.pdf", dpi=300)
    plt.show()


def display_ffts(model, n_atoms, rows, columns, sfreq, savefig="topomap_ffts"):
    if rows * columns < n_atoms:
        raise ValueError("The grid size (rows x columns) must be at least equal to n_atoms")

    figsize = (columns * 5, rows * 5.5)
    fig, axes = plt.subplots(rows, columns, figsize=figsize, squeeze=False)

    for i_atom in range(n_atoms):
        row = i_atom // columns
        col = i_atom % columns
        ax = axes[row, col]

        v_hat = model.v_hat_[i_atom]
        psd = np.abs(np.fft.rfft(v_hat)) ** 2
        frequencies = np.linspace(0, sfreq / 2.0, len(psd))
        ax.semilogy(frequencies, psd)
        ax.set(xlabel='Frequencies (Hz)', title=f'Atom {i_atom + 1}')
        ax.grid(True)
        ax.set_xlim(0, 30)

    for i in range(n_atoms, rows * columns):
        row = i // columns
        col = i % columns
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.savefig(f"../figures/{savefig}.pdf", dpi=300)
    plt.show()

def display_topomap(model, n_atoms, rows, columns, info, savefig="topomap_somato"):
    if rows * columns < n_atoms:
        raise ValueError("The grid size (rows x columns) must be at least equal to n_atoms")

    figsize = (columns * 5, rows * 5.5)
    fig, axes = plt.subplots(rows, columns, figsize=figsize, squeeze=False)

    for i_atom in range(n_atoms):
        row = i_atom // columns
        col = i_atom % columns
        ax = axes[row, col]

        u_hat = model.u_hat_[i_atom]
        mne.viz.plot_topomap(u_hat, info, axes=ax, show=False)
        ax.set(title=f'Atom {i_atom + 1}')

    for i in range(n_atoms, rows * columns):
        row = i // columns
        col = i % columns
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.savefig(f"../figures/{savefig}.pdf", dpi=300)
    plt.show()



In [None]:
from os.path import join
from copy import deepcopy

import mne
import numpy as np
from joblib import Memory
from scipy.signal.windows import tukey

from alphacsc.utils.config import ALPHACSC_CACHE_DIR

mem = Memory(location=ALPHACSC_CACHE_DIR, verbose=0)

## Adaptation of the code at alphacsc
@mem.cache(ignore=['n_jobs'])
def load_data(dataset="somato", n_splits=10, sfreq=None, epoch=None, channels= None,
              filter_params=[2., None], return_array=True, n_jobs=1):
    """Load and prepare the somato dataset for multiCSC

    Parameters
    ----------
    dataset : str in {'somato', 'sample'}
        Dataset to load.
    n_splits : int
        Split the signal in n_split signals of same length before returning it.
        If epoch is provided, the signal is instead splitted according to the
        epochs and this option is not followed.
    sfreq : float
        Sampling frequency of the signal. The data are resampled to match it.
    epoch : tuple or None
        If set to a tuple, extract epochs from the raw data, using
        t_min=epoch[0] and t_max=epoch[1]. Else, use the raw signal, divided
        in n_splits chunks.
    filter_params : tuple of length 2
        Boundaries of filtering, e.g. (2, None), (30, 40), (None, 40).
    return_array : boolean
        If True, return an NumPy array, instead of mne objects.
    n_jobs : int
        Number of jobs that can be used for preparing (filtering) the data.

    Returns
    -------
    X : array, shape (n_splits, n_channels, n_times)
        The loaded dataset.
    info : dict
        MNE dictionary of information about recording settings.
    """

    if dataset == 'somato':
        pick_types_epoch = dict(meg='grad', eeg=False, eog=True, stim=False)
        pick_types_final = dict(meg='grad', eeg=False, eog=False, stim=False)

        data_path = mne.datasets.somato.data_path()
        subjects_dir = None
        file_name = join(data_path, 'sub-01', 'meg',
                         'sub-01_task-somato_meg.fif')
        raw = mne.io.read_raw_fif(file_name, preload=True)

        raw_copy = raw.copy()

        # Keep a copy for event extraction
        raw_stim = raw.copy()
        raw_stim.pick_types(meg=False, stim=True)  # Only keep stim channels for event detection

        # Extract events from stim channels
        event_id = {'somato': 1}
        events = mne.find_events(raw_stim, stim_channel='STI 014')
        events = mne.pick_events(events, include=list(event_id.values()))

        raw_copy.notch_filter(np.arange(50, 101, 50), n_jobs=n_jobs)

        baseline = (None, 0)

    elif dataset == 'sleep':
        pick_types_epoch = dict(meg=False, eeg=True, eog=True, stim=False)
        pick_types_final = dict(meg=False, eeg=True, eog=False, stim=False)

        # Load the sleep PhysioNet dataset
        subject = 1
        subjects_dir = None
        [data_fetch] = mne.datasets.sleep_physionet.age.fetch_data(subjects=[subject], recording=[1])

        raw = mne.io.read_raw_edf(data_fetch[0],stim_channel="Event marker",infer_types=True,preload=True,)

        annot_train = mne.read_annotations(data_fetch[1])
        print(annot_train)
        raw.set_annotations(annot_train, emit_warning=False)

        raw_copy = raw.copy()

        # remove 6 and 7 labels

        annotation_event_id = {
            "Sleep stage W": 1,
            "Sleep stage 1": 2,
            "Sleep stage 2": 3,
            "Sleep stage 3": 4,
            "Sleep stage 4": 4,
            "Sleep stage R": 5,
        }

        # Set reference for EEG channels
        annot_train.crop(annot_train[1]["onset"] - 30 * 60, annot_train[-2]["onset"] + 30 * 60)
        raw_copy.set_annotations(annot_train, emit_warning=False)

        # Extract events based on annotations
        events, _ = mne.events_from_annotations(raw_copy, event_id=annotation_event_id, chunk_duration=30.0)

        # create a new event_id that unifies stages 3 and 4
        event_id = {
            "Sleep stage W": 1,
            "Sleep stage 1": 2,
            "Sleep stage 2": 3,
            "Sleep stage 3/4": 4,
            "Sleep stage R": 5,
        }

        baseline = None

    else:
        ValueError("Dataset must be somato or auditory")

    if channels is not None:
        raw_copy = raw_copy.pick_channels(channels, ordered=True)

    # Dipole fit information
    cov = None  # see below
    file_trans = None
    file_bem = None

    raw_copy.filter(*filter_params, n_jobs=n_jobs)

    # Now pick final channel types for the main raw object
    raw_copy.pick_types(**pick_types_final)

    if dataset == 'somato':
        # compute the covariance matrix for somato
        picks_cov = mne.pick_types(raw_copy.info, **pick_types_epoch)
        epochs_cov = mne.Epochs(raw_copy, events, event_id, tmin=-4, tmax=0,
                                picks=picks_cov, baseline=baseline,
                                reject=dict(grad=4000e-13),
                                preload=True)
        epochs_cov.pick_types(**pick_types_final)
        cov = mne.compute_covariance(epochs_cov)

    if epoch:
        t_min, t_max = epoch
        print(events)
        picks = mne.pick_types(raw_copy.info, **pick_types_epoch)
        epochs = mne.Epochs(raw_copy, events, event_id, t_min, t_max, picks=picks,
                            baseline=baseline,preload=True)
        epochs.pick_types(**pick_types_final)
        info = epochs.info

        print(epochs)
        if sfreq is not None:
            epochs = epochs.resample(sfreq, npad='auto', n_jobs=n_jobs)

        if return_array:
            X = epochs.get_data()

    else:
        events[:, 0] -= raw_copy.first_samp
        if channels is not None:
            raw_copy = raw_copy.pick_channels(channels, ordered=True)
        raw_copy.pick_types(**pick_types_final)
        info = raw_copy.info

        if sfreq is not None:
            raw_copy, events = raw_copy.resample(sfreq, events=events, npad='auto',
                                       n_jobs=n_jobs)

        if return_array:
            X = raw_copy.get_data()
            n_channels, n_times = X.shape
            n_times = n_times // n_splits
            X = X[:, :n_splits * n_times]
            X = X.reshape(n_channels, n_splits, n_times).swapaxes(0, 1)

    # Deep copy before modifying info to avoid issues when saving EvokedArray
    info = deepcopy(info)
    event_info = dict(event_id=event_id,
                      events=events,
                      subject=dataset,
                      subjects_dir=subjects_dir,
                      cov=cov,
                      file_bem=file_bem,
                      file_trans=file_trans)

    info['temp'] = event_info

    if return_array:
        n_splits, n_channels, n_times = X.shape
        X *= tukey(n_times, alpha=0.1)[None, None, :]
        X /= np.std(X)
        return X, info
    elif epoch:
        return epoch, info
    else:
        return raw_copy, info


def separate_sleep_stages(X, info):

    event_id = info['temp']['event_id']  # Mapping of sleep stages
    events = info['temp']['events']  # Event data

    data_by_stage = {stage: [] for stage in event_id.keys()}

    # Sampling frequency (sfreq) is required to slice the data
    sfreq = info['sfreq']  # Replace with the correct sampling frequency of your data

    # Iterate over events to separate data
    j=0
    for _, _, stage_id in events:
        # Find corresponding stage name
        stage_name = next((name for name, sid in event_id.items() if sid == stage_id), None)
        if stage_name is None:
            continue  # Skip if the event ID does not match any stage
        # Store data from each stage in the dict
        data_by_stage[stage_name].append(X[j,:,:])
        j += 1

    X_stage = {}
    for stage in data_by_stage.keys():
        value = data_by_stage[stage]
        if len(value)>0:
            concat_values = np.concatenate(value, axis=0)
            X_stage[stage] = concat_values[:,None,:]

    return X_stage

# Detect neural rhythms

The pattern of our brain waves can change drastically depending on our mental state or task that is being executed. One example of that is during sleep phases, were our brain enters in different rhythmic patterns, related to its level of activity, depending on the sleep stage. These different patterns present distinct frequencies and, therefore, can be distingushed by spectral analysis, such as by looking at their periodogram or spectrogram and detecting the frequencies with the most power.

However, this approach fails to identify between Mu and Alpha waves, since their peak frequency lies on the same band, between 8 and 12 Hz. Howerver, they represent very distinct mental states.

*Alpha waves* are generated on the thalamus and on the occipital lobe, and are in general found in different sleep stages, as well as might having some relationship with visual memory and perception of mistakes.

*Mu waves* are encountered on motor cortex regions and generated by pyramidal neurons. During the performing and perception of motor actions, the decrease or suppression of mu waves can be detected, what is called desynchronization.

In [None]:
import mne
import numpy as np
import scipy
from matplotlib import pyplot as plt


In [None]:
# Let us first define the parameters of our model.

sfreq = 150.

# Define the shape of the dictionary
n_atoms = 25
n_times_atom = int(round(sfreq * 1.0))  # 1000. ms


In [None]:
# Next, we define the parameters for multivariate CSC

from alphacsc import BatchCDL
cdl = BatchCDL(
    # Shape of the dictionary
    n_atoms=n_atoms,
    n_times_atom=n_times_atom,
    # Request a rank1 dictionary with unit norm temporal and spatial maps
    rank1=True, uv_constraint='separate',
    # Initialize the dictionary with random chunk from the data
    D_init='chunk',
    # rescale the regularization parameter to be 20% of lambda_max
    lmbd_max="scaled", reg=.2,
    # Number of iteration for the alternate minimization and cvg threshold
    n_iter=100, eps=1e-4,
    # solver for the z-step
    solver_z="lgcd", solver_z_kwargs={'tol': 1e-2, 'max_iter': 1000},
    # solver for the d-step
    solver_d='alternate_adaptive', solver_d_kwargs={'max_iter': 300},
    # Technical parameters
    verbose=1, random_state=0, n_jobs=6)



In [None]:
# Loading the somatossesorial dataset

t_lim = (-2, 4)
X, info = load_data(dataset='somato', epoch=t_lim, sfreq=sfreq)


In [None]:
# Fit the model and learn rank1 atoms
cdl.fit(X)


In [None]:
rhythms = {4:'Delta',
           8:'Theta',
           12:'Alpha-Mu',
           30:'Beta'}

def find_peaks(model, n_atoms, n=5, figure=False, rows=1, columns=1):

    if figure:
        figsize = (columns * 5, rows * 5.5)
        fig, axes = plt.subplots(rows, columns, figsize=figsize, squeeze=False)

    main_rhythm = {}

    for i_atom in range(n_atoms):

        print(f"Atom {i_atom+1}")

        v_hat = model.v_hat_[i_atom]
        u_hat = model.u_hat_[i_atom]
        psd = np.abs(np.fft.rfft(v_hat)) ** 2
        frequencies = np.linspace(0, sfreq / 2.0, len(psd))

        mask = frequencies<=30
        frequencies = frequencies[mask]
        psd = psd[mask]
        print(frequencies)

        peaks_idx = np.argsort(psd)[-n:][::-1]
        peaks_freq = frequencies[peaks_idx]

        print(peaks_freq)

        for v in rhythms.keys():
            if peaks_freq[0]<v:

                print(f"    {rhythms[v]} wave")

                # n most relevant channels
                idx_sorted = np.argpartition(u_hat, -n)[-n:]
                #idx_sorted = idx_sorted[np.argsort(u_hat[idx_sorted])[::-1]]

                # most relevant channels
                channels = np.array(info.ch_names)[idx_sorted]

                print(f"    {n} most relevant channels:")
                print(f'    {channels}')

                main_rhythm[i_atom] = {"rhythm": rhythms[v],
                                       "channels": channels}

                if figure:
                    row = i_atom // columns
                    col = i_atom % columns
                    ax = axes[row, col]

                    u_hat = model.u_hat_[i_atom]
                    mne.viz.plot_topomap(u_hat, info, axes=ax, show=False)
                    ax.set(title=f'Atom {i_atom + 1} - Rhythm {rhythms[v]}',)

                break

    if figure:
        for i in range(n_atoms, rows * columns):
            row = i // columns
            col = i % columns
            axes[row, col].axis('off')
        plt.tight_layout()
        plt.savefig("../figures/waves_per_region.pdf", dpi=300)
        plt.show()

    return main_rhythm

main_rhythm = find_peaks(cdl,n_atoms, figure=True, rows=5, columns=5)


# Testing rank-1 hypothesis - Comparing regions

The paper imposes a rank-1 constraint over the dictionary of patterns $D$. This constraint facilitates the optimization of the problem transforming it from multivariate to a two-step univariate by changing the computation of the gradient over a $C \times P$ matrix to over two univariate arrays $u_k \in \mathbb{R}^C$ and $v_k \in \mathbb{R}^P$. With this constraint, the problem turns into marginally convex on each variable $v_k$ and $u_k$, and can be solved by a normal projected gradient descent.

This constraint, however, imposes an implicit hypothesis that, at each time, you can consider that there is a single source region on the brain which generates the signal, that is then reproduced in all other locations of the brain with different intensities.


In [None]:
import mne
import copy

from dtw import dtw

import pandas as pd
import seaborn as sns


# Let us first define the parameters of our model.

sfreq = 150.

# Define the shape of the dictionary
n_atoms = 25
n_times_atom = int(round(sfreq * 1.0))  # 1000. ms


In [None]:
# Next, we define the parameters for multivariate CSC

# Monkey-patch scipy.signal.tukey to point to the correct function
scipy.signal.tukey = scipy.signal.windows.tukey

cdl = BatchCDL(
    # Shape of the dictionary
    n_atoms=n_atoms,
    n_times_atom=n_times_atom,
    # Request a rank1 dictionary with unit norm temporal and spatial maps
    rank1=True, uv_constraint='separate',
    # Initialize the dictionary with random chunk from the data
    D_init='chunk',
    # rescale the regularization parameter to be 20% of lambda_max
    lmbd_max="scaled", reg=.2,
    # Number of iteration for the alternate minimization and cvg threshold
    n_iter=100, eps=1e-4,
    # solver for the z-step
    solver_z="lgcd", solver_z_kwargs={'tol': 1e-2, 'max_iter': 1000},
    # solver for the d-step
    solver_d='alternate_adaptive', solver_d_kwargs={'max_iter': 300},
    # Technical parameters
    verbose=1, random_state=0, n_jobs=6)



Now, he can chose two regions of the brain that are the most uncorrelated. We can do that by computing the correlation matrix between two regions and separating channels based on the regions of bigger correlation.

In [None]:
# Here, we load the data from the somato-sensory dataset and preprocess them
# in epochs. The epochs are selected around the stim, starting 2 seconds
# before and finishing 4 seconds after.

t_lim = (-2, 4)

X, info= load_data(dataset='somato', epoch=t_lim, sfreq=sfreq)

# Separate the channels that are more correlated on two groups
n_split = 110
n_channels = len(info['ch_names'])

all_channels = info['ch_names']
channels_1 = all_channels[:n_split] + ['STI 014']
channels_2 = all_channels[n_split:] + ['STI 014']

X1, info1= load_data(dataset='somato', epoch=t_lim, sfreq=sfreq,channels=channels_1)
X2, info2= load_data(dataset='somato', epoch=t_lim, sfreq=sfreq,channels=channels_2)


First, let's see the atoms found when all data is used for solving the optimization problem

In [None]:
# Learn rank-1 atoms
cdl_all = copy.deepcopy(cdl)
cdl_all.fit(X)

In [None]:
# display all
display_atoms(cdl_all, n_atoms, 5, 5, sfreq)
display_ffts(cdl_all, n_atoms, 5, 5, sfreq)
display_topomap(cdl_all, n_atoms, 5, 5, info)

Now, let's separate into regions

In [None]:
# Define the shape of the dictionary
n_atoms = 10
n_times_atom = int(round(sfreq * 1.0))  # 1000. ms

cdl = BatchCDL(
    # Shape of the dictionary
    n_atoms=n_atoms,
    n_times_atom=n_times_atom,
    # Request a rank1 dictionary with unit norm temporal and spatial maps
    rank1=True, uv_constraint='separate',
    # Initialize the dictionary with random chunk from the data
    D_init='chunk',
    # rescale the regularization parameter to be 20% of lambda_max
    lmbd_max="scaled", reg=.2,
    # Number of iteration for the alternate minimization and cvg threshold
    n_iter=100, eps=1e-4,
    # solver for the z-step
    solver_z="lgcd", solver_z_kwargs={'tol': 1e-2, 'max_iter': 1000},
    # solver for the d-step
    solver_d='alternate_adaptive', solver_d_kwargs={'max_iter': 300},
    # Technical parameters
    verbose=1, random_state=0, n_jobs=6)



In [None]:
# Learn rank-1 atoms for each separate part of the brain

# Separate the problem into 2 different
cdl_1 = copy.deepcopy(cdl)
cdl_2 = copy.deepcopy(cdl)

cdl_1.fit(X1)
cdl_2.fit(X2)


We can see that, even in two different regions, there are few region-specific patterns, and most o the atoms found on each of the regions can be related to each other and to the atoms found when using all regions of the brain to train the CSC model.

In [None]:
# Display the 4-th atom, which displays a :math:`\mu`-waveform in its temporal
# pattern.

display_atoms(cdl_1, n_atoms, 2, 5, sfreq, savefig="atoms_somato_1")
display_ffts(cdl_1, n_atoms, 2, 5, sfreq, savefig = "topomap_ffts_1")
display_topomap(cdl_1, n_atoms, 2, 5, info1, savefig = "topomap_somato_1")

display_atoms(cdl_2, n_atoms, 2, 5, sfreq, savefig="atoms_somato_2")
display_ffts(cdl_2, n_atoms, 2, 5, sfreq, savefig = "topomap_ffts_2")
display_topomap(cdl_2, n_atoms, 2, 5, info2, savefig = "topomap_somato_2")


In [None]:
def distance(v_hat_1, v_hat_2, n1, n2):
    table=np.zeros(shape=(n1, n2))

    for i in range(n_atoms):
        align_row = []
        for j in range(i,n_atoms):

            alignment = dtw(v_hat_1[i], v_hat_2[j],keep_internals=True)
            align_row.append(alignment)
            table[i,j]=alignment.distance
            table[j,i]=alignment.distance
    return table

In [None]:
# Compare the atoms found in the two regions

# Compute the dtw distance between atoms found on each region
v_hat_1 = cdl_1.v_hat_
v_hat_2 = cdl_2.v_hat_

table = distance(v_hat_1, v_hat_2, n_atoms, n_atoms)

columns = [f"Atom {i}" for i in range(1,1+n_atoms)]

min_index = np.argmin(table)
row, col = np.unravel_index(min_index, table.shape)

# Plot the most similar atoms
min_distance = table[row,col]
atom_row = v_hat_1[row]
atom_col = v_hat_2[col]

figsize = (11,5)
fig, axes = plt.subplots(1, 2, figsize=figsize, squeeze=False)

t = np.arange(atom_row.size)/sfreq

ax1 = axes[0,0]
ax1.plot(t, atom_row)
ax1.set(xlabel='Time (sec)', title=f'Atom {row + 1}')
ax1.grid(True)

ax2 = axes[0,1]
ax2.plot(t, atom_col)
ax2.set(xlabel='Time (sec)', title=f'Atom {col + 1}')
ax2.grid(True)

plt.tight_layout()
plt.savefig("../figures/most_similar_atoms.pdf", dpi=300)
plt.show()



In [None]:
# Plot the distance between atoms

table_df = pd.DataFrame(table, columns=columns)
table_df.index = columns

fig, ax = plt.subplots()
sns.heatmap(table_df, annot=True, cmap="YlGnBu", linewidths=0.5, ax=ax)
plt.xticks(rotation=45)
plt.savefig("../figures/distance_atoms.pdf", dpi=300)
# Show the plot
plt.show()


# Sleep dataset

It is important to test the robustness of the model by seeing it work on a different dataset. We chose a sleep dataset since each raw trial can be segmented in epochs corresponding to a specific sleep stage.

In this sense, given that each sleep stage generates a specific behaviour in the brain, they are in general associated with different bands. In general, a wave from a stage-1 state is characterized by relatively low-frequencies (alpha, ranging from 8-12Hz, and theta, from 4-7 Hz), while a deeper sleep, on stages 2 and 3, have mostly even lower frequencies on the delta band (0.5-4 Hz). During REM sleep, however, waves present patterns similar to the ones observed in wakefulness.

It could be interesting to try to find the most relevant patterns and peak frequencies from each sleep stage by using the CSC approach developed by the paper.

In [None]:
# Here, we load the data from a sleep stage dataset from one single subject.
# We separate the resultant epochs into their stage so we can analyse if there is
# some pattern that the CSC model can recognize to distinguish between stages.

sfreq = 100

t_lim = (0, 30 - 1/sfreq)
X, info = load_data(dataset='sleep', epoch=t_lim, sfreq=sfreq)

X_stages = separate_sleep_stages(X, info)

stages = ['Sleep stage W','Sleep stage 1','Sleep stage 2','Sleep stage 3/4','Sleep stage R']


In [None]:
# Next, we define the parameters for multivariate CSC

# First, for solving the optimization problem using data from all stages
from alphacsc import BatchCDL

# Define the shape of the dictionary
n_atoms = 10
n_times_atom = int(round(sfreq * 1.0))  # 1000. ms

cdl = BatchCDL(
    # Shape of the dictionary
    n_atoms=n_atoms,
    n_times_atom=n_times_atom,
    # Request a rank1 dictionary with unit norm temporal and spatial maps
    rank1=True, uv_constraint='separate',
    # Initialize the dictionary with random chunk from the data
    D_init='chunk',
    # rescale the regularization parameter to be 20% of lambda_max
    lmbd_max="scaled", reg=.2,
    # Number of iteration for the alternate minimization and cvg threshold
    n_iter=100, eps=1e-4,
    # solver for the z-step
    solver_z="lgcd", solver_z_kwargs={'tol': 1e-2, 'max_iter': 1000},
    # solver for the d-step
    solver_d='alternate_adaptive', solver_d_kwargs={'max_iter': 300},
    # Technical parameters
    verbose=1, random_state=0, n_jobs=6)

# Here, we define for each of the stages individually, so we search for less atoms
n_atoms_stage = 4
n_times_atom = int(round(sfreq * 1.0))  # 1000. ms

cdl_stage = BatchCDL(
    # Shape of the dictionary
    n_atoms=n_atoms_stage,
    n_times_atom=n_times_atom,
    # Request a rank1 dictionary with unit norm temporal and spatial maps
    rank1=True, uv_constraint='separate',
    # Initialize the dictionary with random chunk from the data
    D_init='chunk',
    # rescale the regularization parameter to be 20% of lambda_max
    lmbd_max="scaled", reg=.2,
    # Number of iteration for the alternate minimization and cvg threshold
    n_iter=100, eps=1e-4,
    # solver for the z-step
    solver_z="lgcd", solver_z_kwargs={'tol': 1e-2, 'max_iter': 1000},
    # solver for the d-step
    solver_d='alternate_adaptive', solver_d_kwargs={'max_iter': 300},
    # Technical parameters
    verbose=1, random_state=0, n_jobs=6)


In [None]:
# Fit the model and learn rank1 atoms for all
cdl.fit(X)

###############################################################################
# display all
display_atoms(cdl, n_atoms, 2, 5, sfreq)
display_ffts(cdl, n_atoms, 2, 5, sfreq)
#display_topomap(cdl, n_atoms, 5, 5, info)


In [None]:
# Now, we can see the patterns found for each different sleep stage

rhythms = {4:'Delta',
           8:'Theta',
           12:'Alpha-Mu',
           30:'Beta'}

for i, stage in enumerate(stages):

    X_stage = X_stages[stage]

    cdl_stage_ = copy.deepcopy(cdl_stage)
    cdl_stage_.fit(X_stage)

    # display found atoms
    display_atoms(cdl_stage_, n_atoms_stage, 1, 4, sfreq)
    display_ffts(cdl_stage_, n_atoms_stage, 1, 4, sfreq)

    print(main_rhythm)