# COGS 189 Group 12 
*"Investigating the Impact of Short-Form Video Content on Cognitive Focus Using EEG"*


In [None]:
import numpy as np
import mne
import matplotlib.pyplot as plt

# --------------------------------------------------------
# 1) YOUR CUSTOM CHANNEL NAMES (excluding GND and REF)
# --------------------------------------------------------
channel_names_custom = ["F4", "O1", "O2", "T5", "P3", "Pz", "P4", "T6"]

# We'll rename T5/T6 to T7/T8 for MNE's standard montage:
rename_dict = {"T5": "T7", "T6": "T8"}
channel_names_mne = [rename_dict[ch] if ch in rename_dict else ch 
                     for ch in channel_names_custom]

# All channels are EEG
channel_types = ["eeg"] * len(channel_names_mne)

# Example sampling rate (adjust to your actual data)
sfreq = 250.0

# Create MNE info
info = mne.create_info(ch_names=channel_names_mne, sfreq=sfreq, ch_types=channel_types)

# Attach a standard 10-20 montage
montage = mne.channels.make_standard_montage("standard_1020")
info.set_montage(montage)

# --------------------------------------------------------
# 2) LOAD YOUR EEG DATA (shape = [n_channels, n_times])
#    Suppose it's from a single stage (e.g., Baseline)
# --------------------------------------------------------
data = np.load("my_baseline_data.npy")  
# Make sure data is shape (8, n_times). If not, transpose.

# --------------------------------------------------------
# 3) Create a RawArray. If needed, do additional preprocessing.
# --------------------------------------------------------
raw = mne.io.RawArray(data, info)

# Optional: e.g. filter 1-30 Hz
# raw.filter(1, 30)

# --------------------------------------------------------
# 4) Compute a measure for topoplot (e.g., average power in 1-30 Hz)
# --------------------------------------------------------
psds, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=30)
mean_psd = psds.mean(axis=1)  # average across frequencies => shape (8,)

# --------------------------------------------------------
# 5) Plot topomap
# --------------------------------------------------------
fig, ax = plt.subplots(figsize=(5, 4))
mne.viz.plot_topomap(mean_psd, raw.info, axes=ax, show=False)
ax.set_title("Baseline Power (1-30 Hz)")
plt.show()
