In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import scipy

In [None]:
from scipy.fft import fft, ifft, fftfreq

In [None]:
from ssvepcca.utils import load_mat_data_array
from ssvepcca.pipelines import test_fit_predict, k_fold_predict
from ssvepcca.learners import (
    CCASingleComponent, FilterbankCCA, CCAFixedCoefficients, CCAMultiComponent, AlternativeFBCCA,
    CCASpatioTemporal, FBCCAFixedCoefficients, CCASpatioTemporalFixed, FBSpatioTemporalCCA,
    FBSpatioTemporalCCAFixed
)
from ssvepcca.parameters import electrode_list_fbcca

## Configs

In [None]:
start_time_index = 125
stop_time_index = 625

In [None]:
input_data = load_mat_data_array("dataset_chines/S2.mat")

In [None]:
input_data.shape

## Tests

In [None]:
result_cca_single_component = test_fit_predict(
    input_data, 
    CCASingleComponent(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
    )
)

In [None]:
result_cca_single_component[2]

In [None]:
# assert result_cca_single_component[2][0] == 225

In [None]:
result_cca_fixed_coefficients = k_fold_predict(
    input_data, 
    CCAFixedCoefficients(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
    )
)

In [None]:
result_cca_fixed_coefficients[2]

In [None]:
# assert result_cca_fixed_coefficients[2][0] == 236

In [None]:
result_cca_fusion = test_fit_predict(
    input_data, 
    CCAMultiComponent(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_components=3
    )
)

In [None]:
result_cca_fusion[2]

In [None]:
# assert result_cca_fusion[2][0] == 221

In [None]:
result_cca_filter_bank = test_fit_predict(
    input_data, 
    FilterbankCCA(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=5,
        fb_num_subband=10,
        fb_fundamental_freq=8,
        fb_upper_bound_freq=88,
    )
)

In [None]:
result_cca_filter_bank[2]

In [None]:
# assert result_cca_filter_bank[2][0] == 238

In [None]:
result_fb_cca_fixed_coefficients = k_fold_predict(
    input_data, 
    FBCCAFixedCoefficients(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=5,
        fb_num_subband=10,
        fb_fundamental_freq=8,
        fb_upper_bound_freq=88,
    )
)

In [None]:
result_fb_cca_fixed_coefficients[2]

In [None]:
# assert result_fb_cca_fixed_coefficients[2][0] == 240

In [None]:
result_alt_cca_filter_bank = test_fit_predict(
    input_data, 
    AlternativeFBCCA(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=3,
        fb_num_subband=3,
        fb_fundamental_freq=8,
        fb_upper_bound_freq=88,
    )
)

In [None]:
# assert result_alt_cca_filter_bank[2][0] == 232

In [None]:
result_alt_cca_filter_bank[2]

In [None]:
result_cca_spatio_temporal = test_fit_predict(
    input_data, 
    CCASpatioTemporal(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=3,
        window_gap=0,
        window_length=9,
    )
)

In [None]:
result_cca_spatio_temporal[2]

In [None]:
# assert result_cca_spatio_temporal[2][0] == 234

In [None]:
result_cca_spatio_temporal2 = test_fit_predict(
    input_data, 
    CCASpatioTemporal(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=3,
        window_gap=3,
        window_length=1,
    )
)

In [None]:
result_cca_spatio_temporal2[2]

In [None]:
result_cca_spatio_temporal3 = test_fit_predict(
    input_data, 
    CCASpatioTemporal(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=3,
        window_gap=10,
        window_length=9,
    )
)

In [None]:
result_cca_spatio_temporal3[2]

In [None]:
result_cca_spatio_temporal_zero = test_fit_predict(
    input_data, 
    CCASpatioTemporal(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=3,
        window_gap=0,
        window_length=0,
    )
)

In [None]:
result_cca_spatio_temporal_zero[2]

In [None]:
# assert result_cca_spatio_temporal_zero[2][0] == 225

In [None]:
result_cca_spatio_temporal_fixed = k_fold_predict(
    input_data, 
    CCASpatioTemporalFixed(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=3,
        window_gap=0,
        window_length=9,
    )
)

In [None]:
result_cca_spatio_temporal_fixed[2]

In [None]:
result_cca_filter_bank_zero = test_fit_predict(
    input_data, 
    FBSpatioTemporalCCA(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=5,
        window_gap=0,
        window_length=0,
        fb_num_subband=10,
        fb_fundamental_freq=8,
        fb_upper_bound_freq=88,
    )
)

In [None]:
result_cca_filter_bank_zero[2]

In [None]:
# assert result_cca_filter_bank_zero[2][0] == 238

In [None]:
result_cca_filter_bank_0_9 = test_fit_predict(
    input_data, 
    FBSpatioTemporalCCA(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=5,
        window_gap=0,
        window_length=9,
        fb_num_subband=10,
        fb_fundamental_freq=8,
        fb_upper_bound_freq=88,
    )
)

In [None]:
result_cca_filter_bank_0_9[2]

In [None]:
result_cca_filter_bank_5_1 = test_fit_predict(
    input_data, 
    FBSpatioTemporalCCA(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=5,
        window_gap=5,
        window_length=1,
        fb_num_subband=10,
        fb_fundamental_freq=8,
        fb_upper_bound_freq=88,
    )
)

In [None]:
result_cca_filter_bank_5_1[2]

In [None]:
# assert result_cca_filter_bank_5_1[2][0] == 235

In [None]:
result_cca_filter_bank_2_1 = test_fit_predict(
    input_data, 
    FBSpatioTemporalCCA(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=5,
        window_gap=2,
        window_length=1,
        fb_num_subband=10,
        fb_fundamental_freq=8,
        fb_upper_bound_freq=88,
    )
)

In [None]:
result_cca_filter_bank_2_1[2]

In [None]:
# assert result_cca_filter_bank_2_1[2][0] == 236

In [None]:
result_cca_filter_bank_fixed_2_1 = k_fold_predict(
    input_data, 
    FBSpatioTemporalCCAFixed(
        electrodes_name=electrode_list_fbcca,
        start_time_index=start_time_index,
        stop_time_index=stop_time_index,
        num_harmonics=5,
        window_gap=2,
        window_length=1,
        fb_num_subband=10,
        fb_fundamental_freq=8,
        fb_upper_bound_freq=88,
    )
)

In [None]:
result_cca_filter_bank_fixed_2_1[2]

In [None]:
# assert result_cca_filter_bank_fixed_2_1[2][0] == 240

## develop