In [None]:
import sys
from pathlib import Path

# Add project root to path so imports work from notebooks/ directory
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os 
import os.path as op

from definitions import paper_fig_dir, paper_model_path
from tqdm.notebook import tqdm
import math
from scipy import signal
from core.plotting import set_plotting_defaults
from core.continuous_env import ReachTask
from core.agent import ACLearningAgentWithEmbedding

set_plotting_defaults()

example_angle = 135  # this is the target angle we'll use for the example plots in the main paper. Supplement will show other targets

critic_loss_color = '#A94850'
angle_diff_color = '#44123F'
embedding_loss_color = '#6067B6'

num_actions = 24

actions = np.linspace(0, 2 * np.pi, num_actions, endpoint=False)

action_colors = plt.cm.twilight(np.linspace(0, 1, num_actions))

# plot training data for all targets

In [None]:
all_data_fn = 'all_losses_and_errors_all_angles_and_seeds.csv'
all_npy_files = [file for file in os.listdir(paper_fig_dir) if file.endswith('.npy')]

In [None]:
angle_diff_fns = [op.join(paper_fig_dir, f) for f in all_npy_files if f.startswith('angle_differences')]
critic_loss_fns = [op.join(paper_fig_dir, f) for f in all_npy_files if f.startswith('critic_loss')]

In [None]:
# 10 seeds and all angles with steps of 15
seeds = list(range(10))
angles = list(range(0, 360, 15))
window_size = 500

In [None]:
if not os.path.exists(op.join(paper_fig_dir, all_data_fn)):
    dfs = []
    for angle in tqdm(angles):
        for seed in tqdm(seeds, leave=False):
            critic_loss = np.load(op.join(paper_fig_dir, f'critic_loss_target_{angle}_seed_{seed}.npy'))
            episode = np.arange(len(critic_loss))
            critic_loss = signal.decimate(critic_loss, window_size)
            episode = signal.decimate(episode, window_size)
            
            angle_diff = np.load(op.join(paper_fig_dir, f'angle_differences_target_{angle}_seed_{seed}.npy'))
            # abs_angle_diff = np.abs(angle_diff)
            angle_diff = signal.decimate(angle_diff, window_size)
            # abs_angle_diff = signal.decimate(abs_angle_diff, window_size)
            
            df = pd.DataFrame({'episode': episode, 'critic_loss': critic_loss, 'angle_difference': angle_diff,})
                              # 'abs_angle_difference': abs_angle_diff})
            df['target_angle'] = angle
            df['seed'] = seed
            dfs.append(df)
    all_df = pd.concat(dfs, ignore_index=True)
    all_df['abs_angle_difference'] = all_df.angle_difference.apply(np.abs)
    all_df.to_csv(op.join(paper_fig_dir, all_data_fn))
else:
    print(f'loading pre-processed from disk: {all_data_fn}')
    all_df = pd.read_csv(op.join(paper_fig_dir, all_data_fn))

In [None]:
all_df

In [None]:
set_plotting_defaults(font_size=10)
fig, axes = plt.subplots(4, 6, figsize=(9, 5), sharey='all', sharex='all')
axes = axes.flatten()
all_df['Critic loss'] = all_df.critic_loss
for i, angle in enumerate(all_df.target_angle.unique()):
    sns.lineplot(data=all_df[all_df.target_angle==angle], x='episode', y='Critic loss', ax=axes[i], color=critic_loss_color, errorbar='se')
    axes[i].set_title(f'target {angle}°')

for ax in axes:
    ax.set_ylim([0, 0.5])
    ax.ticklabel_format(style='scientific', axis='x', scilimits=(5,5))

plt.tight_layout()
plt.savefig(op.join(paper_fig_dir, 'critic_loss_all_targets.pdf'))
plt.savefig(op.join(paper_fig_dir, 'critic_loss_all_targets.png'))

In [None]:
set_plotting_defaults(font_size=10)
fig, axes = plt.subplots(4, 6, figsize=(9, 5), sharey='all', sharex='all')
axes = axes.flatten()
all_df['Angular error'] = all_df.abs_angle_difference
for i, angle in enumerate(all_df.target_angle.unique()):
    sns.lineplot(data=all_df[all_df.target_angle==angle], x='episode', y='Angular error', ax=axes[i], color=angle_diff_color, errorbar='se')
    axes[i].set_title(f'target {angle}°')

for ax in axes:
    ax.ticklabel_format(style='scientific', axis='x', scilimits=(5,5))

plt.tight_layout()

plt.savefig(op.join(paper_fig_dir, 'angle_diffs_all_targets.pdf'))
plt.savefig(op.join(paper_fig_dir, 'angle_diffs_all_targets.png'))

In [None]:
# now for a single target (pick 135°)

example_angle = 135

set_plotting_defaults(font_size=7)

df_example = all_df[all_df.target_angle==example_angle]
fig2, axs2 = plt.subplots(1,1, figsize=(1.5, 1.5))
sns.lineplot(data=df_example, x='episode', y='Angular error', errorbar='se', color=angle_diff_color)

# Supervised losses 

In [None]:
n_episodes = 1500000
log_interval = 10000

sl_episodes = np.arange(0, n_episodes, log_interval)
supervised_losses = [np.load(os.path.join(paper_model_path, f'losses_seed_{seed}.npy')) for seed in range(10)]

data_list = []
for seed_idx, losses in enumerate(supervised_losses):
    for episode_idx, episode in enumerate(sl_episodes):
        data_list.append({
            'Episode': episode,
            'Loss': losses[episode_idx],
            'Seed': seed_idx
        })

sl_df = pd.DataFrame(data_list)

In [None]:
sns.lineplot(data=sl_df, x='Episode', y='Loss', color=embedding_loss_color)

# action histogram 

In [None]:
# for this single target, plot the action histogram of an example run 

trained_angle_ind = np.where(np.round(actions, 3) == np.round(np.radians(example_angle), 3))[0][0]

os.listdir(paper_fig_dir)
ep = 190000
num_actions = 24

action_hist = np.load(os.path.join(paper_fig_dir, f'action_hist_episode_{ep}_target_{example_angle}_seed_0.npy'))

fig, axs = plt.subplots(1,1, figsize=(3.2, 3.2))

bins = np.arange(0, num_actions + 1)
# Plot the histogram
n, bins, patches = plt.hist(action_hist, bins=bins, density=True)

# Apply colors to each bin
for patch, color in zip(patches, action_colors):
    patch.set_facecolor(color)
axs.set_xlabel('Action Index')
axs.set_ylabel('Frequency')
axs.axvline(trained_angle_ind +.5, color='k', linestyle='--')
axs.spines['right'].set_visible(False)
axs.spines['top'].set_visible(False)


# embedding plots

In [None]:
# Next we want to plot the embeddings. for this, we'll load the trained model
seed = 0
from core.config import config 

model_weights_fn = f'fully_trained_policy_model_one_target_seed_{seed}_weight_decay_0.0001_tanh_policy_mean_target_{example_angle}_n_actions_{num_actions}.pth'
full_model_load_path = os.path.join(paper_model_path, model_weights_fn)

env = ReachTask(config)
agent = ACLearningAgentWithEmbedding(env, config, full_model_load_path=full_model_load_path, f_plastic=False, g_plastic=False)


In [None]:
# TODO: plot embedding given this agent. 

optimal_action = agent.find_optimal_action_ind()

fig, ax = plt.subplots(figsize=(2,2))
cmap = plt.get_cmap('twilight')
colors = plt.cm.twilight(np.linspace(0, 1, config['num_actions']))
agent.plot_embeddings_and_pi_i(colors[optimal_action], cmap, ax, s=20)

In [None]:
fig, ax = plt.subplots()
agent.plot_actions_task_space(ax)

# Plot figure 1 

In [None]:
# TODO: for this particular seed, embedding looks a bit bad. pick another one. 
set_plotting_defaults(font_size=9)
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# Create figure
fig = plt.figure(figsize=(11, 4))

# Simple gridspec without special accommodation for colorbar
gs_main = gridspec.GridSpec(2, 4, figure=fig, 
                           width_ratios=[1.7, 1, 1.3, 1],  # Equal widths
                           wspace=0.5, hspace=0.5)

# Create axes
axes = np.empty((2, 4), dtype=object)
axes[0, 0] = fig.add_subplot(gs_main[0, 0])
axes[0, 1] = fig.add_subplot(gs_main[0, 1])
axes[0, 2] = fig.add_subplot(gs_main[0, 2])
axes[0, 3] = fig.add_subplot(gs_main[0, 3])
axes[1, 0] = fig.add_subplot(gs_main[1, 0])
axes[1, 1] = fig.add_subplot(gs_main[1, 1])
axes[1, 2] = fig.add_subplot(gs_main[1, 2])
axes[1, 3] = fig.add_subplot(gs_main[1, 3])

# Force square aspect ratio for embedding plot
axes[1, 0].set_aspect('equal', adjustable='box')
axes[0, 2].set_aspect('equal', adjustable='box')
# axes[1, 2].set_aspect('equal', adjustable='datalim')
axes[1, 2].set_box_aspect(1)
################################################## first row ###########################################
axes[0, 0].set_axis_off()
axes[0, 1].set_axis_off()

# plot cartesian coordinates
agent.plot_actions_task_space(axes[0, 2],  s=20)
axes[0,2].set_title('Task space', fontsize=9)

# Add colorbar using inset_axes - positioned to the right of the plot
axins = inset_axes(axes[0, 2],
                   width="5%",  # width = 5% of parent_bbox width
                   height="100%",  # height = 100% of parent_bbox height
                   loc='center left',
                   bbox_to_anchor=(1.02, 0., 1, 1),  # Place just outside the right edge
                   bbox_transform=axes[0, 2].transAxes,
                   borderpad=0)

# Create colorbar in the inset axes
import matplotlib.cm as cm
import matplotlib.colors as mcolors
norm = mcolors.Normalize(vmin=0, vmax=config['num_actions']-1)
sm = cm.ScalarMappable(cmap=plt.cm.twilight, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, cax=axins)
cbar.set_label('Action Index')

# plot embedding loss 
axes[0, 3].spines['top'].set_visible(False)
axes[0, 3].spines['right'].set_visible(False)
axes[0, 3].set_ylabel('Supervised loss')
sns.lineplot(data=sl_df, x='Episode', y='Loss', color=embedding_loss_color, ax=axes[0,3])


################################################## second row ##########################################
# plot embedding space with policy distribution and sample 
colors = plt.cm.twilight(np.linspace(0, 1, config['num_actions']))
agent.plot_embeddings_and_pi_i(colors[optimal_action], cmap, axes[1, 0], s=20)
axes[1,0].set_xlabel('Embedding dim. 1')
axes[1,0].set_ylabel('Embedding dim. 2')
axes[1,0].set_title('Embedding space', fontsize=9)

# Add colorbar using inset_axes - positioned to the right of the plot
axins = inset_axes(axes[1, 0],
                   width="5%",  # width = 5% of parent_bbox width
                   height="100%",  # height = 100% of parent_bbox height
                   loc='center left',
                   bbox_to_anchor=(1.02, 0., 1, 1),  # Place just outside the right edge
                   bbox_transform=axes[1, 0].transAxes,
                   borderpad=0)

# Create colorbar in the inset axes
import matplotlib.cm as cm
import matplotlib.colors as mcolors
norm = mcolors.Normalize(vmin=0, vmax=config['num_actions']-1)
sm = cm.ScalarMappable(cmap=plt.cm.twilight, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, cax=axins)
cbar.set_label('Action Index')

# plot critic loss 
sns.lineplot(data=df_example, x='episode', y='Critic loss', errorbar='se', color=critic_loss_color, ax=axes[1,1])
axes[1, 1].ticklabel_format(style='scientific', axis='x', scilimits=(5,5))
axes[1, 1].spines['top'].set_visible(False)
axes[1, 1].spines['right'].set_visible(False)

# plot angular error 
sns.lineplot(data=df_example, x='episode', y='Angular error', errorbar='se', color=angle_diff_color, ax=axes[1,2])
axes[1, 2].ticklabel_format(style='scientific', axis='x', scilimits=(5,5))
axes[1, 2].spines['top'].set_visible(False)
axes[1, 2].spines['right'].set_visible(False)

# plot action histogram 
bins = np.arange(0, num_actions + 1)
# Plot the histogram
n, bins, patches = axes[1, 3].hist(action_hist, bins=bins, density=True)
# Apply colors to each bin
for patch, color in zip(patches, action_colors):
    patch.set_facecolor(color)
axes[1, 3].set_xlabel('Action Index')
axes[1, 3].set_ylabel('Frequency')
axes[1, 3].axvline(trained_angle_ind +.5, color='#CBCBCB', linestyle='--')
axes[1, 3].spines['top'].set_visible(False)
axes[1, 3].spines['right'].set_visible(False)

plt.savefig(op.join(paper_fig_dir, 'fig1.pdf'))
plt.savefig(op.join(paper_fig_dir, 'fig1.png'))

**Figure 1**. Model architecture and training. **(A)** The action encoder-decoder network learns to predict taken actions via a learned embedding space, *e*. The encoder network, *g*, projects into the embedding space, and the decoder network, *f* learns to map from embeddings to actions. **(B)** Policy network illustration. The policy network learns an *internal policy* in the embedding space, and uses the learned action decoder, *f*, to map this back into the environment's action space. **(C)** Task illustration. Subjects reach from the center to one of 24 target angles. **(D)** Loss of the supervised encoder-decoder network during supervised pre-training. **(E)** Representation of each of the 24 reaches in the model's learned 2-dimensional embedding space. Each point indicates the model's middle layer activity in response to a particular (s_t, s_t+1) pair. The *x* with shaded area represents the agent's learned internal policy (radius of circle corresponds to the policy's standard deviation), and the *+* indicates a single sample from the policy. **(F)** Loss of the critic network (squared TD error) during training. **(G)** Absolute angular error the agent makes in degrees. **(H)** Learned action distribution. Dashed line shows optimal action. Shaded error bars represent SEM across 10 random seeds in D, F, G.