In [None]:
import os
import numpy as np
import mne
import matplotlib.pyplot as plt

mne.set_log_level('warning')
channel_files = []
channels_of_interest = ['Fp1', 'Fp2', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2']
for ch in channels_of_interest:
    file_path = input(f"Enter the file path for channel {ch}: ")
    channel_files.append(file_path)
channel_data = [np.load(file) for file in channel_files]
min_length = min(len(ch) for ch in channel_data)

# Use only the first half of the array
half_length = min_length // 2
channel_data = [ch[:half_length] for ch in channel_data]

# Use only the second half of the array
# half_length = min_length // 2
# channel_data = [ch[half_length:] for ch in channel_data]



segment_size = 250
fs = 500  
info = mne.create_info(channels_of_interest, sfreq=fs, ch_types='eeg')
montage = mne.channels.make_standard_montage('standard_1020')
info.set_montage(montage)
experience_level = input("Enter experience level (E for Experienced, N for Novice): ")
state = input("Enter state (R for Rest, M for Meditation): ")
person_id = input("Enter person ID: ")
bands = {
    'D': ('Delta', 0.1, 3.99),
    'T': ('Theta', 4, 7.99),
    'A': ('Alpha', 8, 13),
    'B': ('Beta', 14, 30),
    'G': ('Gamma', 30.01, 80)
}
for band_key, (band_name, fmin, fmax) in bands.items():
    output_dir = f'{experience_level}_{state}_{person_id}_{band_key}'
    os.makedirs(output_dir, exist_ok=True)
    num_segments = half_length // segment_size
    for i in range(num_segments):
        segment = np.array([ch[i * segment_size:(i + 1) * segment_size] for ch in channel_data])
        if np.isnan(segment).any():
            segment = np.nan_to_num(segment)
        raw_segment = mne.io.RawArray(segment, info)
        psds, freqs = mne.time_frequency.psd_array_welch(segment, sfreq=fs, fmin=fmin, fmax=fmax, n_fft=128)
        band_power = np.mean(psds, axis=1)
        evoked = mne.EvokedArray(band_power[:, np.newaxis], info)
        fig, ax = plt.subplots(figsize=(5, 5)) 
        mne.viz.plot_topomap(evoked.data[:, 0], evoked.info, axes=ax, show=False, cmap='viridis', outlines=None, contours=0, sensors=False)
        ax.set_axis_off()
        filename = f"{experience_level}_{state}_{person_id}_{i+1}_{band_key}.png"
        image_path = os.path.join(output_dir, filename)
        fig.savefig(image_path, bbox_inches='tight', pad_inches=0)
        plt.close(fig)


print("All topomap images saved.")
