In [None]:
import mne
import numpy as np
from joblib import load 
live_eeg_data ={live_eeg_data}#placeholder for live_eeg_data that will come from sdk
eog_channels = ['F7', 'F8', 'AF3', 'AF4']

ica = mne.preprocessing.ICA(n_components=12, random_state=97, max_iter=800)
ica.fit(live_eeg_data)
eog_indices, _ = ica.find_bads_eog(live_eeg_data, ch_name=eog_channels)
ica.exclude = eog_indices
processed_live_data = ica.apply(live_eeg_data)

epoch_duration = 1
overlap = 0.5  

sfreq = processed_live_data.info['sfreq']
duration = epoch_duration - (epoch_duration * overlap) 
events = mne.make_fixed_length_events(
    processed_live_data, duration=duration, start=0, stop=None, overlap=overlap
)
left_epochs = mne.Epochs(
    processed_live_data,
    events=events,
    event_id=None,  
    tmin=0,
    tmax=epoch_duration,
    baseline=None,
    preload=True
)


freq_bands = {'Delta': (0.5, 4),
              'Theta': (4, 8),
              'Alpha': (8, 13),
              'Beta': (13, 30),
              'Gamma': (30, 40)}
def compute_avg_band_amplitudes(epochs):
    avg_band_amplitudes = []
    for epoch in epochs:
        channel_band_amplitudes = []
        for channel_data in epoch:
            for band in freq_bands.values():
                fmin, fmax = band
                sp = np.fft.fft(channel_data)
                freq = np.fft.fftfreq(len(channel_data), d=1/live_eeg_data.info['sfreq'])
                freq = freq[1:int(np.ceil(len(channel_data) / 4))]  
                sp = sp[1:int(np.ceil(len(channel_data) / 4))]
                sp = np.sqrt(sp.real**2 + sp.imag**2)
                band_indices = np.logical_and(freq >= fmin, freq <= fmax)
                band_amplitude = np.mean(sp[band_indices])
                channel_band_amplitudes.append(band_amplitude)
        avg_band_amplitudes.append(np.mean(channel_band_amplitudes))
    return avg_band_amplitudes
live_avg_band_amplitudes = compute_avg_band_amplitudes(left_epochs)
X = np.array(live_avg_band_amplitudes)
lda_classifier = load('lda_model.joblib')
X_live = X.reshape(X.shape[0], -1)
predictions_live = lda_classifier.predict(X_live)
print("Live Predictions:", predictions_live)