In [1]:
import mne.time_frequency
from ieeg.viz.utils import chan_grid
from ieeg.viz.parula import parula_map
from ieeg.io import get_data, update, get_bad_chans
import os
import matplotlib.pyplot as plt

In [None]:
### put back in code for plotting the subject

### plot averaged wavelet for all stimulus and response

In [3]:
def chan_grid(inst: Signal, n_cols: int = 10, n_rows: int = 6,
              plot_func: callable = None, picks: list[str | int] = None,
              size: tuple[int, int] = (8, 12), show: bool = True, **kwargs
              ) -> list[plt.Figure]:
    """Plot a grid of the channels of a Signal object

    Parameters
    ----------
    size
    inst : Signal
        The Signal object to plot
    n_cols : int, optional
        Number of columns in the grid, by default 10
    n_rows : int, optional
        Number of rows in the grid, by default the minimum number of rows
    plot_func : callable, optional
        The function to use to plot the channels, by default inst.plot()
    picks : list[Union[str, int]], optional
        The channels to plot, by default all
    size : tuple[int, int], optional
        The size of the figure, by default (8, 12)
    show : bool, optional
        Whether to show the figure, by default True

    Returns
    -------
    list[plt.Figure]
        The figures containing the grid
    """

    # spec, size=(20, 10), vmin=-2, vmax=2, cmap=parula_map, show=False

    if n_rows is None:
        n_rows = int(np.ceil(len(inst.ch_names) / n_cols))
    if plot_func is None:
        plot_func = inst.plot
    if picks is None:
        chans = inst.ch_names
    elif isinstance(picks[0], str):
        chans = picks
    elif isinstance(picks[0], int):
        chans = [inst.ch_names[i] for i in picks]
    else:
        raise TypeError("picks must be a list of str or int")

    per_fig = n_cols * n_rows
    numfigs = int(np.ceil(len(chans) / per_fig))
    figs = []
    for i in range(numfigs):
        fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, frameon=False,
                                figsize=size)

        select = partial(_onclick_select, inst=inst, axs=fig.axes)
        text_spec = dict(fontsize=12, weight="extra bold")

        for j, chan in enumerate(chans[i * per_fig:(i + 1) * per_fig]):
            if j + 1 % n_cols == 0 or i == len(chans) - 1:
                bar = True
            else:
                bar = False
            if "colorbar" in plot_func.__code__.co_varnames:
                kwargs["colorbar"] = bar
            ax = axs.flatten()[j]
            plot_func(picks=[chan], axes=ax, show=False, **kwargs)
            ax.set_title(chan, pad=0, **text_spec)
            ax.tick_params(axis='both', which='major', labelsize=7,
                            direction="in")
            ax.set_xlabel("")
            ax.set_ylabel("")
            gc.collect()
        fig.supxlabel("Time (s)", **text_spec)
        fig.supylabel("Frequency (Hz)", **text_spec)
        if i == numfigs - 1:
            while j + 1 < n_cols * n_rows:
                j += 1
                ax = axs.flatten()[j]
                ax.axis("off")
        fig.canvas.mpl_connect("button_press_event", select)
        fig.tight_layout()
        figs.append(fig)
        if show:
            figs[i].show()


NameError: name 'Signal' is not defined

do a list of subjects

In [7]:
# Description: Check channels for outliers and remove them

HOME = os.path.expanduser("~")

# get box directory depending on OS
if os.name == 'nt': # windows
    LAB_root = os.path.join(HOME, "Box", "CoganLab")
else: # mac
    LAB_root = os.path.join(HOME, "Library", "CloudStorage", "Box-Box", "CoganLab")

layout = get_data("GlobalLocal", root=LAB_root)

subjects = ['D0065', 'D0069', 'D0071', 'D0077', 'D0090', 'D0094', 'D0100', 'D0102', 'D0103']
for sub in subjects:
    # Load the data
    TASK = "GlobalLocal"
    subj = sub
    output_name = 'Stimulus_fixationCrossBase_0.2sec'
    layout = get_data("GlobalLocal", root=LAB_root)


    fig_path = os.path.join(layout.root, 'derivatives', 'spec', 'wavelet', 'figs')

    filename = os.path.join(layout.root, 'derivatives', 'spec', 'wavelet', subj, f'{output_name}-tfr.h5')
    print("Filename:", filename)
    spec = mne.time_frequency.read_tfrs(filename)[0]
    info_file = os.path.join(layout.root, spec.info['subject_info']['files'][0])

    # Check channels for outliers and remove them
    all_bad = get_bad_chans(info_file)
    spec.info.update(bads=[b for b in all_bad if b in spec.ch_names])

    # Plotting
    figs = chan_grid(spec, size=(20, 10), vmin=-2, vmax=2, cmap=parula_map, show=False)
    for i, f in enumerate(figs):
        fig_name = f'{subj}_{output_name}_{i+1}.jpg'
        fig_pathname = os.path.join(fig_path, fig_name)
        f.savefig(fig_pathname, bbox_inches='tight')
        print("Saved figure:", fig_name)

# update(spec, layout, "bad")


# fig_path = os.path.join(layout.root, 'derivatives', 'spec', 'wavelet', 'figs')

# filename = os.path.join(layout.root, 'derivatives', 'spec', 'wavelet', subj, 'Response_0.2secBeforeStimOnsetBase-tfr.h5')
# print("Filename:", filename)
# spec = mne.time_frequency.read_tfrs(filename)[0]
# info_file = os.path.join(layout.root, spec.info['subject_info']['files'][0])

# # Check channels for outliers and remove them
# # all_bad = get_bad_chans(info_file)
# # spec.info.update(bads=[b for b in all_bad if b in spec.ch_names])

# # Plotting
# figs = chan_grid(spec, size=(20, 10), vmin=-2, vmax=2, cmap=parula_map, show=False)
# for i, f in enumerate(figs):
#     fig_name = f'{subj}_Response_fullTrialBase_{i+1}.jpg'
#     fig_pathname = os.path.join(fig_path, fig_name)
#     f.savefig(fig_pathname, bbox_inches='tight')
#     print("Saved figure:", fig_name)


Filename: C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\spec\wavelet\D0065\Stimulus_fixationCrossBase_0.2sec-tfr.h5
Reading C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\spec\wavelet\D0065\Stimulus_fixationCrossBase_0.2sec-tfr.h5 ...
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline c

  fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, frameon=False,


No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No basel

### just one subject

In [6]:
# Description: Check channels for outliers and remove them

HOME = os.path.expanduser("~")

# get box directory depending on OS
if os.name == 'nt': # windows
    LAB_root = os.path.join(HOME, "Box", "CoganLab")
else: # mac
    LAB_root = os.path.join(HOME, "Library", "CloudStorage", "Box-Box", "CoganLab")

layout = get_data("GlobalLocal", root=LAB_root)


# Load the data
TASK = "GlobalLocal"
subj = "D0065"
output_name = 'Stimulus_fixationCrossBase_0.2sec'
layout = get_data("GlobalLocal", root=LAB_root)


fig_path = os.path.join(layout.root, 'derivatives', 'spec', 'wavelet', 'figs')

filename = os.path.join(layout.root, 'derivatives', 'spec', 'wavelet', subj, f'{output_name}-tfr.h5')
print("Filename:", filename)
spec = mne.time_frequency.read_tfrs(filename)[0]
info_file = os.path.join(layout.root, spec.info['subject_info']['files'][0])

# Check channels for outliers and remove them
all_bad = get_bad_chans(info_file)
spec.info.update(bads=[b for b in all_bad if b in spec.ch_names])

# Plotting
figs = chan_grid(spec, size=(20, 10), vmin=-2, vmax=2, cmap=parula_map, show=False)
for i, f in enumerate(figs):
    fig_name = f'{subj}_{output_name}_{i+1}.jpg'
    fig_pathname = os.path.join(fig_path, fig_name)
    f.savefig(fig_pathname, bbox_inches='tight')
    print("Saved figure:", fig_name)

# update(spec, layout, "bad")


# fig_path = os.path.join(layout.root, 'derivatives', 'spec', 'wavelet', 'figs')

# filename = os.path.join(layout.root, 'derivatives', 'spec', 'wavelet', subj, 'Response_0.2secBeforeStimOnsetBase-tfr.h5')
# print("Filename:", filename)
# spec = mne.time_frequency.read_tfrs(filename)[0]
# info_file = os.path.join(layout.root, spec.info['subject_info']['files'][0])

# # Check channels for outliers and remove them
# # all_bad = get_bad_chans(info_file)
# # spec.info.update(bads=[b for b in all_bad if b in spec.ch_names])

# # Plotting
# figs = chan_grid(spec, size=(20, 10), vmin=-2, vmax=2, cmap=parula_map, show=False)
# for i, f in enumerate(figs):
#     fig_name = f'{subj}_Response_fullTrialBase_{i+1}.jpg'
#     fig_pathname = os.path.join(fig_path, fig_name)
#     f.savefig(fig_pathname, bbox_inches='tight')
#     print("Saved figure:", fig_name)


Filename: C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\spec\wavelet\D0063\Stimulus_fixationCrossBase_0.2sec-tfr.h5
Reading C:\Users\jz421\Box\CoganLab\BIDS-1.1_GlobalLocal\BIDS\derivatives\spec\wavelet\D0063\Stimulus_fixationCrossBase_0.2sec-tfr.h5 ...
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline correction applied
No baseline c

In [None]:
chans