# Visualize EEG data with ME and PC1 method

##### Import library

In [65]:
import os
import numpy as np
from scipy.io import loadmat
import scipy.io as sio
import matplotlib.pyplot as plt
from scipy.fft import fft
import mne
import os
from sklearn.decomposition import PCA


##### Define Initial param

In [6]:
# Define parameters
in_dir = r'D:\Imputed_file'    # EEG data folder
miscDir = r'C:\Users\napat\Documents\GitHub\BCC_2024\Code\Misc'     # Misc folder (if necessary)
fs = 125  # Sampling frequency (Hz)
sec_start = 15  # Start of the epoch in seconds
epoch_len_sec = 4 * 60  # Length of the epoch (4 minutes)
epoch_len_samp = fs * epoch_len_sec  # Number of samples in the epoch

In [4]:
# Function to apply median DC correction
def median_dccorrect_all_trials(data):
    return data - np.median(data, axis=1, keepdims=True)

##### Load song into python

In [5]:

trial_mean_data = []
# Loop through the 10 songs
for i in range(1, 2):
    song_idx = i + 20
    curr_fn = os.path.join(in_dir, f'song{song_idx}_Imputed.mat')
    print(f'Loading {curr_fn}...')
    
    # Load the MAT file
    mat_data = loadmat(curr_fn)
    
    # Extract the EEG data for the current song (replace 'dataXX' with the proper key)
    data_key = f'data{song_idx}'
    tempX_0 = mat_data[data_key]
    
    # Apply DC correction
    tempX_dc = median_dccorrect_all_trials(tempX_0)
    
    # Extract the epoch (chan x T x participant)
    tempX_epoch = tempX_dc[:, sec_start * fs : sec_start * fs + epoch_len_samp, :]
    
    # Take the mean of the data across participants (axis 2)
    trial_mean = np.mean(tempX_epoch, axis=2)
    
    # Apply another median DC correction on the mean data across participants
    trial_mean_corrected = median_dccorrect_all_trials(trial_mean)
    
    # Append or store the result in the trial_mean_data array (or list)
    trial_mean_data.append(trial_mean_corrected)

# Convert the list of trial means to a NumPy array
trial_mean_data = np.array(trial_mean_data)  # Shape: (10 songs, channels, time)

Loading D:\Imputed_file\song21_Imputed.mat...
Loading D:\Imputed_file\song22_Imputed.mat...
Loading D:\Imputed_file\song23_Imputed.mat...
Loading D:\Imputed_file\song24_Imputed.mat...
Loading D:\Imputed_file\song25_Imputed.mat...
Loading D:\Imputed_file\song26_Imputed.mat...
Loading D:\Imputed_file\song27_Imputed.mat...
Loading D:\Imputed_file\song28_Imputed.mat...
Loading D:\Imputed_file\song29_Imputed.mat...
Loading D:\Imputed_file\song30_Imputed.mat...


##### Load the tempos

In [29]:
# Load the tempos
tempo_data = sio.loadmat(os.path.join(miscDir, 'tempoHz.mat'))
tempoHz = tempo_data['tempoHz'].flatten()

# Creating a tempo matrix
tempoMatrix = np.tile(tempoHz, (6, 1)) * np.tile(2.0 ** np.arange(-2, 4).reshape(-1, 1), (1, 10))

# Get the number of channels, time points, and songs
nChan, T, nSongs = trial_mean_data.shape
xax = np.arange(T) / (T / fs)  # x-axis for plotting
xl = [0, 15]  # xlim for plotting

In [13]:
# Analysis 1: Uniform spatial filter (mean of all channels)
channel_means_per_song = np.mean(trial_mean_data, axis=0)  # T x song

# Analysis 2: PCA on concatenated data
# Concatenate the trialMeanData across all songs (axis=2) -> (channels x time points)
concat_trial_means = np.concatenate([trial_mean_data[:, :, i] for i in range(nSongs)], axis=1)

# Perform PCA using SVD
U, S, Vt = np.linalg.svd(concat_trial_means, full_matrices=False)
# Extract the first principal component (PC1) by multiplying the first singular vector
concat_pc1 = np.dot(U.T[0], concat_trial_means)

# Reshape PC1 into (time x songs) to match the song structure
pc1 = concat_pc1.reshape(-1, 10)

# Compute FFT for each song's channel mean and PC1
CH = np.abs(np.fft.fft(channel_means_per_song, axis=0))  # FFT of channel means per song
PC1 = np.abs(np.fft.fft(pc1, axis=0))  # FFT of PC1

In [71]:
# Function to plot frequency domain for each channel (FFT of means)
def plot_mean_channels_frequency_domain(data, fs, song_idx=1):
    T = data.shape[1]
    freq_axis = np.fft.fftfreq(T, d=1/fs)[:T // 2]
    mean_channels = np.mean(data, axis=2)  # Mean across trials
    fft_data = np.abs(fft(mean_channels, axis=1))[:, :T // 2]

    # Filter to 0-15 Hz range
    valid_freqs = freq_axis <= 15
    freq_axis = freq_axis[valid_freqs]
    fft_data = fft_data[:, valid_freqs]
    plt.figure(figsize=(10, 6))
    for ch in range(fft_data.shape[0]):
        plt.plot(freq_axis, fft_data[ch, :],color='black', linewidth=5, zorder=1)
        
    plt.axvline(x=tempoHz[0] / 4, color='blue', linestyle='--', linewidth=2, label='1/4x Tempo', zorder=0)
    plt.axvline(x=tempoHz[0] / 2, color='orange', linestyle='--', linewidth=2, label='1/2x Tempo', zorder=0)
    plt.axvline(x=tempoHz[0], color='green', linestyle='--', linewidth=2, label='Tempo', zorder=0)
    plt.axvline(x=tempoHz[0] * 2, color='red', linestyle='--', linewidth=2, label='2x Tempo', zorder=0)
    plt.axvline(x=tempoHz[0] * 4, color='purple', linestyle='--', linewidth=2, label='4x Tempo', zorder=0)
    plt.axvline(x=tempoHz[0] * 8, color='brown', linestyle='--', linewidth=2, label='8x Tempo', zorder=0)
        
    plt.title(f'Mean Frequency-Domain Signal (0-15 Hz) for All Channels (Song {song_idx})')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Magnitude')
    plt.legend(loc='upper right')
    plt.grid(True)
    plt.show()



# Function to plot PC1 (Frequency Domain)
def plot_pc1_frequency_domain(data, fs, song_idx=1):
    pca = PCA(n_components=1)
    T = data.shape[1]
    freq_axis = np.fft.fftfreq(T, d=1/fs)[:T // 2]
    pca_data = pca.fit_transform(data.reshape(data.shape[0], -1).T).T[0]  # First PC (PC1)
    fft_pca = np.abs(fft(pca_data))[:T // 2]

    # Filter to 0-15 Hz range
    valid_freqs = freq_axis <= 15
    freq_axis = freq_axis[valid_freqs]
    fft_pca = fft_pca[valid_freqs]

    plt.figure(figsize=(10, 6))
    plt.plot(freq_axis, fft_pca, color='black', linewidth=5, zorder=1)
    
    plt.axvline(x=tempoHz[0] / 4, color='blue', linestyle='--', linewidth=2, label='1/4x Tempo', zorder=0)
    plt.axvline(x=tempoHz[0] / 2, color='orange', linestyle='--', linewidth=2, label='1/2x Tempo', zorder=0)
    plt.axvline(x=tempoHz[0], color='green', linestyle='--', linewidth=2, label='Tempo', zorder=0)
    plt.axvline(x=tempoHz[0] * 2, color='red', linestyle='--', linewidth=2, label='2x Tempo', zorder=0)
    plt.axvline(x=tempoHz[0] * 4, color='purple', linestyle='--', linewidth=2, label='4x Tempo', zorder=0)
    plt.axvline(x=tempoHz[0] * 8, color='brown', linestyle='--', linewidth=2, label='8x Tempo', zorder=0)

    plt.title(f'PC1 Frequency-Domain Signal (0-15 Hz) (Song {song_idx})')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Magnitude')
    plt.grid(True)
    plt.show()

In [76]:
song_idx = 21
data_key = f'data{song_idx}'
tempX_0 = mat_data[data_key]

tempX_dc = median_dccorrect_all_trials(tempX_0)
tempX_epoch = tempX_dc[:, sec_start * fs:(sec_start * fs + epoch_len_samp), :]

plot_mean_channels_frequency_domain(tempX_epoch, fs, song_idx=song_idx)
plot_pc1_frequency_domain(tempX_epoch, fs, song_idx=song_idx)

KeyError: 'data21'