In [None]:
import itertools
import json
import os
import sys
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

import bbcpy

%matplotlib inline

import numpy as np
import scipy as sp
from matplotlib import pyplot as plt

import bci_minitoolbox as bci

In [None]:
module_path = os.path.abspath(os.path.join('../src/'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from data.smr_datamodule import SMR_Data, normalize

In [None]:
# Using a raw string
data_dir = Path("D:\\SMR\\")
task_name = "2D"
subject_sessions_dict = {"S4" : "all"} 
loading_data_mode = "within_subject"
ival = "2s:10s:1ms"
bands = [ 8, 13 ]
chans = "*" 
fallback_neighbors = 4
transform = None
normalize_dict = { "norm_type": "std", "norm_axis": 0 }

In [None]:
def plot_data_dist(eeg_data, noisy_channels):
    # Compute statistics along axis 1
    mean = np.mean(eeg_data, axis=(0, 2))
    std_dev = np.std(eeg_data, axis=(0, 2))
    min_val = np.min(eeg_data, axis=(0, 2))
    max_val = np.max(eeg_data, axis=(0, 2))
    percentile_25 = np.percentile(eeg_data, 25, axis=(0, 2))
    median = np.median(eeg_data, axis=(0, 2))
    percentile_75 = np.percentile(eeg_data, 75, axis=(0, 2))

        # Create vertical boxplots
    plt.figure(figsize=(20, 10))
    
    # Plot boxplots with custom coloring for noisy channels
    boxplots = plt.boxplot(eeg_data.reshape(-1, 62), vert=True, patch_artist=True, labels=eeg_data.chans)
    
        # Set color for noisy channels
    noisy_color = 'red'
    default_color = 'lightblue'
    
    # Coloring boxes
    for patch, index in zip(boxplots['boxes'], range(1, 63)):  
        if index in noisy_channels:
            patch.set_facecolor(noisy_color)
        else:
            patch.set_facecolor(default_color)
    
    # Coloring medians
    for median_line in boxplots['medians']:
        median_line.set_color('black')
    
    # Coloring whiskers and caps
    for i in range(0, len(boxplots['caps']), 2):  # iterate in steps of 2 as each box has two caps
        index = i // 2 + 1
        if index in noisy_channels:
            boxplots['whiskers'][i].set_color(noisy_color)
            boxplots['whiskers'][i+1].set_color(noisy_color)
            boxplots['caps'][i].set_color(noisy_color)
            boxplots['caps'][i+1].set_color(noisy_color)
        else:
            boxplots['whiskers'][i].set_color('black')
            boxplots['whiskers'][i+1].set_color('black')
    
    # Coloring the outliers
    for flier, index in zip(boxplots['fliers'], range(1, 63)):
        if index in noisy_channels:
            flier.set_markerfacecolor(noisy_color)
            flier.set_markeredgecolor(noisy_color)
        else:
            flier.set_markerfacecolor('blue')
            flier.set_markeredgecolor('blue')
    
    plt.title('Boxplot of EEG Data for Each Channel (Noisy channels in red)')
    plt.xlabel('Channel Name')
    plt.ylabel('EEG Data Value')
    plt.xticks(rotation=90)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.show()

In [None]:
import warnings

import numpy as np
import scipy as sp
from matplotlib import pyplot as plt


def map(eegdata, v=None, clim='minmax', cb_label='', colorbar=True, senspos=True, aspect='equal',
        extent=None):
    '''
    Usage:
        scalpmap(mnt, v, clim='minmax', cb_label='')
    Parameters:
        mnt: a 2D array of channel coordinates (channels x 2)
        v:   a 1D vector (channels)
        clim: limits of color code, either
          'minmax' to use the minimum and maximum of the data
          'sym' to make limits symmetrical around zero, or
          a two element vector giving specific values
        cb_label: label for the colorbar
    '''
    mnt = eegdata.chans.mnt
    if np.any(np.isnan(mnt)):
        v = v[~np.isnan(mnt[:, 0])]
        mnt = mnt[~np.isnan(mnt[:, 0]), :]
        warnings.warn('Some sensor positions are undefined and thus excluded from plotting.')

    maxx = np.max([mnt[:, 0].max(), 1])
    maxy = np.max([mnt[:, 1].max(), 1])
    maxi = np.max([maxx,maxy])
    if extent is None:
        extent = np.array([-1, 1, -1, 1]) * maxi
    # interpolate between channels
    xi, yi = np.linspace(-1, 1, 100) * maxx, np.linspace(-1, 1, 100) * maxy
    xi, yi = np.meshgrid(xi, yi)
    rbf = sp.interpolate.Rbf(mnt[:, 0], mnt[:, 1], v, function='linear')
    zi = rbf(xi, yi)

    # mask area outside of the scalp
    a, b, n, r = 50, 50, 100, 50
    mask_y, mask_x = np.ogrid[-a:n - a, -b:n - b]
    mask = mask_x * mask_x + mask_y * mask_y >= r * r
    zi[mask] = np.nan

    if clim == 'minmax':
        vmin = v.min()
        vmax = v.max()
    elif clim == 'sym':
        vmin = -np.absolute(v).max()
        vmax = np.absolute(v).max()
    else:
        vmin = clim[0]
        vmax = clim[1]

    plt.imshow(zi, vmin=vmin, vmax=vmax, aspect=aspect, origin='lower', extent=extent, cmap='jet')
    if colorbar:
        plt.colorbar(shrink=.5, label=cb_label)
    if senspos:
        plt.scatter(mnt[:, 0], mnt[:, 1], c='k', marker='+', vmin=vmin, vmax=vmax)
    plt.axis('off')

In [None]:
def scalpmap(mnt, v, clim='minmax', cb_label=''): 
    '''
    Usage:
        scalpmap(mnt, v, clim='minmax', cb_label='')
    Parameters:
        mnt: a 2D array of channel coordinates (channels x 2)
        v:   a 1D vector (channels)
        clim: limits of color code, either
          'minmax' to use the minimum and maximum of the data
          'sym' to make limits symmetrical around zero, or
          a two element vector giving specific values
        cb_label: label for the colorbar
    '''    
    # interpolate between channels
    xi, yi = np.linspace(-1, 1, 100), np.linspace(-1, 1, 100)
    xi, yi = np.meshgrid(xi, yi)
    rbf = sp.interpolate.Rbf(mnt[:,0], mnt[:,1], v, function='linear')
    zi = rbf(xi, yi)
        
    # mask area outside of the scalp  
    a, b, n, r = 50, 50, 100, 50
    mask_y, mask_x = np.ogrid[-a:n-a, -b:n-b]
    mask = mask_x*mask_x + mask_y*mask_y >= r*r    
    zi[mask] = np.nan

    if clim=='minmax':
        vmin = v.min()
        vmax = v.max()
    elif clim=='sym':
        vmin = -np.absolute(v).max()
        vmax = np.absolute(v).max()
    else:
        vmin = clim[0]
        vmax = clim[1]
    
    plt.imshow(zi, vmin=vmin, vmax=vmax, origin='lower', extent=[-1, 1, -1, 1], cmap='jet')
    plt.colorbar(shrink=.5, label=cb_label)
    plt.scatter(mnt[:,0], mnt[:,1], c='k', marker='+', vmin=vmin, vmax=vmax)
    plt.axis('off')

from bokeh.models import ColorBar, LinearColorMapper, BasicTicker
from bokeh.plotting import figure

def bokeh_scalpmap(mnt, v, clim='minmax', cb_label='', plot_width=250, plot_height=250):
    '''
    Usage:
        bokeh_scalpmap(mnt, v, clim='minmax', cb_label='')
    Parameters:
        ... (same as before)
    Returns:
        Bokeh figure object with the scalpmap plotted.
    '''
    
    # interpolate between channels
    xi, yi = np.linspace(-1, 1, 100), np.linspace(-1, 1, 100)
    xi, yi = np.meshgrid(xi, yi)
    rbf = sp.interpolate.Rbf(mnt[:,0], mnt[:,1], v, function='linear')
    zi = rbf(xi, yi)
    
    # mask area outside of the scalp  
    a, b, n, r = 50, 50, 100, 50
    mask_y, mask_x = np.ogrid[-a:n-a, -b:n-b]
    mask = mask_x*mask_x + mask_y*mask_y >= r*r    
    zi[mask] = np.nan

    if clim=='minmax':
        vmin = v.min()
        vmax = v.max()
    elif clim=='sym':
        vmin = -np.absolute(v).max()
        vmax = np.absolute(v).max()
    else:
        vmin = clim[0]
        vmax = clim[1]

    mapper = LinearColorMapper(palette="Viridis256", low=vmin, high=vmax)

    p = figure(width=plot_width, height=plot_height, x_range=(-1,1), y_range=(-1,1), tools="")
    p.image(image=[zi], x=-1, y=-1, dw=2, dh=2, color_mapper=mapper)
    p.circle(mnt[:,0], mnt[:,1], color='black', size=5)
    
    # Add a color bar
    color_bar = ColorBar(color_mapper=mapper, ticker=BasicTicker(),
                         label_standoff=12, border_line_color=None, location=(0,0))
    p.add_layout(color_bar, 'right')
    
    p.axis.visible = False
    p.grid.visible = False

    return p


In [None]:
def normalize_positions(mnt):
    # Normalize x and y coordinates to fit within [-1, 1]
    mnt[:, 0] = 2 * (mnt[:, 0] - mnt[:, 0].min()) / (mnt[:, 0].max() - mnt[:, 0].min()) - 1
    mnt[:, 1] = 2 * (mnt[:, 1] - mnt[:, 1].min()) / (mnt[:, 1].max() - mnt[:, 1].min()) - 1
    return mnt

# With Noisy chans 

In [None]:
smr_datamodule_nn = SMR_Data(data_dir=data_dir,
                           task_name=task_name,
                           subject_sessions_dict=subject_sessions_dict,
                           loading_data_mode=loading_data_mode,
                           ival=ival,
                           bands=bands,
                           chans=chans,
                           fallback_neighbors=fallback_neighbors,
                           transform=transform,
                           normalize=normalize_dict,
                         process_noisy_channels=False)

subjects_sessions_path_dict = smr_datamodule_nn.collect_subject_sessions(subject_sessions_dict)
subject_data_dict, subjects_info_dict = smr_datamodule_nn.load_subjects_sessions(subjects_sessions_path_dict)

subject_name = list(subject_data_dict.keys())[0]
loaded_subject_sessions = subject_data_dict[subject_name]
loaded_subject_sessions_info = subjects_info_dict[subject_name]["sessions_info"]

# append the sessions (FIXME : forced trials are not used)
valid_trials = smr_datamodule_nn.append_sessions(loaded_subject_sessions,
                                     loaded_subject_sessions_info)

mrk_classname = valid_trials.className

In [None]:
valid_trials_norm, norm_params_valid = normalize(valid_trials,
                                                 norm_type=normalize_dict["norm_type"],
                                                 axis=normalize_dict["norm_axis"])

In [None]:

epo = valid_trials_norm
normalized_mnt = normalize_positions(epo.chans.mnt[:,:2])
# Generate intervals with 500ms windows and 500ms steps
start_times = np.arange(3000, 7000, 500)
ival = [[start, start + 500] for start in start_times]
mrk_classname = valid_trials.className

In [None]:
from matplotlib.animation import FuncAnimation
%matplotlib inline
# Generate intervals with 100ms windows
start_times = np.arange(3000, 7300, 40)
ival = [[start, start + 100] for start in start_times]

rows = 1  # All classes will be plotted in one row
cols = len(mrk_classname)

fig, axs = plt.subplots(rows, cols, figsize=(20, 5))

def update(interval_idx):
    for klass_idx, klass_name in enumerate(mrk_classname):
        ax = axs[klass_idx]
        ax.clear()
        plt.sca(ax)  # Set the current axis to 'ax' so that scalpmap plots in the correct location

        start, end = ival[interval_idx]
        indices = (epo.t >= start) & (epo.t <= end)
        mean = np.mean(epo[:, :, indices][epo.y == klass_idx ,:, :], axis=(0, 2))
        scalpmap(epo.chans.mnt[:,:2], mean, clim='sym')
        ax.set_title(f'Task: {klass_name}, Time: {start}-{end} ms')

    plt.tight_layout()

ani = FuncAnimation(fig, update, frames=len(ival), repeat=True)

In [None]:
from bokeh.io import push_notebook, show, output_notebook
from bokeh.layouts import row
from time import sleep

output_notebook()

# Set up a row of plots for each class
plots = [bokeh_scalpmap(normalized_mnt, np.zeros(normalized_mnt.shape[0])) for _ in mrk_classname]
plot_sources = [p.renderers[0].data_source for p in plots]

layout = row(*plots)
handle = show(layout, notebook_handle=True)

def update(interval_idx):
    for klass_idx, klass_name in enumerate(mrk_classname):
        start, end = ival[interval_idx]
        indices = (epo.t >= start) & (epo.t <= end)
        mean = np.mean(epo[:, :, indices][epo.y == klass_idx ,:, :], axis=(0, 2))
        
        # Interpolate between channels (similar to scalpmap function)
        xi, yi = np.linspace(-1, 1, 100), np.linspace(-1, 1, 100)
        xi, yi = np.meshgrid(xi, yi)
        rbf = sp.interpolate.Rbf(normalized_mnt[:,0], normalized_mnt[:,1], mean, function='linear')
        zi = rbf(xi, yi)
        
        plot_sources[klass_idx].data.update(image=[zi])
    
    push_notebook(handle=handle)

for _ in range(5):  # Loop through the data 5 times
    for interval_idx in range(len(ival)):
        update(interval_idx)
        sleep(0.5)  # Pause for a short time between updates


In [None]:
from bokeh.io import push_notebook, show, output_notebook
from bokeh.plotting import figure
from bokeh.layouts import gridplot
from bokeh.models import ColorBar, LinearColorMapper
from bokeh.palettes import Viridis256 as palette
from time import sleep

output_notebook()

# Generate intervals with 100ms windows
start_times = np.arange(3000, 7300, 40)
ival = [[start, start + 100] for start in start_times]

# Set up a grid of plots
plots = []
for klass_name in mrk_classname:
    p = figure(width=250, height=250, title=f"Task: {klass_name}")
    plots.append(p)

# The function to update the plots
def update(interval_idx):
    for klass_idx, klass_name in enumerate(mrk_classname):
        p = plots[klass_idx]
        p.title.text = f"Task: {klass_name}, Time: {ival[interval_idx][0]}-{ival[interval_idx][1]} ms"
        
        # Get the data for this time interval and class
        start, end = ival[interval_idx]
        indices = (epo.t >= start) & (epo.t <= end)
        mean = np.mean(epo[:, :, indices][epo.y == klass_idx ,:, :], axis=(0, 2))
        
        # Use the data to update the Bokeh plot
        # This part requires conversion of your data to something Bokeh can plot, similar to the scalpmap function
        # For simplicity, I'll just show an example using random data
        # You'll need to replace this with appropriate plotting commands based on your data
        p.image(image=[np.random.random((10, 10))], x=0, y=0, dw=1, dh=1, palette="Spectral11")

    push_notebook()

# Display the plots
grid = gridplot([plots])
handle = show(grid, notebook_handle=True)

In [None]:
# Loop through the time intervals, updating the plots for each interval
for interval_idx in range(len(ival)):
    update(interval_idx)
    sleep(0.5)  # Pause for a short time between updates

In [None]:
plot_data_dist(valid_trials, [3.0, 5.0, 8.0, 23.0, 41.0, 62.0])

In [None]:
plot_data_dist(valid_trials_norm, [3.0, 5.0, 8.0, 23.0, 41.0, 62.0])

In [None]:

epo = valid_trials_norm
normalized_mnt = normalize_positions(epo.chans.mnt[:,:2])
# Generate intervals with 500ms windows and 500ms steps
start_times = np.arange(3000, 7000, 500)
ival = [[start, start + 500] for start in start_times]

In [None]:
rows = len(mrk_classname)
cols = len(ival)

plt.figure(figsize=(18, 7))
for klass_idx, klass_name in enumerate(mrk_classname):
    for interval_idx, [start, end] in enumerate(ival):
        plot_idx = klass_idx * cols + interval_idx + 1
        ax = plt.subplot(rows, cols, plot_idx)

        indices = (epo.t >= start) & (epo.t <= end)
        mean = np.mean(epo[:, :, indices][epo.y == klass_idx ,:, :], axis=(0, 2))
        scalpmap(normalized_mnt, mean, clim='sym', cb_label='')

        # Set title for the plots in the top row to represent the interval
        if klass_idx == 0:
            ax.set_title(f'{start}-{end} ms')
        # Set y-label for the plots in the first column to represent the class
        if interval_idx == 0:
            ax.set_ylabel(klass_name)

plt.suptitle("EEG Potentials for Different Tasks and Time Intervals", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
epo = valid_trials

In [None]:
rows = len(mrk_classname)
cols = len(ival)

plt.figure(figsize=(18, 7))
for klass_idx, klass_name in enumerate(mrk_classname):
    for interval_idx, [start, end] in enumerate(ival):
        plot_idx = klass_idx * cols + interval_idx + 1
        ax = plt.subplot(rows, cols, plot_idx)

        indices = (epo.t >= start) & (epo.t <= end)
        mean = np.mean(epo[:, :, indices][epo.y == klass_idx ,:, :], axis=(0, 2))
        scalpmap(normalized_mnt, mean, clim='sym', cb_label='')

        # Set title for the plots in the top row to represent the interval
        if klass_idx == 0:
            ax.set_title(f'{start}-{end} ms')
        # Set y-label for the plots in the first column to represent the class
        if interval_idx == 0:
            ax.set_ylabel(klass_name)

plt.suptitle("EEG Potentials for Different Tasks and Time Intervals", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

# Without Noisy chans

In [None]:
smr_datamodule_wn = SMR_Data(data_dir=data_dir,
                           task_name=task_name,
                           subject_sessions_dict=subject_sessions_dict,
                           loading_data_mode=loading_data_mode,
                           ival=ival,
                           bands=bands,
                           chans=chans,
                           fallback_neighbors=fallback_neighbors,
                           transform=transform,
                           normalize=normalize_dict,
                         process_noisy_channels=True)

subjects_sessions_path_dict = smr_datamodule_wn.collect_subject_sessions(subject_sessions_dict)
subject_data_dict, subjects_info_dict = smr_datamodule_wn.load_subjects_sessions(subjects_sessions_path_dict)

subject_name = list(subject_data_dict.keys())[0]
loaded_subject_sessions = subject_data_dict[subject_name]
loaded_subject_sessions_info = subjects_info_dict[subject_name]["sessions_info"]

# append the sessions (FIXME : forced trials are not used)
valid_trials_1 = smr_datamodule_wn.append_sessions(loaded_subject_sessions,
                                     loaded_subject_sessions_info)
valid_trials_norm_1, norm_params_valid = normalize(valid_trials_1,
                                                 norm_type=normalize_dict["norm_type"],
                                                 axis=normalize_dict["norm_axis"])

In [None]:
plot_data_dist(valid_trials_1, [3.0, 5.0, 8.0, 23.0, 41.0, 62.0])

In [None]:
plot_data_dist(valid_trials_norm_1, [3.0, 5.0, 8.0, 23.0, 41.0, 62.0])

In [None]:
mrk_classname = valid_trials_norm_1.className

In [None]:
mrk_classname

In [None]:
valid_trials_norm_1.t

## Plotting ERPs 

In [None]:
# Store given information in variables. Subsequent code should only refer to these variables and not
# contain the constants.
ival= [-100, 1000]
ref_ival= [-100, 0]
chans = ['C3', 'C4']
mrk_classname = valid_trials_norm_1.className

# Baseline correction:

epo = valid_trials_norm_1

# Now it is your turn to continue ...
get_mean = lambda chan, klass: np.mean(epo[epo.y == klass,epo.chans.index(chan),:], axis=0)

def plot_channel(channel):
    for (klass_idx, klass_name) in enumerate(mrk_classname):
        print(klass_idx)
        print(klass_name)
        
        mean_chan = get_mean(channel, klass_idx)
        print(mean_chan.shape)
        plt.plot(epo.t, mean_chan.squeeze(), label=f'{channel}({klass_name})')

plt.figure(figsize=(16, 4))

for channel in chans:
    plot_channel(channel)

plt.xlabel('time  [ms]')
plt.ylabel('potential  [$\mu$V]')
plt.legend()
plt.show()

## Scalp Topographies of ERPs

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

In [None]:
ival = [[3000, 3200], [3800, 4300], [3000, 3200], [3800, 4300], [5000, 5200], [6800, 7300]]

rows = len(mrk_classname)
cols = len(ival)

In [None]:
normalized_mnt = normalize_positions(epo.chans.mnt[:,:2])

In [None]:
def plot_eeg_intervals(interval_idx=0):
    plt.figure(figsize=(18, 7))
    for klass_idx, klass_name in enumerate(mrk_classname):
        plot_idx = klass_idx * cols + interval_idx + 1
        plt.subplot(rows, cols, plot_idx)

        start, end = ival[interval_idx]
        indices = (epo.t >= start) & (epo.t <= end)
        mean = np.mean(epo[:, :, indices][epo.y == klass_idx ,:, :], axis=(0, 2))
        scalpmap(normalized_mnt, mean,clim='sym', cb_label='' )
        plt.title(f'Task: {klass_name}')  # Add title for each subplot
    plt.tight_layout()  # Adjust layout for better spacing
    plt.show()
widgets.interactive(plot_eeg_intervals, interval_idx=(0, len(ival)-1))

In [None]:
np.mean(epo[:, :, :][epo.y == 0,:,:], axis=(0, 2))

In [None]:
mrk_classname = valid_trials_norm_1.className

In [None]:
# Generate intervals with 500ms windows and 500ms steps
start_times = np.arange(3000, 7000, 500)
ival = [[start, start + 500] for start in start_times]


In [None]:

epo = valid_trials_1
rows = len(mrk_classname)
cols = len(ival)

plt.figure(figsize=(18, 7))
for klass_idx, klass_name in enumerate(mrk_classname):
    for interval_idx, [start, end] in enumerate(ival):
        plot_idx = klass_idx * cols + interval_idx + 1
        ax = plt.subplot(rows, cols, plot_idx)

        indices = (epo.t >= start) & (epo.t <= end)
        mean = np.mean(epo[:, :, indices][epo.y == klass_idx ,:, :], axis=(0, 2))
        scalpmap(epo.chans.mnt[:,:2], mean, clim='sym', cb_label='')

        # Set title for the plots in the top row to represent the interval
        if klass_idx == 0:
            ax.set_title(f'{start}-{end} ms')
        # Set y-label for the plots in the first column to represent the class
        if interval_idx == 0:
            ax.set_ylabel(klass_name)

plt.suptitle("EEG Potentials for Different Tasks and Time Intervals", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
epo.shape

In [None]:
start_times = np.arange(3000, 8000, 500)
ival = [[start, start + 500] for start in start_times]

In [None]:
ival

In [None]:
epo = valid_trials_norm_1
rows = len(mrk_classname)
cols = len(ival)

plt.figure(figsize=(18, 7))
for klass_idx, klass_name in enumerate(mrk_classname):
    for interval_idx, [start, end] in enumerate(ival):
        plot_idx = klass_idx * cols + interval_idx + 1
        ax = plt.subplot(rows, cols, plot_idx)

        indices = (epo.t >= start) & (epo.t <= end)
        mean = np.mean(epo[:, :, indices][epo.y == klass_idx ,:, :], axis=(0, 2))
        scalpmap(epo.chans.mnt[:,:2], mean, clim='sym', cb_label='')

        # Set title for the plots in the top row to represent the interval
        if klass_idx == 0:
            ax.set_title(f'{start}-{end} ms')
         # Display the task name on the left-most plots
        if interval_idx == 0:
            ax.annotate(klass_name, xy=(-0.2, 0.5), xycoords="axes fraction", fontsize=12, ha="center", va="center")


plt.suptitle("EEG Potentials for Different Tasks and Time Intervals", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
from matplotlib.animation import FuncAnimation
%matplotlib notebook
# Generate intervals with 100ms windows and 40ms steps
start_times = np.arange(3000, 7300, 40)
ival = [[start, start + 100] for start in start_times]

rows = 1  # Only one row since we're plotting all classes horizontally
cols = len(mrk_classname)  # One column for each class

fig, axs = plt.subplots(rows, cols, figsize=(20, 5), squeeze=False)

def draw(interval_idx):
    for klass_idx, klass_name in enumerate(mrk_classname):
        ax = axs[0, klass_idx]
        ax.clear()

        start, end = ival[interval_idx]
        indices = (epo.t >= start) & (epo.t <= end)
        mean = np.mean(epo[:, :, indices][epo.y == klass_idx ,:, :], axis=(0, 2))
        scalpmap(epo.chans.mnt[:,:2], mean, clim='sym', ax=ax)
        ax.set_title(f'Task: {klass_name}, Time: {start}-{end} ms')

    plt.tight_layout()
ani = FuncAnimation(fig, draw, frames=len(ival), repeat=True)
plt.show()

In [None]:
plt.show()