In [None]:
import mne
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from mne.datasets import sample

plt.style.use('dark_background')

In [None]:
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
raw = mne.io.read_raw_fif(raw_fname)

raw.pick_types(meg=False, eeg=True, eog=True).load_data()

# This particular dataset already has an average reference projection added
# that we now want to remove it for the sake of this example.
raw.set_eeg_reference([])

reject = dict(eeg=180e-6, eog=150e-6)
event_id, tmin, tmax = {'left/auditory': 1}, -0.2, 0.5
events = mne.read_events(event_fname)
epochs_params = dict(events=events, event_id=event_id, tmin=tmin, tmax=tmax,
                     reject=reject)

erps = mne.Epochs(raw, **epochs_params)
erps.load_data()
epochs = erps._data

In [None]:
def dist(f1, f2):
    assert f1.shape == f2.shape
    return np.log(np.sum((f1 - f2) ** 2))

def plot(*functions):
    for func in functions:
        plt.plot(erps.times, func.flatten())
    plt.axhline()
    plt.axvline()

In [None]:
clean = epochs.mean(axis=0)[0]  # averaged ERP on channel 0
clean = clean.reshape((1, -1))
epoch0 = epochs[17]              # single epoch
weights = np.ones((1, 60)) / 60
#f1 = epochs[0].mean(axis=0)
#f2 = weights @ epoch0
plot(clean, epoch0[:5].mean(axis=0))
#dist(clean, weights @ epoch0)

In [None]:
dist(clean, np.zeros_like(clean))

In [None]:
dist(clean, weights @ epoch0)

In [None]:
dist(clean, epoch0[:8].mean(axis=0).reshape(1, -1))