# Preprocess

## 1. Load data

In [None]:
%matplotlib widget
import os
import numpy as np
import matplotlib.pyplot as plt
import mne

cnt_file = r"D:\bcmi\exp\eeg_cnt_file\yinhao_20210323.cnt"

subj_name = cnt_file[cnt_file.rindex('\\')+1 : cnt_file.rindex('_')]
print(subj_name)

EOG_channels = ['VEO', 'HEO']
unused_channels = ['M1', 'M2']

raw = mne.io.read_raw_cnt(cnt_file, eog=EOG_channels)
raw.info['bads'].extend(EOG_channels)
raw.info['bads'].extend(unused_channels)

In [None]:
raw.info

In [None]:
# M1, M2, VEO, HEO
raw.info['ch_names']

In [None]:
# number of time points
raw.n_times

In [None]:
len(raw.ch_names)

In [None]:
m = raw.get_montage()
m

## 2. Filtering & Downsample
No occurence of AC power line interference.

In [None]:
raw.load_data()

# Downsample the data and events at the same time
raw = raw.filter(l_freq=1, h_freq=75).resample(200)

## 3. Mark bad channels

In [None]:
raw.plot(duration=40, n_channels=66)

In [None]:
# raw.info['bads'] = EOG_channels + unused_channels
bad_channels = ['C3', 'C1']
raw.info['bads'].extend(bad_channels)
# the plot after bandpass is much better
# raw.plot(duration=40, n_channels=66)

In [None]:
raw.info

## 4. Set average reference ???

In [None]:
raw.set_eeg_reference(ref_channels='average')
fig2 = raw.plot(duration=40, n_channels=66)

## 5. Repair bad channels

In [None]:
# mne montages
montage_dir = os.path.join(os.path.dirname(mne.__file__), 'channels', 'data', 'montages')

print('\nBUILT-IN MONTAGE FILES')
print('======================')
print(sorted(os.listdir(montage_dir)))

In [None]:
from collections import OrderedDict

montage_fpath = r"D:\bcmi\EMBC\montages\Scan-SynAmps2-Quik-Cap64\SynAmps2 Quik-Cap64.DAT"

with open(montage_fpath, 'r') as fid:
    lines = fid.readlines()

ch_names, poss = list(), list()
nasion = lpa = rpa = None
for i, line in enumerate(lines):
    items = line.split()
    pos = np.array([float(item) for item in items[1:]])
    if items[0]=='Centroid' or items[0]=='EKG' or items[0]=='EMG' or items[0]=='REF.':
        continue
    elif items[0] == 'Nasion':
        nasion = pos
    elif items[0] == 'Left':
        lpa = pos
    elif items[0] == 'Right':
        rpa = pos
    else:
        ch_names.append(items[0].upper())
        poss.append(pos)

electrodes = OrderedDict(zip(ch_names, poss))
extended_1020_montage = mne.channels.make_dig_montage(electrodes, nasion, lpa, rpa)

print(extended_1020_montage)
print(len(extended_1020_montage.ch_names))
extended_1020_montage.ch_names

In [None]:
# visualize
fig_3d = extended_1020_montage.plot(kind='3d')
fig_3d.gca().view_init(azim=70, elev=15)
extended_1020_montage.plot(kind='topomap', show_names=True)

In [None]:
raw.set_montage(extended_1020_montage, on_missing='warn')

In [None]:
raw = raw.interpolate_bads(reset_bads=False, exclude=unused_channels+EOG_channels)

raw.plot(duration=40, n_channels=66)

## 6. Remove artifacts with ICA
### 6.1 Visualize EOG artifact

In [None]:
from mne.preprocessing import create_eog_epochs

eog_evoked = create_eog_epochs(raw, ch_name=EOG_channels).average()
eog_evoked.apply_baseline(baseline=(None, -0.2))
eog_evoked.plot_joint()

### 6.2 Visualize ECG artifact

In [None]:
from mne.preprocessing import create_ecg_epochs

ecg_evoked = create_ecg_epochs(raw).average()
ecg_evoked.apply_baseline(baseline=(None, -0.2))
ecg_evoked.plot_joint()

### 6.3 ICA
Decompose

In [None]:
from mne.preprocessing import ICA

ica = ICA(n_components=5, max_iter='auto', random_state=97)
ica.fit(raw)

# plot the unfiltered raw file
ica.plot_sources(raw)

Exclude artifact components.

In [None]:
ica.exclude = [0]  # indices chosen based on plots above

# ica.apply() changes the Raw object in-place, so let's make a copy first:
reconst_raw = raw.copy()
ica.apply(reconst_raw)

raw.plot(duration=40, n_channels=66)
reconst_raw.plot(duration=40, n_channels=66)

In [None]:
del raw

In [None]:
reconst_raw.plot_sensors(show_names=True)

## 7. Extract trigger events && Epoching && Drop bad epochs

In [None]:
events, event_id = mne.events_from_annotations(reconst_raw)

In [None]:
type(events), type(event_id)

In [None]:
# the trigger number represents image number
# check psychopy code
event_id.keys()

In [None]:
len(event_id.keys())

In [None]:
event_id['1'], type(event_id['1'])

In [None]:
events.shape

In [None]:
# (sample number, _, event code)
events[:10]

In [None]:
import numpy as np

# only the start trigger is needed
choice = np.ones(180)

for i in range(1, 180, 2):
    choice[i] = 0

choice = (choice == 1)
events = events[choice]

events.shape, events[:10]

In [None]:
useful_channels = reconst_raw.ch_names[:]

for ch in unused_channels:
    useful_channels.remove(ch)

for ch in EOG_channels:
    useful_channels.remove(ch)

epochs = mne.Epochs(reconst_raw, events, tmin=-0.3, tmax=20, picks=useful_channels)


In [None]:
epochs

In [None]:
print(epochs.event_id)

# ERP

In [None]:
epochs['1'].plot(n_channels=62)

In [None]:
epochs['1'].plot_psd()

In [None]:
epochs['1'].plot_psd_topomap()

In [None]:
# Evoked
## Negative Evoked
negative_evokeds = []
## Neutral Evoked
neutral_evokeds = []
## Positive Evoked
positive_evokeds = []


for n in range(1, 91):
    ev = epochs[str(n)].average()
    if labels[n] == 0:
        negative_evokeds.append(ev)
    elif labels[n] == 1:
        neutral_evokeds.append(ev)
    elif labels[n] == 2:
        positive_evokeds.append(ev)

negative_evoked = mne.combine_evoked(negative_evokeds, weights='equal')
neutral_evoked = mne.combine_evoked(neutral_evokeds, weights='equal')
positive_evoked = mne.combine_evoked(positive_evokeds, weights='equal')

print(negative_evoked)
print(neutral_evoked)
print(positive_evoked)

In [None]:
negative_evoked.plot(spatial_colors=True, window_title='Negative Evoked')
neutral_evoked.plot(spatial_colors=True, window_title='Neutral Evoked')
positive_evoked.plot(spatial_colors=True, window_title='Positive Evoked')

In [None]:
times = np.linspace(0.05, 20, 5)
negative_evoked.plot_topomap(times=times, colorbar=True)
neutral_evoked.plot_topomap(times=times, colorbar=True)
positive_evoked.plot_topomap(times=times, colorbar=True)

In [None]:
negative_evoked.plot_joint()
neutral_evoked.plot_joint()
positive_evoked.plot_joint()

In [None]:
def custom_func(x):
    return x.max(axis=1)


for combine in ('mean', 'median', 'gfp', custom_func):
    mne.viz.plot_compare_evokeds([negative_evoked, neutral_evoked, positive_evoked], combine=combine)

In [None]:
negative_evoked.plot_image()
neutral_evoked.plot_image()
positive_evoked.plot_image()

# Export npy

In [None]:
epochs['1'].get_data().shape

In [None]:
# process labels
import csv

labels = {}

psyfile = r"D:\bcmi\exp\psychopy_export\yinhao.csv"

with open(psyfile, 'r', newline='') as psyf:
    reader = csv.DictReader(psyf)
    for row in reader:
        img_name = row['imageName']
        img_no = int(img_name[:img_name.rindex('.')])
        if row[' category'] == '负向':
            emotion_label = 0
        elif row[' category'] == '中性':
            emotion_label = 1
        elif row[' category'] == '正向':
            emotion_label = 2
        else:
            print('error')
        
        labels[img_no] = emotion_label

print(labels)

In [None]:
positive_count = 0
neutral_count = 0
negative_count = 0

for k, v in labels.items():
    if v == 0:
        negative_count += 1
    elif v == 1:
        neutral_count += 1
    else:
        positive_count += 1

positive_count, neutral_count, negative_count

In [None]:
# For this subject
# (number of images, 6(1 total+5 frequency bands), number of channels, sample numbers)
# (90, 6, 62, 4000) 20s/image, 200Hz

# further cut into 2 seconds (400 sample points) slices
# (90, 6, 10, 62, 400)

# final
# (bands, number of cases, channels, sample points)
# (6, 90*10, 62, 400)

bands = ['all', 'delta', 'theta', 'alpha', 'beta', 'gamma']
lows = [1, 1, 4, 8, 14, 31]
highs = [75, 4, 8, 14, 31, 50]

epochs.load_data()

# 6 * (1, 900, 62, 400)
datas = []
for b in range(6):
    # 90 * (10, 62, 400)
    for_this_band = []
    filtered_epochs = epochs.copy().filter(l_freq=lows[b], h_freq=highs[b])
    for img in range(1, 91):
        # 10 * (1, 62, 400)
        slices = []
        # (1, 62, 4000)
        img_data = filtered_epochs[str(img)].get_data()[:, :, :4000]
        for s in range(10):
            slices.append(img_data[:, :, 400*s : 400*(s+1)])
        
        for_this_band.append(np.concatenate(slices, axis=0))
    
    datas.append(np.expand_dims(np.concatenate(for_this_band, axis=0), axis=0))

subj_data = np.concatenate(datas, axis=0)

# 900
labs = []
for img in range(1, 91):
    for s in range(10):
        labs.append(labels[img])

subj_label = np.array(labs)

print(subj_data.shape, subj_label.shape)
print(subj_data.dtype, subj_label.dtype)

out_dir = './npydata'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

np.save(out_dir+'/{}_data.npy'.format(subj_name), subj_data)
np.save(out_dir+'/{}_label.npy'.format(subj_name), subj_label)
