In [None]:
import sys
from pathlib import Path

# Add project root to path so imports work from notebooks/ directory
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import numpy as np
import pandas as pd
from definitions import data_root
import matplotlib.pyplot as plt
import seaborn as sns
import os
from core.plotting import set_plotting_defaults
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.colors import ListedColormap

from tqdm.notebook import tqdm

from definitions import paper_fig_dir, data_root
from scipy.stats import vonmises

In [None]:
critic_loss_color = '#A94850'
angle_diff_color = '#44123F'
embedding_loss_color = '#6067B6'

# some environment properties
n_actions = 24
action_angles = np.linspace(0, 2 * np.pi, n_actions, endpoint=False)

In [None]:
data_dir_180 = Path(data_root) / "doubleAdaptationExp_interleaved" / "phase2relangle_180deg"
palette = sns.color_palette('Paired')

In [None]:
# load all csvs in data_dir as pandas dataframes and concatenate them into a single dataframe
def get_action_hist_df(data_dir):
    all_files = os.listdir(data_dir)
    df_list = []
    for file in all_files:
        if file.endswith(".csv"):
            if file.startswith("action_hist"):
                df = pd.read_csv(data_dir / file)
                seed = int(file.split('seed_')[1].split('_')[0])
                df['seed'] = seed  # extract seed from filename
                df_list.append(df)
    df = pd.concat(df_list, ignore_index=True)
    return df

a_hist_df_180 = get_action_hist_df(data_dir_180)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar'))

conds = ['baseline_phase1', 'post_phase1', 'baseline_phase2', 'post_phase2']

for k, cond in enumerate(conds):
    acts_history = a_hist_df_180[cond].values
    unique_actions, counts = np.unique(acts_history, return_counts=True)
    assert len(unique_actions) == len(action_angles), "Not all actions were taken"
    proportions = counts / len(acts_history)
    taken_angles = action_angles[unique_actions]
    bar_width = 2 * np.pi / len(action_angles)
    ax.bar(taken_angles, proportions, width=bar_width, alpha=0.7,
           edgecolor='black', linewidth=0.5, color=palette[k])

ax.set_theta_zero_location('E')  # 0° at the right (East)
ax.set_theta_direction(1)        # positive angles counter-clockwise
ax.set_title('RL Agent Action Distribution\n(Proportion of Total Actions)', pad=20)
ax.set_ylabel('Proportion', labelpad=30)
ax.tick_params(axis='both', which='major', labelsize=12)

legend_elements = [
    plt.Rectangle((0, 0), 1, 1, facecolor=palette[0], label='Before +30° adaptation'),
    plt.Rectangle((0, 0), 1, 1, facecolor=palette[1], label='Before -30° adaptation'),
    plt.Rectangle((0, 0), 1, 1, facecolor=palette[2], label='After +30° adaptation'),
    plt.Rectangle((0, 0), 1, 1, facecolor=palette[3], label='After -30° adaptation')
]
ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0.1, 1.1))


plt.tight_layout()



In [None]:
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar'))

def plot_ahist_kde_polar(a_hist_df, ax=None, conds='all', only_post=False):
    
    if ax is None: 
        fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='polar'))

    conds = ['baseline_phase1', 'post_phase1', 'baseline_phase2', 'post_phase2']
    
    for k, cond in enumerate(conds):
        if only_post:
            if k % 2 == 0:
                continue
        acts_history = a_hist_df[cond].values
        taken_angles = action_angles[acts_history]
    
        # Create smooth theta values
        theta_smooth = np.linspace(0, 2*np.pi, 200)
    
        # Sum of von Mises distributions 
        kappa = 20  # Concentration parameter (higher = sharper)
        density = np.zeros_like(theta_smooth)
        for angle in taken_angles:
            density += vonmises.pdf(theta_smooth, kappa, loc=angle)
        density /= len(taken_angles)
    
        ax.plot(theta_smooth, density, linewidth=2.5, alpha=0.8, color=palette[k])
        ax.fill(theta_smooth, density, alpha=0.3, color=palette[k])
    return ax 

plot_ahist_kde_polar(a_hist_df_180, ax=ax)

In [None]:
# Experiment parameters
phase_1_rotation = -30
phase_2_rotation = 30

downsample_factor = 1000

def get_runlogs_df(data_dir, relangle=180):

    assert str(relangle) in str(data_dir)
    
    phase_1_angle = 135
    phase_2_angle = 135 + relangle
    
    # Data structures matching original code
    phase1_runlog_dfs = []
    phase2_runlog_dfs = []
    
    for seed in tqdm(range(10)):  # change to 10 when done
        # Phase 1 runlog
        fp_runlog_p1 = Path(data_dir) / f'interleaved_run_log_phase1_seed_{seed}_target_{phase_1_angle}_rotation_{phase_1_rotation}.csv'
        
        # Phase 2 runlog  
        fp_runlog_p2 = Path(data_dir) / f'interleaved_run_log_phase2_seed_{seed}_target_{phase_2_angle}_rotation_{phase_2_rotation}.csv'
        
        # Load and process Phase 1
        if fp_runlog_p1.exists():
            runlog_p1 = pd.read_csv(fp_runlog_p1)
            runlog_p1['seed'] = seed
            runlog_p1['target'] = phase_1_angle
            runlog_p1['phase'] = 1
            runlog_p1['episode_bin'] = runlog_p1.episode // downsample_factor
            runlog_p1_ds = runlog_p1.groupby('episode_bin').mean()
            # remove last bin (last bin will have less data)
            runlog_p1_ds = runlog_p1_ds[:-2]
            phase1_runlog_dfs.append(runlog_p1_ds)
        
        # Load and process Phase 2
        if fp_runlog_p2.exists():
            runlog_p2 = pd.read_csv(fp_runlog_p2)
            runlog_p2['seed'] = seed
            runlog_p2['target'] = phase_2_angle
            runlog_p2['phase'] = 2
            runlog_p2['episode_bin'] = runlog_p2.episode // downsample_factor
            runlog_p2_ds = runlog_p2.groupby('episode_bin').mean()
            runlog_p2_ds = runlog_p2_ds[:-2]
            phase2_runlog_dfs.append(runlog_p2_ds)
    
    # Combine into dataframes
    phase1_df = pd.concat(phase1_runlog_dfs, ignore_index=True) if phase1_runlog_dfs else pd.DataFrame()
    phase2_df = pd.concat(phase2_runlog_dfs, ignore_index=True) if phase2_runlog_dfs else pd.DataFrame()
    
    print(f"Loaded {len(phase1_runlog_dfs)} Phase 1 and {len(phase2_runlog_dfs)} Phase 2 runlogs")
    return phase1_df, phase2_df

phase1_df_180, phase2_df_180 = get_runlogs_df(data_dir_180)

In [None]:
# variations of the original color palette to separately plot 

# Phase 1 (e.g., early training)
critic_loss_p1 = '#A94850'      # original
angle_diff_p1 = '#44123F'       # original
embedding_loss_p1 = '#6067B6'   # original

# Phase 2 (e.g., late training) - lighter variants
critic_loss_p2 = '#D47A83'      # lighter red
angle_diff_p2 = '#7A4875'       # lighter purple
embedding_loss_p2 = '#9098D4'   # lighter blue


In [None]:
fig, axes = plt.subplots(1, 2)


# plot embedding loss across re-training 
sns.lineplot(data=phase1_df_180, x='episode', y='nll_loss', ax=axes[0],
            color=embedding_loss_color)
sns.lineplot(data=phase2_df_180, x='episode', y='nll_loss', ax=axes[0],
            color=embedding_loss_color)

axes[0].ticklabel_format(style='scientific', axis='x', scilimits=(4,4))
axes[0].spines['top'].set_visible(False)
axes[0].spines['right'].set_visible(False)
axes[0].set_ylabel('Embedding loss')


####################################################################################################################################
############################################ Second row ############################################################################
####################################################################################################################################

# plot angular error across re-training
sns.lineplot(data=phase1_df_180, x='episode', y='angle_diff', ax=axes[1], color=angle_diff_color)
sns.lineplot(data=phase2_df_180, x='episode', y='angle_diff', ax=axes[1], color=angle_diff_color)

axes[1].ticklabel_format(style='scientific', axis='x', scilimits=(4,4))
axes[1].set_ylabel('Angular error (deg)')
axes[1].spines['top'].set_visible(False)
axes[1].spines['right'].set_visible(False)


# Now for the run where we start with the same target angle in phase1 and phase2 and we try to learn opposite adaptations 

In [None]:
data_dir_0 = Path(data_root) / "doubleAdaptationExp_interleaved" / "phase2relangle_0deg"

In [None]:
a_hist_df_0 = get_action_hist_df(data_dir_0)

In [None]:
plot_ahist_kde_polar(a_hist_df_0, only_post=True)

In [None]:
phase1_df_0, phase2_df_0 = get_runlogs_df(data_dir_0, relangle=0)

In [None]:
phase2_df_0

In [None]:
fig, ax = plt.subplots()

sns.lineplot(data=phase1_df_0, x='episode', y='angle_diff', ax=ax, color=palette[1])
sns.lineplot(data=phase2_df_0, x='episode', y='angle_diff', ax=ax, color=palette[3])
ax.ticklabel_format(style='scientific', axis='x', scilimits=(4,4))
ax.set_ylabel('Angular error (deg)')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Final figure 

In [None]:
set_plotting_defaults(font_size=9)
fig = plt.figure(figsize=(10, 5))

# Create GridSpec with 2 rows, 3 columns
# You can adjust width_ratios and height_ratios as needed
gs = gridspec.GridSpec(2, 3, figure=fig, 
                       width_ratios=[1.4, 1, 1],  # adjust these for relative widths
                       height_ratios=[1, 1],     # adjust these for relative heights
                       hspace=.6,               # vertical spacing
                       wspace=0.3)               # horizontal spacing

# Create subplots using GridSpec
ax1 = fig.add_subplot(gs[0, 0], projection='polar')
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[0, 2])
ax4 = fig.add_subplot(gs[1, 0], projection='polar')
ax5 = fig.add_subplot(gs[1, 1])
ax6 = fig.add_subplot(gs[1, 2])


plot_ahist_kde_polar(a_hist_df_0, only_post=True, ax=ax1)

# plot angular error across re-training for interference cond
sns.lineplot(data=phase1_df_0, x='episode', y='angle_diff', ax=ax2, color=palette[1])
sns.lineplot(data=phase2_df_0, x='episode', y='angle_diff', ax=ax2, color=palette[3])
ax2.ticklabel_format(style='scientific', axis='x', scilimits=(4,4))
ax2.set_ylabel('Angular error (deg)')
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

plot_ahist_kde_polar(a_hist_df_180, only_post=False, ax=ax4)


# plot angular error across re-training
sns.lineplot(data=phase1_df_180, x='episode', y='angle_diff', ax=ax5, color=palette[1])
sns.lineplot(data=phase2_df_180, x='episode', y='angle_diff', ax=ax5, color=palette[3])
ax5.ticklabel_format(style='scientific', axis='x', scilimits=(4,4))
ax5.set_ylabel('Angular error (deg)')
ax5.spines['top'].set_visible(False)
ax5.spines['right'].set_visible(False)

plt.tight_layout()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def circular_mean(angles):
    """Compute circular mean of angles in radians"""
    return np.arctan2(np.mean(np.sin(angles)), np.mean(np.cos(angles)))

def circular_distance(angle1, angle2):
    """Compute shortest angular distance between two angles"""
    diff = angle1 - angle2
    return np.arctan2(np.sin(diff), np.cos(diff))

# Angular separations to loop over
angular_seps = np.arange(0, 181, 30)

# Store results for all conditions
all_results = []

for sep in angular_seps:
    # Get the dataframe for this separation]
    
    data_dir_sep = Path(data_root) / "doubleAdaptationExp_interleaved" / "phase2relangle_{}deg".format(sep)
    print(data_dir_sep)
    a_hist_df_cur = get_action_hist_df(data_dir_sep)
    
    seeds = a_hist_df_cur['seed'].unique()
    
    for seed in seeds:
        seed_data = a_hist_df_cur[a_hist_df_cur['seed'] == seed]
        
        # Get actions for each phase
        baseline_phase1 = seed_data['baseline_phase1'].values
        baseline_phase2 = seed_data['baseline_phase2'].values
        post_phase1 = seed_data['post_phase1'].values
        post_phase2 = seed_data['post_phase2'].values
        
        # Convert action indices to angles
        baseline_angles_1 = action_angles[baseline_phase1]
        baseline_angles_2 = action_angles[baseline_phase2]
        post_angles_1 = action_angles[post_phase1]
        post_angles_2 = action_angles[post_phase2]
        
        # Compute circular means for each phase
        baseline_mean_1 = circular_mean(baseline_angles_1)
        baseline_mean_2 = circular_mean(baseline_angles_2)
        post_mean_1 = circular_mean(post_angles_1)
        post_mean_2 = circular_mean(post_angles_2)
        
        # Compute adaptation amounts
        adaptation_1 = np.abs(circular_distance(post_mean_1, baseline_mean_1))
        adaptation_2 = np.abs(circular_distance(post_mean_2, baseline_mean_2))
        
        # Convert to degrees
        adaptation_1_deg = np.degrees(adaptation_1)
        adaptation_2_deg = np.degrees(adaptation_2)
        
        all_results.append({
            'angular_separation': sep,
            'seed': seed,
            'adaptation_phase1_deg': adaptation_1_deg,
            'adaptation_phase2_deg': adaptation_2_deg,
            'mean_adaptation_deg': np.mean([adaptation_1_deg, adaptation_2_deg])
        })

# Convert to DataFrame
results_df = pd.DataFrame(all_results)

# Compute mean and SEM for each angular separation
summary = results_df.groupby('angular_separation')['mean_adaptation_deg'].agg(['mean', 'sem']).reset_index()

# Print summary statistics
print(summary)
print(f"\nCorrelation between separation and adaptation: {results_df['angular_separation'].corr(results_df['mean_adaptation_deg']):.3f}")

In [None]:
# Set seaborn style
sns.set_theme(style="whitegrid", font_scale=1.1)

# Plot
fig, ax = plt.subplots(figsize=(10, 6))

sns.barplot(
    data=results_df,
    x='angular_separation',
    y='mean_adaptation_deg',
    errorbar='se',
    capsize=0.15,
    palette='viridis',
    edgecolor='black',
    linewidth=1.2,
    ax=ax
)

ax.set_xlabel('Angular Separation (°)')
ax.set_ylabel('Mean Adaptation Amount (°)')
ax.set_title('Adaptation Amount vs Angular Separation Between Targets')

sns.despine(left=True)
plt.tight_layout()
plt.show()

# Print summary statistics
summary = results_df.groupby('angular_separation')['mean_adaptation_deg'].agg(['mean', 'sem']).reset_index()
print(summary)
print(f"\nCorrelation between separation and adaptation: {results_df['angular_separation'].corr(results_df['mean_adaptation_deg']):.3f}")


In [None]:
# Compute interference index using 180° as baseline (no interference condition)
baseline_adaptation = results_df[
    results_df['angular_separation'] == 180
]['mean_adaptation_deg'].mean()

results_df['interference'] = 1 - (results_df['mean_adaptation_deg'] / baseline_adaptation)

# Set seaborn style
sns.set_theme(style="whitegrid", font_scale=1.1)

# Plot
fig, ax = plt.subplots(figsize=(10, 6))

sns.barplot(
    data=results_df,
    x='angular_separation',
    y='interference',
    errorbar='se',
    capsize=0.15,
    palette='rocket_r',  # warm colors: high interference = red
    edgecolor='black',
    linewidth=1.2,
    ax=ax
)

# Add reference line at 0 (no interference)
ax.axhline(0, color='gray', linestyle='--', linewidth=1, alpha=0.7)

ax.set_xlabel('Angular Separation (°)')
ax.set_ylabel('Interference Index')
ax.set_title('Interference Between Dual Adaptation Targets')

# Add annotation explaining the metric
ax.text(
    0.98, 0.95,
    'baseline: 180° separation',
    transform=ax.transAxes,
    ha='right', va='top',
    fontsize=9,
    color='gray',
    style='italic'
)

sns.despine(left=True)
plt.tight_layout()
plt.show()

# Print summary statistics
summary = results_df.groupby('angular_separation')['interference'].agg(['mean', 'sem']).reset_index()
print(summary)
print(f"\nCorrelation between separation and interference: {results_df['angular_separation'].corr(results_df['interference']):.3f}")


In [None]:
num_actions = 24

actions = np.linspace(0, 2 * np.pi, num_actions, endpoint=False)

action_colors = plt.cm.twilight(np.linspace(0, 1, num_actions))

In [None]:
set_plotting_defaults(font_size=9)
fig = plt.figure(figsize=(10, 5))

# Create GridSpec with 2 rows, 3 columns
gs = gridspec.GridSpec(2, 3, figure=fig, 
                       width_ratios=[1.4, 1, 1],
                       height_ratios=[1, 1],
                       hspace=0.6,
                       wspace=0.35)

# Create subplots using GridSpec
ax1 = fig.add_subplot(gs[0, 0], projection='polar')
ax2 = fig.add_subplot(gs[0, 1])
ax4 = fig.add_subplot(gs[1, 0], projection='polar')
ax5 = fig.add_subplot(gs[1, 1])

# Spanning subplot for the right column (both rows)
ax_right = fig.add_subplot(gs[:, 2])

# --- Left column: polar plots ---
plot_ahist_kde_polar(a_hist_df_0, only_post=True, ax=ax1)
plot_ahist_kde_polar(a_hist_df_180, only_post=False, ax=ax4)

# --- Middle column: learning curves ---
sns.lineplot(data=phase1_df_0, x='episode', y='angle_diff', ax=ax2, color=palette[1])
sns.lineplot(data=phase2_df_0, x='episode', y='angle_diff', ax=ax2, color=palette[3])
ax2.ticklabel_format(style='scientific', axis='x', scilimits=(4,4))
ax2.set_ylabel('Angular error (deg)')
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

sns.lineplot(data=phase1_df_180, x='episode', y='angle_diff', ax=ax5, color=palette[1])
sns.lineplot(data=phase2_df_180, x='episode', y='angle_diff', ax=ax5, color=palette[3])
ax5.ticklabel_format(style='scientific', axis='x', scilimits=(4,4))
ax5.set_ylabel('Angular error (deg)')
ax5.spines['top'].set_visible(False)
ax5.spines['right'].set_visible(False)

# --- Right column: horizontal interference line plot ---

sns.lineplot(
    data=results_df,
    x='angular_separation',
    y='interference',
    errorbar='se',
    marker='o',
    markersize=6,
    color='#c44e52',
    ax=ax_right
)


# Reference line at 0 (no interference)
ax_right.axvline(0, color='gray', linestyle='--', linewidth=1, alpha=0.7)

ax_right.set_xlabel('Angular Separation (°)')
ax_right.set_ylabel('Interference Index')
ax_right.set_title('Interference')
ax_right.spines['top'].set_visible(False)
ax_right.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig('fig4.pdf')
plt.show()