In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from core.config import config
from pathlib import Path
import matplotlib.colors as mcolors
import numpy as np

In [None]:
# data extracted from Park et al (2025) using webplotdigitizer:
boxes = [
    {
        'label': 'Selection',
        'whislo': 0.09794628751974715,
        'q1': 0.23854660347551346,
        'med': 0.33491311216429703,
        'q3': 0.48183254344391785,
        'whishi': 0.819905213270142,
    },
    {
        'label': 'Specification',
        'whislo': 0.,
        'q1': 0.001579778830963592,
        'med': 0.03317535545023692,
        'q3': 0.05687203791469191,
        'whishi': 0.11216429699842022,
    },
    {
        'label': 'STR',
        'whislo': 0.025276461295418582,
        'q1': 0.03317535545023692,
        'med': 0.053712480252764545,
        'q3': 0.06635071090047384,
        'whishi': 0.08372827804107427,
    },
    {
        'label': 'MOp',
        'whislo': 0.020537124802527625,
        'q1': 0.050552922590837164,
        'med': 0.07898894154818313,
        'q3': 0.13270142180094785,
        'whishi': 0.19273301737756712,
    },
    {
        'label': 'ACA',
        'whislo': 0.047393364928909984,
        'q1': 0.06477093206951023,
        'med': 0.09004739336492883,
        'q3': 0.11848341232227479,
        'whishi': 0.13112164296998408,
    },
]


In [None]:
# get simulation data

df_small = pd.read_csv(Path(config['fig_dir'] / 'two_target_angles_delta30.csv'))
df_large = pd.read_csv(Path(config['fig_dir'] /'two_target_angles_delta165.csv'))

df = pd.concat([df_small, df_large])
df['condition'] = df['delta_degrees'].map({30: 'Small Δ (30°)', 165: 'Large Δ (165°)'})


In [None]:
# Reshape data for plotting both angles
df_melted = df.melt(
    id_vars=['seed', 'delta_degrees', 'condition'],
    value_vars=['policy_normalized', 'embedding_normalized'],
    var_name='angle_type',
    value_name='normalized_angle'
)

# Clean up labels
df_melted['angle_type'] = df_melted['angle_type'].map({
    'policy_normalized': 'Policy',
    'embedding_normalized': 'Embedding'
})

# Order: Large Δ first, then Small Δ
df_melted['condition'] = pd.Categorical(
    df_melted['condition'],
    categories=['Large Δ (165°)', 'Small Δ (30°)'],
    ordered=True
)


In [None]:
fig, axs = plt.subplots(1, 3, figsize=(5.5, 2.75), gridspec_kw={'width_ratios': [2, 3, 4]})


def darken_color(color, factor=0.6):
    """Darken a color by a factor (0=black, 1=original)"""
    rgb = mcolors.to_rgb(color)
    return tuple(c * factor for c in rgb)

small_delta_color = '#711D4F'
large_delta_color = '#E59DB3'
small_delta_dark = darken_color(small_delta_color, 0.8)
large_delta_dark = darken_color(large_delta_color, 0.8)


# Left panel - Encoding models (Selection, Specification)
ax = axs[0]
ax.set_ylabel('Intertarget rotation (norm)')

boxes_encoding = boxes[:2]  # Selection, Specification

bp = ax.bxp(boxes_encoding, showfliers=False, patch_artist=True, widths=0.4,
            whiskerprops=dict(linewidth=2), capprops=dict(linewidth=2), medianprops=dict(linewidth=2))

for patch in bp['boxes']:
    patch.set_facecolor(small_delta_color)
    patch.set_edgecolor(small_delta_color)
for median in bp['medians']:
    median.set_color(small_delta_dark)
for whisker in bp['whiskers']:
    whisker.set_color(small_delta_dark)
for cap in bp['caps']:
    cap.set_color(small_delta_dark)
for flier in bp['fliers']:
    flier.set_markeredgecolor(small_delta_dark)


ax.spines['bottom'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.tick_params(bottom=True, labelbottom=True, rotation=45)
ax.set_ylim([0, 1.05])
ax.set_title('Encoding models')
ax.grid(False)

# Middle panel - Neural data (STR, MOp, ACA)
ax = axs[1]

boxes_neural = boxes[2:]  # STR, MOp, ACA
bp2 = ax.bxp(boxes_neural, showfliers=False, patch_artist=True, widths=0.4, medianprops=dict(color=small_delta_dark, linewidth=2),
             whiskerprops=dict(linewidth=2), capprops=dict(linewidth=2))

for patch in bp2['boxes']:
    patch.set_facecolor(small_delta_color)
    patch.set_edgecolor(small_delta_color)
for median in bp2['medians']:
    median.set_color(small_delta_dark)
for whisker in bp2['whiskers']:
    whisker.set_color(small_delta_dark)
for cap in bp2['caps']:
    cap.set_color(small_delta_dark)
for flier in bp2['fliers']:
    flier.set_markeredgecolor(small_delta_dark)

ax.spines['bottom'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.tick_params(bottom=True, labelbottom=True, rotation=45)
ax.set_ylim([0, 1.05])
ax.set_ylabel('Intertarget rotation (norm)')
ax.set_title('Neural data (Park et al.)')
ax.set_xticklabels(['Striatum', 'M1', 'M2/ACA'])
ax.grid(False)

# Right panel - RL model
ax = axs[2]

bp3 = sns.boxplot(
    data=df_melted,
    x='angle_type',
    y='normalized_angle',
    hue='condition',
    ax=ax,
    width=0.6,
    palette=[large_delta_color, small_delta_color],
    linewidth=3
)

# Set edge colors to match face colors, and darken medians
for i, patch in enumerate(ax.patches):
    facecolor = patch.get_facecolor()
    patch.set_edgecolor(facecolor)

# Median lines are stored in ax.lines - every 5th line starting from index 4 is a median
# But easier to just iterate and check position
for line in ax.lines:
    xdata = line.get_xdata()
    if len(xdata) >= 2:
        x_pos = np.mean(xdata)
        # x positions for hue are offset around 0 and 1
        # large delta (first hue) is slightly left, small delta (second hue) slightly right
        if x_pos < 0:  # first category, large delta
            line.set_color(large_delta_dark)
        elif x_pos < 0.5:  # first category, small delta
            line.set_color(small_delta_dark)
        elif x_pos < 1:  # second category, large delta
            line.set_color(large_delta_dark)
        else:  # second category, small delta
            line.set_color(small_delta_dark)

ax.spines['bottom'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.tick_params(bottom=True, labelbottom=True, rotation=45)
ax.set_xlabel('')
ax.set_ylabel('Intertarget angle (norm)')
ax.set_ylim(0, 1.05)
ax.set_title('Our RL model')
ax.legend(title='', loc='right', frameon=False, fontsize='small')
ax.grid(False)

plt.tight_layout()
plt.savefig('combined_boxplot.pdf', bbox_inches='tight')
plt.show()