In [None]:
import matplotlib.pyplot as plt
import h5py
import numpy as np

import tensorflow as tf
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from scipy.signal import filtfilt
import seaborn as sns
from scipy.signal import find_peaks

plt.rcParams['svg.fonttype'] = 'none' # To get editable text in Illustrator

# find all files in the directory that start with 'logs_' and end with '.hdf5'
def get_hdf5_files(directory):
    import os
    import re
    files = []
    for file in os.listdir(directory):
        if re.match(r'logs_\d+\.hdf5', file):
            files.append(os.path.join(directory, file))
    return files

def get_action_energy(impulse_factor, angle_factor):
    action_names, action_colors, action_types, action_is_turn, action_means = get_actions('../actions_all_bouts_with_null.h5')
    action_means = np.array(action_means)
    action_distances = action_means[:, 0]
    action_angles = action_means[:, 1]
    action_impulses = (action_distances * 10) * 0.34452532909386484
    action_energy = action_impulses * impulse_factor + np.abs(action_angles) * angle_factor
    return action_energy

def get_exploratory_energy(hdf5_file):
    action_names, action_colors, action_types, action_is_turn, action_means = get_actions('../actions_all_bouts_with_null.h5')
    
    with h5py.File(hdf5_file, 'r') as f:
        cfg = dict(f['env_variables'].attrs)
        action = f['action'][:]
        internal_state = f['internal_state'][:]

    action_energy = get_action_energy(cfg['energy_impulse_factor'], cfg['energy_angle_factor'])

    exploratory_energy_use = np.zeros_like(action, dtype=float)
    exploratory_energy_use[:] = np.nan
    for action_id in range(len(action_types)):
        if action_names[action_id] != 'LCS' and action_names[action_id] != 'SCS': # exclude capture swims
            exploratory_energy_use[action == action_id] = action_energy[action_id]
    energy = internal_state[:, 1]
    nulls = action == 20 # null action
    return exploratory_energy_use, energy, nulls

def get_reward_split(hdf5_file):

    with h5py.File(hdf5_file, 'r') as f:
        cfg = dict(f['env_variables'].attrs)

        event_survived_preditor = f['event_survived_predator'][:]
        event_captured_by_preditor = f['event_captured_by_predator'][:]
        event_consumed_prey = f['event_consumed_prey'][:]
        internal_state = f['internal_state'][:]
        episode_return = f['episode_return'][()]
    salt_reward = np.sum(internal_state[:, 2] * cfg['reward_salt_factor'])
    prey_reward = np.sum(event_consumed_prey * cfg['reward_consumption'])
    pred_reward = np.sum(event_survived_preditor * cfg['reward_predator_avoidance'] +
                       event_captured_by_preditor * cfg['reward_predator_caught'])
    energy_changes = np.diff(internal_state[:, 1], prepend=internal_state[0, 1])
    energy_losses = np.sum(energy_changes[energy_changes < 0])
    energy_reward = energy_losses * cfg['reward_energy_use_factor']

    reward_dict = {
        'salt_reward': salt_reward,
        'prey_reward': prey_reward,
        'pred_reward': pred_reward,
        'energy_reward': energy_reward,
        'episode_return': episode_return
    }
    return reward_dict
    
def trig_actions(actions, trigs, seq_start, seq_end):
    trig_ind = np.where(trigs)[0]
    seqs = np.zeros((len(trig_ind), seq_end - seq_start))
    for i, ind in enumerate(trig_ind):
        start = max(0, ind + seq_start)
        end = min(actions.shape[0], ind + seq_end)
        seqs[i, :end - start] = actions[start:end]
    return seqs

def trig_fish_frame(fish_x, fish_y, fish_ori, obj_x, obj_y, trigs, offset=1, lim=17):
    trigs = trigs[offset:]
    fish_x = fish_x[:-offset]
    fish_y = fish_y[:-offset]
    fish_ori = fish_ori[:-offset]
    obj_x = obj_x[:-offset, :]
    obj_y = obj_y[:-offset, :]
    
    fish_x = fish_x[trigs]
    fish_y = fish_y[trigs]
    fish_ori = fish_ori[trigs]
    obj_x = obj_x[trigs, :]
    obj_y = obj_y[trigs, :]
    fish_ori = np.arctan2(np.sin(fish_ori), np.cos(fish_ori))
    fish_prey_vectors = np.stack([(obj_x.T - fish_x).T, (obj_y.T - fish_y).T], axis=-1)

    # rotate the vectors to the fish reference frame
    c, s = np.cos(-fish_ori), np.sin(-fish_ori)
    R = np.array([[c, -s], [s, c]])
    fish_prey_vectors = fish_prey_vectors @ R.T

    # limit to a square of size lim x lim
    mask = (np.abs(fish_prey_vectors[..., 0]) < lim) & (np.abs(fish_prey_vectors[..., 1]) < lim)
    fish_prey_vectors = fish_prey_vectors[mask]
    return fish_prey_vectors

def read_events_file(events_file, tag):
    event_acc = EventAccumulator(events_file, size_guidance={'tensors': 0})
    event_acc.Reload()
    res =np.array([[s, float(tf.make_ndarray(t))] for w, s, t in event_acc.Tensors(tag)])
    return res

def get_turn_streaks(dir, length=25, turn_threshold=0.18):

    files = get_hdf5_files(dir)
    streaks = []
    eligible_turns = 0
    for file in files:
        with h5py.File(file, 'r') as f:
            ori = f['fish_angle'][:]
            fish_x = f['fish_x'][:]
            fish_y = f['fish_y'][:]
            cfg = dict(f['env_variables'].attrs)
        turn_angle = ori[1:] - ori[:-1]
        turn_angle_mask = np.abs(turn_angle) > turn_threshold
        eligible_turns += np.sum(turn_angle_mask)
        fish_x = fish_x[1:][turn_angle_mask]
        fish_y = fish_y[1:][turn_angle_mask]
        turn_angle = turn_angle[turn_angle_mask]
        turn_dir = (turn_angle > 0)*2 - 1
        switch_points = np.where(turn_dir[1:] != turn_dir[:-1])[0] + 1
        for s in switch_points:
            if fish_x[s] > 500 and fish_y[s] > 500 and fish_x[s] < (cfg['arena_width'] - 500) and fish_y[s] < (cfg['arena_height'] - 500):
                this_dir = turn_dir[s]
                if s + length < len(turn_dir):
                    streaks.append(np.cumsum(turn_dir[s+1:s+length] * this_dir))
    streaks = np.array(streaks)
    # add a column of zeros to the beginning of streaks
    streaks = np.hstack((np.zeros((streaks.shape[0], 1)), streaks))

    return streaks, eligible_turns

def get_reward_composition(dir_env):
    files = get_hdf5_files(dir_env)
    all_rewards = {
        'salt_reward': [],
        'prey_reward': [],
        'pred_reward': [],
        'energy_reward': [],
        'episode_return': []
    }
    durations = []
    for file in files:
        agent_rewards = get_reward_split(file)
        for k in agent_rewards.keys():
            all_rewards[k].append(np.array(agent_rewards[k]))
        with h5py.File(file, 'r') as f:
            durations.append(f['episode_length'][()])
    durations = np.array(durations)
    for k in all_rewards.keys():
        all_rewards[k] = np.array(all_rewards[k])
    return all_rewards, durations

def get_dirs(training_env, eval_env):
    dirs = [f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_1/logs/eval_{eval_env}/',
            f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_2/logs/eval_{eval_env}/',
            f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_3/logs/eval_{eval_env}/',
            f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_4/logs/eval_{eval_env}/',
            f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_5/logs/eval_{eval_env}/',
            f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_6/logs/eval_{eval_env}/',
            f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_7/logs/eval_{eval_env}/',
            f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_8/logs/eval_{eval_env}/',
    ]
    return dirs

def get_actions(actions_file):
    action_types = []
    ids = []
    with h5py.File(actions_file, 'r') as f:
        for group_name in f.keys():
            group = f[group_name]
            action = {
                        'name': group_name,
                        'mean': group['mean'][:],
                        'cov': group['cov'][:],
                        'is_turn': group.attrs['is_turn'],
                        'is_capture': group.attrs['is_capture'],
                        'color': group.attrs['color']
                        
                }
            if '_L' in action['name']:
                action['color'] *= 1.1
            if '_R' in action['name']:
                action['color'] *= 0.9
            if 'Null' in action['name']:
                action['color'] = np.array([0, 0, 0])
            action['color'] = np.clip(action['color'], 0, 1)
            ids.append(group.attrs['id'])
            action_types.append(action)
            # sort actions by id
    action_types = [x for _, x in sorted(zip(ids, action_types), key=lambda pair: pair[0])]
    action_names = [action_types[i]['name'] for i in range(len(action_types))]
    action_colors = [action_types[i]['color'] for i in range(len(action_types))]  
    action_is_turn = [action_types[i]['is_turn'] for i in range(len(action_types))]
    action_means = [action_types[i]['mean'] for i in range(len(action_types))]
    return action_names, action_colors, action_types, action_is_turn, action_means

def get_trig_space(dir, lim=17):
    files = get_hdf5_files(dir)
    trig_space_prey = [np.zeros((0, 2)) for _ in range(21)]
    trig_space_pred = [np.zeros((0, 2)) for _ in range(21)]
    for file in files:
        with h5py.File(file, 'r') as f:
            # get the keys of the file
            ori = f['fish_angle'][:]
            fish_x = f['fish_x'][:]
            fish_y = f['fish_y'][:]
            pred_x = f['predator_x'][:]
            pred_x = pred_x[:, np.newaxis]
            pred_y = f['predator_y'][:]
            pred_y = pred_y[:, np.newaxis]
            action = f['action'][:]
            prey_x = f['prey_x'][:]
            prey_y = f['prey_y'][:]
            for action_id in range(21):
                this_trig = trig_fish_frame(fish_x, fish_y, ori, prey_x, prey_y, action==action_id, lim=lim)
                this_trig = np.reshape(this_trig, (-1, 2))
                trig_space_prey[action_id] = np.concatenate((trig_space_prey[action_id], this_trig), axis=0)
                this_trig = trig_fish_frame(fish_x, fish_y, ori, pred_x, pred_y, action==action_id, lim=lim)
                this_trig = np.reshape(this_trig, (-1, 2))
                trig_space_pred[action_id] = np.concatenate((trig_space_pred[action_id], this_trig), axis=0)
    return trig_space_prey, trig_space_pred

def get_sequences(dir, steps_pre=20, steps_post=5):
    files = get_hdf5_files(dir)
    avoid_seq = []
    caught_seq = []
    pred_seq = []
    consume_seq = []
    for file in files:
        with h5py.File(file, 'r') as f:

            event_survived_preditor = f['event_survived_predator'][:]
            event_captured_by_preditor = f['event_captured_by_predator'][:]
            event_consumed_prey = f['event_consumed_prey'][:]
            pred_x = f['predator_x'][:]
            # add initial zero to pred_x
            pred_x = np.concatenate(([0], pred_x))

            pred_events = np.where(np.diff((pred_x > 0).astype(int))>0)[0]
            diffs = np.concatenate([[np.inf], np.diff(pred_events)])
            pred_events = pred_events[diffs>np.max([steps_pre, steps_post])]
            pred_event_vec = np.zeros_like(pred_x)
            pred_event_vec[pred_events] = 1
            action = f['action'][:]
        avoid_seq.append(trig_actions(action, event_survived_preditor>0, -steps_pre, steps_post))
        consume_seq.append(trig_actions(action, event_consumed_prey>0, -steps_pre, steps_post))
        pred_seq.append(trig_actions(action, pred_event_vec>0, -steps_pre, steps_post))
        caught_seq.append(trig_actions(action, event_captured_by_preditor>0, -steps_pre, steps_post))
    avoid_seq = np.concatenate(avoid_seq)
    caught_seq = np.concatenate(caught_seq)
    pred_seq = np.concatenate(pred_seq)
    consume_seq = np.concatenate(consume_seq)
    return avoid_seq, caught_seq, pred_seq, consume_seq

def plot_episode(file, arrows=False, prey_density=False):
    action_names, action_colors, action_types, action_is_turn, _ = get_actions('../actions_all_bouts_with_null.h5')
    with h5py.File(file, 'r') as f:
        cfg = dict(f['env_variables'].attrs)
        fish_x = f['fish_x'][:]/10
        fish_y = f['fish_y'][:]/10
        fish_ori = f['fish_angle'][:]
        prey_x = f['prey_x'][:]
        prey_y = f['prey_y'][:]
        pred_x = f['predator_x'][:]
        pred_y = f['predator_y'][:]
        internal_state = f['internal_state'][:]
        actions = f['action'][:]
        event_consumed_prey = f['event_consumed_prey'][:]

        # make a square with size cfg['arena_wifht'] and cfg['arena_height']
        plt.figure(figsize=(12, 12))
        w = cfg['arena_width']/10
        h = cfg['arena_height']/10
        if prey_density:
            # show 2D KDE of initial prey locations
            init_prey_x = prey_x[0, :]/10
            init_prey_y = prey_y[0, :]/10
            sns.kdeplot(x=init_prey_x, y=init_prey_y, fill=True, cmap="Blues", thresh=0.5, bw_adjust=0.7, levels=30, alpha=0.5, antialiased=True)
        
        for step in range(1,len(actions)):
            action = actions[step]
            color = action_colors[action]
            if arrows:
                plt.arrow(fish_x[step-1], fish_y[step-1], (fish_x[step]-fish_x[step-1])*0.8, (fish_y[step]-fish_y[step-1])*0.8, head_width=1, head_length=1, fc=color, ec=color, linewidth=1, zorder=1)
            else:
                plt.plot(fish_x[step-1:step+1], fish_y[step-1:step+1], color=color, linewidth=1.5, zorder=1)
        #plt.plot(fish_x, fish_y, '-b', label='Fish')
        consumption_x = fish_x[event_consumed_prey==1]
        consumption_y = fish_y[event_consumed_prey==1]
        predator_present = np.where(np.diff(pred_x > 0))[0]
        plt.scatter(fish_x[predator_present], fish_y[predator_present], c='red', label='Predator appears', s=40, edgecolors='k', zorder=10)
        plt.scatter(consumption_x, consumption_y, c='green', label='Prey consumed', s=40, edgecolors='k', zorder=10)
        # make a semi-transparent dark rectangle on the top fraction of the arena, according to cfg['arena_dark_fraction']
        plt.fill_between([0, w], 0, h * (cfg['arena_dark_fraction']), color='k', alpha=0.5)
        plt.plot([10, 10, 30], [200, 220, 220], '-k', linewidth=2, label='2 cm')
        # invert the y axis
        plt.xlim(0, w)
        plt.ylim(0, h)
        # remove ticks
        plt.xticks([])
        plt.yticks([])
        plt.gca().set_aspect('equal', adjustable='box')
        plt.gca().invert_yaxis()




In [None]:
hdf5_file = '/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_sparse_nohdf5/stage2_1/logs/eval_sparse/logs_15.hdf5'
action_names, action_colors, action_types, action_is_turn, action_means = get_actions('../actions_all_bouts_with_null.h5')
    
with h5py.File(hdf5_file, 'r') as f:
    cfg = dict(f['env_variables'].attrs)
    action = f['action'][:]
    internal_state = f['internal_state'][:]

action_energy = get_action_energy(cfg['energy_impulse_factor'], cfg['energy_angle_factor'])

# show the relative action energies by plotting horizontal lines with corresponding colors at hight action_energy
plt.figure(figsize=(6, 4))
for i in range(len(action_energy)):
    plt.plot([0, 1], [action_energy[i], action_energy[i]], color=action_colors[i], linewidth=6)
plt.xlim(0, 1)
plt.ylim(0, np.max(action_energy)*1.1)
plt.ylabel('Action energy (a.u.)')
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
# conditional exploratory energy (without fixing energy levels)

conditional_explore_en = {
    'normal': np.zeros((8, 10)),
    'sparse': np.zeros((8, 10))

}

conditional_null_fraction = {
    'normal': np.zeros((8, 10)),
    'sparse': np.zeros((8, 10))
}

for env in ['normal', 'sparse']:
    dirs = get_dirs(env, 'empty')

    for dataset in range(len(dirs)):
        print(f'Processing dataset {dataset+1}/{len(dirs)}')
        all_explore_en = []
        all_energy = []
        all_nulls = []
        for file in get_hdf5_files(dirs[dataset]):
            explore_en, energy, nulls = get_exploratory_energy(file)
            all_explore_en.append(explore_en)
            all_energy.append(energy)
            all_nulls.append(nulls)
        all_explore_en = np.concatenate(all_explore_en)
        all_energy = np.concatenate(all_energy)
        all_nulls = np.concatenate(all_nulls)
        # divide energy into 10 bins based on percentiles, calculate mean explore_en and null fraction in each bin
        energy_bins = np.linspace(0, 1, 11)
        energy_percentiles = np.percentile(all_energy, energy_bins * 100)
        mean_explore_en = []
        null_fractions = []
        for i in range(len(energy_percentiles)-1):
            in_bin = (all_energy >= energy_percentiles[i]) & (all_energy < energy_percentiles[i+1])
            mean_explore_en.append(np.nanmean(all_explore_en[in_bin]))
            null_fractions.append(np.sum(all_nulls[in_bin]) / np.sum(in_bin) if np.sum(in_bin) > 0 else np.nan)

        conditional_explore_en[env][dataset] = np.array(mean_explore_en)
        conditional_null_fraction[env][dataset] = np.array(null_fractions)

In [None]:
plt.subplot(1, 2, 1)
bins = np.linspace(0, 100, 10)
plt.plot(bins, conditional_explore_en['normal'].T, color='C0', alpha=0.3)
plt.plot(bins, conditional_explore_en['sparse'].T, color='C1', alpha=0.3)
plt.plot(bins, np.mean(conditional_explore_en['normal'], axis=0), '-o', color='C0', linewidth=3, label='Dense')
plt.plot(bins, np.mean(conditional_explore_en['sparse'], axis=0), '-o', color='C1', linewidth=3, label='Patchy')
plt.xlabel('Energy percentile (%)')
plt.ylabel('Exploratory energy (a.u.)')
plt.yticks([])
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(bins, conditional_null_fraction['normal'].T, color='C0', alpha=0.3)
plt.plot(bins, conditional_null_fraction['sparse'].T, color='C1', alpha=0.3)
plt.plot(bins, np.mean(conditional_null_fraction['normal'], axis=0), '-o', color='C0', linewidth=3, label='Dense')
plt.plot(bins, np.mean(conditional_null_fraction['sparse'], axis=0), '-o', color='C1', linewidth=3, label='Patchy')
plt.xlabel('Energy percentile (%)')
plt.ylabel('Null action fraction')


In [None]:
# conditional exploratory energy (fixing energy levels)

conditional_explore_en = {
    'normal_empty_low_energy': np.zeros(8),
    'normal_empty_high_energy': np.zeros(8),
    'sparse_empty_low_energy': np.zeros(8),
    'sparse_empty_high_energy': np.zeros(8)
}
for train_env in ['normal', 'sparse']:
    for test_env in ['empty_low_energy', 'empty_high_energy']:
        dirs = get_dirs(train_env, test_env)

        for dataset in range(len(dirs)):
            print(f'Processing dataset {dataset+1}/{len(dirs)}')
            all_explore_en = []
            for file in get_hdf5_files(dirs[dataset]):
                explore_en, energy, nulls = get_exploratory_energy(file)
                all_explore_en.append(explore_en)
            all_explore_en = np.concatenate(all_explore_en)
            conditional_explore_en[train_env + '_' + test_env][dataset] = np.nanmean(all_explore_en)


In [None]:
%matplotlib qt
plt.figure(figsize=(4, 6))
plt.plot([0, 1], [conditional_explore_en['normal_empty_low_energy'], conditional_explore_en['normal_empty_high_energy']], '-o', alpha=0.3, color='C0')
plt.plot([0, 1], [conditional_explore_en['sparse_empty_low_energy'], conditional_explore_en['sparse_empty_high_energy']], '-o', alpha=0.3, color='C1')
plt.plot([0, 1], [np.mean(conditional_explore_en['normal_empty_low_energy']), np.mean(conditional_explore_en['normal_empty_high_energy'])], '-o', linewidth=3, label='Dense', color='C0')
plt.plot([0, 1], [np.mean(conditional_explore_en['sparse_empty_low_energy']), np.mean(conditional_explore_en['sparse_empty_high_energy'])], '-o', linewidth=3, label='Patchy', color='C1')
plt.xticks([0, 1], ['Low', 'High'])
plt.yticks([])

plt.ylabel('Exploratory energy output (a.u.)')


In [None]:
# prey and predator locations, triggered on actions
%matplotlib qt

agent_id = 6
env = 'normal'
lim = 17
fig, ax = plt.subplots(1, 1, figsize=(15, 15))
dirs = get_dirs(env, env)
print(f'Env: {env}, Agent: {agent_id+1}')
t_prey, t_pred = get_trig_space(dirs[agent_id], lim=lim*10)
ax.scatter(0, 0, s=50, c='k', label='Fish', marker='x')

ax.scatter(t_prey[6][:, 1]/10, t_prey[6][:, 0]/10, s=2, alpha=0.05, c='b', label='Left J turn')
ax.scatter(t_prey[5][:, 1]/10, t_prey[5][:, 0]/10, s=2, alpha=0.05, c='r', label='Right J turn')
ax.scatter(t_prey[19][:, 1]/10, t_prey[19][:, 0]/10, s=2, alpha=0.05, c='g', label='Left HAT')
ax.scatter(t_prey[18][:, 1]/10, t_prey[18][:, 0]/10, s=2, alpha=0.05, c='m', label='Right HAT')
ax.scatter(t_prey[0][:, 1]/10, t_prey[0][:, 0]/10, s=2, alpha=0.05, c='c', label='Short Capture')
ax.scatter(t_prey[1][:, 1]/10, t_prey[1][:, 0]/10, s=2, alpha=0.05, c='orange', label='Long Capture')

ax.scatter(t_pred[4][:, 1]/10, t_pred[4][:, 0]/10, s=12, alpha=0.4, c='b', label='Left O-bend', marker='d')
ax.scatter(t_pred[3][:, 1]/10, t_pred[3][:, 0]/10, s=12, alpha=0.4, c='r', label='Right O-bend', marker='d')
ax.scatter(t_pred[14][:, 1]/10, t_pred[14][:, 0]/10, s=12, alpha=0.4, c='g', label='Left LLC', marker='d')
ax.scatter(t_pred[13][:, 1]/10, t_pred[13][:, 0]/10, s=12, alpha=0.4, c='m', label='Right LLC', marker='d')
ax.set_xlim(-lim, lim)
ax.set_ylim(-lim, lim)
# remove ticks
ax.set_xticks([])
ax.set_yticks([])
# plot a right angle to indicate scale
corner_coord = -lim + 3
ax.plot([corner_coord, corner_coord + 3], [corner_coord, corner_coord], '-k', linewidth=2)
ax.plot([corner_coord, corner_coord], [corner_coord, corner_coord + 3], '-k', linewidth=2)

# make aspect ratio equal
ax.set_aspect('equal', adjustable='box')

# put the legend in a separate figure, with larger markers and alpha=1
fig_legend, ax_legend = plt.subplots(1, 1, figsize=(3, 3))
handles, labels = ax.get_legend_handles_labels()

ax_legend.legend(handles, labels, loc='center', frameon=False, markerscale=2)
# change alpha to 1
for handle in ax_legend.legend_.legendHandles:
    handle.set_alpha(1)
ax_legend.axis('off')
fig_legend.tight_layout()
plt.show()

In [None]:

plot_episode('/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_sparse_nohdf5/stage2_5/logs/eval_sparse/logs_17.hdf5', prey_density=True)
plot_episode('/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_normal_nohdf5/stage2_5/logs/eval_normal/logs_20.hdf5')


In [46]:
# get reward compositions
combined_rewards = {}
combined_durations = {}
for training_env in ['normal', 'sparse']:
    for eval_env in ['normal', 'sparse']:
        print(f"Processing training env: {training_env}, eval env: {eval_env}")
        dirs = get_dirs(training_env, eval_env)
        all_rewards = {
            'salt_reward': [],
            'prey_reward': [],
            'pred_reward': [],
            'energy_reward': [],
            'episode_return': []
        }
        all_durations = []
        for dir in dirs:
            print(f"Processing dir: {dir}")
            dir_rewards, durations = get_reward_composition(dir)
            for k in all_rewards.keys():
                all_rewards[k].append(dir_rewards[k])
            all_durations.append(durations)
        for k in all_rewards.keys():
            all_rewards[k] = np.array(all_rewards[k])
        all_durations = np.array(all_durations)
        combined_rewards[f'{training_env}_{eval_env}'] = all_rewards
        combined_durations[f'{training_env}_{eval_env}'] = all_durations


Processing training env: normal, eval env: normal
Processing dir: /media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_normal_nohdf5/stage2_1/logs/eval_normal/
Processing dir: /media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_normal_nohdf5/stage2_2/logs/eval_normal/
Processing dir: /media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_normal_nohdf5/stage2_3/logs/eval_normal/
Processing dir: /media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_normal_nohdf5/stage2_4/logs/eval_normal/
Processing dir: /media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_normal_nohdf5/stage2_5/logs/eval_normal/
Processing dir: /media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_normal_nohdf5/stage2_6/logs/eval_normal/
Processing dir: /media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_normal_nohdf5/stage2_7/logs/eval_normal/
Processing dir: /media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_normal_nohdf5/stage2_8/logs/eva

In [44]:
combined_rewards['sparse_sparse']['prey_reward'].shape

(8, 100)

In [48]:
reward_type = 'prey_reward'
ratio = 0.5

normalized_combined_rewards = {}
for key in combined_rewards.keys():
    normalized_combined_rewards[key] = {}
    for k in combined_rewards[key].keys():
        normalized_combined_rewards[key][k] = combined_rewards[key][k] / combined_durations[key]



In [51]:

rewards_to_show = normalized_combined_rewards

nn_mean = np.mean(rewards_to_show['normal_normal'][reward_type], axis=1)
nn_sem = np.std(rewards_to_show['normal_normal'][reward_type], axis=1) / np.sqrt(rewards_to_show['normal_normal'][reward_type].shape[1])
sn_mean = np.mean(rewards_to_show['sparse_normal'][reward_type], axis=1)
sn_sem = np.std(rewards_to_show['sparse_normal'][reward_type], axis=1) / np.sqrt(rewards_to_show['sparse_normal'][reward_type].shape[1])
ns_mean = np.mean(rewards_to_show['normal_sparse'][reward_type], axis=1)
ns_sem = np.std(rewards_to_show['normal_sparse'][reward_type], axis=1) / np.sqrt(rewards_to_show['normal_sparse'][reward_type].shape[1])
ss_mean = np.mean(rewards_to_show['sparse_sparse'][reward_type], axis=1)
ss_sem = np.std(rewards_to_show['sparse_sparse'][reward_type], axis=1) / np.sqrt(rewards_to_show['sparse_sparse'][reward_type].shape[1])

# make a line plot with error bars
plt.figure(figsize=(4,6))
for agent in range(8):
    plt.errorbar([0, 1], [nn_mean[agent]*ratio, ns_mean[agent]*ratio], yerr=[nn_sem[agent]*ratio, ns_sem[agent]*ratio], fmt='-o', color='C0', alpha=0.5)
    plt.errorbar([0, 1], [sn_mean[agent]*ratio, ss_mean[agent]*ratio], yerr=[sn_sem[agent]*ratio, ss_sem[agent]*ratio], fmt='-o', color='C1', alpha=0.5)
plt.plot([0, 1], [np.mean(nn_mean)*ratio, np.mean(ns_mean)*ratio], '-o', color='C0', label='Trained normal')
plt.plot([0, 1], [np.mean(sn_mean)*ratio, np.mean(ss_mean)*ratio], '-o', color='C1', label='Trained sparse')
plt.xticks([0, 1], ['Dense', 'Patchy'])
plt.ylabel('Episode return')
# plt.title('Episode return by training and eval environment')
# plt.legend()
# plt.tight_layout()

Text(0, 0.5, 'Episode return')

In [None]:
# get streaks for normal and sparse envs

dirs = get_dirs('sparse', 'empty_high_energy')

mean_sparse_streaks = []
sem_sparse_streaks = []
for dir in dirs:
    print(f"Processing dir: {dir}")
    streaks, eligible_turns = get_turn_streaks(dir)
    print(f"Eligible turns: {eligible_turns}")
    mean_sparse_streaks.append(np.mean(streaks, axis=0))
    sem_sparse_streaks.append(np.std(streaks, axis=0) / np.sqrt(streaks.shape[0]))
mean_sparse_streaks = np.array(mean_sparse_streaks)
sem_sparse_streaks = np.array(sem_sparse_streaks)

dirs = get_dirs('normal', 'empty_high_energy')
mean_normal_streaks = []
sem_normal_streaks = []
for dir in dirs:
    print(f"Processing dir: {dir}")
    streaks, eligible_turns = get_turn_streaks(dir)
    print(f"Eligible turns: {eligible_turns}")
    mean_normal_streaks.append(np.mean(streaks, axis=0))
    sem_normal_streaks.append(np.std(streaks, axis=0) / np.sqrt(streaks.shape[0]))
mean_normal_streaks = np.array(mean_normal_streaks)
sem_normal_streaks = np.array(sem_normal_streaks)



In [None]:
# turn chains plot
step10_means_normal = np.zeros(len(dirs))
step10_means_sparse = np.zeros(len(dirs))

plt.figure()
plt.subplot(1, 3, 1)
for dataset in range(len(dirs)):
    step10_means_normal[dataset] = mean_normal_streaks[dataset][10]
    plt.plot(mean_normal_streaks[dataset], '-o', label=f'Ag {dataset+1}', markersize=3)
    plt.fill_between(np.arange(len(mean_normal_streaks[dataset])),
                     mean_normal_streaks[dataset] - sem_normal_streaks[dataset],
                     mean_normal_streaks[dataset] + sem_normal_streaks[dataset], alpha=0.3)

plt.axhline(0, color='k', linestyle='--', label='Chance')
plt.xlabel('Timestep after switch')
plt.ylabel('Cumulative turn angle sign')
plt.title('Dense Env')
plt.ylim(-0.05, 1.5)
# plt.legend(fontsize=8)
plt.subplot(1, 3, 2)
for dataset in range(len(dirs)):
    step10_means_sparse[dataset] = mean_sparse_streaks[dataset][10]
    plt.plot(mean_sparse_streaks[dataset], '-o', label=f'Ag {dataset+1}', markersize=3)
    plt.fill_between(np.arange(len(mean_sparse_streaks[dataset])),
                     mean_sparse_streaks[dataset] - sem_sparse_streaks[dataset],
                     mean_sparse_streaks[dataset] + sem_sparse_streaks[dataset], alpha=0.3)
plt.axhline(0, color='k', linestyle='--', label='Chance')
plt.xlabel('Timestep after switch')
plt.ylabel('Cumulative turn angle sign')
plt.title('Patchy Env')
plt.ylim(-0.05, 1.5)
plt.legend(fontsize=8)

plt.subplot(1, 3, 3)
plt.plot([0, 1], [step10_means_normal, step10_means_sparse], '-o', label='Mean')
plt.xticks([0, 1], ['Dense', 'Patchy'])
plt.title('Cumulative turn angle at t=10')


In [None]:
dirs = get_dirs('normal', 'normal')
action_names, action_colors, action_types, action_is_turn, _ = get_actions('../actions_all_bouts_with_null.h5')
agents = [5, 6]
fig, ax = plt.subplots(len(agents), 2, figsize=(15, 20))
for ii, agent_id in enumerate(agents):
    avoid_seq, caught_seq, pred_seq, consume_seq = get_sequences(dirs[agent_id], steps_pre=15, steps_post=15)
    timebase = np.arange(-15, 15)
    for ss, seq in enumerate([pred_seq, consume_seq]):
        for action_name, action_id in zip(action_names, range(len(action_names))):
            if '_L' in action_name:
                continue
            elif '_R' in action_name:
                this_name = action_name.split('_')[0]
                action_props = np.mean((seq == action_id) | (seq == action_id + 1), axis=0)
            else:
                this_name = action_name
                action_props = np.mean(seq == action_id, axis=0)
            this_color = action_colors[action_id]
            if np.max(action_props) < 0.1:
                continue
            ax[ii, ss].plot(timebase, action_props, '-o', label=this_name, color=this_color, markersize=3, linewidth=2)
            # vertical line at time 0
        ax[ii, ss].axvline(0, color='k', linestyle='--', linewidth=1)
        ax[ii, ss].set_ylim(0, 0.8)
        ax[ii, ss].set_xlim(-15, 15)
        ax[ii, ss].set_xlabel('Timestep')
        ax[ii, ss].set_ylabel('Proportion of actions')
        
        
        #plt.legend(fontsize=6, bbox_to_anchor=(1.05, 1), loc='upper left')
        # ax[0, ss].set_xlabel('Timestep relative to avoid event')
        # ax[0, ss].set_ylabel('Proportion of actions')


In [None]:
avoid_seq.shape, caught_seq.shape, pred_seq.shape, consume_seq.shape

In [None]:
import glob
def smooth(scalars: list[float], weight: float) -> list[float]:
    """
    EMA implementation according to
    https://github.com/tensorflow/tensorboard/blob/34877f15153e1a2087316b9952c931807a122aa7/tensorboard/components/vz_line_chart2/line-chart.ts#L699
    """
    last = 0
    smoothed = []
    num_acc = 0
    for next_val in scalars:
        last = last * weight + (1 - weight) * next_val
        num_acc += 1
        # de-bias
        debias_weight = 1
        if weight != 1:
            debias_weight = 1 - weight**num_acc
        smoothed_val = last / debias_weight
        smoothed.append(smoothed_val)

    return smoothed
training_episode_return_normal = []
training_episode_return_sparse = []

training_env = 'normal'
dirs = [f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_1/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_2/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_3/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_4/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_5/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_6/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_7/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_8/logs/',

]


for dir in dirs:
    training_logs_dir = dir + 'training/actor_0/'
    events_file = glob.glob(training_logs_dir + 'events*')[0]
    training_episode_return_normal.append(read_events_file(events_file, 'actor/EpisodeReturn'))

training_env = 'sparse'

dirs = [f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_1/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_2/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_3/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_4/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_5/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_6/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_7/logs/',
        f'/media/gf/DBIO_Bianco_Lab8/temp_data_storage/gcp_output/stage2_{training_env}_nohdf5/stage2_8/logs/',

]

for dir in dirs:
    training_logs_dir = dir + 'training/actor_0/'
    events_file = glob.glob(training_logs_dir + 'events*')[0]
    training_episode_return_sparse.append(read_events_file(events_file, 'actor/EpisodeReturn'))

plt.figure(figsize=(6, 4))
all_normal = []
all_sparse = []
for agent in range(len(training_episode_return_normal)):
    smoothed_normal = smooth(training_episode_return_normal[agent][:, 1], 0.99)[10:]
    smoothed_sparse = smooth(training_episode_return_sparse[agent][:, 1], 0.99)[10:]
    # fill with nans up to length 5000
    if len(smoothed_normal) < 5000:
        smoothed_normal = smoothed_normal + [np.nan] * (5000 - len(smoothed_normal))
    if len(smoothed_sparse) < 5000:
        smoothed_sparse = smoothed_sparse + [np.nan] * (5000 - len(smoothed_sparse))
    steps_normal = training_episode_return_normal[agent][10:, 0]
    steps_sparse = training_episode_return_sparse[agent][10:, 0]
    if len(steps_normal) < 5000:
        steps_normal = np.concatenate((steps_normal, np.full(5000 - len(steps_normal), np.nan)))
    if len(steps_sparse) < 5000:
        steps_sparse = np.concatenate((steps_sparse, np.full(5000 - len(steps_sparse), np.nan)))
    all_normal.append(np.array(smoothed_normal)[:5000])
    all_sparse.append(np.array(smoothed_sparse)[:5000])
    plt.plot(steps_normal/1e6, smoothed_normal, color='C0', alpha=0.3)
    plt.plot(steps_sparse/1e6, smoothed_sparse, color='C1', alpha=0.3)

all_normal = np.array(all_normal)
all_sparse = np.array(all_sparse)
mean_normal = np.nanmean(all_normal, axis=0)
mean_sparse = np.nanmean(all_sparse, axis=0)
plt.plot(steps_normal[:5000]/1e6, mean_normal, color='C0', label='Trained dense', linewidth=3)
plt.plot(steps_sparse[:5000]/1e6, mean_sparse, color='C1', label='Trained patchy', linewidth=3)
plt.xlabel('Training steps (millions)')
plt.ylabel('Episode return')
         

In [None]:
steps_sparse.shape, steps_normal.shape, mean_normal

In [None]:
[len(all_normal[i]) for i in range(len(all_normal))], [len(all_sparse[i]) for i in range(len(all_sparse))]

In [None]:
plt.figure()
for ds in range(8):
    smoothed = smooth(training_episode_return[ds][:, 1], 0.99)[10:]
    plt.plot(training_episode_return[ds][10:, 0]/1e6, smoothed, label=f'Dir {ds+1}')
plt.xlabel('Actor Steps (millions)')
plt.ylabel('Episode Return (EMA smoothed)')

In [None]:
plt.imshow(mean_streaks)

In [None]:
# pre_avoid = np.concatenate(pre_avoid_seq)
# plt.imshow((pre_avoid==3) | (pre_avoid==4))
%matplotlib qt
plt.figure()
from matplotlib.colors import ListedColormap
pre_con = np.concatenate(pre_consume_seq)
pre_avoid = np.concatenate(pre_avoid_seq)
my_cmap = ListedColormap(np.array(action_colors))
time_base = np.arange(pre_con.shape[1]) - 20
plt.subplot(2, 2, 1)
plt.imshow(pre_con, aspect='auto', cmap=my_cmap, extent=[time_base[0], time_base[-1], 0, pre_con.shape[0]], interpolation='nearest')
plt.subplot(2, 2, 3)
plt.plot(time_base, np.mean((pre_con==3) | (pre_con==4), axis=0), label='O-bend', color=action_colors[3])
plt.plot(time_base, np.mean((pre_con==5) | (pre_con==6), axis=0), label='J turn', color=action_colors[5])
plt.plot(time_base, np.mean((pre_con==0) | (pre_con==1), axis=0), label='Capture', color=action_colors[0])
plt.plot(time_base, np.mean((pre_con==18) | (pre_con==19), axis=0), label='HAT', color=action_colors[18])
plt.xlim([-20, 5])
plt.legend()
plt.xlabel('Time steps relative to prey capture')
plt.ylabel('Proportion of actions')
plt.subplot(2, 2, 2)
plt.imshow(pre_avoid, aspect='auto', cmap=my_cmap, extent=[time_base[0], time_base[-1], 0, pre_avoid.shape[0]], interpolation='nearest')
plt.subplot(2, 2, 4)
plt.plot(time_base, np.mean((pre_avoid==3) | (pre_avoid==4), axis=0), label='O-bend', color=action_colors[3])
plt.plot(time_base, np.mean((pre_avoid==5) | (pre_avoid==6), axis=0), label='J turn', color=action_colors[5])
plt.plot(time_base, np.mean((pre_avoid==0) | (pre_avoid==1), axis=0), label='Capture', color=action_colors[0])
plt.plot(time_base, np.mean((pre_avoid==18) | (pre_avoid==19), axis=0), label='HAT', color=action_colors[18])
plt.xlim([-20, 5])
plt.xlabel('Time steps relative to predator avoidance')
plt.ylabel('Proportion of actions')


In [None]:
qq0 = np.concatenate(q0, axis=0)/10
qq1 = np.concatenate(q1, axis=0)/10
qq3 = np.concatenate(q3, axis=0)/10
qq4 = np.concatenate(q4, axis=0)/10
qq5 = np.concatenate(q5, axis=0)/10
qq6 = np.concatenate(q6, axis=0)/10
qq18 = np.concatenate(q18, axis=0)/10
qq19 = np.concatenate(q19, axis=0)/10


plt.figure()

plt.scatter(qq0[:, :, 1], qq0[:, :, 0], s=2, alpha=0.05,color='orange', label='Short Capture')
plt.scatter(qq1[:, :, 1], qq1[:, :, 0], s=2, alpha=0.05,color='c', label='Long Capture')

plt.scatter(qq5[:, :, 1], qq5[:, :, 0], s=2, alpha=0.05,color='r', label='Right J turn')
plt.scatter(qq6[:, :, 1], qq6[:, :, 0], s=2, alpha=0.05,color='b', label='Left J turn')

plt.scatter(qq3[:, :, 1], qq3[:, :, 0], s=8, alpha=0.3,color='r', label='Right O-bend', marker='d')
plt.scatter(qq4[:, :, 1], qq4[:, :, 0], s=8, alpha=0.3,color='b', label='Left O-bend', marker='d')

plt.scatter(qq18[:, :, 1], qq18[:, :, 0], s=2, alpha=0.05,color='m', label='Right HAT')
plt.scatter(qq19[:, :, 1], qq19[:, :, 0], s=2, alpha=0.05,color='g', label='Left HAT')
plt.scatter(0, 0, s=100, c='k', marker='x')
# plt.axis('equal')

plt.xlim(-17, 17)
plt.ylim(-10, 15)
plt.xlabel('Object X (fish frame, mm)')
plt.ylabel('Object Y (fish frame, mm)')
# plt.legend(bbox_to_anchor=(1.02, 1))


In [None]:
action_types = []
ids = []
with h5py.File('../actions_all_bouts_with_null.h5', 'r') as f:
    for group_name in f.keys():
        group = f[group_name]
        action = {
                    'name': group_name,
                    'mean': group['mean'][:],
                    'cov': group['cov'][:],
                    'is_turn': group.attrs['is_turn'],
                    'is_capture': group.attrs['is_capture'],
                    'color': group.attrs['color']
            }
        if '_L' in action['name']:
            action['color'] *= 1.1
        if '_R' in action['name']:
            action['color'] *= 0.9
        action['color'] = np.clip(action['color'], 0, 1)
        ids.append(group.attrs['id'])
        action_types.append(action)
        # sort actions by id
action_types = [x for _, x in sorted(zip(ids, action_types), key=lambda pair: pair[0])]
action_names = [action_types[i]['name'] for i in range(len(action_types))]
action_colors = [action_types[i]['color'] for i in range(len(action_types))]  


In [None]:
%matplotlib qt
a_prop = np.array(actions_proportion)
# show as stacked bar plot
labels = [str(i) for i in range(a_prop.shape[1])]
x = np.arange(a_prop.shape[0])
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
bottom = np.zeros(a_prop.shape[0])
# set a color sequence

for i in range(a_prop.shape[1]):
    ax.bar(x, a_prop[:, i], bottom=bottom, label=action_names[i], color=action_colors[i])
    bottom += a_prop[:, i]
ax.set_xlabel('Agent')
ax.set_ylabel('Proportion')
ax.set_title('Action Proportions Across Agents')
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], loc='upper left', bbox_to_anchor=(1.02, 1), fontsize=8)

bottom = np.zeros(a_prop.shape[0])
fig, ax = plt.subplots(1, 1, figsize=(6, 6))

for i, reward_type in enumerate(['prey_reward', 'pred_reward']):
    ax.bar(x, np.mean(all_rewards[reward_type], axis=1), label=reward_type, bottom=bottom)
    bottom += np.mean(all_rewards[reward_type], axis=1)
bottom = np.zeros(a_prop.shape[0])
for i, reward_type in enumerate(['salt_reward', 'energy_reward']):
    ax.bar(x, np.mean(all_rewards[reward_type], axis=1), label=reward_type, bottom=bottom)
    bottom += np.mean(all_rewards[reward_type], axis=1)

# make a line plot of episode return on the same axis
#ax.plot(x, np.mean(all_rewards['episode_return'], axis=1), color='k', marker='o', label='Total Return')
# plot error bars for episode return
ax.errorbar(x, np.mean(all_rewards['episode_return'], axis=1), yerr=np.std(all_rewards['episode_return'], axis=1), color='k', fmt='o', capsize=3, label='Total Return')
# ax[1].scatter(x, np.mean(-1*all_rewards['energy_reward'], axis=1), color='k', marker='o', label='Energy Cost')
# ax[1].scatter(x, np.mean(-1*all_rewards['salt_reward'], axis=1), color='grey', marker='o', label='Salt Cost')

ax.set_xlabel('Agent')
ax.set_ylabel('Reward')
ax.legend(loc='upper left', bbox_to_anchor=(1.02, 1))
# bold horizontal line at y=0
ax.axhline(0, color='k', linewidth=2)
plt.tight_layout()
plt.show()
