In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os

In [2]:
def process_data(df):
    # process the data
    # the first 32 rows are the metadata
    metadata = df.iloc[:32,:2]
    metadata = metadata.T.reset_index(drop=True)
    metadata.columns = metadata.iloc[0]
    metadata = metadata.iloc[1:]

    # the rest is the data
    data = df.iloc[33:]
    # make the columns the first row
    data.columns = data.iloc[0]
    units = data.iloc[1]
    data = data.iloc[2:]
    data = data.reset_index(drop=True)
    data = data.replace('-', np.nan)
    # make a arm column
    arms = data[[x for x in data.columns if 'In zone' in x]]
    arms = arms.iloc[:,1:]
    # for each row, find the column with the max value
    arms = arms.idxmax(axis=1)
    arms = arms.str.replace('Center Stage', 'Arm0')
    arms = arms.fillna(-1)
    # convert to int
    converter = lambda x: int(x.replace('In zone(Arm', '').replace(' / center-point)', '')) if x != -1 else -1
    arms = arms.apply(converter).values
    # clean the data
    # cleaning rules:
    # if there is a stretch of -1s, with the same value before and after, replace the -1s with that value
    # if there is a stretch of -1s, with different values before and after, replace the -1s with the value with 0
    # if there is a stretch of -1s at the beginning or the end of the array, do nothing

    # loop through the array
    in_stretch = False
    stretch_start = 0
    for i in range(len(arms)):
        if arms[i] == -1 and not in_stretch:
            in_stretch = True
            stretch_start = i
        elif arms[i] == -1 and in_stretch:
            pass
        else:
            if in_stretch:
                if stretch_start == 0 or i == len(arms):
                    # at the beginning or the end (do nothing)
                    pass
                else:
                    # in the middle
                    if arms[stretch_start-1] == arms[i]:
                        arms[stretch_start:i] = arms[stretch_start-1]
                    else:
                        arms[stretch_start:i] = 0
                in_stretch = False
            else:
                pass

    # get encounter durations from arms
    times = data['Trial time'].values
    stretch_start = 0
    stretch_start_time = times[0]
    current_state = arms[0]


    encounter_states = []
    encounter_durations = []
    current_encounter_no = [0]
    for i in range(1,len(arms)):
        if arms[i] != current_state:
            # encounter ended
            encounter_states.append(current_state)
            encounter_durations.append(times[i] - stretch_start_time)
            current_state = arms[i]
            stretch_start_time = times[i]
            stretch_start = i
        else:
            pass
        current_encounter_no.append(len(encounter_states))

    # add the last state
    encounter_states.append(current_state)
    encounter_durations.append(times[-1] - stretch_start_time)

    # convert to numpy arrays
    encounter_states = np.array(encounter_states)
    encounter_durations = np.array(encounter_durations)
    current_encounter_no = np.array(current_encounter_no)
    current_encounter_duration = np.zeros(len(arms))
    for i in range(len(encounter_states)):
        current_encounter_duration[current_encounter_no == i] = encounter_durations[i]

    # set the arms with < 1s duration to 0
    arms[np.logical_and(np.logical_and(arms != -1, arms != 0), current_encounter_duration < 1)] = 0
    # recalculate the encounter durations
    stretch_start = 0
    stretch_start_time = times[0]
    current_state = arms[0]

    encounter_states = []
    encounter_durations = []
    current_encounter_no = [0]
    for i in range(1,len(arms)):
        if arms[i] != current_state:
            # encounter ended
            encounter_states.append(current_state)
            encounter_durations.append(times[i] - stretch_start_time)
            current_state = arms[i]
            stretch_start_time = times[i]
            stretch_start = i
        else:
            pass
        current_encounter_no.append(len(encounter_states))

    # add the last state
    encounter_states.append(current_state)
    encounter_durations.append(times[-1] - stretch_start_time)

    # convert to numpy arrays
    encounter_states = np.array(encounter_states)
    encounter_durations = np.array(encounter_durations)
    current_encounter_no = np.array(current_encounter_no)
    current_encounter_duration = np.zeros(len(arms))
    for i in range(len(encounter_states)):
        current_encounter_duration[current_encounter_no == i] = encounter_durations[i]

    # add the arms, encounter number and encounter duration to the data
    data['Arm'] = arms
    data['Encounter number'] = current_encounter_no
    data['Encounter duration'] = current_encounter_duration

    return data, metadata


In [3]:
# function to find novel sequences
def novel_sequence_finder(arm_sequence):
    novel_sequences = []
    novel_sequences_start = []
    novel_sequences_end = []
    past_set = set()
    index = 0
    i = 1
    while i < len(arm_sequence):
        new_set = set(arm_sequence[index:i])
        if len(new_set) > len(past_set):
            past_set = new_set
        else:
            # end of the sequence
            past_sequence = arm_sequence[index:i-1]
            novel_sequences.append(past_sequence)
            novel_sequences_start.append(index)
            novel_sequences_end.append(i-1)
            # get the repeated arm
            repeated_arm = arm_sequence[i-1]
            # get the last time the currently repeated arm was novel
            last_novel_index = np.where(past_sequence[::-1] == repeated_arm)[0][0]
            # reset the index to the last time the sequence was novel
            index = i-1-last_novel_index
            i = index
            # reset the past set
            past_set = set()
        i += 1
    # add the last sequence
    past_sequence = arm_sequence[index:]
    novel_sequences.append(past_sequence)
    novel_sequences_start.append(index)
    novel_sequences_end.append(len(arm_sequence)-1)

    # convert to numpy arrays
    novel_sequences_length = np.array([len(x) for x in novel_sequences])
    novel_sequences_start = np.array(novel_sequences_start)
    novel_sequences_end = np.array(novel_sequences_end)

    return novel_sequences, novel_sequences_length, novel_sequences_start, novel_sequences_end

# function to generate sequence from heuristic visitation
def generate_sequence_from_distribution(heuristic_visits, n_sequences, sequence_length):
    # print(heuristic_visits)
    heuristic_map = {
        0: -3, # 3rd right
        1: -2, # 2nd right
        2: -1, # 1st right
        3: 0, # same
        4: 1, # 1st left
        5: 2, # 2nd left
        6: 3, # 3rd left
        7: 4 # straight
    }
    heuristic_prob = np.array(heuristic_visits) / np.sum(heuristic_visits)
    sequences = []
    for i in range(n_sequences):
        sequence = []
        current_arm = np.random.choice(8)
        for j in range(sequence_length):
            next_heuristic = np.random.choice(8, p=heuristic_prob)
            sequence.append((current_arm + heuristic_map[next_heuristic]) % 8)
            current_arm = sequence[-1]
        sequences.append(sequence)
    # convert to numpy array and add 1 to the arms to make them 1-indexed
    sequences = np.array(sequences)+1
    return sequences

# function to generate sequence from heuristic transition matrix
def generate_sequence_from_transition_matrix(transition_matrix, n_sequences, sequence_length):
    heuristic_map = {
        0: -3, # 3rd right
        1: -2, # 2nd right
        2: -1, # 1st right
        3: 0, # same
        4: 1, # 1st left
        5: 2, # 2nd left
        6: 3, # 3rd left
        7: 4 # straight
    }
    # get the reduced array before the transition matrix is normalized
    reduced_array = transition_matrix.sum(axis=1)
    reduced_array = reduced_array / reduced_array.sum()
    # if any row is all zeros, replace with the reduced array
    for i in range(8):
        if transition_matrix[i].sum() == 0:
            transition_matrix[i] = reduced_array
    # normalize the transition matrix
    transition_matrix = transition_matrix / transition_matrix.sum(axis=1)[:, np.newaxis]
    # print(transition_matrix)
    sequences = []
    for i in range(n_sequences):
        sequence = []
        current_arm = np.random.choice(8)
        current_heuristic = np.random.choice(8, p=reduced_array)
        for j in range(sequence_length):
            # apply the heuristic to the current arm
            sequence.append((current_arm + heuristic_map[current_heuristic]) % 8)
            current_arm = sequence[-1]
            # get the next heuristic
            current_heuristic = np.random.choice(8, p=transition_matrix[current_heuristic])
        sequences.append(sequence)
    # convert to numpy array and add 1 to the arms to make them 1-indexed
    sequences = np.array(sequences)+1
    return sequences

In [4]:
def plot_data(data, metadata, session_id, savepath=None):
    # convert to arm sequence
    for i in range(len(data)):
        df = data[i]
        meta = metadata[i]
        mouse_id = meta['Mouse-ID'].values[0]

        # get the arm sequence
        arm_sequence = df['Arm'].values
        # remove the duplicates except where the arms before and after are the same
        diffs = np.concatenate(([1], np.diff(arm_sequence)))
        arm_sequence = arm_sequence[diffs != 0]
        # remove the -1s and 0s
        arm_sequence = arm_sequence[np.logical_and(arm_sequence != -1, arm_sequence != 0)]

        # make subplots
        fig = plt.figure(figsize=(8,15))

        ax = fig.subplots(6,1)
        # plot the arm sequence
        ax[0].scatter(df['Trial time'], df['Arm'], s=1, c=df['Arm'], cmap='tab10')
        ax[0].set_xlabel('Time')
        ax[0].set_ylabel('Arm')
        ax[0].set_xlim(0, df['Trial time'].values[-1])
        ax[0].set_ylim(-0.5, 8.5)
        ax[0].set_yticks(np.arange(0,9))
        ax[0].set_yticklabels(['Center']+[f'Arm {i}' for i in range(1,9)])
        # turn off the box and tick marks
        for spine in ax[0].spines.values():
            spine.set_visible(False)
        ax[0].tick_params(left=False)

        # turn off the other subplots
        ax[1].axis('off')
        ax[2].axis('off')
        ax[3].axis('off')
        ax[4].axis('off')
        ax[5].axis('off')
        


        # for each arm, how many times it was visited
        n_visits = np.zeros(8)
        for i in range(8):
            n_visits[i] = np.sum(arm_sequence == i+1)

        # plot the number of visits in a bar plot on the lower left
        ax = fig.add_subplot(6,2,3)
        ax.bar(np.arange(8)+1, n_visits)
        ax.axhline(np.mean(n_visits), color='black', linestyle='--')
        ax.set_xlabel('Arm')
        ax.set_ylabel('Number of visits')
        ax.set_xticks(np.arange(1,9))
        # turn off the box and tick marks
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.tick_params(left=False)
        
        # create arm to arm transition matrix
        transitions = np.zeros((8,8))
        for i in range(len(arm_sequence)-1):
            transitions[arm_sequence[i]-1, arm_sequence[i+1]-1] += 1

        # plot the transition matrix and show counts
        ax = fig.add_subplot(6,2,4)
        ax.imshow(transitions, cmap='Blues', vmin=0, vmax=8)
        for i in range(8):
            for j in range(8):
                plt.text(j, i, int(transitions[i, j]) if transitions[i, j] != 0 else '', ha='center', va='center', color='black')
        ax.set_xlabel('To Arm')
        ax.set_ylabel('From Arm')
        ax.set_xticks(np.arange(8))
        ax.set_xticklabels(np.arange(1,9))
        ax.set_yticks(np.arange(8))
        ax.set_yticklabels(np.arange(1,9))
        # turn off the box and tick marks
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.tick_params(left=False)
        ax.title.set_text('Arm to Arm Transitions')

        # turn to heuristic arm sequence
        heuristic_arm_sequence = []
        for i in range(1, len(arm_sequence)):
            # difference from the previous arm account for circularity
            shortest_distance = min(abs(arm_sequence[i] - arm_sequence[i-1]), 8 - abs(arm_sequence[i] - arm_sequence[i-1]))
            # get the sign of the difference
            sign_diff = np.sign(arm_sequence[i] - arm_sequence[i-1])
            sign_direction = +1 if min(abs(arm_sequence[i] - arm_sequence[i-1]), 8 - abs(arm_sequence[i] - arm_sequence[i-1])) == abs(arm_sequence[i] - arm_sequence[i-1]) else -1
            heuristic_arm_sequence.append(shortest_distance * sign_diff * sign_direction if shortest_distance != 4 else 4)
        heuristic_arm_sequence = np.array(heuristic_arm_sequence)

        # plot the heuristic arm sequence
        ax = fig.add_subplot(6,1,3)
        ax.scatter(np.arange(len(heuristic_arm_sequence)), heuristic_arm_sequence, s=100, c=heuristic_arm_sequence, cmap='RdBu', edgecolor='black', vmin=-3, vmax=3)
        ax.set_ylabel('Heuristic')
        ax.set_xlim(-1, len(heuristic_arm_sequence)+1)
        ax.set_xticks([])
        ax.set_ylim(-3.5, 4.5)
        ax.set_yticks([-3, -2, -1, 0, 1, 2, 3, 4])
        ax.set_yticklabels(['3rd Right', '2nd Right', '1st Right', 'Same', '1st Left', '2nd Left', '3rd Left','Straight'])
        # turn off the box and tick marks
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.tick_params(left=False)

        # for each heuristic, how many times it was utilized
        n_visits = np.zeros(8)
        for i in np.arange(-3, 5):
            n_visits[i+3] = np.sum(heuristic_arm_sequence == i)

        # plot the number of visits in a bar plot on the lower right
        ax = fig.add_subplot(6,2,7)
        ax.bar(np.arange(-3,5), n_visits)
        ax.set_xlabel('Heuristic Used')
        ax.set_ylabel('Number of occurrences')
        ax.set_xticks(np.arange(-3,5))
        ax.set_xticklabels(['3rd Right', '2nd Right', '1st Right', 'Same', '1st Left', '2nd Left', '3rd Left','Straight'], rotation=45)
        # turn off the box and tick marks
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.tick_params(left=False)

        # create heuristic to heuristic transition matrix
        transitions = np.zeros((8,8))
        for i in range(len(heuristic_arm_sequence)-1):
            transitions[heuristic_arm_sequence[i]+3, heuristic_arm_sequence[i+1]+3] += 1

        # plot the transition matrix and show counts
        ax = fig.add_subplot(6,2,8)
        ax.imshow(transitions, cmap='Blues', vmin=0, vmax=8)
        for i in range(8):
            for j in range(8):
                plt.text(j, i, int(transitions[i, j]) if transitions[i, j] != 0 else '', ha='center', va='center', color='black')
        ax.set_xlabel('To Heuristic')
        ax.set_ylabel('From Heuristic')
        ax.set_xticks(np.arange(8))
        ax.set_xticklabels(['3rd Right', '2nd Right', '1st Right', 'Same', '1st Left', '2nd Left', '3rd Left','Straight'], rotation=90)
        ax.set_yticks(np.arange(8))
        ax.set_yticklabels(['3rd Right', '2nd Right', '1st Right', 'Same', '1st Left', '2nd Left', '3rd Left','Straight'], rotation=0)
        # turn off the box and tick marks
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.tick_params(left=False)
        ax.title.set_text('Heuristic to Heuristic Transitions')

        # get the novel sequences
        novel_sequences, novel_sequences_length, novel_sequences_start, novel_sequences_end = novel_sequence_finder(arm_sequence)
        
        # plot the arm sequence (similar to the heuristic arm sequence)
        ax = fig.add_subplot(3,1,3)
        ax.scatter(np.arange(len(arm_sequence)), arm_sequence, s=100, c=arm_sequence, cmap='tab10', edgecolor='black', vmin=1, vmax=8)
        # add lines for novel sequences
        y_pos = 10
        for i in range(len(novel_sequences_start)):
            if i > 0 and novel_sequences_start[i] < novel_sequences_end[i-1]:
                y_pos += 1
            else:
                y_pos = 10
            ax.plot([novel_sequences_start[i], novel_sequences_end[i]-1], [y_pos, y_pos], color=plt.cm.RdBu(novel_sequences_length[i]/8), linewidth=3)
        ax.set_xlim(-1, len(arm_sequence)+1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_ylabel('Novel Sequences')
        # turn off the box and tick marks
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.tick_params(left=False)

        # generate heuristic distribution bases sequences
        seq_length = len(heuristic_arm_sequence)
        heuristic_dist_seqs = generate_sequence_from_distribution(n_visits, 100, seq_length)
        # get the novel sequence lengths
        novel_seq_len_dist = np.concatenate([novel_sequence_finder(x)[1] for x in heuristic_dist_seqs])

        # generate heuristic transition matrix based sequences
        heuristic_trans_seqs = generate_sequence_from_transition_matrix(transitions, 100, seq_length)
        # get the novel sequence lengths
        novel_seq_len_trans = np.concatenate([novel_sequence_finder(x)[1] for x in heuristic_trans_seqs]).flatten()

        # generate shuffled sequences
        shuffled_seqs = []
        for i in range(100):
            shuffled_seqs.append(np.random.permutation(arm_sequence))
        # get the novel sequence lengths
        novel_seq_len_shuffled = np.concatenate([novel_sequence_finder(x)[1] for x in shuffled_seqs]).flatten()


        # add a axis within the plot for the novel sequences histogram (top left of the novel sequences plot)
        ax = fig.add_axes([0.15, 0.2, 0.2, 0.05])
        values, counts = np.unique(novel_sequences_length, return_counts=True)
        counts = counts / np.sum(counts)
        ax.bar(values, counts, color='C0', label='Data')
        # add the mean line
        ax.axvline(np.mean(novel_sequences_length), color='black', linestyle='--')

        # plot the other novel sequence length distributions in the same axis
        values, counts = np.unique(novel_seq_len_dist, return_counts=True)
        counts = counts / np.sum(counts)
        ax.plot(values, counts, color='grey', label='H. Dist. Samples')
        values, counts = np.unique(novel_seq_len_trans, return_counts=True)
        counts = counts / np.sum(counts)
        ax.plot(values, counts, color='black', label='H. Trans. Samples')
        values, counts = np.unique(novel_seq_len_shuffled, return_counts=True)
        counts = counts / np.sum(counts)
        ax.plot(values, counts, color='red', label='Shuffled Samples')
    
        ax.set_xlabel('Novel Sequence Length')
        ax.set_ylabel('Frequency')
        ax.set_xticks(np.arange(min(values), max(values)+1))
        ax.set_xticklabels(np.arange(min(values), max(values)+1))
        # turn off the box and tick marks
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.tick_params(left=False)
        # make the background transparent
        ax.patch.set_alpha(0.0)

        # add legend
        ax.legend(loc='center right', bbox_to_anchor=(2.2, 0.5), frameon=False)



        # add suptitle
        fig.suptitle(f'Mouse {mouse_id} {session_id} Sequence Analysis', fontsize=16)
        plt.tight_layout()
        if savepath is not None:
            # create the directory if it doesn't exist
            if not os.path.exists(savepath):
                os.makedirs(savepath)
            plt.savefig(os.path.join(savepath, f'{mouse_id}_{session_id}.png'))
            plt.close()
        else:
            plt.show()


# Baseline Day

In [5]:
baseline_folder = '../data/pilot_data/2-2-24/'
# find xlsx files in the folder
data_files = os.listdir(baseline_folder)
data_files = [f for f in data_files if f.endswith('.xlsx')]
# sort the files and keep only last 4
data_files.sort()
data_files = data_files[-4:]
print(data_files)

['Raw data-8arm_pilot-Trial     2.xlsx', 'Raw data-8arm_pilot-Trial     3.xlsx', 'Raw data-8arm_pilot-Trial     4.xlsx', 'Raw data-8arm_pilot-Trial     5.xlsx']


In [6]:
# load the data
dfs = [pd.read_excel(baseline_folder + f, engine='openpyxl') for f in data_files]

In [7]:
metadata = []
data = []
for df in dfs:
    d, m = process_data(df)
    metadata.append(m)
    data.append(d)

  arms = arms.idxmax(axis=1)
  arms = arms.idxmax(axis=1)


In [8]:
# plot the data
plot_data(data, metadata, 'baseline', savepath='../processed_data/pilot_data/')

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


# Day 1 (No Weighing Boat)

In [9]:
day1_folder = '../data/pilot_data/2-5-24/'
# find xlsx files in the folder
data_files = os.listdir(day1_folder)
data_files = [f for f in data_files if f.endswith('.xlsx')]
# sort the files and keep only last 4
data_files.sort()
data_files = data_files[-4:]
print(data_files)

['Raw data-8arm_pilot-Trial     7.xlsx', 'Raw data-8arm_pilot-Trial     8.xlsx', 'Raw data-8arm_pilot-Trial     9.xlsx', 'Raw data-8arm_pilot-Trial    10.xlsx']


In [10]:
# load the data
dfs = [pd.read_excel(day1_folder + f, engine='openpyxl') for f in data_files]

In [11]:
metadata = []
data = []
for df in dfs:
    d, m = process_data(df)
    metadata.append(m)
    data.append(d)

  arms = arms.idxmax(axis=1)


In [12]:
# plot the data
plot_data(data, metadata, 'day1', savepath='../processed_data/pilot_data/')

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


# Day 2 (Weighing Boat added)

In [13]:
day2_folder = '../data/pilot_data/2-6-24/'
# find xlsx files in the folder
data_files = os.listdir(day2_folder)
data_files = [f for f in data_files if f.endswith('.xlsx')]
# sort the files and keep only last 4
data_files.sort()
data_files = data_files[-4:]
print(data_files)

['Raw data-8arm_pilot-Trial    11.xlsx', 'Raw data-8arm_pilot-Trial    12.xlsx', 'Raw data-8arm_pilot-Trial    13.xlsx', 'Raw data-8arm_pilot-Trial    14.xlsx']


In [14]:
# load the data
dfs = [pd.read_excel(day2_folder + f, engine='openpyxl') for f in data_files]

In [15]:
metadata = []
data = []
for df in dfs:
    d, m = process_data(df)
    metadata.append(m)
    data.append(d)

  arms = arms.idxmax(axis=1)
  arms = arms.idxmax(axis=1)


In [16]:
# plot the data
plot_data(data, metadata, 'day2', savepath='../processed_data/pilot_data/')

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


# Day 3 (Weighing Boat present)

In [17]:
day3_folder = '../data/pilot_data/2-7-24/'
# find xlsx files in the folder
data_files = os.listdir(day3_folder)
data_files = [f for f in data_files if f.endswith('.xlsx')]
# sort the files and keep only last 4
data_files.sort()
data_files = data_files[-4:]
print(data_files)

['Raw data-8arm_pilot-Trial    15.xlsx', 'Raw data-8arm_pilot-Trial    16.xlsx', 'Raw data-8arm_pilot-Trial    17.xlsx', 'Raw data-8arm_pilot-Trial    18.xlsx']


In [18]:
# load the data
dfs = [pd.read_excel(day3_folder + f, engine='openpyxl') for f in data_files]

In [19]:
metadata = []
data = []
for df in dfs:
    d, m = process_data(df)
    metadata.append(m)
    data.append(d)

  arms = arms.idxmax(axis=1)
  arms = arms.idxmax(axis=1)
  arms = arms.idxmax(axis=1)


In [20]:
# plot the data
plot_data(data, metadata, 'day3', savepath='../processed_data/pilot_data/')

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


# Day 4 (Weighing Boat present)

In [21]:
day4_folder = '../data/pilot_data/2-8-24/'
# find xlsx files in the folder
data_files = os.listdir(day4_folder)
data_files = [f for f in data_files if f.endswith('.xlsx')]
# sort the files and keep only last 4
data_files.sort()
data_files = data_files[-4:]
print(data_files)

['Raw data-8arm_pilot-Trial    19.xlsx', 'Raw data-8arm_pilot-Trial    20.xlsx', 'Raw data-8arm_pilot-Trial    21.xlsx', 'Raw data-8arm_pilot-Trial    22.xlsx']


In [22]:
# load the data
dfs = [pd.read_excel(day4_folder + f, engine='openpyxl') for f in data_files]

In [23]:
metadata = []
data = []
for df in dfs:
    d, m = process_data(df)
    metadata.append(m)
    data.append(d)

  arms = arms.idxmax(axis=1)
  arms = arms.idxmax(axis=1)
  arms = arms.idxmax(axis=1)


In [24]:
# plot the data
plot_data(data, metadata, 'day4', savepath='../processed_data/pilot_data/')

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


# Day 5 (Remove after visiting all 4 rewarded arms) - Session 1

In [33]:
day5_ses1_folder = '../data/pilot_data/2-9-24-1/'
# find xlsx files in the folder
data_files = os.listdir(day5_ses1_folder)
data_files = [f for f in data_files if f.endswith('.xlsx')]
# sort the files and keep only last 4
data_files.sort()
data_files = data_files[-4:]
print(data_files)

['Raw data-8arm_pilot-Trial    23.xlsx', 'Raw data-8arm_pilot-Trial    24.xlsx', 'Raw data-8arm_pilot-Trial    25.xlsx', 'Raw data-8arm_pilot-Trial    26.xlsx']


In [34]:
# load the data
dfs = [pd.read_excel(day5_ses1_folder + f, engine='openpyxl') for f in data_files]

In [35]:
metadata = []
data = []
for df in dfs:
    d, m = process_data(df)
    metadata.append(m)
    data.append(d)

  arms = arms.idxmax(axis=1)
  arms = arms.idxmax(axis=1)
  arms = arms.idxmax(axis=1)
  arms = arms.idxmax(axis=1)


In [36]:
# plot the data
plot_data(data, metadata, 'day5-ses1', savepath='../processed_data/pilot_data/')

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


# Day 5 (Remove after visiting all 4 rewarded arms) - Session 2

In [37]:
day5_ses2_folder = '../data/pilot_data/2-9-24-2/'
# find xlsx files in the folder
data_files = os.listdir(day5_ses2_folder)
data_files = [f for f in data_files if f.endswith('.xlsx')]
# sort the files and keep only last 4
data_files.sort()
data_files = data_files[-4:]
print(data_files)

['Raw data-8arm_pilot-Trial    27.xlsx', 'Raw data-8arm_pilot-Trial    28.xlsx', 'Raw data-8arm_pilot-Trial    29.xlsx', 'Raw data-8arm_pilot-Trial    30.xlsx']


In [38]:
# load the data
dfs = [pd.read_excel(day5_ses2_folder + f, engine='openpyxl') for f in data_files]

In [39]:
metadata = []
data = []
for df in dfs:
    d, m = process_data(df)
    metadata.append(m)
    data.append(d)

  arms = arms.idxmax(axis=1)


In [40]:
# plot the data
plot_data(data, metadata, 'day5-ses2', savepath='../processed_data/pilot_data/')

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
