In [22]:
%load_ext autoreload
%autoreload 2
%matplotlib qt5
import sys
sys.path.append('..\..')
import matplotlib.pyplot as plt
import matplotlib
from utils.post_processing_utils import *
from utils.behaviour_utils import CalculateRBiasWindow
from scipy import stats
from utils.plotting import calculate_error_bars, multi_conditions_plot
from set_global_params import processed_data_path, state_change_mice, figure_directory, reproduce_figures_path, spreadsheet_path
from utils.stats import cohen_d_paired
from save_to_excel import save_figure_data_to_excel

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Making sure files are where they need to be - don't need to run if you have repro data

In [14]:
# In case the repro data file doesn't exist, copy it over from the raw path 
# (don't run if you don't have the full data set and only have repro data - will throw error)
repro_dir = os.path.join(reproduce_figures_path, 'ED_fig6', 'state_change_behaviour')
mice = state_change_mice['tail'] + state_change_mice['Nacc']
for mouse_num, mouse_id in enumerate(mice):
    state_change_data = {}
    exp_type = 'state change white noise'
    all_experiments = get_all_experimental_records()
    all_experiments = remove_unsuitable_recordings(all_experiments)
    experiment_to_process = all_experiments[(all_experiments['experiment_notes'] == exp_type) & (all_experiments['mouse_id'] == mouse_id)]
    copy_behaviour_to_folder_mouse_name(experiment_to_process, source_dir=processed_data_path, target_dir=repro_dir)

# Main code

In [15]:
def moving_average(x, window=50):
    rolling_average = np.empty(len(x))
    rolling_average[:] = np.nan
    for i in range(window, len(x)):
        win = range((i - window), i)
        rolling_average[i] = np.mean(x[win])
    return rolling_average

In [17]:
repro_dir = os.path.join(reproduce_figures_path, 'ED_fig6', 'state_change_behaviour')

In [18]:
pre_pc = []
post_pc = []
mice = state_change_mice['tail'] + state_change_mice['Nacc']
moving_avs = []
for mouse_num, mouse_id in enumerate(mice):
    state_change_data = {}
    exp_type = 'state change white noise'
    all_experiments = get_all_experimental_records()
    all_experiments = remove_unsuitable_recordings(all_experiments)
    experiment_to_process = all_experiments[(all_experiments['experiment_notes'] == exp_type) & (all_experiments['mouse_id'] == mouse_id)]
    contra_side = experiment_to_process['fiber_side'].values[0]
    fiber_options = np.array(['left', 'right'])
    contra_choice_ind = (np.where(fiber_options != contra_side)[0] + 1)[0]
    if contra_choice_ind == 2:
        fiber_side_numeric = 7
    else:
        fiber_side_numeric = 1
    print(contra_side,fiber_side_numeric)
    trial_data = open_experiment_just_behaviour(experiment_to_process, root_dir=repro_dir)
    trial_data.loc[trial_data['Trial outcome'] == 3, 'Trial outcome'] = 0
    only_contra = trial_data.loc[trial_data['Trial type'] == fiber_side_numeric] 
    red_trial_data = only_contra[only_contra['State name'] == 'TrialStart']
    post_trials = red_trial_data[np.logical_and(red_trial_data['Trial num'] >= 150, red_trial_data['Trial num'] <= 200)]
    post_trials = post_trials[post_trials['Trial type'] == fiber_side_numeric]
    wn_correct_trials =  post_trials[(post_trials['Response'] == contra_choice_ind) & (post_trials['Trial outcome'] == 1)] 
    wn_trial_type = wn_correct_trials['Trial type'].unique()[0]
    wn_trials = post_trials[post_trials['Trial type'] == wn_trial_type]
    post_pc.append(np.mean(wn_trials['Trial outcome'].values) * 100)
    pre_pc.append(np.mean(red_trial_data[np.logical_and(red_trial_data['Trial num'] < 150, red_trial_data['Trial num'] >= 100)]['Trial outcome'].values) * 100)
    moving_avs.append(moving_average(red_trial_data['Trial outcome'].values, window=20))

right 1
left 7
right 1
right 1
right 1
right 1
left 7
left 7
left 7
right 1
left 7
right 1
left 7


In [19]:
moving_avs = []
response_times = []
missed_trials = []
performance = []

mice = state_change_mice['tail'] + state_change_mice['Nacc']
moving_avs = []
for mouse_num, mouse_id in enumerate(mice):
    state_change_data = {}
    exp_type = 'state change white noise'
    all_experiments = get_all_experimental_records()
    all_experiments = remove_unsuitable_recordings(all_experiments)
    experiment_to_process = all_experiments[(all_experiments['experiment_notes'] == exp_type) & (all_experiments['mouse_id'] == mouse_id)]
    contra_side = experiment_to_process['fiber_side'].values[0]
    fiber_options = np.array(['right', 'left'])
    fiber_side_numeric = (np.where(fiber_options != contra_side)[0] + 1)[0]
    if fiber_side_numeric == 2:
        side = 1
    else:
        side = -1
    trial_data = open_experiment_just_behaviour(experiment_to_process, root_dir=repro_dir)
    trial_data.loc[trial_data['Trial outcome'] == 3, 'Trial outcome'] = 0
    red_trial_data = trial_data[trial_data['State name'] == 'TrialStart']
    red_trial_data_for_missed_trials = trial_data[trial_data['State name'] == 'TrialStart']
    red_trial_data_for_missed_trials.loc[trial_data['Trial outcome'] == 1, 'Trial outcome'] = 0
    red_trial_data_for_missed_trials.loc[trial_data['Trial outcome'] == 3, 'Trial outcome'] = 1
    red_trial_data.loc[trial_data['Trial outcome'] == 3, 'Trial outcome'] = 0
    response_trial_data = trial_data[trial_data['State name'] == 'WaitForResponse']
    post_trials = red_trial_data[np.logical_and(red_trial_data['Trial num'] >= 150, red_trial_data['Trial num'] <= 200)]
    pre_trials = red_trial_data[np.logical_and(red_trial_data['Trial num'] < 150, red_trial_data['Trial num'] >= 100)]
    moving_avs.append(CalculateRBiasWindow(red_trial_data['First response'].reset_index(drop=True), red_trial_data['First choice correct'].reset_index(drop=True), 20) *side *100)
    response_times.append(moving_average(response_trial_data['Time end'].values - response_trial_data['Time start'].values, window=20))
    missed_trials.append(moving_average(red_trial_data_for_missed_trials['Trial outcome'].values, window=20) * 100)
    performance.append(moving_average(red_trial_data['Trial outcome'].values, window=20) * 100)




In [27]:
def plot_rolling_mean(var, mice, moving_avs, y_axis_label='% bias to big reward side'):
    
    all_mice, min_num_trials = align_mulitple_mice_moving_avs(mice, moving_avs)
    error_bar_lower, error_bar_upper = calculate_error_bars(np.nanmean(all_mice, axis=0),
                                                                     all_mice,
                                                                    error_bar_method='sem')
    font = {'size': 8}
    matplotlib.rc('font', **font)
    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['font.sans-serif'] = 'Arial'
    matplotlib.rcParams['font.family']


    fig, axs = plt.subplots(1, 1, figsize=[2.5, 2])
    [axs.plot(m[:min_num_trials], alpha=0.5, c='gray', lw=0.5, label=f'mouse {i}') for i, m in enumerate(moving_avs)]
    axs.plot(np.nanmean(all_mice,axis=0), c='#5e8c89', lw=1, label='mean')
    axs.fill_between(np.arange(0, min_num_trials), error_bar_lower, error_bar_upper, alpha=0.5,
                                facecolor='#7FB5B5', linewidth=0)

    axs.axvline(150, color='k', label='state change onset marker')
    axs.set_xlabel('trial number')
    axs.set_ylabel(y_axis_label)
    axs.spines['right'].set_visible(False)
    axs.spines['top'].set_visible(False)
    plt.tight_layout()
    plt.tight_layout()
    #plt.savefig(os.path.join(figure_directory, '{} developing over trials.pdf'.format(var)))

    plt.show()
    return fig

In [10]:
def calc_per_mouse_mean(mouse_moving_avs):
    num_sessions = len(mouse_moving_avs)
    num_trials = [len(m) for m in mouse_moving_avs]
    min_num_trials = min(num_trials)
    all_sessions = np.empty((num_sessions, min_num_trials))
    all_sessions[:] = np.nan
    for i, session_data in enumerate(mouse_moving_avs):
        all_sessions[i, :min_num_trials] = session_data[:min_num_trials]
    moving_av = np.nanmean(all_sessions, axis=0)
    return moving_av

In [11]:
def align_mulitple_mice_moving_avs(mice, moving_avs):
    num_mice = len(mice)
    num_trials = [len(m) for m in moving_avs]
    min_num_trials = min([len(m) for m in moving_avs])
    all_mice = np.empty((num_mice, min_num_trials))
    all_mice[:] = np.nan
    for i, mouse_data in enumerate(moving_avs):
        all_mice[i, :min_num_trials] = mouse_data[:min_num_trials]
    return all_mice, min_num_trials

In [28]:
bias_fig = plot_rolling_mean('bias', mice, moving_avs, y_axis_label='% bias to ipsi side')
bias_xl = os.path.join(spreadsheet_path, 'ED_fig6', 'ED_fig6T_bias_state_change.xlsx')
if not os.path.exists(bias_xl):
    save_figure_data_to_excel(bias_fig, bias_xl)

  error_bar_lower, error_bar_upper = calculate_error_bars(np.nanmean(all_mice, axis=0),
  axs.plot(np.nanmean(all_mice,axis=0), c='#5e8c89', lw=1, label='mean')


Data has been saved to S:\projects\APE_data_francesca_for_paper\spreadsheets_for_nature\ED_fig6\ED_fig6T_bias_state_change.xlsx


In [29]:
response_time_fig = plot_rolling_mean('response time', mice, response_times, y_axis_label='response time (s)')
response_time_xl = os.path.join(spreadsheet_path, 'ED_fig6', 'ED_fig6S_response_time_state_change.xlsx')
if not os.path.exists(response_time_xl):
    save_figure_data_to_excel(response_time_fig, response_time_xl)

  error_bar_lower, error_bar_upper = calculate_error_bars(np.nanmean(all_mice, axis=0),
  axs.plot(np.nanmean(all_mice,axis=0), c='#5e8c89', lw=1, label='mean')


Data has been saved to S:\projects\APE_data_francesca_for_paper\spreadsheets_for_nature\ED_fig6\ED_fig6S_response_time_state_change.xlsx


In [34]:
performance_fig = plot_rolling_mean('performance', mice, performance, y_axis_label='% correct')
performance_xl = os.path.join(spreadsheet_path, 'ED_fig6', 'ED_fig6R_performance_state_change.xlsx')
if not os.path.exists(performance_xl):
    save_figure_data_to_excel(performance_fig, performance_xl)

  error_bar_lower, error_bar_upper = calculate_error_bars(np.nanmean(all_mice, axis=0),
  axs.plot(np.nanmean(all_mice,axis=0), c='#5e8c89', lw=1, label='mean')


In [31]:
missed_trials_fig = plot_rolling_mean('missed trials', mice, missed_trials, y_axis_label='% missed trials')

  error_bar_lower, error_bar_upper = calculate_error_bars(np.nanmean(all_mice, axis=0),
  axs.plot(np.nanmean(all_mice,axis=0), c='#5e8c89', lw=1, label='mean')


In [37]:
behavioural_change = {}
behavioural_change['mouse'] = mice
behavioural_change['pre performance'] = pre_pc
behavioural_change['post performance'] = post_pc
behavioural_change_df = pd.DataFrame(behavioural_change)

In [38]:
df_for_plot = behavioural_change_df.set_index('mouse').transpose()

In [41]:
comparison_csv = os.path.join(spreadsheet_path, 'ED_fig6', 'ED_fig6O_pre_post_state_change.csv')
if not os.path.exists(comparison_csv):
    (df_for_plot.T).to_csv(comparison_csv)


In [16]:
from utils.plotting import output_significance_stars_from_pval
font = {'size': 7}
matplotlib.rc('font', **font)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['font.sans-serif'] = 'Arial'
matplotlib.rcParams['font.family']



fig, ax = plt.subplots(figsize=[1.5, 1.5])
multi_conditions_plot(ax, df_for_plot, mean_line_color='#7FB5B5', mean_linewidth=0, show_err_bar=False)
plt.xticks([0, 1], ['Tone', 'WN'])
plt.ylabel('Performance (%)')
plt.tight_layout()
#significance stars
y = df_for_plot.to_numpy().max() + 2
h = 2
plt.plot([0, 0, 1, 1], [y, y+h, y+h, y],c='k',lw=1)

pre_data = df_for_plot.T['pre performance']
post_data = df_for_plot.T['post performance']
stat, pval = stats.ttest_rel(pre_data, post_data)
stars = output_significance_stars_from_pval(pval)
ax.text(.5, y+h, stars, ha='center', fontsize=10)



Text(0.5, 104.0, '***')

In [17]:
pre_data = df_for_plot.T['pre performance']
post_data = df_for_plot.T['post performance']
stat, pval = stats.ttest_rel(pre_data, post_data)
pval

0.00019880481396597073

In [18]:
cohen_d_paired(pre_data, post_data)

cohen d:  1.4607844475894787


1.4607844475894787

In [46]:
from scipy.stats import shapiro

In [47]:
differences = pre_data - post_data
shapiro(differences)

ShapiroResult(statistic=0.9323903918266296, pvalue=0.3661174774169922)