In [1]:
import mne
import numpy as np
from pymatreader import read_mat
import matplotlib.pyplot as plt
import random
import pandas as pd

In [6]:
import sys
import numpy as np
import matplotlib.pyplot as plt
from time import sleep


%matplotlib
class Annotator:
    
    def __init__(self, pt, path_data='/Users/lina_01/Desktop/ab_ica/pt_AB_ICA_raw_files/'):
        self.channel = None
        self.onset = None
        self.offset = None
        self.path_data = path_data
        self.pt = pt
        self.color_dict = None
        self.df = pd.DataFrame(columns=['channel', 'onset', 'offset', 'AB', 'ICA'])
        self.text = None
        self.wait = True
        self.responses = []
        self.labels = {'o':'overcorrected', 'u':'undercorrected', 'e':'effectively corrected'}
        self.get_times()
        self.get_raws()
        
    def get_raws(self):
        self.rest_raw_processed = mne.io.read_raw_fif(f'{self.path_data}{self.pt}_rest_raw_processed.fif')
        self.rest_raw_ICA = mne.io.read_raw_fif(f'{self.path_data}{self.pt}_rest_raw_ica.fif')
        s_m12_AB_mat_outdata = read_mat(f'{self.path_data}{self.pt}_AB_outdata.mat')
        s_m12_AB_mat_outdata = s_m12_AB_mat_outdata['datatosave_out']['OutData']
        self.rest_raw_AB = mne.io.RawArray(s_m12_AB_mat_outdata, self.rest_raw_processed.info)
        
    def annotate_channel(self, channel):
        self.channel = channel
        self.fig, self.ax = plt.subplots(1,1,figsize=(10,5))
        plt.ion()
        plt.show()

        for self.onset, self.offset in zip(self.times['onset'],self.times['offset']):
            self.visual_insp()
            self.fig.canvas.mpl_connect('key_press_event', self.on_press)
            while self.wait:
                plt.pause(0.1)
            self.wait=True
            self.save_responses()
            self.responses = []
    
    def get_times(self):
        annot_file = f"/Users/lina_01/Desktop/ab_ica/manual_annot_AB_ICA_{self.pt}.csv"
        annotations = mne.read_annotations(annot_file)
        self.times = {}
        for annot in annotations:
            self.times['onset'] = np.array([annot['onset'] for annot in annotations
                                            if annot['description']=='saccade'])
            self.times['duration'] = np.array([annot['duration'] for annot in annotations 
                                               if annot['description']=='saccade'])
            self.times['offset'] = self.times['onset'] + self.times['duration']
    
    def visual_insp(self, win_len=1.0):
        self.ax.clear()
        midpoint = (self.offset+self.onset)/2
        start = midpoint-win_len/2
        stop = midpoint+win_len/2
        y_AB = self.rest_raw_AB.get_data(picks=self.channel, tmin=start, tmax=stop).squeeze()
        y_ICA = self.rest_raw_ICA.get_data(picks=self.channel,tmin=start, tmax=stop).squeeze()
        y_raw = self.rest_raw_processed.get_data(picks=self.channel,tmin=start, tmax=stop).squeeze()
        x = np.linspace(start, stop, len(y_AB))
        color_list = np.random.choice(['r', 'b'], 2, replace=False)
        self.ax.plot(x, y_AB, c=color_list[0],label='Alg 1')
        self.ax.plot(x, y_ICA, c=color_list[1],label='Alg 2')
        self.ax.plot(x, y_raw,label='original')
        self.color_dict = {'AB':color_list[0], 'ICA':color_list[1]}
        self.text = self.ax.text(0.3,0.9,'annotating blue', transform=self.ax.transAxes, ha='left', va='top')
        self.ax.legend()
        self.ax.axvspan(self.onset, self.offset, color="blue", alpha=0.1)

    def get_annotation_str(self):
        if len(self.responses) == 0:
            disp_str = 'annotating blue'
        elif len(self.responses) == 1:
            disp_str = f'blue: {self.labels[self.responses[0]]}\n'
            disp_str += 'annotating red'
        elif len(self.responses) == 2:
            disp_str = f'blue: {self.labels[self.responses[0]]}\n'
            disp_str += f'red: {self.labels[self.responses[1]]}\n'
            disp_str += 'hit enter to confirm'
        return disp_str

    def on_press(self, event):
        sys.stdout.flush()
        if event.key in ['o','u','e']:
            self.responses.append(event.key)
            self.text.set_text(self.get_annotation_str())
        elif len(self.responses) == 2 and event.key == 'enter':
            self.wait = False
        elif event.key == 'left':
            self.responses = self.responses[:-1]
            self.text.set_text(self.get_annotation_str())

    def save_responses(self):
        response_dict = {'b': self.responses[0], 'r':self.responses[1]}
        save_dict = {'channel':self.channel, 'onset':self.onset, 'offset':self.offset,
                     'AB':response_dict[self.color_dict['AB']], 
                     'ICA':response_dict[self.color_dict['ICA']]}
        self.df = self.df.append(save_dict, ignore_index=True)
        self.df.to_csv(f'/Users/lina_01/Desktop/{self.pt}_{self.channel}.csv')
    


Using matplotlib backend: MacOSX
