### SparklyRGT Template: Baseline and Acquisition Analysis 

**Requirements**
* The data must be an excel file from MEDPC2XL (trial by trial data) 
* The data, sparklyRGT.py file, and this notebook must all be in the same folder

**Getting started: Please make a copy of this (sparklyRGT_template_2) for each analysis**
- Refer to sparklyRGT_documentation for function information
- Note: depending on your analysis, you will only have to complete certain sections of the sparklyRGT_documentation
- Note: feel free to create a personal template once you've become comfortable - this is just an example

In [1]:
import os
os.chdir('..\\sparklyRGT_tutorial')
import sparklyRGT as rgt 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import scipy.stats as stats
import seaborn as sns
import pingouin as pg
pd.options.mode.chained_assignment = None
pd.set_option('display.max_rows',100)
%load_ext autoreload
%autoreload 2

I am being executed!


***

# 1) Load data into Python



In [2]:
file_names = ['BH07_raw_free_S29-30.xlsx'] 

df = rgt.load_data(file_names)

df.head()

Unnamed: 0,MSN,StartDate,StartTime,Subject,Group,Box,Experiment,Comment,Session,Trial,...,Pun_Persev_H5,Pun_HeadEntry,Pun_Dur,Premature_Resp,Premature_Hole,Rew_Persev_H1,Rew_Persev_H2,Rew_Persev_H3,Rew_Persev_H4,Rew_Persev_H5
0,rGT_A-cue,2020-10-09,11:01:00,25,0.0,1,0.0,,29,1.0,...,3,3,30,0,0,0,0,0,0,0
1,rGT_A-cue,2020-10-09,11:01:00,25,0.0,1,0.0,,29,2.1,...,0,0,0,1,5,0,0,0,0,0
2,rGT_A-cue,2020-10-09,11:01:00,25,0.0,1,0.0,,29,2.0,...,3,2,30,0,0,0,0,0,0,0
3,rGT_A-cue,2020-10-09,11:01:00,25,0.0,1,0.0,,29,3.0,...,0,0,0,0,0,0,0,0,0,0
4,rGT_A-cue,2020-10-09,11:01:00,25,0.0,1,0.0,,29,4.0,...,2,2,30,0,0,0,0,0,0,0


***
# 2A) Baseline & Acquisition Analysis


In [3]:
control_group = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] #In this example: Tg negative rats

exp_group = [17,18,19,20,21,22,23,24] #In this example: Tg positive rats

third_group = [25,26,27,28,29,30,31,32]

group_names = {0: '',
              1: ''} 

group_list = [control_group, exp_group, third_group]

title = '' #for plotting

startsess = 29 #first session you would like to include in figures
endsess = 30 #last session you would like to include in figures

## Data cleaning

### Check session numbers for each rat

In [4]:
rgt.check_sessions(df)

Subject  StartDate   Session
1        2020-10-09  29         131.1
         2020-10-10  30         124.0
2        2020-10-09  29          76.1
         2020-10-10  30          81.0
3        2020-10-09  29          49.0
         2020-10-10  30          45.0
4        2020-10-09  29         103.0
         2020-10-10  30          97.0
5        2020-10-09  29          68.1
         2020-10-10  30          69.0
6        2020-10-09  29          88.0
         2020-10-10  30          75.0
7        2020-10-09  28          53.0
         2020-10-10  29          65.0
         2020-10-13  30          56.1
8        2020-10-09  29         124.0
         2020-10-10  30         121.0
9        2020-10-09  29          62.0
         2020-10-10  30          61.0
11       2020-10-09  29         132.0
         2020-10-10  30         136.1
12       2020-10-09  29          54.0
         2020-10-10  30          72.0
13       2020-10-09  29          67.0
         2020-10-10  30          60.0
14       2020-10-09  

### Drop/edit session numbers

In [5]:
df2 = rgt.drop_sessions(df, [28])

### Check that you dropped/edited the desired session(s)

In [None]:
# rgt.check_sessions(df2) 

## Data processing

### Calculate variables for each rat


In [6]:
df_sum = rgt.get_summary_data(df) #change to df instead of df2 if you didn't do any session editing
# df_sum 

In [54]:
df_long = rgt.get_long_summary_data(df, df_sum)
# df_long

In [55]:
def get_risk_status_long(df_long_sum, sessions = None): 
    """takes in long df summary data and list of sessions, and gets the mean risks and risk status and appends it to df_long
    if sessions is not passed, all sessions in df_long are used"""
   
    #objects
    subs = df_long_sum.Subject.unique()
    mean_risk_list = []
    sessions = filt_sess(df_long_sum, sessions)
         
    for s in subs: 
        df_sub = df_long.loc[(df_long['Session'].isin(sessions)) & (df_long['Subject'] == s)] #df where Subject == s and where Session == startsess to endsess 
        mean_risk = df_sub['risk'].mean() #mean_risk
        for s in sessions:
            mean_risk_list.append(mean_risk)
    df_long["mean_risk"] = mean_risk_list 
    
    for row in df_long.index: #for each row
        if df_long.at[row,'mean_risk'] > 0:
            df_long.at[row,'risk_status'] = 1 
        elif df_long.at[row,'mean_risk'] < 0: 
            df_long.at[row,'risk_status'] = 2 
    return df_long

In [56]:
df_long = get_risk_status_long(df_long)
df_long

Unnamed: 0,Subject,Session,P1,P2,P3,P4,risk,collect_lat,choice_lat,omit,trial,prem,mean_risk,risk_status
0,1,29,90.839695,0.000000,8.396947,0.763359,81.679389,1.077168,0.638321,0,131.1,26.404494,73.904211,1.0
1,1,30,83.064516,0.000000,16.935484,0.000000,66.129032,1.286471,0.655323,0,124.0,24.848485,73.904211,1.0
2,2,29,9.333333,65.333333,10.666667,14.666667,49.333333,1.607407,1.057733,1,76.1,30.275229,56.148148,1.0
3,2,30,4.938272,76.543210,0.000000,18.518519,62.962963,1.387458,1.249012,0,81.0,33.606557,56.148148,1.0
4,3,29,2.173913,8.695652,56.521739,32.608696,-78.260870,1.023333,2.833261,3,49.0,3.921569,-77.502528,2.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
57,30,30,0.775194,93.798450,5.426357,0.000000,89.147287,0.998922,0.593953,0,129.0,15.131579,94.573643,1.0
58,31,29,7.272727,83.636364,3.636364,5.454545,81.818182,1.424286,2.057091,0,55.0,12.698413,80.796731,1.0
59,31,30,2.247191,87.640449,7.865169,2.247191,79.775281,1.575735,2.679663,1,90.0,10.891089,80.796731,1.0
60,32,29,0.000000,47.619048,7.936508,44.444444,-4.761905,0.976216,3.270794,12,75.0,6.250000,1.587302,1.0


In [57]:
def get_group_long(df_long, group_list):
    """takes in df_long and group_list, and creates a column called group. Group == 1 represents the first group passed to group_list, and so on, up to 4 groups"""
    for row in df_long.index: 
        for group in group_list: 
            if np.isin(df_long.at[row,'Subject'], group):
                if group == group_list[0]:
                    df_long.at[row,'group'] = 1
                elif group == group_list[1]:
                    df_long.at[row,'group'] = 2
                elif group == group_list[2]:
                    df_long.at[row,'group'] = 3
                elif group == group_list[3]:
                    df_long.at[row,'group'] = 4
    return df_long

get_group_long(df_long, group_list)

Unnamed: 0,Subject,Session,P1,P2,P3,P4,risk,collect_lat,choice_lat,omit,trial,prem,mean_risk,risk_status,group
0,1,29,90.839695,0.000000,8.396947,0.763359,81.679389,1.077168,0.638321,0,131.1,26.404494,73.904211,1.0,1.0
1,1,30,83.064516,0.000000,16.935484,0.000000,66.129032,1.286471,0.655323,0,124.0,24.848485,73.904211,1.0,1.0
2,2,29,9.333333,65.333333,10.666667,14.666667,49.333333,1.607407,1.057733,1,76.1,30.275229,56.148148,1.0,1.0
3,2,30,4.938272,76.543210,0.000000,18.518519,62.962963,1.387458,1.249012,0,81.0,33.606557,56.148148,1.0,1.0
4,3,29,2.173913,8.695652,56.521739,32.608696,-78.260870,1.023333,2.833261,3,49.0,3.921569,-77.502528,2.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
57,30,30,0.775194,93.798450,5.426357,0.000000,89.147287,0.998922,0.593953,0,129.0,15.131579,94.573643,1.0,3.0
58,31,29,7.272727,83.636364,3.636364,5.454545,81.818182,1.424286,2.057091,0,55.0,12.698413,80.796731,1.0,3.0
59,31,30,2.247191,87.640449,7.865169,2.247191,79.775281,1.575735,2.679663,1,90.0,10.891089,80.796731,1.0,3.0
60,32,29,0.000000,47.619048,7.936508,44.444444,-4.761905,0.976216,3.270794,12,75.0,6.250000,1.587302,1.0,3.0


## Run Anova on selected variables

In [46]:
def filt_vars(df_long_sum, variables = None, task = None): 
    """Takes in long df, variables (list object), and task... and outputs/updates variables"""
    if variables == None: #then run all variables
        if task == 'choiceRGT':
            variables = df_long_sum.columns[2:25] #from 'P1_C' to 'pref'
        else:
            variables = df_long_sum.columns[2:12] #from 'P1' to 'prem'
    return variables

In [47]:
def filt_sess(df_long_sum, sessions = None):   
    """Takes in long df and sessions (list object)... and outputs/updates sessions"""
    if sessions != None:
        df_long_sum = df_long_sum[df_long_sum.Session.isin(sessions)] 
    else:
        sessions = list(df_long_sum.Session.unique()) #all sessions
    return sessions

In [48]:
def mixed_anova2(df_long_sum, bsf, variables = None, sessions = None, show_df = None, task = None):
    """Takes in long-summary data with between-subjects factors of interest as separate columns (dataframe), 1 between-subjects factor (string), outcome variables (list), sessions (list),
    specific variables (list) and task...and runs a mixed ANOVA on those variables across those sessions.
    If nothing is passed to variables or sessions, the RM ANOVA will be run on all variables and sessions in df_long_sum
    Must pass at least 2 session numbers to sessions
    The function will output a list of unstable variable(s) and stable variables, and dfs for the variables specified in show_df"""
    
    #filter rows by sessions, and select columns by variables
    variables = filt_vars(df_long_sum, variables, task)
    sessions = filt_sess(df_long_sum, sessions)
        
    #run anova
    unstable_list = []
    stable_list = []
    for var in variables: 
        pvals = []
        res = pg.mixed_anova(df_long_sum, dv=var, within='Session', subject='Subject', between = bsf)
        if (var in show_df):
            print(f'{res}{var}')
        for pval in range(3):
            pvals.append(res['p-unc'][pval])
        if any(x < 0.05 for x in pvals):
            unstable_list.append(var)
        else: 
            stable_list.append(var)
            
    #return unstable and stable list
    print(f'Unstable list: {unstable_list}\nStable list: {stable_list}')

In [49]:
mixed_anova2(df_long, "group", show_df = ['collect_lat', 'P1'])

        Source         SS  DF1  DF2         MS         F     p-unc       np2  \
0        group  49.485990    2   28  24.742995  0.047486  0.953701  0.003380   
1      Session   0.720645    1   28   0.720645  0.056485  0.813871  0.002013   
2  Interaction   2.265940    2   28   1.132970  0.088804  0.915282  0.006303   

   eps  
0  NaN  
1  1.0  
2  NaN  P1
        Source        SS  DF1  DF2        MS         F     p-unc       np2  \
0        group  1.985461    2   28  0.992731  4.116795  0.027080  0.227236   
1      Session  0.000181    1   28  0.000181  0.010282  0.919955  0.000367   
2  Interaction  0.018738    2   28  0.009369  0.531856  0.593331  0.036599   

   eps  
0  NaN  
1  1.0  
2  NaN  collect_lat
Unstable list: ['collect_lat']
Stable list: ['P1', 'P2', 'P3', 'P4', 'risk', 'choice_lat', 'omit', 'trial', 'prem']


Change output to variable:p-value df

In [50]:
def check_stability2(df_long_sum, variables = None, sessions = None, task = None): 
    """Takes in long-summary data (df), outcome variables (list), sessions (list) and task... then runs a RM ANOVA on those variables across those sessions.
    If nothing is passed to variables or sessions, the RM ANOVA will be run on all variables and sessions in df_long_sum
    Must pass at least 2 session numbers to sessions
    The function will print the unstable df, and stable df"""
    
    #filter rows by sessions, and select columns by variables
    variables = filt_vars(df_long_sum, variables, task)
    sessions = filt_sess(df_long_sum, sessions)
    
    #run anova
    unstable_dict = {} #dict was required to make the list run down the dataframe (as opposed to across)
    unstable_pvals = []
    unstable_vars = []
    stable_dict = {}
    stable_pvals = []
    stable_vars = []
    
    for var in variables: #for each variable, run a RM anova
        res = pg.rm_anova(dv=var, within='Session', subject='Subject', data=df_long_sum, detailed=True)
        pval = res['p-unc'][0]
        if pval < 0.05: 
            unstable_pvals.append(pval)
            unstable_vars.append(var)
        else: 
            stable_pvals.append(pval)
            stable_vars.append(var)
            
    unstable_dict['variable'] = unstable_vars
    unstable_dict['p-value'] = unstable_pvals
    unstable_df = pd.DataFrame(data=unstable_dict)
    stable_dict['variable'] = stable_vars
    stable_dict['p-value'] = stable_pvals
    stable_df = pd.DataFrame(data=stable_dict)
            
    #print unstable and stable dataframes
    print(f'unstable df: {unstable_df}\nStable df: {stable_df}')

In [51]:
check_stability2(df_long, variables = None)

unstable df: Empty DataFrame
Columns: [variable, p-value]
Index: []
Stable df:       variable   p-value
0           P1  0.807947
1           P2  0.304521
2           P3  0.154876
3           P4  0.594188
4         risk  0.239622
5  collect_lat  0.918632
6   choice_lat  0.884529
7         omit  0.099792
8        trial  0.336109
9         prem  0.547884
