# Optimization

## Import libraries

In [1]:
# Default libraries
import numpy as np
import scipy.optimize as optimize
import matplotlib.pyplot as plt

# Custom libraries
import Functions.data_tools as data_tools
from Functions.temple_data import TempleData
from Functions.artifact_removal_tool import ART
from Functions import eeg_quality_index as eqi

## Import data

In [2]:
file = r"Data\Temple\edf\01_tcp_ar\002\00000254\s005_2010_11_15\00000254_s005_t000.edf" # NO artifacts
# file = r"Data\Temple\edf\01_tcp_ar\002\00000254\s005_2010_11_15\00000254_s005_t000.edf" # With artifacts
temple = TempleData(file)

Extracting EDF parameters from c:\Users\danie\Documents\Projects\art-eqi-p300-validation\Data\Temple\edf\01_tcp_ar\002\00000254\s005_2010_11_15\00000254_s005_t000.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


In [3]:
artifacts_file = fr"Data\\Temple\\csv\\labels_{temple.montage_type}.csv"
temple.get_artifacts_from_csv(artifacts_file)
[chans, data] = temple.get_artifact_type_data("eyem")

## Single run example

In [20]:

# Settings
# - Temple data
artifact_type = "eyem"  # Type of artifact to process
test_percentage = 20    # Percentage of artifact to use as test [\%]
split_seed = 42         # Seed for reproducibility
window_length = 5       # Length of window data [sec]

window_samples = int(window_length * temple.srate)

# - Optimization
n_clusters = slice(1, 6, 1)
fd_threshold = slice(1, 3, 0.2)
ssa_threshold = slice(0.005, 0.02, 0.005)

# Get artifacts
artifacts_file = fr"Data\\Temple\\csv\\labels_{temple.montage_type}.csv"
temple.get_artifacts_from_csv(artifacts_file)
[artifacts_chans, eyem_artifacts] = temple.get_artifact_type_data(
    artifact_type = artifact_type,
    window_length = window_length,
    )

# Get clean data
clean = temple.get_clean_data(window_length)
clean_avg = np.mean(clean, axis=0)

# Separate optimization and test sets
[i_optim, i_test] = data_tools.split_list(
    lst = list(artifacts_chans),
    test_percentage = test_percentage,
    seed = split_seed
    )

# Creat artifact removal
art = ART(
    window_length = window_length,   
)

# Create EEG Quality index
eqi_total = np.zeros(eyem_artifacts.shape[:2])

for (a, artifact) in enumerate(artifacts_chans):
    for (c, chan) in enumerate(artifact):
        subset_chans = [temple.ch_names.index(chan) for chan in artifact if chan in temple.ch_names]
        test_eeg = art.remove_artifacts(
            eyem_artifacts[a,c,:],
            srate = temple.srate
            )
    
        eqi_total[a,:] = eqi.scoring(
            clean_eeg = clean_avg[subset_chans,:],
            test_eeg = test_eeg,
            srate_clean = temple.srate,
            srate_test = temple.srate,
            window = int(window_samples // 10),
            slide = int(window_samples // 20)
        )[0]  
    

In [16]:
a = eqi.scoring(
            clean_eeg = clean_avg[subset_chans,:],
            test_eeg = test_eeg,
            srate_clean = temple.srate,
            srate_test = temple.srate,
            window = int(window_samples // 10),
            slide = int(window_samples // 20)
        )[0]  

In [11]:
eqi_total

array([[0., 0., 0., 0., 0., 2., 0., 0.]])

In [None]:
fig, ax = plt.subplots(len(subset_chans))
t = np.linspace(0, clean_avg.shape[1]/temple.srate, int(clean_avg.shape[1]))

for c,chan in enumerate(subset_chans):
    ax[c].plot(t, clean_avg[c,:])

In [None]:
fig, ax = plt.subplots(eyem_artifacts.shape[1])

for a in range(eyem_artifacts.shape[1]):
    ax[a].plot(t, eyem_artifacts[2,a,:])

In [None]:
%matplotlib qt
import mne

mne_data = mne.io.read_raw_edf(file)
mne_data.plot()

In [None]:
artifact_type = "eyem"
window_length = window_length

artifact_chans = []
artifact_data = []

if artifact_type in temple.artifacts:
    # Create time vector for trial
    [_, nsamples] = np.shape(temple.data)
    t = np.linspace(0, nsamples/temple.srate, nsamples)

    for _,artifact in temple.artifacts[artifact_type].items():
        
        artifact_duration = artifact["start_end"][1] - artifact["start_end"][0]
        if (artifact_duration >= window_length):
            # Artifact channel names
            artifact_chans = artifact["chans"]
            
            # Artifact data
            chans_mask = np.isin(temple.ch_names, artifact_chans)
            artifact_data = temple.data[]

            