In [1]:
from pathlib import Path

import numpy as np
import ezmsg.core as ez
import matplotlib.pyplot as plt

from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.terminate import TerminateOnTotal
from ezmsg.util.messagereplay import (
    MessageReplay, 
    MessageReplaySettings,
    MessageCollector
)

from ezmsg.ssvep.spectralstats import SpectralStats, SpectralStatsSettings

from typing import List, Any

In [2]:
class OfflineStatsSettings(ez.Settings):
    replay_settings: MessageReplaySettings
    stats_settings: SpectralStatsSettings


class OfflineStats(ez.Collection):
    SETTINGS: OfflineStatsSettings

    REPLAY = MessageReplay()
    STATS = SpectralStats()
    COLLECTOR = MessageCollector()
    TERM = TerminateOnTotal()

    def configure(self) -> None:
        self.REPLAY.apply_settings(self.SETTINGS.replay_settings)
        self.STATS.apply_settings(self.SETTINGS.stats_settings)

    def network(self) -> ez.NetworkDefinition:
        return (
            (self.REPLAY.OUTPUT_MESSAGE, self.STATS.INPUT_SAMPLE),
            (self.STATS.OUTPUT_STATS, self.COLLECTOR.INPUT_MESSAGE),
            (self.COLLECTOR.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE),
            (self.REPLAY.OUTPUT_TOTAL, self.TERM.INPUT_TOTAL),
        )
    
    @property
    def output(self) -> List[Any]:
        return self.COLLECTOR.messages

In [None]:
data_dir = Path.home() / 'ssvep_data'

for data_fname in (data_dir / 'SSVEP_S5').glob('*.txt'):

    settings = OfflineStatsSettings(
        replay_settings = MessageReplaySettings(
            filename = data_fname
        ),
        stats_settings = SpectralStatsSettings(
            time_axis = 'time',
            freq_axis = 'freq',
            freq_range = slice(0.0, 50.0),
            integration_time = 4.0,
            multiple_comparisons = False,
        )
    )

    system = OfflineStats(settings)

    ez.run(system, force_single_process = True)

    stats: AxisArray = system.output[-1]
    freq_axis = 'freq'
    axis = stats.get_axis(freq_axis)
    axis_idx = stats.get_axis_idx(freq_axis)
    freqs = (np.arange(stats.shape[axis_idx]) * axis.gain) + axis.offset
    num_tests = np.prod(stats.shape)

    fig, ax = plt.subplots()
    with stats.view2d(freq_axis) as view:
        for ch in range(view.shape[1]):
            ax.plot(freqs, view[:, ch], label = f'Ch{ch+1}')

    for thresh, color in zip([0.05, 0.01, 0.001], ['red', 'orange', 'green']):
        xthresh = -np.log10(thresh / num_tests)
        ax.axhline(xthresh, color = color)
        ax.annotate(f'p = {thresh}', (1, xthresh), va = 'bottom', color = color)

    ax.legend()

    ax.set_xlabel('Freq (Hz)')
    ax.set_ylabel(r'$-\log_{10}(p)$')
    ax.grid(True, which = 'both')
    ax.set_xlim( 0, 50.0 )

    ax.set_title(data_fname.stem)