In [44]:
import pandas as pd
import glob
import matplotlib.pyplot as plt

In [48]:
columns = [
    "Delta_TP9",
    "Delta_AF7",
    "Delta_AF8",
    "Delta_TP10",
    "Theta_TP9",
    "Theta_AF7",
    "Theta_AF8",
    "Theta_TP10",
    "Alpha_TP9",
    "Alpha_AF7",
    "Alpha_AF8",
    "Alpha_TP10",
    "Beta_TP9",
    "Beta_AF7",
    "Beta_AF8",
    "Beta_TP10",
    "Gamma_TP9",
    "Gamma_AF7",
    "Gamma_AF8",
    "Gamma_TP10",
    "RAW_TP9",
    "RAW_AF7",
    "RAW_AF8",
    "RAW_TP10",
    "HSI_TP9",
    "HSI_AF7",
    "HSI_AF8",
    "HSI_TP10",
]

channel_groups = {
    "Delta": ["Delta_TP9", "Delta_AF7", "Delta_AF8", "Delta_TP10"],
    "Theta": ["Theta_TP9", "Theta_AF7", "Theta_AF8", "Theta_TP10"],
    "Alpha": ["Alpha_TP9", "Alpha_AF7", "Alpha_AF8", "Alpha_TP10"],
    "Beta": ["Beta_TP9", "Beta_AF7", "Beta_AF8", "Beta_TP10"],
    "Gamma": ["Gamma_TP9", "Gamma_AF7", "Gamma_AF8", "Gamma_TP10"],
    "HSI": ["HSI_TP9", "HSI_AF7", "HSI_AF8", "HSI_TP10"],
    "RAW": ["RAW_TP9", "RAW_AF7", "RAW_AF8", "RAW_TP10"],
}

channel_groups_individual = {key: [key] for key in columns}

In [None]:
def load_df(path):
    df = pd.read_csv(path)
    df["TimeStamp"] = pd.to_datetime(df["TimeStamp"], errors="coerce")
    start_time = df["TimeStamp"].min()
    print(start_time)
    df["NormalizedTimeStamp"] = df["TimeStamp"] - start_time
    df = df[columns + ["NormalizedTimeStamp"]]
    return df


dfs = []
for path in glob.glob("../ufal_emmt/preprocessed-data/eeg/Read/*.csv"):
    df = load_df(path)
    dfs.append(df)

mean_df = pd.concat(dfs).groupby("NormalizedTimeStamp").mean()
median_df = pd.concat(dfs).groupby("NormalizedTimeStamp").median()
min_df = pd.concat(dfs).groupby("NormalizedTimeStamp").min()
max_df = pd.concat(dfs).groupby("NormalizedTimeStamp").max()

In [None]:
# Function to plot each feature against timestamps
def plot_channel_groups_with_stats(mean_df, median_df, min_df, max_df, channel_groups):
    for group, channels in channel_groups.items():
        plt.figure(figsize=(15, 8))

        for channel in channels:
            if channel in mean_df.columns:
                # Get the index (timestamps) for plotting
                timestamps = mean_df.index.total_seconds()

                # Plot the range between min and max as a shaded area
                plt.fill_between(
                    timestamps,
                    min_df[channel],
                    max_df[channel],
                    alpha=0.2,
                    label=f"{channel} range",
                )

                # Plot mean as a solid line
                plt.plot(timestamps, mean_df[channel], label=f"{channel} mean", linewidth=2)

                # Plot median as a dashed line
                plt.plot(
                    timestamps,
                    median_df[channel],
                    label=f"{channel} median",
                    linestyle="--",
                    linewidth=1.5,
                )

        plt.xlabel("Time (s)")
        plt.ylabel("Signal Value")
        plt.title(f"{group} Channels Statistics Over Time")
        plt.grid(True, alpha=0.3)

        # Create a more organized legend
        handles, labels = plt.gca().get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        plt.legend(by_label.values(), by_label.keys(), loc="best")

        plt.tight_layout()
        plt.show()


def plot_features_vs_time(df):
    timestamp_col = "NormalizedTimeStamp"

    # Drop non-numeric and non-feature columns
    numeric_columns = df.select_dtypes(include=["number"]).columns

    for column in numeric_columns:
        print(df[timestamp_col], df[column])
        plt.figure(figsize=(12, 5))
        plt.plot(df[timestamp_col], df[column], label=column)
        plt.xlabel("Timestamp")
        plt.ylabel(column)
        plt.title(f"{column} vs Time")
        plt.legend()
        plt.grid()
        plt.show()


def plot_channel_groups(df, channel_groups):
    for group, channels in channel_groups.items():
        plt.figure(figsize=(12, 5))
        for channel in channels:
            if channel in df.columns:
                plt.plot(df["NormalizedTimeStamp"], df[channel], label=channel)
        plt.xlabel("Time (s)")
        plt.ylabel("Signal Value")
        plt.title(f"{group} Channels vs Time")
        plt.legend()
        plt.grid()
        plt.show()


plot_channel_groups_with_stats(mean_df, median_df, min_df, max_df, channel_groups_individual)