In [351]:
import matplotlib.pyplot as plt
import h5py
import numpy as np

import tensorflow as tf
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from scipy.signal import filtfilt


# find all files in the directory that start with 'logs_' and end with '.hdf5'
def get_hdf5_files(directory):
    import os
    import re
    files = []
    for file in os.listdir(directory):
        if re.match(r'logs_\d+\.hdf5', file):
            files.append(os.path.join(directory, file))
    return files

def get_reward_split(hdf5_file):

    with h5py.File(hdf5_file, 'r') as f:
        cfg = dict(f['env_variables'].attrs)

        event_survived_preditor = f['event_survived_predator'][:]
        event_captured_by_preditor = f['event_captured_by_predator'][:]
        event_consumed_prey = f['event_consumed_prey'][:]
        internal_state = f['internal_state'][:]
        episode_return = f['episode_return'][()]
    salt_reward = np.sum(internal_state[:, 2] * cfg['reward_salt_factor'])
    prey_reward = np.sum(event_consumed_prey * cfg['reward_consumption'])
    pred_reward = np.sum(event_survived_preditor * cfg['reward_predator_avoidance'] +
                       event_captured_by_preditor * cfg['reward_predator_caught'])
    energy_changes = np.diff(internal_state[:, 1], prepend=internal_state[0, 1])
    energy_losses = np.sum(energy_changes[energy_changes < 0])
    energy_reward = energy_losses * cfg['reward_energy_use_factor']

    reward_dict = {
        'salt_reward': salt_reward,
        'prey_reward': prey_reward,
        'pred_reward': pred_reward,
        'energy_reward': energy_reward,
        'episode_return': episode_return
    }
    return reward_dict
    
def trig_actions(actions, trigs, seq_start, seq_end):
    trig_ind = np.where(trigs)[0]
    seqs = np.zeros((len(trig_ind), seq_end - seq_start))
    for i, ind in enumerate(trig_ind):
        start = max(0, ind + seq_start)
        end = min(actions.shape[0], ind + seq_end)
        seqs[i, :end - start] = actions[start:end]
    return seqs

def trig_fish_frame(fish_x, fish_y, fish_ori, obj_x, obj_y, trigs, offset=1):
    trigs = trigs[offset:]
    fish_x = fish_x[:-offset]
    fish_y = fish_y[:-offset]
    fish_ori = fish_ori[:-offset]
    obj_x = obj_x[:-offset, :]
    obj_y = obj_y[:-offset, :]
    
    fish_x = fish_x[trigs]
    fish_y = fish_y[trigs]
    fish_ori = fish_ori[trigs]
    obj_x = obj_x[trigs, :]
    obj_y = obj_y[trigs, :]
    fish_ori = np.arctan2(np.sin(fish_ori), np.cos(fish_ori))
    fish_prey_vectors = np.stack([(obj_x.T - fish_x).T, (obj_y.T - fish_y).T], axis=-1)

    # rotate the vectors to the fish reference frame
    c, s = np.cos(-fish_ori), np.sin(-fish_ori)
    R = np.array([[c, -s], [s, c]])
    fish_prey_vectors = fish_prey_vectors @ R.T


    return fish_prey_vectors

def read_events_file(events_file, tag):
    event_acc = EventAccumulator(events_file, size_guidance={'tensors': 0})
    event_acc.Reload()
    res =np.array([[s, float(tf.make_ndarray(t))] for w, s, t in event_acc.Tensors(tag)])
    return res


In [441]:
import glob

env_type = 'eval_normal'
dirs = ['/home/asaph/gcp_output/stage2_normal_nohdf5/stage2_1/logs/',
        '/home/asaph/gcp_output/stage2_normal_nohdf5/stage2_2/logs/',
        '/home/asaph/gcp_output/stage2_normal_nohdf5/stage2_3/logs/',
        '/home/asaph/gcp_output/stage2_normal_nohdf5/stage2_4/logs/',
        '/home/asaph/gcp_output/stage2_normal_nohdf5/stage2_5/logs/',
        '/home/asaph/gcp_output/stage2_normal_nohdf5/stage2_6/logs/',
        '/home/asaph/gcp_output/stage2_normal_nohdf5/stage2_7/logs/',
        '/home/asaph/gcp_output/stage2_normal_nohdf5/stage2_8/logs/',

]

mean_streaks = []
sem_streaks = []
actions_proportion = []
training_episode_return = []
q5 = []
q6 = []
q3 = []
q4 = []
q0 = []
q1 = []
q18 = []
q19 = []
pre_avoid_seq = []
pre_consume_seq = []
all_rewards = {
    'salt_reward': [],
    'prey_reward': [],
    'pred_reward': [],
    'energy_reward': [],
    'episode_return': []
}
for dir in dirs:
    agent_rewards = {
    'salt_reward': [],
    'prey_reward': [],
    'pred_reward': [],
    'energy_reward': [],
    'episode_return': []
    }
    training_logs_dir = dir + 'training/actor_0/'
    events_file = glob.glob(training_logs_dir + 'events*')[0]
    training_episode_return.append(read_events_file(events_file, 'actor/EpisodeReturn'))
    
    files = get_hdf5_files(dir + env_type + '/')
    start_dist = []
    end_dist = []
    episode_return = []
    actor_steps = []
    streaks = []
    actions = []
    for file in files:
        this_rewards = get_reward_split(file)
        for k in this_rewards.keys():
            agent_rewards[k].append(this_rewards[k])
        with h5py.File(file, 'r') as f:
            # get the keys of the file
            ori = f['fish_angle'][:]
            fish_x = f['fish_x'][:]
            fish_y = f['fish_y'][:]
            prey_x = f['prey_x'][:]
            prey_y = f['prey_y'][:]
            pred_x = f['predator_x'][:]
            pred_x = pred_x[:, np.newaxis]
            pred_y = f['predator_y'][:]
            pred_y = pred_y[:, np.newaxis]
            internal_state = f['internal_state'][:]
            event_survived_preditor = f['event_survived_predator'][:]
            event_consumed_prey = f['event_consumed_prey'][:]
            action = f['action'][:]
        if dir == dirs[5]:
            q0.append(trig_fish_frame(fish_x, fish_y, ori, prey_x, prey_y, action==0))
            q1.append(trig_fish_frame(fish_x, fish_y, ori, prey_x, prey_y, action==1))
            q5.append(trig_fish_frame(fish_x, fish_y, ori, prey_x, prey_y, action==5))
            q6.append(trig_fish_frame(fish_x, fish_y, ori, prey_x, prey_y, action==6))
            q3.append(trig_fish_frame(fish_x, fish_y, ori, pred_x, pred_y, action==3))
            q4.append(trig_fish_frame(fish_x, fish_y, ori, pred_x, pred_y, action==4))
            q18.append(trig_fish_frame(fish_x, fish_y, ori, pred_x, pred_y, action==18))
            q19.append(trig_fish_frame(fish_x, fish_y, ori, pred_x, pred_y, action==19))
            pre_avoid_seq.append(trig_actions(action, event_survived_preditor>0, -20, 5))
            pre_consume_seq.append(trig_actions(action, event_consumed_prey>0, -20, 5))
        mean_sal_state = np.mean(internal_state[:, 2])
        actions.append(action)
        early_ori = ori[:350]
        turn_angle = early_ori[1:] - early_ori[:-1]
        # shuffle turn_angle
        #turn_angle = np.random.permutation(turn_angle)
        turn_dir = (turn_angle > 0)*2 - 1
        # replace turn_dir with a random vector of 1s and -1s

        #turn_dir = np.random.choice([-1, 1], size=len(turn_dir))

        switch_points = np.where(turn_dir[1:] != turn_dir[:-1])[0] + 1
        for s in switch_points:
            if fish_x[s] > 500 and fish_y[s] > 500 and fish_x[s] < 3500 and fish_y[s] < 3500:
                this_dir = turn_dir[s]
                if s + 25 < len(turn_dir):
                    streaks.append(np.cumsum(turn_angle[s+1:s+25] * this_dir))
    actions = np.concatenate(actions)
    actions_counts = np.bincount(action, minlength=21)
    actions_proportion.append(actions_counts / np.sum(actions_counts))
    for k in agent_rewards.keys():
        all_rewards[k].append(np.array(agent_rewards[k]))

    streaks = np.array(streaks)
    # add a column of zeros to the beginning of streaks
    streaks = np.hstack((np.zeros((streaks.shape[0], 1)), streaks))
    mean_streak = np.mean(streaks, axis=0)
    sem_streak = np.std(streaks, axis=0) / np.sqrt(streaks.shape[0])
    mean_streaks.append(mean_streak)
    sem_streaks.append(sem_streak)
    print(f"Dir: {dir}, Num streaks: {streaks.shape[0]}")
mean_streaks = np.array(mean_streaks)
sem_streaks = np.array(sem_streaks)
for k in all_rewards.keys():
    all_rewards[k] = np.array(all_rewards[k])

Dir: /home/asaph/gcp_output/stage2_normal_nohdf5/stage2_1/logs/, Num streaks: 9816
Dir: /home/asaph/gcp_output/stage2_normal_nohdf5/stage2_2/logs/, Num streaks: 10078
Dir: /home/asaph/gcp_output/stage2_normal_nohdf5/stage2_3/logs/, Num streaks: 10074
Dir: /home/asaph/gcp_output/stage2_normal_nohdf5/stage2_4/logs/, Num streaks: 10433
Dir: /home/asaph/gcp_output/stage2_normal_nohdf5/stage2_5/logs/, Num streaks: 12001
Dir: /home/asaph/gcp_output/stage2_normal_nohdf5/stage2_6/logs/, Num streaks: 10693
Dir: /home/asaph/gcp_output/stage2_normal_nohdf5/stage2_7/logs/, Num streaks: 11436
Dir: /home/asaph/gcp_output/stage2_normal_nohdf5/stage2_8/logs/, Num streaks: 10131


In [371]:
def smooth(scalars: list[float], weight: float) -> list[float]:
    """
    EMA implementation according to
    https://github.com/tensorflow/tensorboard/blob/34877f15153e1a2087316b9952c931807a122aa7/tensorboard/components/vz_line_chart2/line-chart.ts#L699
    """
    last = 0
    smoothed = []
    num_acc = 0
    for next_val in scalars:
        last = last * weight + (1 - weight) * next_val
        num_acc += 1
        # de-bias
        debias_weight = 1
        if weight != 1:
            debias_weight = 1 - weight**num_acc
        smoothed_val = last / debias_weight
        smoothed.append(smoothed_val)

    return smoothed

In [405]:
action_types = []
ids = []
with h5py.File('../actions_all_bouts_with_null.h5', 'r') as f:
    for group_name in f.keys():
        group = f[group_name]
        action = {
                    'name': group_name,
                    'mean': group['mean'][:],
                    'cov': group['cov'][:],
                    'is_turn': group.attrs['is_turn'],
                    'is_capture': group.attrs['is_capture'],
                    'color': group.attrs['color']
            }
        if '_L' in action['name']:
            action['color'] *= 1.1
        if '_R' in action['name']:
            action['color'] *= 0.9
        action['color'] = np.clip(action['color'], 0, 1)
        ids.append(group.attrs['id'])
        action_types.append(action)
        # sort actions by id
action_types = [x for _, x in sorted(zip(ids, action_types), key=lambda pair: pair[0])]
action_names = [action_types[i]['name'] for i in range(len(action_types))]
action_colors = [action_types[i]['color'] for i in range(len(action_types))]  


In [383]:
plt.figure()
for ds in range(8):
    smoothed = smooth(training_episode_return[ds][:, 1], 0.99)[10:]
    plt.plot(training_episode_return[ds][10:, 0]/1e6, smoothed, label=f'Dir {ds+1}')
plt.xlabel('Actor Steps (millions)')
plt.ylabel('Episode Return (EMA smoothed)')

Text(0, 0.5, 'Episode Return (EMA smoothed)')

In [438]:
# pre_avoid = np.concatenate(pre_avoid_seq)
# plt.imshow((pre_avoid==3) | (pre_avoid==4))
plt.figure()
from matplotlib.colors import ListedColormap
pre_con = np.concatenate(pre_consume_seq)
pre_avoid = np.concatenate(pre_avoid_seq)
my_cmap = ListedColormap(np.array(action_colors))
time_base = np.arange(pre_con.shape[1]) - 20
plt.subplot(2, 2, 1)
plt.imshow(pre_con, aspect='auto', cmap=my_cmap, extent=[time_base[0], time_base[-1], 0, pre_con.shape[0]], interpolation='nearest')
plt.subplot(2, 2, 3)
plt.plot(time_base, np.mean((pre_con==3) | (pre_con==4), axis=0), label='O-bend', color=action_colors[3])
plt.plot(time_base, np.mean((pre_con==5) | (pre_con==6), axis=0), label='J turn', color=action_colors[5])
plt.plot(time_base, np.mean((pre_con==0) | (pre_con==1), axis=0), label='Capture', color=action_colors[0])
plt.plot(time_base, np.mean((pre_con==18) | (pre_con==19), axis=0), label='HAT', color=action_colors[18])
plt.xlim([-20, 5])
plt.legend()
plt.xlabel('Time steps relative to prey capture')
plt.ylabel('Proportion of actions')
plt.subplot(2, 2, 2)
plt.imshow(pre_avoid, aspect='auto', cmap=my_cmap, extent=[time_base[0], time_base[-1], 0, pre_avoid.shape[0]], interpolation='nearest')
plt.subplot(2, 2, 4)
plt.plot(time_base, np.mean((pre_avoid==3) | (pre_avoid==4), axis=0), label='O-bend', color=action_colors[3])
plt.plot(time_base, np.mean((pre_avoid==5) | (pre_avoid==6), axis=0), label='J turn', color=action_colors[5])
plt.plot(time_base, np.mean((pre_avoid==0) | (pre_avoid==1), axis=0), label='Capture', color=action_colors[0])
plt.plot(time_base, np.mean((pre_avoid==18) | (pre_avoid==19), axis=0), label='HAT', color=action_colors[18])
plt.xlim([-20, 5])
plt.xlabel('Time steps relative to predator avoidance')
plt.ylabel('Proportion of actions')


Text(0, 0.5, 'Proportion of actions')

In [452]:
qq0 = np.concatenate(q0, axis=0)
qq1 = np.concatenate(q1, axis=0)
qq3 = np.concatenate(q3, axis=0)
qq4 = np.concatenate(q4, axis=0)
qq5 = np.concatenate(q5, axis=0)
qq6 = np.concatenate(q6, axis=0)

plt.figure()

plt.scatter(qq0[:, :, 1], qq0[:, :, 0], s=2, alpha=0.05,color='orange', label='Short Capture')
plt.scatter(qq1[:, :, 1], qq1[:, :, 0], s=2, alpha=0.05,color='c', label='Long Capture')

plt.scatter(qq5[:, :, 1], qq5[:, :, 0], s=2, alpha=0.05,color='r', label='Right J turn')
plt.scatter(qq6[:, :, 1], qq6[:, :, 0], s=2, alpha=0.05,color='b', label='Left J turn')

plt.scatter(qq3[:, :, 1], qq3[:, :, 0], s=3, alpha=0.3,color='m', label='Right O-bend')
plt.scatter(qq4[:, :, 1], qq4[:, :, 0], s=3, alpha=0.3,color='g', label='Left O-bend')
plt.scatter(0, 0, s=100, c='k', marker='x')
# plt.axis('equal')

plt.xlim(-200, 200)
plt.ylim(-100, 250)
# plt.legend(bbox_to_anchor=(1.02, 1))


(-100.0, 250.0)

In [407]:
%matplotlib qt
a_prop = np.array(actions_proportion)
# show as stacked bar plot
labels = [str(i) for i in range(a_prop.shape[1])]
x = np.arange(a_prop.shape[0])
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
bottom = np.zeros(a_prop.shape[0])
# set a color sequence

for i in range(a_prop.shape[1]):
    ax.bar(x, a_prop[:, i], bottom=bottom, label=action_names[i], color=action_colors[i])
    bottom += a_prop[:, i]
ax.set_xlabel('Agent')
ax.set_ylabel('Proportion')
ax.set_title('Action Proportions Across Agents')
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], loc='upper left', bbox_to_anchor=(1.02, 1), fontsize=8)

bottom = np.zeros(a_prop.shape[0])
fig, ax = plt.subplots(1, 1, figsize=(6, 6))

for i, reward_type in enumerate(['prey_reward', 'pred_reward']):
    ax.bar(x, np.mean(all_rewards[reward_type], axis=1), label=reward_type, bottom=bottom)
    bottom += np.mean(all_rewards[reward_type], axis=1)
bottom = np.zeros(a_prop.shape[0])
for i, reward_type in enumerate(['salt_reward', 'energy_reward']):
    ax.bar(x, np.mean(all_rewards[reward_type], axis=1), label=reward_type, bottom=bottom)
    bottom += np.mean(all_rewards[reward_type], axis=1)

# make a line plot of episode return on the same axis
#ax.plot(x, np.mean(all_rewards['episode_return'], axis=1), color='k', marker='o', label='Total Return')
# plot error bars for episode return
ax.errorbar(x, np.mean(all_rewards['episode_return'], axis=1), yerr=np.std(all_rewards['episode_return'], axis=1), color='k', fmt='o', capsize=3, label='Total Return')
# ax[1].scatter(x, np.mean(-1*all_rewards['energy_reward'], axis=1), color='k', marker='o', label='Energy Cost')
# ax[1].scatter(x, np.mean(-1*all_rewards['salt_reward'], axis=1), color='grey', marker='o', label='Salt Cost')

ax.set_xlabel('Agent')
ax.set_ylabel('Reward')
ax.legend(loc='upper left', bbox_to_anchor=(1.02, 1))
# bold horizontal line at y=0
ax.axhline(0, color='k', linewidth=2)
plt.tight_layout()
plt.show()


In [408]:
for dataset in range(len(dirs)):
   
    plt.plot(mean_streaks[dataset], '-o', label=f'Ag {dataset+1}')
    plt.fill_between(np.arange(len(mean_streaks[dataset])),
                     mean_streaks[dataset] - sem_streaks[dataset],
                     mean_streaks[dataset] + sem_streaks[dataset], alpha=0.3)

plt.axhline(0, color='k', linestyle='--', label='Chance')
plt.xlabel('Timestep after switch')
plt.ylabel('Cumulative Turn Angle (rad)')
plt.title('Cumulative Turn Angle After Direction Switch')
plt.legend(fontsize=8)

<matplotlib.legend.Legend at 0x7b4f223a0a60>