In [1]:

from scipy.stats import sem
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from PyPDF2 import PdfFileReader, PdfFileWriter, PdfWriter, PdfReader
from scipy.interpolate import interp1d
from datetime import date
from statistics import mean 
import math
def save_image(filename): 
    
    p = PdfPages(filename+'.pdf') 
    fig_nums = plt.get_fignums()   
    figs = [plt.figure(n) for n in fig_nums] 
      
    for fig in figs:  
        
        fig.savefig(p, format='pdf', dpi=300)
           
    p.close() 

def run(subject_session_data,output_dir_onedrive, output_dir_local):
    max_sessions = 20
    subject = subject_session_data['subject']
    dates = subject_session_data['dates']
    session_id = np.arange(len(dates)) + 1
    jitter_flag = subject_session_data['jitter_flag']
    raw_data = subject_session_data['raw']
    outcomes = subject_session_data['outcomes']
    outcomes_time = subject_session_data['outcomes_time']
    categories = subject_session_data['isi_post_emp']
    row = 4 
    col = 5
    pre_delay = 300
    post_delay = 3000
    alignments = ['1st flash' , '3rd flash' , '4th flash' , 'choice window' , 'outcome']
    row_names = ['rewarded short' , 'rewarded long' , 'punished short' , 'punished long']
    
    
    for i in range(len(dates)):
        #fig, axs = plt.subplots(nrows=row, ncols=col, figsize=(20, 30))
        pdf_streams = []
        pdf_paths = []
        numTrials = raw_data[i]['nTrials']
        outcome = outcomes[i]
        outcome_time = np.multiply(outcomes_time[i],1000)
        session_date = dates[i]
        category = categories[i]
        count = np.zeros([2 , 2 , 5])
        curve = np.zeros([3 , 2 , 2 , 5 , pre_delay+post_delay])
        
        for trial in range(numTrials):
            choice =  np.multiply(raw_data[i]['RawEvents']['Trial'][trial]['States']['WindowChoice'][0],1000)
            stim_seq = np.divide(subject_session_data['stim_seq'][i][trial],1000)
            step = 10000
            start = 0
            
            if not 'Port1In' in raw_data[i]['RawEvents']['Trial'][trial]['Events'].keys():
                    port1 = [np.nan]
            elif type(raw_data[i]['RawEvents']['Trial'][trial]['Events']['Port1In']) == float:
                port1 = [raw_data[i]['RawEvents']['Trial'][trial]['Events']['Port1In']]
            else:
                port1 = raw_data[i]['RawEvents']['Trial'][trial]['Events']['Port1In']

            if not 'Port2In' in raw_data[i]['RawEvents']['Trial'][trial]['Events'].keys():
                port2= [np.nan]
            elif type(raw_data[i]['RawEvents']['Trial'][trial]['Events']['Port2In']) == float:
                port2 = [raw_data[i]['RawEvents']['Trial'][trial]['Events']['Port2In']]
            else:
                port2 = raw_data[i]['RawEvents']['Trial'][trial]['Events']['Port2In']

            if not 'Port3In' in raw_data[i]['RawEvents']['Trial'][trial]['Events'].keys():
                port3= [np.nan]
            elif type(raw_data[i]['RawEvents']['Trial'][trial]['Events']['Port3In']) == float:
                port3 = [raw_data[i]['RawEvents']['Trial'][trial]['Events']['Port3In']]
            else:
                port3 = raw_data[i]['RawEvents']['Trial'][trial]['Events']['Port3In']
            
            maximum = math.ceil(np.nanmax([np.nanmax(port1),np.nanmax(port2),np.nanmax(port3)]))
            stop = 10
            time = np.linspace(start , stop , step*(stop-start)+1)
            time = np.round(time , 4)
            lick1 = np.zeros(step*(stop-start)+1)
            lick2 = np.zeros(step*(stop-start)+1)
            lick3 = np.zeros(step*(stop-start)+1)
            for t in range(len(port1)):
                lick1[np.where(time == round(port1[t] , 4))] = 1
            for t in range(len(port2)):
                lick2[np.where(time == round(port2[t] , 4))] = 1
            for t in range(len(port3)):
                lick3[np.where(time == round(port3[t] , 4))] = 1
            
            if outcome[trial] == 'Reward':
                if category[trial] < 500:
                    if len(stim_seq[1 , :]) > 0 and stim_seq[1 , 0] > pre_delay:
                        curve[0 , 0 , 0 , 0 , :] += lick1[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        curve[1 , 0 , 0 , 0 , :] += lick2[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        curve[2 , 0 , 0 , 0 , :] += lick3[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        count[0 , 0 , 0] += 1 
                    if len(stim_seq[1 , :]) > 2 and stim_seq[1 , 2]> pre_delay:
                        curve[0 , 0 , 0 , 1 , :] += lick1[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        curve[1 , 0 , 0 , 1 , :] += lick2[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        curve[2 , 0 , 0 , 1 , :] += lick3[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        count[0 , 0 , 1] += 1
                    if len(stim_seq[1 , :]) > 3 and stim_seq[1 , 3]> pre_delay:
                        curve[0 , 0 , 0 , 2 , :] += lick1[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        curve[1 , 0 , 0 , 2 , :] += lick2[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        curve[2 , 0 , 0 , 2 , :] += lick3[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        count[0 , 0 , 2] += 1
                        
                    
                    curve[0 , 0 , 0 , 3 , :] += lick1[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    curve[1 , 0 , 0 , 3 , :] += lick2[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    curve[2 , 0 , 0 , 3 , :] += lick3[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    count[0 , 0 , 3] += 1
                    curve[0 , 0 , 0 , 4 , :] += lick1[int(choice)-pre_delay:int(choice)+post_delay]
                    curve[1 , 0 , 0 , 4 , :] += lick2[int(choice)-pre_delay:int(choice)+post_delay]
                    curve[2 , 0 , 0 , 4 , :] += lick3[int(choice)-pre_delay:int(choice)+post_delay]
                    count[0 , 0 , 4] += 1
                    
                if category[trial] > 500:
                    if len(stim_seq[1 , :]) > 0 and stim_seq[1 , 0] > pre_delay:
                        curve[0 , 0 , 1 , 0 , :] += lick1[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        curve[1 , 0 , 1 , 0 , :] += lick2[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        curve[2 , 0 , 1 , 0 , :] += lick3[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        count[0 , 1 , 0] += 1 
                    if len(stim_seq[1 , :]) > 2:
                        curve[0 , 0 , 1 , 1 , :] += lick1[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        curve[1 , 0 , 1 , 1 , :] += lick2[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        curve[2 , 0 , 1 , 1 , :] += lick3[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        count[0 , 1 , 1] += 1
                    if len(stim_seq[1 , :]) > 3 :
                        curve[0 , 0 , 1 , 2 , :] += lick1[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        curve[1 , 0 , 1 , 2 , :] += lick2[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        curve[2 , 0 , 1 , 2 , :] += lick3[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        count[0 , 1 , 2] += 1
                        
                    curve[0 , 0 , 1 , 3 , :] += lick1[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    curve[1 , 0 , 1 , 3 , :] += lick2[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    curve[2 , 0 , 1 , 3 , :] += lick3[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    count[0 , 1 , 3] += 1
                    curve[0 , 0 , 1 , 4 , :] += lick1[int(choice)-pre_delay:int(choice)+post_delay]
                    curve[1 , 0 , 1 , 4 , :] += lick2[int(choice)-pre_delay:int(choice)+post_delay]
                    curve[2 , 0 , 1 , 4 , :] += lick3[int(choice)-pre_delay:int(choice)+post_delay]
                    count[0 , 1 , 4] += 1
                    
            if outcome[trial] == 'Punish':
                if category[trial] < 500:
                    if len(stim_seq[1 , :]) > 0 and stim_seq[1 , 0] > pre_delay:
                        curve[0 , 1 , 0 , 0 , :] += lick1[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        curve[1 , 1 , 0 , 0 , :] += lick2[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        curve[2 , 1 , 0 , 0 , :] += lick3[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        count[1 , 0 , 0] += 1 
                    if len(stim_seq[1 , :]) > 2:
                        curve[0 , 1 , 0 , 1 , :] += lick1[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        curve[1 , 1 , 0 , 1 , :] += lick2[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        curve[2 , 1 , 0 , 1 , :] += lick3[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        count[1 , 0 , 1] += 1
                    if len(stim_seq[1 , :]) > 3:
                        curve[0 , 1 , 0 , 2 , :] += lick1[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        curve[1 , 1 , 0 , 2 , :] += lick2[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        curve[2 , 1 , 0 , 2 , :] += lick3[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        count[1 , 0 , 2] += 1
                        
                    curve[0 , 1 , 0 , 3 , :] += lick1[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    curve[1 , 1 , 0 , 3 , :] += lick2[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    curve[2 , 1 , 0 , 3 , :] += lick3[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    count[1 , 0 , 3] += 1
                    curve[0 , 1 , 0 , 4 , :] += lick1[int(choice)-pre_delay:int(choice)+post_delay]
                    curve[1 , 1 , 0 , 4 , :] += lick2[int(choice)-pre_delay:int(choice)+post_delay]
                    curve[2 , 1 , 0 , 4 , :] += lick3[int(choice)-pre_delay:int(choice)+post_delay]
                    count[1 , 0 , 4] += 1
                    
                if category[trial] > 500:
                    if len(stim_seq[1 , :]) > 0 and stim_seq[1 , 0] > pre_delay:
                        curve[0 , 1 , 1 , 0 , :] += lick1[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        curve[1 , 1 , 1 , 0 , :] += lick2[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        curve[2 , 1 , 1 , 0 , :] += lick3[int(1000*stim_seq[1 , 0])-pre_delay:int(1000*stim_seq[1 , 0])+post_delay]
                        count[1 , 1 , 0] += 1 
                    if len(stim_seq[1 , :]) > 2:
                        curve[0 , 1 , 1 , 1 , :] += lick1[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        curve[1 , 1 , 1 , 1 , :] += lick2[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        curve[2 , 1 , 1 , 1 , :] += lick3[int(1000*stim_seq[1 , 2])-pre_delay:int(1000*stim_seq[1 , 2])+post_delay]
                        count[1 , 1 , 1] += 1
                    if len(stim_seq[1 , :]) > 3:
                        curve[0 , 1 , 1 , 2 , :] += lick1[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        curve[1 , 1 , 1 , 2 , :] += lick2[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        curve[2 , 1 , 1 , 2 , :] += lick3[int(1000*stim_seq[1 , 3])-pre_delay:int(1000*stim_seq[1 , 3])+post_delay]
                        count[1 , 1 , 2] += 1
                        
                    curve[0 , 1 , 1 , 3 , :] += lick1[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    curve[1 , 1 , 1 , 3 , :] += lick2[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    curve[2 , 1 , 1 , 3 , :] += lick3[int(outcome_time[trial])-pre_delay:int(outcome_time[trial])+post_delay]
                    count[1 , 1 , 3] += 1
                    curve[0 , 1 , 1 , 4 , :] += lick1[int(choice)-pre_delay:int(choice)+post_delay]
                    curve[1 , 1 , 1 , 4 , :] += lick2[int(choice)-pre_delay:int(choice)+post_delay]
                    curve[2 , 1 , 1 , 4 , :] += lick3[int(choice)-pre_delay:int(choice)+post_delay]
                    count[1 , 1 , 4] += 1
    
                    
        fig = plt.figure(constrained_layout=True, figsize=(20, 30))
        fig.suptitle(session_date)

        # create 3x1 subfigs
        subfigs = fig.subfigures(nrows=4, ncols=1)
        for row, subfig in enumerate(subfigs):
            subfig.suptitle(row_names[row])
        
        
            axs = subfig.subplots(nrows=1, ncols=5)
            for col, ax in enumerate(axs):
                ploting = np.divide(curve[1, row//2 , row%2 , col  , :],count[row//2 , row%2 , col])
                ploting = np.convolve(ploting,np.ones(400,dtype=int),'same')
                x = np.arange(len(ploting))/1000
#                 cubic_interpolation_model = interp1d(x, ploting, kind = "cubic")
#                 X_=np.linspace(x.min(), x.max(), 10*(pre_delay+post_delay))
#                 Y_=cubic_interpolation_model(X_)
                X_ = x
                Y_ = ploting
                ax.plot(X_, Y_  , label = 'center' , color = 'gray')
                
                ploting = np.divide(curve[0, row//2 , row%2 , col  , :],count[row//2 , row%2 , col])
                ploting = np.convolve(ploting,np.ones(400,dtype=int),'same')
#                 cubic_interpolation_model = interp1d(x, ploting, kind = "cubic")
#                 X_=np.linspace(x.min(), x.max(), 10*(pre_delay+post_delay))
#                 Y_=cubic_interpolation_model(X_)
                X_ = x
                Y_ = ploting
                ax.plot(X_, Y_ , label = 'left' , color = 'r')
                
                ploting = np.divide(curve[2, row//2 , row%2 , col  , :],count[row//2 , row%2 , col])
                ploting = np.convolve(ploting,np.ones(400,dtype=int),'same')
#                 cubic_interpolation_model = interp1d(x, ploting, kind = "cubic")
#                 X_=np.linspace(x.min(), x.max(), 10*(pre_delay+post_delay))
#                 Y_=cubic_interpolation_model(X_)
                X_ = x
                Y_ = ploting
                ax.plot(X_, Y_  , label = 'reght' , color = 'limegreen')
                ax.vlines(pre_delay/1000 , 0 , 0.5 , linestyle='--')
                ax.set_title('aligned with ' + alignments[col] + ' ('+str(count[row//2 , row%2 , col])+')')
                if col == 0:
                    ax.legend(loc='upper left', bbox_to_anchor=(1,1), ncol=1)
                    
            

                
        output_dir_onedrive, 
        output_dir_local

        output_pdf_dir =  output_dir_onedrive + subject + '/'
        output_pdf_pages_dir = output_dir_local + subject + '/_alingment/alingment_' + session_date + '/'
        os.makedirs(output_pdf_dir, exist_ok = True)
        os.makedirs(output_pdf_pages_dir, exist_ok = True)
        output_pdf_filename = output_pdf_pages_dir + subject +  session_date + '_alingment' + str(i)
        pdf_paths.append(output_pdf_filename + '.pdf')
        save_image(output_pdf_filename)        
        plt.close(fig)
            
        
        output = PdfWriter()
        pdf_files = []
        for pdf_path in pdf_paths:
            f = open(pdf_path, "rb")
            pdf_streams.append(PdfReader(f))
            pdf_files.append(f)

        for pdf_file_stream in pdf_streams:
            output.add_page(pdf_file_stream.pages[0])

        for pdf_file in pdf_files:
            pdf_file.close()


        outputStream = open(r'' + output_pdf_dir + subject + '_' + session_date + '_alingment' + '.pdf', "wb")
        output.write(outputStream)
        outputStream.close()

        


    


