In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import math 
import pandas as pd
from datetime import datetime
import os

%matplotlib inline
%matplotlib notebook

sns.set_context('talk')

In [None]:
data_path = '/home/bhanu/work/hrishikeshn/'
animal_ids = ['G394', 'G396', 'G404', 'G405', 'G492', 'G493', 'G506', 'G508']
colors = sns.color_palette("colorblind", 6)
colors.append(colors[0])
protocols = ['SoAn1', 'An1', 'An2', 'An3', 'Hr7', 'All3', 'Hr6']
protocol_info = {
                'SoAn1' : '250ms',
                'An1'   : '350ms',
                'An2'   : '450ms',
                'An3'   : '550ms',
                'Hr7'   : '250ms-550ms interleaved',
                'All3'  : 'extinction'
                }
color_dict = dict(zip(protocols, colors))

In [None]:
def read_behavior_data(session_file):
    try:
        data = np.load(session_file, allow_pickle=True)
    except FileNotFoundError:
        print(session_file + ' file does not exist')
        return None 
    data = data.item()
    return data

In [None]:
def plot_eye_blink_traces(ax, data):
    for t, br in enumerate(data['blink_response']):
#         print(data['arduino_timestamp'][0])
#         time_array = np.array([(data['arduino_timestamp'][t][i] - data['arduino_timestamp'][t][0]).microseconds for i in range(len(data['arduino_timestamp'][t]))])
        ax.plot(np.arange(len(br)),br)
    return

In [None]:
def get_learning_stats(data):
    conditioned_response = []
    try:
        for t, br in enumerate(data['blink_response']):
            try:
                trace_frames = np.where(np.array(data['trial_phase'][t])==3)[0]
                trace_start = trace_frames[0]#data['trial_phase'][t].index(3) 
                trace_end = trace_frames[-1]#data['trial_phase'][t].index(4) 
            except:
                continue
            if np.any(np.array(br[trace_start:trace_end])>2):
                conditioned_response.append(1)
            else:
                conditioned_response.append(0)

        if len(conditioned_response)>0:
            learnt_percent = sum(conditioned_response)*100/len(conditioned_response)
        else:
            learnt_percent = math.nan
        return learnt_percent
    
    except ValueError:
        print("returning nan")
        return math.nan
        

In [None]:
def calc_cr_peak_timing(data):
    conditioned_response = []
    peak_times = []
    for t, br in enumerate(data['blink_response']):
        if data['probe_flag'] == 0:
            continue
        trace_frames = np.where(np.array(data['trial_phase'][t])==3)[0][:-2]
        if np.any(np.array(br[trace_frames])>2):
            conditioned_response.append(1)
            
            cs_onset = data['arduino_timestamp'][t][data['trial_phase'][t].index(2)]
            peak_t  = data['arduino_timestamp'][t][np.argmax(br[trace_frames]) + trace_frames[0]]
            peak_times.append(((peak_t-cs_onset).seconds*1e6 + (peak_t-cs_onset).microseconds)/1000)
        else:
            conditioned_response.append(0)
#             peak_times.append(math.nan)

    if len(conditioned_response)>0:
        learnt_percent = sum(conditioned_response)*100/len(conditioned_response)
    else:
        learnt_percent = math.nan
    return learnt_percent, conditioned_response, peak_times

        

In [None]:
# n_cols = 3
# if len(animal_ids)%n_cols == 0:
#     n_rows = int(len(animal_ids)/n_cols)
# else:
#     n_rows = int(len(animal_ids)/n_cols) + 1
    
# row, col = 0, 0

# f, ax = plt.subplots(n_rows, n_cols, figsize=(15,10), sharey=True, sharex=True)    
lines = []
proto_list = []

for a, animal_id in enumerate(animal_ids):
    print(animal_id)
    metadata = pd.read_csv(data_path + 'csv/' + animal_id + '.csv', delimiter=',',
                      dtype={'date':object})
#     session_files = sorted(glob.glob(data_path + 'behavior_data/' + animal_id + '/*.npy'))
    learning_percentage = []
    upi = []
    dates = []
    sess_codes = []
    peak_time_dict = {}
    for s in range(len(metadata)): 
#         print(s)

        session_file = data_path + 'behavior_data/' + animal_id + '/' + animal_id + \
                        '_' + metadata['behaviour_code'].iloc[s] + '_' + str(metadata['behaviour_session_number'].iloc[s]) + \
                        '_behavior_data.npy'
        if not(os.path.isfile(session_file)) or ('All1' in session_file) or ('All4' in session_file) or ('So2' in session_file) or('error' in session_file) or (os.path.getsize(session_file)< 4000000):
            continue
#         print(session_file)
        data = read_behavior_data(session_file)
        lp, cr, pt = calc_cr_peak_timing(data)
        peak_time_dict[s] = pt
        sess_codes.append(metadata['behaviour_code'].iloc[s])
        
    f, ax = plt.subplots(figsize=(12,5))
    b_plot = ax.boxplot(peak_time_dict.values(), patch_artist=True)
    ax.set_xticklabels(peak_time_dict.keys())
    ax.set_title(animal_id)
    ax.set_xlabel('Session_number')
    ax.set_ylabel('Time of peak eyeblink (ms)')
    
    for patch, sc in zip(b_plot['boxes'], sess_codes):
        patch.set_facecolor(color_dict[sc])
        
#         f, ax = plt.subplots()
#         plot_eye_blink_traces(ax, data)
#         learning_percentage.append(get_learning_stats(data))
        
        
        
#         behav_code = session_file.split('/')[-1].split('_')[1]
#         behav_sess_num = int(session_file.split('/')[-1].split('_')[2])
#         sess_codes.append(behav_code)
    
        
#         upi.append(metadata.loc[(metadata['behaviour_code'] == behav_code) & (metadata['behaviour_session_number'] == behav_sess_num), 'upi'].iloc[0])
        

#         date = datetime.strptime(metadata.loc[(metadata['behaviour_code'] == behav_code) & (metadata['behaviour_session_number'] == behav_sess_num), 'date'].iloc[0], "%Y%m%d")
#         print(date.date())
#         dates.append(date)
        
#     sorted_learning_percentage = np.array([x for _, x in sorted(zip(upi, learning_percentage))])
#     sorted_sess_codes = np.array([x for _, x in sorted(zip(upi, sess_codes))])
#     sorted_upi = np.array(sorted(upi))
    
#     plt.figure()
    
    
#     for sc in set(sorted_sess_codes):
#         indices = np.where(sorted_sess_codes == sc)[0]
#         line, = plt.plot(sorted_upi[indices], sorted_learning_percentage[indices], c=color_dict[sc], marker='o')
#         line.set_label(sc)
#         if (animal_id=='G405') and (sc!='Hr6'):
#             lines.append(line)
#             proto_list.append(sc)
#     plt.title(animal_id)
#     plt.xlabel('session number')
#     plt.ylabel('performance score')
#     plt.xlim(0,45)
#     plt.xticks(np.arange(0,45, 10))
#     plt.ylim(0,100)
#     plt.yticks(np.arange(0,101, 25))
#     plt.savefig('%s_performance_score.svg' %animal_id)
    

# plt.figure()
# plt.legend(handles = lines, labels=[protocol_info[sc] for sc in proto_list])
# plt.savefig('legend_performance_score.svg')
      

        
    


# Summary

In [None]:
# n_cols = 3
# if len(animal_ids)%n_cols == 0:
#     n_rows = int(len(animal_ids)/n_cols)
# else:
#     n_rows = int(len(animal_ids)/n_cols) + 1
    
# row, col = 0, 0

# f, ax = plt.subplots(n_rows, n_cols, figsize=(15,10), sharey=True, sharex=True)    
lines = []
proto_list = []

for a, animal_id in enumerate(animal_ids):
    print(animal_id)
    session_files = sorted(glob.glob(data_path + 'behavior_data/' + animal_id + '/*.npy'))
    learning_percentage = []
    upi = []
    dates = []
    sess_codes = []
    for s, session_file in enumerate(session_files):
        if ('All1' in session_file) or ('All4' in session_file) or ('So2' in session_file) or('error' in session_file) or (os.path.getsize(session_file)< 4000000):
            continue
#         print(session_file)
#         print(s, row, col)
        data = read_behavior_data(session_file)
#         plot_eye_blink_traces(ax[row, col], data)
        learning_percentage.append(get_learning_stats(data))
#         print(learning_percentage[-1])
        
        
        
        behav_code = session_file.split('/')[-1].split('_')[1]
        behav_sess_num = int(session_file.split('/')[-1].split('_')[2])
        sess_codes.append(behav_code)
#         print(behav_code, behav_sess_num)
        metadata = pd.read_csv(data_path + 'csv/' + animal_id + '.csv', delimiter=',',
                          dtype={'date':object})
    
        
#         upi = metadata.loc[(metadata['behaviour_code'] == behav_code) & (metadata['behaviour_session_number'] == behav_sess_num), 'upi'].iloc[0]
        upi.append(metadata.loc[(metadata['behaviour_code'] == behav_code) & (metadata['behaviour_session_number'] == behav_sess_num), 'upi'].iloc[0])

#         date = datetime.strptime(metadata.loc[(metadata['behaviour_code'] == behav_code) & (metadata['behaviour_session_number'] == behav_sess_num), 'date'].iloc[0], "%Y%m%d")
#         print(date.date())
#         dates.append(date)
#         color_list.append(color_dict[behav_code])
        
    sorted_learning_percentage = np.array([x for _, x in sorted(zip(upi, learning_percentage))])
    sorted_sess_codes = np.array([x for _, x in sorted(zip(upi, sess_codes))])
    sorted_upi = np.array(sorted(upi))
    
    plt.figure()
    
    
    for sc in set(sorted_sess_codes):
        indices = np.where(sorted_sess_codes == sc)[0]
        line, = plt.plot(sorted_upi[indices], sorted_learning_percentage[indices], c=color_dict[sc], marker='o')
        line.set_label(sc)
        if (animal_id=='G405') and (sc!='Hr6'):
            lines.append(line)
            proto_list.append(sc)
    plt.title(animal_id)
    plt.xlabel('session number')
    plt.ylabel('performance score')
    plt.xlim(0,45)
    plt.xticks(np.arange(0,45, 10))
    plt.ylim(0,100)
    plt.yticks(np.arange(0,101, 25))
    plt.savefig('%s_performance_score.svg' %animal_id)
    

plt.figure()
plt.legend(handles = lines, labels=[protocol_info[sc] for sc in proto_list])
plt.savefig('legend_performance_score.svg')
      

        
    


## 