Here is a new notebook incorporating some recent feedback from Dr. B on this. 

In [None]:
# imports
import scipy
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.io import loadmat
from scipy.signal import butter, sosfilt, welch, freqz, sosfreqz, filtfilt, lfilter
from scipy.fft import rfft, rfftfreq, irfft
from typing import List, Tuple

In [None]:
# uncomment to viz whole thing
# emotion_subj_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35] # note subj. 22 is missing
# cue_subj_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

# comment out to do more than a few. just a handful for testing the notebook
emotion_subj_list = ['01', '02', '03', '04']
frolich_subj_list = ['01', '02', '03', '04']

In [None]:
def make_windows_from_signal(signal, window_length):
    if signal.ndim == 1 or np.min(signal.shape)==1:
        segment_length = len(signal)
    else:
        n_channels, segment_length = signal.shape
    num_windows = segment_length // window_length
    trunc_signal_size = num_windows * window_length
    if signal.ndim == 1 or np.min(signal.shape)==1:
        windows = signal[:int(trunc_signal_size)].reshape((-1, window_length))
    else:
        windows = signal[:,:int(trunc_signal_size)].reshape((-1, window_length))
    return windows


def make_psds(windows, fs=1, return_freq=False, nfft=None):
    window_length = windows.shape[1]
    if nfft is None:
        nfft = window_length
    if return_freq:
        T = 1 / fs
        return rfftfreq(nfft, T), np.abs(rfft(windows, axis=1, n=nfft)) ** 2
    else:
        return np.abs(rfft(windows, axis=1, n=nfft)) ** 2


def make_cmm_filters(f_targets, f_sources):
    freq_response_magnitudes = np.zeros_like(f_sources)
    f_js_12 = np.sqrt(f_targets)
    for i in range(len(f_sources)):
        f_i_12 = np.sqrt(f_sources[i])
        j_star_i = np.argmin(np.sum((f_js_12 - f_i_12) ** 2, axis=1))
        freq_response_magnitudes[i] = f_js_12[j_star_i] / f_i_12
    return freq_response_magnitudes


def spectral_filtering(x, freq_response_magnitude, type='lfilter', pad_frac=0.5):
    # Assume x is 2D where first dimension is channels/ICs and second is time
    pad_x = np.zeros_like(x)[:, :int(x.shape[1] * pad_frac)]
    padded_x = np.hstack((pad_x, x))
    if type=='filtfilt':
        b = irfft(np.sqrt(freq_response_magnitude))  # take square root for filtfilt
        filtered_x = filtfilt(b, 1, padded_x, axis=1)[:, pad_x.shape[1]:]
    else:
        b = irfft(freq_response_magnitude)
        filtered_x = lfilter(b, 1, padded_x, axis=1)[:, pad_x.shape[1]:]
    return filtered_x


def cmm_filter_signals_to_signals(source_signals, target_signals, window_length, pad_frac=0, filtering_type='filtfilt'):
    # Assume source_signals/target_signals are list of 2D signals where first dimension is channels/ICs and second is time
    source_psds = [np.mean(make_psds(make_windows_from_signal(s, window_length)),axis=0) for s in source_signals]
    target_psds = [np.mean(make_psds(make_windows_from_signal(s, window_length)), axis=0) for s in target_signals]
    cmm_freq_response_magnitudes = make_cmm_filters(target_psds, source_psds)
    source_signals_filtered = [spectral_filtering(s, f, filtering_type, pad_frac)
                               for s, f in zip(source_signals, cmm_freq_response_magnitudes)]
    return source_signals_filtered

In [None]:
# plotting and util functions



In [None]:
# load data
emotion_filepath = Path('../data/emotion_256/raw_data_and_IC_labels')
frolich_filepath = Path('../data/frolich_256/frolich_extract_256_hz')

emotion_data = []
for subj in emotion_subj_list:
    emotion_data.append(loadmat(emotion_filepath / f'subj-{subj}.mat')['data'])

frolich_data = []
for subj in frolich_subj_list:
    frolich_data.append(loadmat(frolich_filepath / f'frolich_extract_{subj}_256_hz.mat')['X'])

In [None]:
# compute all filters

In [None]:
# filter data

In [None]:
# plot results