In [8]:
from glob import glob
import os
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from joblib import load

In [9]:
# Load the model from the h5 file
model = load('trained_model.h5')

In [10]:
import mne
import numpy as np
from scipy import stats

def extract_features(file_path):
    # Load EEG data from the file
    data = mne.io.read_raw_edf(file_path, preload=True)
    data.set_eeg_reference()
    data.filter(l_freq=0.5, h_freq=45)
    
    # Create fixed-length epochs
    epochs = mne.make_fixed_length_epochs(data, duration=5, overlap=1)
    
    # Get data from epochs
    epoch_data = epochs.get_data()
    
    # Define feature extraction functions
    def mean(data):
        return np.mean(data, axis=-1)

    def std(data):
        return np.std(data, axis=-1)

    def ptp(data):
        return np.ptp(data, axis=-1)

    def var(data):
        return np.var(data, axis=-1)

    def minim(data):
        return np.min(data, axis=-1)

    def maxim(data):
        return np.max(data, axis=-1)

    def argminim(data):
        return np.argmin(data, axis=-1)

    def argmaxim(data):
        return np.argmax(data, axis=-1)

    def mean_square(data):
        return np.mean(data**2, axis=-1)

    def rms(data): #root mean square
        return np.sqrt(np.mean(data**2, axis=-1))

    def abs_diffs_signal(data):
        return np.sum(np.abs(np.diff(data, axis=-1)), axis=-1)

    def skewness(data):
        return stats.skew(data, axis=-1)

    def kurtosis(data):
        return stats.kurtosis(data, axis=-1)

    # Concatenate features
    features = np.concatenate((
        mean(epoch_data),
        std(epoch_data),
        ptp(epoch_data),
        var(epoch_data),
        minim(epoch_data),
        maxim(epoch_data),
        argminim(epoch_data),
        argmaxim(epoch_data),
        mean_square(epoch_data),
        rms(epoch_data),
        abs_diffs_signal(epoch_data),
        skewness(epoch_data),
        kurtosis(epoch_data)
    ), axis=-1)

    return features

def predict_mental_disorder_from_eeg(file_path):
    # Extract features from the EEG file
    features = extract_features(file_path)
    
    # Use the best estimator from GridSearchCV
    best_estimator = model.best_estimator_
    
    # Predict labels for the input features
    predictions = best_estimator.predict(features)

    # Map predictions to human-readable labels
    prediction_label = 'Healthy' if predictions[0] == 0 else 'Schizophrenia'

    return prediction_label

def final_pred(input):
    predicted_label = predict_mental_disorder_from_eeg(input)
    print("Predicted label:", predicted_label)
# Example usage:
# predicted_label = predict_mental_disorder_from_eeg('data\s06.edf')
# print("Predicted label:", predicted_label)



In [11]:
final_pred('data\\s09.edf')

Extracting EDF parameters from C:\Users\Viraj Wadke\Desktop\Projects\minie_project_eeg\data\s09.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 296249  =      0.000 ...  1184.996 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (6.604 s)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s


Not setting metadata
296 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 296 events and 1250 original time points ...
0 bad epochs dropped


[Parallel(n_jobs=1)]: Done  19 out of  19 | elapsed:    0.1s finished


Predicted label: Schizophrenia


In [14]:
import streamlit as st

def main():
    st.title("Mental Disorder Prediction from EEG Data")

    # File uploader for EDF files
    st.sidebar.header("Upload EDF File")
    uploaded_file = st.sidebar.file_uploader("Upload EDF file", type=["edf"])

    if uploaded_file is not None:
        st.sidebar.text("File uploaded successfully!")

        # Display filename
        st.sidebar.write('Filename:', uploaded_file.name)

        # Display file details
        file_details = {"Filename": uploaded_file.name, "Filesize": uploaded_file.size}
        st.sidebar.write(file_details)

        # Button to trigger prediction
        if st.sidebar.button("Predict"):
            # Perform prediction
            prediction = predict_mental_disorder_from_eeg(uploaded_file)

            # Display prediction result
            st.write("Predicted Label:", prediction)

if __name__ == "__main__":
    main()


2024-04-04 00:16:03.905 
  command:

    streamlit run C:\Users\Viraj\anaconda3\Lib\site-packages\ipykernel_launcher.py [ARGUMENTS]


In [13]:
!pip install streamlit

Collecting streamlit
  Obtaining dependency information for streamlit from https://files.pythonhosted.org/packages/9b/ea/7219c01b5e92d02d2bc994a36245d99331cd66eb12d284707a2060a013d0/streamlit-1.32.2-py2.py3-none-any.whl.metadata
  Downloading streamlit-1.32.2-py2.py3-none-any.whl.metadata (8.5 kB)
Collecting altair<6,>=4.0 (from streamlit)
  Obtaining dependency information for altair<6,>=4.0 from https://files.pythonhosted.org/packages/46/30/2118537233fa72c1d91a81f5908a7e843a6601ccc68b76838ebc4951505f/altair-5.3.0-py3-none-any.whl.metadata
  Downloading altair-5.3.0-py3-none-any.whl.metadata (9.2 kB)
Collecting blinker<2,>=1.0.0 (from streamlit)
  Obtaining dependency information for blinker<2,>=1.0.0 from https://files.pythonhosted.org/packages/fa/2a/7f3714cbc6356a0efec525ce7a0613d581072ed6eb53eb7b9754f33db807/blinker-1.7.0-py3-none-any.whl.metadata
  Downloading blinker-1.7.0-py3-none-any.whl.metadata (1.9 kB)
Collecting toml<2,>=0.10.1 (from streamlit)
  Obtaining dependency inform