In [None]:
import sys
from pathlib import Path

# Add project root to path so imports work from notebooks/ directory
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns
from core.plotting import set_plotting_defaults
from definitions import paper_fig_dir
import os
import os.path as op
from adaptation.adaptation_generalization_test import make_generalization_plot
from core.continuous_env import ReachTask
from core.agent import ACLearningAgentWithEmbedding
from core.config import config as pre_cfg
from adaptation.config import config as post_cfg
from tqdm.notebook import tqdm
from definitions import paper_model_path, paper_fig_dir
import distinctipy
import torch
import random

In [None]:
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.colors import ListedColormap
# plt.rcParams['text.usetex'] = True

critic_loss_color = '#A94850'
angle_diff_color = '#44123F'
embedding_loss_color = '#6067B6'


n_acts = 24

colors = distinctipy.get_colors(n_acts, pastel_factor=0.7)
cmap = ListedColormap(colors)


In [None]:
# fix random seed (we'll load in a model and sample actions from the policy)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)


In [None]:
os.listdir(paper_fig_dir)

In [None]:
n_episodes = 100000
seed = 0 
example_target = 135
fp = Path(paper_fig_dir) / f'generalization_stats_seed_{seed}_target_{example_target}_rotation_-30.csv'
genstats = pd.read_csv(fp)

In [None]:
make_generalization_plot(genstats)

In [None]:
dfs = []
runlog_dfs = []
for seed in tqdm(range(10)): # change to 10 when done
    for target in tqdm(range(0, 360, 15), leave=False):
        fp = Path(paper_fig_dir) / f'generalization_stats_seed_{seed}_target_{target}_rotation_-30.csv'
        fp_runlog = Path(paper_fig_dir) / f'adaptation_run_log_seed_{seed}_target_{target}.csv'
        genstats = pd.read_csv(fp)
        runlog = pd.read_csv(fp_runlog)
        genstats['seed'] = seed
        genstats['target'] = target 
        runlog['seed'] = seed
        runlog['target'] = target
        runlog = runlog[runlog.episode< n_episodes]
        
        downsample_factor = 10000
        runlog['episode_bin'] = runlog.episode // downsample_factor
        runlog_ds = runlog.groupby('episode_bin').mean()
        dfs.append(genstats)
        runlog_dfs.append(runlog_ds)

In [None]:
all_gen_data = pd.concat(dfs)
all_runlogs = pd.concat(runlog_dfs)

all_targets = all_gen_data.target.unique()

In [None]:
fig, axs = plt.subplots()
sns.lineplot(data=all_gen_data, x='angle from target', y='angular error',  errorbar='sd', ax=axs)
axs.set_ylabel('angular error')
axs.set_xlim([-90, 180])
axs.set_ylim([35, -15])

In [None]:
fig, axs = plt.subplots()
sns.lineplot(data=all_gen_data, x='angle from target', y='rotation generalization',  errorbar='sd', ax=axs)
axs.set_ylabel('Rotation generalization')
axs.set_xlim([-90, 180])

## plot the action histograms for pre- and post-adaptation

In [None]:
# first load pre-adapation model
example_reach_angle = 135
n_acts = 24

env_norot = ReachTask(pre_cfg)
act_hist_pre_dict = {}
for train_seed in range(10):
    model_weights_dir = paper_model_path
    model_weights_fn = f'fully_trained_policy_model_one_target_seed_{train_seed}_weight_decay_0.0001_tanh_policy_mean_target_{example_reach_angle}_n_actions_{n_acts}.pth'
    full_model_load_path = os.path.join(model_weights_dir, model_weights_fn)
    
    pre_agent = ACLearningAgentWithEmbedding(env_norot, pre_cfg,
                                             full_model_load_path=full_model_load_path,
                                             fg_load_path=None,
                                             g_plastic=False,
                                             f_plastic=False,
                                             actor_plastic=False,
                                             critic_plastic=False)
    
    
    act_hist_pre = []
    for i in range(10000):
        state_tensor = pre_agent.env.get_features(pre_agent.env.current_xy)
        action_ind, sample_emb, mean_emb, log_std_emb = pre_agent.select_action(state_tensor)
    
        act_hist_pre.append(action_ind)

    act_hist_pre_dict[train_seed] = act_hist_pre

In [None]:
all_runlogs['episode'].max()


In [None]:
def plot_action_distribution(action_history, ax=None, optimal_action_pre=None, optimal_action_post=None, **kwargs):
    # n_acts and colors are global params 
    
    bins = np.arange(0, n_acts + 1)
    n, bins, patches = ax.hist(action_history, bins=bins, density=True, **kwargs)
    
    for patch, color in zip(patches, colors):
        patch.set_facecolor(color)
    ax.set_xlabel('Action Index')
    ax.set_ylabel('Frequency')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    
    ax.set_xlabel('Action Index')
    ax.set_ylabel('Frequency')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    
    # Find optimal angles (pre and post rotation)
    if optimal_action_pre is not None:
        ax.axvline(x=optimal_action_pre + .5, color='black', linestyle='-', lw=1)
    if optimal_action_post is not None:
        ax.axvline(x=optimal_action_post + .5, color='black', linestyle='--', lw=1)


In [None]:
fig, axs = plt.subplots(figsize=(1.5, 1.5))

optimal_action_pre = pre_agent.find_optimal_action_ind(rotation_angle=0)
optimal_action_post = pre_agent.find_optimal_action_ind(rotation_angle=np.radians(-30))

mean_act_hist_pre = np.array(list(act_hist_pre_dict.values())).mean(axis=0)

plot_action_distribution(mean_act_hist_pre, ax=axs, optimal_action_pre=optimal_action_pre, optimal_action_post=optimal_action_post, alpha=1.)

In [None]:
# now for post-adaptation model 

act_hist_post_dict = {}
env_rot = ReachTask(post_cfg)

for train_seed in tqdm(range(10)):
    
    model_weights_dir = paper_model_path
    model_weights_fn = f'post_adaptation_model_seed_{train_seed}_rotation_-30_temp_2.5_weight_decay_0.0001_tanh_policy_mean_n_actions_24_target_{example_reach_angle}.0.pth'
    
    full_model_load_path = os.path.join(model_weights_dir, model_weights_fn)
    
    post_agent = ACLearningAgentWithEmbedding(env_rot, pre_cfg,
                                             full_model_load_path=full_model_load_path,
                                             fg_load_path=None,
                                             g_plastic=False,
                                             f_plastic=False,
                                             actor_plastic=False,
                                             critic_plastic=False)
    
    act_hist_post = []
    for i in tqdm(range(10000), leave=False):
        state_tensor = post_agent.env.get_features(post_agent.env.current_xy)
        action_ind, sample_emb, mean_emb, log_std_emb = post_agent.select_action(state_tensor)
    
        act_hist_post.append(action_ind)

    act_hist_post_dict[train_seed] = act_hist_post


In [None]:
fig, ax = plt.subplots(figsize=(1.5, 1.5))
mean_act_hist_post = np.array(list(act_hist_post_dict.values())).mean(axis=0)
plot_action_distribution(mean_act_hist_post, ax=ax, optimal_action_pre=optimal_action_pre, 
                             optimal_action_post=optimal_action_post, alpha=1.)

# plot action decoder outputs

In [None]:
from core.plotting import plot_f_output

In [None]:
example_train_seed = 6

In [None]:
model_weights_dir = paper_model_path
model_weights_fn = f'fully_trained_policy_model_one_target_seed_{example_train_seed}_weight_decay_0.0001_tanh_policy_mean_target_{example_reach_angle}_n_actions_{n_acts}.pth'
full_model_load_path = os.path.join(model_weights_dir, model_weights_fn)

pre_agent = ACLearningAgentWithEmbedding(env_norot, pre_cfg,
                                         full_model_load_path=full_model_load_path,
                                         fg_load_path=None,
                                         g_plastic=False,
                                         f_plastic=False,
                                         actor_plastic=False,
                                         critic_plastic=False)


act_hist_pre = []

In [None]:
fig, ax = plt.subplots(figsize=(3.5,3.5))
pre_agent.plot_f_output(ax=ax, cbar=True, cmap=cmap)
ax.set_aspect('equal')

In [None]:
model_weights_dir = paper_model_path
model_weights_fn = f'post_adaptation_model_seed_{example_train_seed}_rotation_-30_temp_2.5_weight_decay_0.0001_tanh_policy_mean_n_actions_24_target_{example_reach_angle}.0.pth'

full_model_load_path = os.path.join(model_weights_dir, model_weights_fn)

post_agent = ACLearningAgentWithEmbedding(env_rot, pre_cfg,
                                         full_model_load_path=full_model_load_path,
                                         fg_load_path=None,
                                         g_plastic=False,
                                         f_plastic=False,
                                         actor_plastic=False,
                                         critic_plastic=False)
    

In [None]:
fig, ax = plt.subplots(figsize=(3.5,3.5))
post_agent.plot_f_output(ax=ax, cbar=True, cmap=cmap)
ax.set_aspect('equal')

# plot krakauer et al data

In [None]:
data_path = 'adaptation/extracted_data_krakauer2000.csv'

def load_krakauer_data(data_path):
    df = pd.read_csv(data_path, header=None)
    df = df.rename({0: 'Target direction (°)', 1:'Percent adaptation', 2:'cat', 3: 'Participant'},axis=1)
    return df

krakauer_df = load_krakauer_data(data_path)


In [None]:


sns.lineplot(
    data=krakauer_df, 
    x='Target direction (°)', 
    y='Percent adaptation', 
    style='Participant',
    markers=['^', 'v', 'o', 's'],
    color=angle_diff_color,  
    dashes=False,
    markersize=10,
    legend=False
)


# plot main figure 

In [None]:
sample_emb

In [None]:
save_dir = op.join(paper_fig_dir, 'fig3')
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
set_plotting_defaults(font_size=9)
fig = plt.figure(figsize=(11, 4))

# Simple gridspec without special accommodation for colorbar
gs_main = gridspec.GridSpec(2, 4, figure=fig, 
                           width_ratios=[2, 1, 1, 1],  # Equal widths
                           wspace=0.4, hspace=0.6)

# Create axes
axes = np.empty((2, 4), dtype=object)
axes[0, 0] = fig.add_subplot(gs_main[0, 0])
axes[0, 1] = fig.add_subplot(gs_main[0, 1])
axes[0, 2] = fig.add_subplot(gs_main[0, 2])
axes[0, 3] = fig.add_subplot(gs_main[0, 3])

# Create a sub-gridspec for the lower left position
gs_lower_left = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs_main[1, 0], 
                                                  wspace=.6, width_ratios=[1.5, 1])
# Left half of lower left
axes[1, 0] = fig.add_subplot(gs_lower_left[0, 0])

# Create another sub-gridspec for the right half of lower left (split into two rows)
gs_lower_left_right = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs_lower_left[0, 1],
                                                       hspace=0.3)
# Create two new axes for the split right column
ax_lower_left_bottom = fig.add_subplot(gs_lower_left_right[1, 0])
ax_lower_left_top = fig.add_subplot(gs_lower_left_right[0, 0], sharex=ax_lower_left_bottom)


# Continue with the rest of the main grid
axes[1, 1] = fig.add_subplot(gs_main[1, 1])
axes[1, 2] = fig.add_subplot(gs_main[1, 2])
axes[1, 3] = fig.add_subplot(gs_main[1, 3])

axes[0, 0].set_axis_off()

# plot krakauer data 
sns.lineplot(
    data=krakauer_df, 
    x='Target direction (°)', 
    y='Percent adaptation', 
    style='Participant',
    markers=['^', 'v', 'o', 's'],
    color=angle_diff_color,  
    dashes=False,
    markersize=6,
    legend=False,
    ax=axes[0,1]
)
axes[0, 1].spines['top'].set_visible(False)
axes[0, 1].spines['right'].set_visible(False)
# axes[0, 1].set_xlabel('angle from target (°)')
axes[0, 1].set_title('Data (Krakauer et al.)')

axes[0, 2].set_axis_off()


####################################################################################################################################
############################################ First row #############################################################################
####################################################################################################################################


# plot embedding loss across re-training 
sns.lineplot(data=all_runlogs[all_runlogs.target==example_target], x='episode', y='nll_loss', ax=axes[0,3],
            color=embedding_loss_color)
axes[0, 3].ticklabel_format(style='scientific', axis='x', scilimits=(4,4))
axes[0, 3].spines['top'].set_visible(False)
axes[0, 3].spines['right'].set_visible(False)
axes[0, 3].set_ylabel('Embedding loss')


####################################################################################################################################
############################################ Second row ############################################################################
####################################################################################################################################

# plot angular error across re-training
sns.lineplot(data=all_runlogs[all_runlogs.target==example_target], x='episode', y='angle_diff', ax=axes[1,0], color=angle_diff_color)
axes[1, 0].ticklabel_format(style='scientific', axis='x', scilimits=(4,4))
axes[1, 0].set_ylabel('Angular error (deg)')
axes[1, 0].spines['top'].set_visible(False)
axes[1, 0].spines['right'].set_visible(False)


# plot action distributions pre and post adaptation 
bins = np.arange(0, pre_agent.env.n_actions + 1)
# Pre adaptation

plot_action_distribution(mean_act_hist_pre, ax=ax_lower_left_top, optimal_action_pre=optimal_action_pre, optimal_action_post=optimal_action_post)
ax_lower_left_top.tick_params(labelbottom=False, bottom=False)
ax_lower_left_top.set_xlabel('')


# Post adaptation
plot_action_distribution(mean_act_hist_post, ax=ax_lower_left_bottom, optimal_action_pre=optimal_action_pre, optimal_action_post=optimal_action_post)
ax_lower_left_bottom.set_ylabel('Frequency')
ax_lower_left_bottom.spines['right'].set_visible(False)
ax_lower_left_bottom.spines['top'].set_visible(False)

ax_lower_left_bottom.set_xlabel('Action Index')
ax_lower_left_bottom.set_ylim([0, 0.5])
ax_lower_left_bottom.set_yticks([0, 0.5])
ax_lower_left_top.set_ylim([0, 0.5])
ax_lower_left_top.set_yticks([0, 0.5])


# now the weight plots pre and post 
pre_agent.plot_f_output(ax=axes[1,1], cbar=True, cmap=cmap)
axes[1, 1].set_title('$f$ pre adaptation', fontsize=10)
post_agent.plot_f_output(ax=axes[1,2], cbar=True, cmap=cmap)
axes[1, 2].set_title('$f$ post adaptation', fontsize=10)


# plot generalization in terms of angular error
sns.lineplot(data=all_gen_data, x='angle from target', y='angular error',  errorbar='sd',
             ax=axes[1, 3], color=angle_diff_color)
axes[1, 3].set_ylabel('Angular error (°)')
axes[1, 3].set_xlabel('Target direction (°)')
axes[1, 3].set_xlim([-90, 180])
axes[1, 3].set_ylim([35, -15])
axes[1, 3].spines['top'].set_visible(False)
axes[1, 3].spines['right'].set_visible(False)
axes[1, 3].set_title('Model')

plt.savefig(op.join(save_dir, 'fig3.pdf'))
plt.savefig(op.join(save_dir, 'fig3.png'))

In [None]:
op.join(save_dir,'fig3.pdf')

# plot angular error generalization for all trained reach angles

In [None]:
set_plotting_defaults(font_size=10)
fig, axes = plt.subplots(4, 6, figsize=(9, 5), sharey='all', sharex='all')
axes = axes.flatten()

for i, target in enumerate(all_targets):
    sns.lineplot(data=all_gen_data, x='angle from target', y='angular error',  errorbar='sd',
                 ax=axes[i], color=angle_diff_color)
    axes[i].set_ylabel('angular error °')
    axes[i].set_xlabel('Angle from target °')
    axes[i].set_xlim([-90, 180])
    axes[i].set_ylim([35, -15])
    axes[i].set_title(f'target {all_targets[i]}°')

plt.tight_layout()
plt.savefig(paper_fig_dir / 'supp_angularerr_alltargets.pdf')

# plot rotation generalization measure (as in original experimental paper) for all trained reach angles

In [None]:
fig, axes = plt.subplots(4, 6, figsize=(9, 5), sharey='all', sharex='all')
axes = axes.flatten()
# todo: maybe nicer color for this
# todo: is this metric computed correctly? 
for i, target in enumerate(all_targets):
    sns.lineplot(data=all_gen_data, x='angle from target', y='rotation generalization',  errorbar='sd',
                 ax=axes[i], color='k')
    axes[i].set_ylabel('Rotation generalization')
    axes[i].set_xlim([-90, 180])
    axes[i].set_ylim([-10, 140])
    axes[i].set_title(f'target {all_targets[i]}')

plt.tight_layout()
plt.savefig(paper_fig_dir / 'rotation_generalization_alltargets.pdf')

In [None]:
os.path.exists(paper_fig_dir/'rotation_generalization_alltargets.pdf')

In [None]:
paper_fig_dir