In [None]:
#From BH06 latin square analysis

def check_groups(df):
    pd.set_option('display.max_rows', None)
    print(df.groupby(['Subject','Group'])['Trial'].max())
    pd.set_option('display.max_rows',df.Subject.max())
    
def get_risk_status_vehicle(df1):
    #get risk status from 'risk1' only - ie in the case of LS, the saline dose
    #create lists for indexing based on risk score
    risky = []
    optimal = []
    for sub in df1.index:
        if df1.at[sub,'risk1'] > 0:
            df1.at[sub,'risk_status'] = 1
            optimal.append(sub)
        elif df1.at[sub,'risk1'] < 0:
            df1.at[sub,'risk_status'] = 2
            risky.append(sub)
    return df1, risky, optimal

def ls_bar_plot(figure_group, group_means, sem):
    labels = ['P1','P2','P3','P4']
    veh_means =  list(group_means.loc[figure_group,[col for col in group_means.columns if col.startswith('1')]])
    dose1_means = list(group_means.loc[figure_group,[col for col in group_means.columns if col.startswith('2')]])
    dose2_means = list(group_means.loc[figure_group,[col for col in group_means.columns if col.startswith('3')]])
    dose3_means = list(group_means.loc[figure_group,[col for col in group_means.columns if col.startswith('4')]])

    x = np.arange(len(labels))*3  # the label locations
    width = 0.5  # the width of the bars

    fig, ax = plt.subplots(figsize = (10,5))
    rects1 = ax.bar(x - width/2-width, veh_means, width, label='Vehicle',
                    yerr = list(sem.loc[figure_group,[col for col in sem.columns if col.startswith('1')]]), capsize = 8, ecolor='C0')
    rects2 = ax.bar(x - width/2, dose1_means, width, label='Dose 1', 
                    yerr =list(sem.loc[figure_group,[col for col in sem.columns if col.startswith('2')]]), capsize = 8, ecolor='C1')
    rects3 = ax.bar(x + width/2, dose2_means, width, label='Dose 2',
                    yerr = list(sem.loc[figure_group,[col for col in sem.columns if col.startswith('3')]]),capsize = 8, ecolor='C3',color = 'C3')
    rects4 = ax.bar(x + width/2+width, dose3_means, width, label='Dose 3',capsize = 8, 
                    yerr = list(sem.loc[figure_group,[col for col in sem.columns if col.startswith('4')]]), ecolor='C2', color = 'C2')

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_ylabel('% Choice', fontweight = 'bold', fontsize = 22)
    ax.set_title(figure_group + ': P1-P4', fontweight = 'bold', fontsize = 24, pad = 20)
    ax.set_ylim(bottom = 0,top = 85)
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.spines['right'].set_linewidth(0)
    ax.spines['top'].set_linewidth(0)
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)
    ax.legend()

In [None]:
#from BH06 baseline analysis - acquisition - by session

#this is the same as the previous load data, except that there is an option to reset the session numbers starting from 1
def load_data(fnames, reset_sessions = False): 
#load data from computer
    for i,file in enumerate(fnames):
        if i == 0:
            df = pd.read_excel(fnames[i])
            if reset_sessions:
                for i,session in enumerate(df.Session.unique()):
                    for j in range(len(df)):
                        if df.at[j,'Session'] == session:
                            df.at[j,'Session'] = i + 1
        else:
            df2 = pd.read_excel(fnames[i])
            if reset_sessions:
                for i,session in enumerate(df2.Session.unique()):
                    for j in range(len(df2)):
                        if df2.at[j,'Session'] == session:
                            df2.at[j,'Session'] = i + 1
            df = df.append(df2, ignore_index = True)
    return df

def remove_subject(df, subs):
    for sub in subs:
        df.drop(list(df.loc[df['Subject']==sub].index), inplace=True)
    df.reset_index(inplace=True)
    return df

def edit_sess(df, orig_sess, new_sess,subs = 'all'):
    if subs == 'all':
        for i,sess in enumerate(orig_sess):
            for i in range(len(df)):
                if df.at[i,'Session']== sess:
                    df.at[i,'Session'] = new_sess[i]
    else:
        for sub in subs:
            index = list(df.loc[df['Subject']==sub].index)
            for i,sess in enumerate(orig_sess):
                for idx in index:
                    if df.at[idx,'Session'] == sess:
                        df.at[idx,'Session'] = new_sess[i]
    return df



In [None]:
#from rgt variant baseline analysis

#allows you to load in data from multiple projects, and assigns unique subject numbers 
def load_multiple_data(fnames, reset_sessions=False): 
#load multiple datasets from computer and redo subject numbers
    for i,file in enumerate(fnames):
        if i == 0:
            df = pd.read_excel(fnames[i])
            df['Subject'] += 100
            if reset_sessions:
                for i,session in enumerate(df.Session.unique()):
                    for j in range(len(df)):
                        if df.at[j,'Session'] == session:
                            df.at[j,'Session'] = i + 1
        else:
            df2 = pd.read_excel(fnames[i])
            df2['Subject'] += 100 * (1+i)
            if reset_sessions:
                for i,session in enumerate(df2.Session.unique()):
                    for j in range(len(df2)):
                        if df2.at[j,'Session'] == session:
                            df2.at[j,'Session'] = i + 1
            df = df.append(df2, ignore_index = True)
    return df

#takes the same data as rgt_plot, but creates a bar plot instead of a line plot
def rgt_bar_plot(variable,startsess,endsess,group_names,title,scores,sem, var_title = None):
    if var_title == None:
        var_title = variable
    plt.rcParams.update({'font.size': 20}) 
    plt.rcParams["figure.figsize"] = (15,8)
    
    bars = [0,0,0,0,0]
    err = [0,0,0,0,0]
    sess = [variable + str(s) for s in list(range(startsess,endsess+1))]
    for i,group in enumerate(scores.index):
        bars[i] = scores.loc[group, [col for col in scores.columns if col in sess]].mean()
        err[i] = sem.loc[group, [col for col in sem.columns if col in sess]].mean()
        plt.bar(i, bars[i], yerr = err[i], capsize = 8, width = .7, color = ['C'+str(i)], label = scores.index[i])
   
    ax = plt.gca()
    plt.xticks([])
    plt.xlabel('Task', labelpad = 20)
    ax.set_ylabel('% Premature')
    ax.set_ylim(0,50)
    ax.set_title(var_title,fontweight = 'bold', fontsize = 22, pad = 20)
    ax.spines['right'].set_linewidth(0)
    ax.spines['top'].set_linewidth(0)
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)
    ax.legend()