**Outline**

The purpose of this script is to manually identify bad channels in raw data, from individual subjects. Bad channels are marked in a 'bad_channels.txt' file (one per task), which is read by later scripts. Note that, when we read the raw data, we'll apply basic preprocessing (filtering and automatic channel rejection) to mimick the behaviour of later scripts before manual channel rejection occurs.

**Import packages**

In [None]:
import mne
import os
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import *
import scipy.stats as sstats
from sklearn import linear_model
import pandas as pd
import copy



mne.set_log_level('ERROR')

**Define file paths, subjects list, and data prefix**

In [None]:
# Define main directories 
projectDir = '../'
dataDir = os.path.join(projectDir, 'proc_data')

# Define prefix for the task
rawPre = "buttonPress" 

# Define filestem for raw data (to be read in)
raw_fstem = rawPre + '-trans-raw.fif'

# Define frequencies used for filtering. 
notchFreqs = [60, 120, 180]

n_fft = 2000
l_freq1=3
h_freq1=150

**Define functions to read and preprocess raw data for a single subjcet**

In [None]:
def ampSpec(raw, n_fft):
	# Amplitude spectra 
	PSDs = raw.compute_psd(n_fft=n_fft, exclude='bads')
	ampSpecData = np.sqrt( PSDs.get_data() ) * 1e15 # in fT
	freq = PSDs.freqs

	return ampSpecData, freq


def load_raw(subject):

	# Define path to raw data
	raw_fname = os.path.join(dataDir, subject, 'meg', raw_fstem)

	# Read in the data
	raw = mne.io.read_raw(raw_fname, preload=True).pick('mag')

	# Set the last 3 channels to 'bad' so they'll be excluded from PSD
	refChannels = raw.info['ch_names'][-3::]
	raw.info['bads'].extend(refChannels)
	
	# Apply filtering
	raw_filt = raw.copy()
	raw_filt = raw_filt.notch_filter(notchFreqs)
	raw_filt = raw_filt.filter(l_freq=l_freq1, h_freq=h_freq1)


	# Grab the amplitude spectrum
	ampSpec_filt, freq = ampSpec(raw_filt.copy(), n_fft)

	# Drop channels with high noise at frequencies above 120 and below 150 Hz 
	#		(to be dropped)
	a = freq > 120
	b = freq < 145
	c = a*b
	hiFreqInd = np.where(c)[0]
	hiFreqAmp = np.mean(ampSpec_filt[:,hiFreqInd], axis=1)
	z_scores = np.abs(sstats.zscore(hiFreqAmp))
	outliers = np.where(z_scores > 2)[0]
	hiChans = [raw_filt.info['ch_names'][i] for i in outliers]
	raw_filt.info['bads'].extend(hiChans)

	return raw_filt

def referenceArrayRegression(raw_filter, opmChannels, sensorChannels, refChannels):

	# Window data (1 second cosine) to clean out high-pass edge effects
	opmData = raw_filter.get_data()[opmChannels]

	# Remove signals related to reference signals via regression
	sensorData = opmData[sensorChannels,:]
	referenceData = opmData[refChannels,:]

	numSensors = len(sensorChannels)
	regressData = copy.copy(sensorData)
	for i in np.arange(numSensors):
		# Put data into a pandas dataframe
		data = {'sensor': sensorData[i,:],
				'Xref': referenceData[0,:],
				'Yref': referenceData[1,:],
				'Zref': referenceData[2,:],
				}
		df = pd.DataFrame(data)
		x = df[['Xref','Yref', 'Zref']]
		y = df['sensor']
		# Run multi-variable regression
		regr = linear_model.LinearRegression()
		regr.fit(x, y)
		# Extract cleaned sensor data 
		regressData[i,:] = sensorData[i,:] - regr.coef_[0]*referenceData[0,:] - regr.coef_[1]*referenceData[1,:] - regr.coef_[2]*referenceData[2,:]

	# Put cleaned data into a raw_regress object
	allData = raw_filter.get_data()
	allData[sensorChannels,:] = regressData
	raw_regressed = mne.io.RawArray(allData, raw_filter.info)

	return raw_regressed



**Round 1 (pre HFC) of channel rejection**

Add any obvious bad channels to the txt file. Try to be pretty conservative here because HFC (applied after this step) can clear up a lot of noise

In [None]:
# If interactive plots won't generate, run...
plt.close()

# ...and try again

In [None]:
# Plot raw data for a given subject
%matplotlib widget

subject = 'mnsbp005'

# Load preprocessed raw data
raw_filt = load_raw(subject)

# Plot the data in an interactive window. Note that you should be able to modify the size of the window by pulling on the grey triangle in the bottom right corner  
plot = raw_filt.copy().plot(scalings=dict(mag=15e-12), use_opengl=True, duration=10)

**Plot PSD**

In [None]:
raw_filt.plot_psd(fmin=0, fmax=200, average=False, spatial_colors=True, show=False)

**Mark bad channels for this subject**

In [None]:
bads = ["C5"]
raw_filt.info['bads'].extend(bads)

raw_filt.info['bads']

**Apply reference array regression and HFC**

In [None]:
# Get indices for sensors and references
opmIndices = mne.channel_indices_by_type(raw_filt.info)['mag']
numChannels = len(opmIndices)
numSensors = numChannels-3
sensorIndices = opmIndices[0:numSensors]
referenceIndices = opmIndices[-3::]

# Reference array regression 
raw_regressed = referenceArrayRegression(raw_filt.copy(), opmIndices, sensorIndices, referenceIndices)

# Apply HFC
raw_hfc = raw_regressed.copy()
projs = mne.preprocessing.compute_proj_hfc(raw_hfc.info, exclude='bads', order=1) # use order=1 for 1st order HFC 
raw_hfc.add_proj(projs)
raw_hfcApplied = raw_hfc.copy().apply_proj(verbose="error")

**Round 2 (post-HFC)**

Mark any remaining bad channels not corrected by HFC. Go back, add them to 'bads', re-run HFC and re-check

In [None]:
plot = raw_hfcApplied.copy().plot(scalings=dict(mag=15e-12), use_opengl=True, duration=10)

In [None]:
raw_hfcApplied.copy().plot_psd(fmin=0, fmax=200, average=False, spatial_colors=True, show=False)