In [None]:
import mne
from pathlib import Path
import pandas as pd
from brainflow.board_shim import BoardShim, BoardIds
import numpy as np
import matplotlib.pyplot as plt
from brainflow.data_filter import DataFilter, FilterTypes, AggOperations
from sklearn.datasets import load_iris
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
# Disable MNE logging output
mne.set_log_level('ERROR')

class Study:
    def __init__(self, csv_file, event_file, event_labels_file, sfreq=200):
        self.csv_file = csv_file
        self.event_file = event_file
        # self.event_dict = event_dict
        self.dataset_x = list()
        self.dataset_y = list()
        self.sfreq = sfreq
        self.raw = None
        self.filtered = None
        self.epochs_filtered = None
        self.events = None
        self.lda = None
        self.event_labels_file = event_labels_file

    
    def get_lda(self, n_components=2, colors=['red', 'green', 'blue', 'purple']):
        # Apply LDA to reduce to 2D
        lda = LinearDiscriminantAnalysis()
        X_lda = lda.fit_transform(self.dataset_x, self.dataset_y)

        # Create index array based on length of labels
        indices = np.arange(len(self.event_labels))

        # Plot the reduced data
        plt.figure(figsize=(8, 6))
        for color, i, label in zip(colors, indices, self.event_labels):
            plt.scatter(X_lda[self.dataset_y == i, 0], X_lda[self.dataset_y == i, 1], alpha=0.7, color=color, label=label, edgecolor='k')

        plt.title('LDA: Dimensionality Reduction (2D)')
        plt.xlabel('LD1')
        plt.ylabel('LD2')
        plt.legend(loc='best')
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    def filter_data(self, raw, l_freq=1.0, h_freq=95.0):
        filtered = raw.copy().filter(l_freq=l_freq, h_freq=h_freq)
        self.filtered = filtered.notch_filter(freqs=60)
        return filtered

    def load_events(self, file_path):
        events = np.loadtxt(file_path, delimiter=',', dtype=int)
        return events
    
    def raw_to_epochs(self,raw, event_file, event_dict, tmin=-0.1, tmax=2.0, baseline=None, reject=None):
        events = np.array(self.load_events(event_file))  
        epochs = mne.Epochs(raw, events, event_id=event_dict, tmin=tmin, tmax=tmax, preload=True, baseline=baseline, reject=reject)
        # psd = epochs.compute_psd()
        return epochs, events

    def csv_to_dataframe(self, file):
        eeg_channels_names = [str(i) for i in range(15)]
        df = pd.read_csv(file, usecols = eeg_channels_names).transpose()
        return df

    def df_to_raw( self, df, sfreq=200, ch_types='emg'):
        # eeg_channels_names = BoardShim.get_eeg_names(self.board_id)
        eeg_channels_names = [str(i) for i in range(df.shape[0])]
        ch_types = ['eeg'] * len(eeg_channels_names)

        # Create MNE info object
        info = mne.create_info(ch_names = eeg_channels_names, sfreq = sfreq, ch_types=ch_types)

        # Create MNE raw object
        raw = mne.io.RawArray(df, info)
        raw.load_data()
        # raw.plot(clipping=None, scalings=dict(eeg='1e3', emg='1e5'))

        return raw

    def csv_to_raw(self, file, start_channel=0, end_channel=15):
        df = self.csv_to_dataframe(file)
        df = df.iloc[start_channel:end_channel, :]
        raw = self.df_to_raw(df)
        return raw

    def get_band_powers(self, epoch):
        # Currently using featiures in 10hz bins inspired by: 
        # Saponas, T. S., Tan, D. S., Morris, D. & Balakrishnan, R. Demonstrating the feasibility of using forearm electromyography for muscle-computer interfaces. 515–524 (2008) doi:10.1145/1357054.1357138.

        # [avgs, stddevs] = DataFilter.get_custom_band_powers(epoch, [(1,10), (10, 20), (20, 30), (30, 40), (40, 50), (50, 60), (60, 70), (70, 80), (80, 90)], [0], self.sfreq, False)
        [avgs, stddevs] = DataFilter.get_custom_band_powers(epoch, [(1,10), (10, 20), (20, 30), (30, 40), (40, 50), (50, 60), (60, 70), (70, 80), (80, 90)], [0], self.sfreq, False)
        # rel_avg = [avgs[0] / avgs[8] * 1000, avgs[1] / avgs[7] * 1000, avgs[3] / avgs[6] * 1000, avgs[4] / avgs[5] * 1000]
        return avgs

    def get_band_power_by_epoch(self, epochs, epoch_index):
        selected_epoch = epochs[epoch_index, :, :].reshape(1, -1)
        avgs = self.get_band_powers(selected_epoch)
        return avgs

    def add_data_and_labels(self, epochs, event_id, label):
        epochs_data = epochs[event_id].get_data()
        for i in range(epochs_data.shape[0]):
            avgs = self.get_band_power_by_epoch(epochs_data, i)
            self.dataset_x.append(avgs)
            self.dataset_y.append(label)

    def plot_band_powers(self, avgs, title="Average Band Powers"):
        # Create bar plot of average band powers
        plt.figure(figsize=(10,6))
        bands = ['(1-10Hz)', '(10-20Hz)', '(20-30Hz)', '(40-50Hz)', 
                '(50-60Hz)', '(70-80Hz)', '(80-90Hz)']
        plt.bar(bands, avgs)
        plt.xticks(rotation=45, ha='right')
        plt.ylabel('Power')
        plt.title(title)
        plt.tight_layout()
        plt.show()

    def get_pca(self, n_components=2, colors=['red', 'green', 'blue', 'purple']):
        print("get_pca")
        # Apply PCA to reduce dimensionality
        pca = PCA(n_components=n_components)
        
        print(self.dataset_x)
        print(self.dataset_y)
        X_pca = pca.fit_transform(self.dataset_x)

        # Create index array based on length of labels
        indices = np.arange(len(self.event_labels))

        # Plot the reduced data
        plt.figure(figsize=(8, 6))
        for color, i, label in zip(colors, indices, self.event_labels):
            plt.scatter(X_pca[self.dataset_y == i, 0], X_pca[self.dataset_y == i, 1], alpha=0.7, color=color, label=label, edgecolor='k')

        plt.title('PCA: Dimensionality Reduction (2D)')
        plt.xlabel('PC1')
        plt.ylabel('PC2')
        plt.legend(loc='best')
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    def plot_tfr(self, epochs, event):
        # Time frequency analysis
        fmax = 100
        freqs = np.logspace(*np.log10([1, fmax]), num=fmax)
        n_cycles = freqs / 2.0
        baseline = (-0.1, 0)

        epochs_tfr_relax = epochs[event].compute_tfr("morlet", n_cycles=n_cycles, return_itc=False, freqs=freqs, average=True, use_fft=True)
        epochs_tfr_relax.plot(title=event, baseline=baseline, mode="logratio") # vmax=1e-8,

    

    
    def run_study(self, csv_file, event_file, plot_epochs=False):
        self.raw = self.csv_to_raw(csv_file, 1, 2)
        self.filtered = self.filter_data(self.raw)
        self.event_labels =  np.loadtxt(self.event_labels_file, delimiter=',', dtype=str)
        print(self.event_labels)
        self.event_dict = {str(label): i for i, label in enumerate(self.event_labels)}

        epochs_filtered, events = self.raw_to_epochs(self.filtered, event_file, self.event_dict, tmin=0.0, tmax=2.0)
        self.epochs_filtered = epochs_filtered
        if plot_epochs:
            print("Plotting epochs")
            self.epochs_filtered.plot(scalings=dict(eeg='1e3', emg='1e5'), events=events)

        self.events = events

        for event_name, event_id in self.event_dict.items():
            self.add_data_and_labels(self.epochs_filtered, event_name, event_id)

        self.dataset_x = np.array(self.dataset_x)
        self.dataset_y = np.array(self.dataset_y)

        for event in self.event_labels:
            self.plot_tfr(self.epochs_filtered, event)

    def run(self, plot_epochs=False):
        self.run_study(self.csv_file, self.event_file, plot_epochs)

study = Study("data/study5_data.csv", "events/study5.txt", "events/study5_event_labels.txt")
study.run()
study.get_pca()






TypeError: Study.__init__() got an unexpected keyword argument 'setup_type'