---
title: "Run simulation"
format: html
jupyter: python3
---

# Run simulation 

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns

from array_to_df import using_multiindex


SKIP_SIMULATION = True
SKIP_HM_SIMULATION = True
SKIP_HM_N_TEACHER = True

R = 10

R_myopic = 1400  # Base reward for myopic solution
R_optimal = 2200  # Base reward for optimal solution
R_random = 600  # Base reward for random solutions
sigma = 200  # Standard deviation for Gaussian noise

In [None]:
def grid_dict(params):
    keys = list(params)
    grids = np.meshgrid(*[params[k] for k in keys], indexing="ij")
    flat = [g.ravel() for g in grids]
    return {k: v for k, v in zip(keys, flat)}

In [None]:
def individual_learning(q_opt, d_myopic, d_optimal, k):
    # effective number of trails
    k_opt = np.random.binomial(k, q_opt)
    k_myo = k - k_opt

    # probability of discovery
    d_opt_total = 1.0 - np.power(1.0 - d_optimal[:,np.newaxis], k_opt)
    d_myo_total = 1.0 - np.power(1.0 - d_myopic[:,np.newaxis], k_myo)

    u = np.random.rand(*q_opt.shape)
    discovered_optimal = (u < d_opt_total)
    discovered_myopic  = (~discovered_optimal) & (u < (d_opt_total + (1.0 - d_opt_total) * d_myo_total))

    q_opt_post = np.zeros_like(q_opt)
    q_opt_post = np.where(discovered_myopic, 0, q_opt_post)
    q_opt_post = np.where(discovered_optimal, 1, q_opt_post)

    return q_opt_post

def demonstration(K_demonstration, q_opt):
    n_optimal = np.random.binomial(K_demonstration, q_opt)
    n_myopic = K_demonstration - n_optimal
    return n_optimal, n_myopic

def get_rewards(n_optimal, n_myopic):
    base_reward = n_optimal * R_optimal + n_myopic * R_myopic
    noise = np.random.normal(0, sigma, size=base_reward.shape)
    return base_reward + noise

def get_teacher_indices(prev_rewards, n_teacher):
    M = prev_rewards.shape[0]
    N_gen = prev_rewards.shape[1]
    teacher_indices = np.zeros((M, N_gen), dtype=int)
    for m in range(prev_rewards.shape[0]): # For each replication
        for n in range(prev_rewards.shape[1]): # For each agent
            random_k_possible_teacher = np.random.choice(N_gen, n_teacher[m], replace=False) # Randomly select K possible teachers
            teacher_indices_within_k = np.argmax(prev_rewards[m, random_k_possible_teacher]) # Select the teacher with the highest average reward
            teacher_indices[m, n] = random_k_possible_teacher[teacher_indices_within_k] # Retrieve the teacher's index
    return teacher_indices


def social_learning(q_opt_post, teacher_indices, lambda_):
    M = q_opt_post.shape[0]
    N_gen = q_opt_post.shape[1]
    m_idx = np.arange(M)[:,np.newaxis]
    q_opt_social = q_opt_post[m_idx, teacher_indices]
    is_learned = (np.random.rand(M, N_gen) < lambda_[:,np.newaxis])
    q_opt_social = np.where(is_learned, q_opt_social, 0)
    return q_opt_social

def run_simulation(d, G, N_gen, N_mach, K_demo):
    M = len(d['replication'])
    is_machine = np.zeros((M, G, N_gen), dtype=bool)
    is_machine[d['condition'] == 'human-machine',0,:N_mach] = True
    is_machine[d['condition'] == 'human-machine-all',:,:N_mach] = True

    print(is_machine.mean())

    K_ind = np.ones((M, G, N_gen), dtype=int) * d['K_human'][:,np.newaxis,np.newaxis]
    K_ind = np.where(is_machine, d['K_machine'][:,np.newaxis,np.newaxis], K_ind)

    q_opt = np.ones((M, G, N_gen)) * d['q_opt_human'][:,np.newaxis,np.newaxis]
    q_opt = np.where(is_machine, d['q_opt_machine'][:,np.newaxis,np.newaxis], q_opt)


    n_optimal = np.empty((M, G, N_gen))
    n_myopic = np.empty((M, G, N_gen))

    rewards = np.empty((M, G, N_gen))

    cond = d['condition']
    d_optimal_log = d['d_optimal_log']
    d_opt = np.power(10, d_optimal_log)
    d_myo = d['d_myopic']
    n_t = d['n_teacher']
    l = d['lambda']
    e = d['epsilon']

    for g in range(G):
        # print("Individual learning")
        q_opt_ind = individual_learning(q_opt[:,g], d_myo, d_opt, K_ind[:,g])
        if g > 0:
            # print("Select teacher")
            teacher_indices = get_teacher_indices(rewards[:,g-1], n_t)
            q_opt_social = social_learning(q_opt_post, teacher_indices, l)
            q_opt_post = np.maximum(q_opt_ind, q_opt_social)
        else:
            q_opt_post = q_opt_ind
        # print("Demonstration")
        n_optimal[:,g], n_myopic[:,g] = demonstration(K_demo, q_opt_post)
        rewards[:,g] = get_rewards(n_optimal[:,g], n_myopic[:,g])

    rewards_df = using_multiindex(rewards, ['rep', 'gen', 'agent'], value_name='rewards')
    n_optimal_df = using_multiindex(n_optimal, ['rep', 'gen', 'agent'], value_name='optimal')
    n_myopic_df = using_multiindex(n_myopic, ['rep', 'gen', 'agent'], value_name='myopic')


    meta_df = pd.DataFrame(d)
    meta_df['rep'] = np.arange(len(meta_df))

    df = meta_df.merge(rewards_df, on='rep')
    df = df.merge(n_optimal_df, on=['rep', 'gen', 'agent'])
    df = df.merge(n_myopic_df, on=['rep', 'gen', 'agent'])
    df['discovered'] = (df.groupby(['rep', 'gen'])['optimal'].transform('sum') > 0)
    df['learnability difficulty'] = 1 - df['lambda']
    df['discoverability difficulty'] = -df['d_optimal_log']
    df['myopic bias'] = 1 - df['q_opt_human'] * 2
    df['myopic bias intensity'] = np.log((1 - df['q_opt_human']) / df['q_opt_human'])
    return df

In [None]:
# How does exploration horizon and bias discover rate of optimal solution

N_gen = 100  # Number of agents per generation
N_mach = 0  # Number of machines
G = 100
K_demo = 10

K_human = np.logspace(1, 4, 10, base=10, endpoint=True).astype(int)
K_machine = [0] # not used

q_opt_human = np.logspace(-8, -1, 8, base=2, endpoint=True)
q_opt_machine = [0] # not used


epsilon = [1] # Probability of non-random exploration
n_teacher = [10] # Number of agents observed during social learning

d_myopic = [0.5]  # Discoverability rate for myopic solutions
d_optimal_log_ = np.linspace(0, -6, 10)

lambda_ = [0.9]  # Social learning rate
conditions = ['human']


grid_d = {
    'condition': conditions,
    'd_myopic': d_myopic,
    'd_optimal_log': d_optimal_log_,
    'lambda': lambda_,
    'n_teacher': n_teacher,
    'replication': np.arange(R),
    'epsilon': epsilon,
    'q_opt_human': q_opt_human,
    'K_human': K_human,
    'K_machine': K_machine,
    'q_opt_machine': q_opt_machine,
}

d = grid_dict(grid_d)
if not SKIP_SIMULATION:
    df = run_simulation(d, G, N_gen, N_mach, K_demo)
    df_agg = df.groupby(['rep','replication', 'gen', 'K_human', 'discoverability difficulty', 'myopic bias intensity']).agg(
        {
            'discovered': 'max',
            'optimal': 'mean',
            'myopic': 'mean',
            'rewards': 'mean',
        }
    ).sort_index()
    df_agg.to_parquet('../data/abm_v2/bias_discoverability.parquet')

In [None]:
df_agg = pd.read_parquet('../data/abm_v2/bias_discoverability.parquet')

In [None]:
df_agg.reset_index()['rep'].nunique()

In [None]:
max_discovered = df_agg[df_agg['discovered']].reset_index().groupby(['gen', 'K_human', 'myopic bias intensity', 'rep'])['discoverability difficulty'].max().reset_index()

m = max_discovered.reset_index().pivot_table(index=['gen', 'K_human'],
                               columns='myopic bias intensity', values='discoverability difficulty').sort_index()

import matplotlib.pyplot as plt

m = m.sort_index(axis=0, ascending=False)
m = m.sort_index(axis=1, ascending=False)
# Three plots next to each other, plus a fourth axis for the shared colorbar
fig, axs = plt.subplots(1, 4, figsize=(18, 5), gridspec_kw={"width_ratios": [1, 1, 1, 0.08]})


hm = sns.heatmap(m.loc[0], ax=axs[0], cbar=False, cmap='viridis', vmin=0, vmax=4)
hm = sns.heatmap(m.loc[9], ax=axs[1], cbar=False, cmap='viridis', vmin=0, vmax=4)
hm = sns.heatmap(m.loc[99], ax=axs[2], cbar=False, cmap='viridis', vmin=0, vmax=4)

# Put the colorbar in its own (fourth) axis so all three heatmaps have equal size
cbar = hm.figure.colorbar(hm.collections[0], cax=axs[3])
# Remove border around the colorbar
cbar.outline.set_visible(False)
# Set new label for the colorbar
cbar.ax.set_ylabel('Maximal Discovered Solution Complexity', rotation=-90, va="bottom")

axs[0].set_ylabel('Trials per Agent')
# Remove the y-axis labels for plot 2 and 3
axs[1].set_ylabel('')
axs[2].set_ylabel('')

axs[0].set_xlabel('Myopic Bias Intensity')
axs[1].set_xlabel('Myopic Bias Intensity')
axs[2].set_xlabel('Myopic Bias Intensity')


# Format the y-ticks (one significant digit)
yvals = m.loc[0].index.values
mask = np.isclose(np.log10(yvals) % 1, 0)
axs[0].set_yticklabels([rf"$10^{int(np.log10(v))}$" if m else '' for v,m in zip(yvals, mask)])

# Remove tick labels for plot 2 and 3
axs[1].set_yticks([])
axs[2].set_yticks([])

# Format the x-ticks (one significant digit)
tick_labels = [f"{x:.1f}" for x in m.columns]
# axs[0].set_xticks(rm1.columns)
axs[0].set_xticklabels(tick_labels)
# axs[1].set_xticks(rm1.columns)
axs[1].set_xticklabels(tick_labels)
# axs[2].set_xticks(rm1.columns)
axs[2].set_xticklabels(tick_labels)

axs[0].set_title('Generation 1')
axs[1].set_title('Generation 10')
axs[2].set_title('Generation 100')

# Hide ticks and labels on the colorbar axis except for the bar itself
axs[3].set_yticklabels(axs[3].get_yticklabels())  # keep colorbar tick labels
axs[3].set_xticks([])

In [None]:
dd_values = [0, 2, 4, 6]

m = df_agg.reset_index().pivot_table(index=['gen', 'discoverability difficulty', 'K_human'],
                               columns='myopic bias intensity', values='discovered')

import matplotlib.pyplot as plt

# Ensure MultiIndex is lexicographically sorted to avoid PerformanceWarning
m = m.sort_index(axis=0)
# Sort columns in descending order for display
m = m.sort_index(axis=1, ascending=False)

# Reset index once to avoid MultiIndex access issues
m_reset = m.reset_index()

generations = [0, 9, 99]

# Create a 4x4 grid: 4 rows (one for each dd_value) and 4 columns (3 for generations + 1 for colorbar)
fig, axs = plt.subplots(len(dd_values), 4, figsize=(18, 5*len(dd_values)), 
                        gridspec_kw={"width_ratios": [1, 1, 1, 0.08]})

# Loop through dd_values (rows) and generations (columns)
for row_idx, dd_val in enumerate(dd_values):
    for col_idx, gen in enumerate(generations):
        # Access data by filtering on reset index to avoid MultiIndex PerformanceWarning
        data_df = m_reset[(m_reset['gen'] == gen) & (m_reset['discoverability difficulty'] == dd_val)]
        data_df = data_df.set_index('K_human')[m.columns]
        # Ensure the DataFrame index is sorted for heatmap
        data_df = data_df.sort_index(ascending=False)
        ax = axs[row_idx, col_idx]
        
        # Create heatmap - convert DataFrame to numpy array for clean 2D structure
        data_array = data_df.values
        hm = sns.heatmap(data_array, ax=ax, cbar=False, cmap='viridis', vmin=0, vmax=1,
                        xticklabels=data_df.columns, yticklabels=data_df.index)
        
        # Set labels
        if col_idx == 0:
            ax.set_ylabel(f'Solution complexity == {dd_val}\nTrials per Agent')
            # Format the y-ticks (one significant digit)
            yvals = data_df.index.values
            mask = np.isclose(np.log10(yvals) % 1, 0)
            ax.set_yticklabels([rf"$10^{int(np.log10(v))}$" if mask_val else '' for v,mask_val in zip(yvals, mask)])
        else:
            ax.set_ylabel('')
            ax.set_yticks([])
        
        if row_idx == len(dd_values) - 1:
            ax.set_xlabel('Myopic Bias Intensity')
            # Format the x-ticks (one significant digit)
            tick_labels = [f"{x:.1f}" for x in m.columns]
            ax.set_xticklabels(tick_labels)
        else:
            ax.set_xlabel('')
            ax.set_xticklabels([])
        
        # Set titles
        if row_idx == 0:
            ax.set_title(f'Generation {gen+1}')

# Put the colorbar in the last row, rightmost column
last_row_idx = len(dd_values) - 1
cbar = hm.figure.colorbar(hm.collections[0], cax=axs[last_row_idx, 3])
# Remove border around the colorbar
cbar.outline.set_visible(False)
# Set new label for the colorbar
cbar.ax.set_ylabel('Probability of discovery', rotation=-90, va="bottom")
axs[last_row_idx, 3].set_xticks([])

# Hide empty colorbar axes in all other rows
for row_idx in range(len(dd_values)):
    if row_idx != last_row_idx:
        axs[row_idx, 3].set_visible(False)

In [None]:
df_ = df_agg.reset_index()

In [None]:
df_.groupby(['gen', 'K_human', 'myopic bias intensity', 'discoverability difficulty']).mean()

In [None]:
df_['discovered'] = df_['discovered'].astype(float)

In [None]:
df_['myopic bias intensity'].unique()

In [None]:
df_['discoverability difficulty'].unique()

In [None]:
w = (
    df_['myopic bias intensity'].isin(df_['myopic bias intensity'].unique()[::3])
    & df_['K_human'].isin(df_['K_human'].unique()[::3]) 
    & df_['discoverability difficulty'].isin([4])
)

df_['K_human'] = df_['K_human'].astype(str)

sns.relplot(
    data=df_[w], x='gen', 
    y='discovered', 
    hue='K_human', 
    col='myopic bias intensity', 
    kind='line')

In [None]:
# How does exploration horizon and bias discover rate of optimal solution

N_gen = 100  # Number of agents per generation
N_mach = 5  # Number of machines
G = 20
K_demo = 1

K_human = [10]
K_machine = [10, 100, 1000, 10000]

q_opt_human = [0.01]
q_opt_machine = [0.01, 0.5]

epsilon = [1] # Success in demonstration learned solution
n_teacher = [10] # Number of agents observed during social learning

d_myopic = [0.5]  # Discoverability rate for myopic solutions
d_optimal_log_ = np.linspace(0, -6, 10)

lambda_ = np.linspace(0.0, 0.5, 10)  # Social learning rate
conditions = ['human', 'human-machine']


grid_d = {
    'condition': conditions,
    'd_myopic': d_myopic,
    'd_optimal_log': d_optimal_log_,
    'lambda': lambda_,
    'n_teacher': n_teacher,
    'replication': np.arange(R),
    'epsilon': epsilon,
    'q_opt_human': q_opt_human,
    'K_human': K_human,
    'K_machine': K_machine,
    'q_opt_machine': q_opt_machine,
}

d = grid_dict(grid_d)
if not SKIP_HM_SIMULATION:
    df = run_simulation(d, G, N_gen, N_mach, K_demo)
    gen_max = df.reset_index()['gen'].max()
    df = df[df['gen'] == gen_max]
    df_agg = df.groupby(['rep','replication', 'K_machine', 'q_opt_machine', 'discoverability difficulty', 'learnability difficulty', 'condition']).agg(
        {
            'discovered': 'max',
            'optimal': 'mean',
            'myopic': 'mean',
            'rewards': 'mean',
        }
    ).sort_index()
    df_agg.to_parquet('../data/abm_v2/human_machine_discoverability.parquet')
else:
    df_agg = pd.read_parquet('../data/abm_v2/human_machine_discoverability.parquet')

In [None]:
dimensions = ['K_machine', 'q_opt_machine', 'discoverability difficulty', 'learnability difficulty']

df_delta = df_agg.reset_index().pivot_table(index=dimensions, columns='condition', values='rewards')
df_delta['machine-uplift'] = df_delta['human-machine'] - df_delta['human']

df_delta_m = df_delta.reset_index().pivot_table(index=dimensions[:-1], columns=dimensions[-1], values='machine-uplift')

In [None]:
d_optimal_log_

In [None]:
# Machine uplift visualization
# X-axis: learnability difficulty
# Y-axis: discoverability difficulty
# Columns: machine K (K_machine)
# Rows: machine bias (q_opt_machine)

# Mapping dictionary for machine bias labels
bias_label_map = {
    0.5: 'No Machine bias',
    0.01: 'Human Like Machine bias',
}

# Get unique values for rows and columns
q_opt_machine_vals = sorted(df_delta.reset_index()['q_opt_machine'].unique())
K_machine_vals = sorted(df_delta.reset_index()['K_machine'].unique())

# Set K_human value for calculating exploration advantage
K_human_base = 10

# Create pivot table for the last generation
df_uplift_viz = df_delta.reset_index()

# Create a grid: rows = q_opt_machine, columns = K_machine
n_rows = len(q_opt_machine_vals)
n_cols = len(K_machine_vals)

# Create figure with subplots (add one column for colorbar)
fig, axs = plt.subplots(n_rows, n_cols + 1, figsize=(6*(n_cols+1), 5*n_rows),
                        gridspec_kw={"width_ratios": [1]*n_cols + [0.08]})

# Find global min/max for consistent color scale
uplift_values = df_uplift_viz['machine-uplift'].values
vmin = 0
vmax = uplift_values.max()

# Create heatmaps for each combination
for row_idx, q_opt_val in enumerate(q_opt_machine_vals):
    for col_idx, K_val in enumerate(K_machine_vals):
        # Filter data for this combination
        data_subset = df_uplift_viz[
            (df_uplift_viz['q_opt_machine'] == q_opt_val) & 
            (df_uplift_viz['K_machine'] == K_val)
        ]
        
        # Create pivot table for heatmap
        heatmap_data = data_subset.pivot_table(
            index='discoverability difficulty',
            columns='learnability difficulty',
            values='machine-uplift'
        )
        
        # Sort indices for proper display
        heatmap_data = heatmap_data.sort_index(axis=0, ascending=False)
        heatmap_data = heatmap_data.sort_index(axis=1, ascending=False)
        
        ax = axs[row_idx, col_idx]
        
        # Create heatmap with data labels (like the earlier plot)
        hm = sns.heatmap(heatmap_data, ax=ax, cbar=False, cmap='RdBu_r', 
                        center=0, vmin=vmin, vmax=vmax,
                        xticklabels=heatmap_data.columns, yticklabels=heatmap_data.index)
        
        # Invert x-axis
        ax.invert_xaxis()
        
        # Set labels
        if col_idx == 0:
            # Map q_opt_machine values to row labels using the mapping dictionary
            row_label = bias_label_map.get(q_opt_val, f'q_opt_machine = {q_opt_val:.2f}')
            ax.set_ylabel(f'{row_label}\nDiscoverability Difficulty')
            # Format the y-ticks - only show 0, 2, 4, 6
            yvals = heatmap_data.index.values
            mask = np.isin(yvals, [0, 2, 4, 6])
            ax.set_yticklabels([f"{int(v)}" if m else '' for v, m in zip(yvals, mask)])
        else:
            ax.set_ylabel('')
            ax.set_yticks([])
        
        if row_idx == n_rows - 1:
            ax.set_xlabel('Learnability Difficulty')
            # Format the x-ticks - select 4 equally spaced with 2 significant digits
            xvals = heatmap_data.columns.values
            n_x = len(xvals)
            if n_x > 4:
                # Select 4 equally spaced indices
                step = max(1, (n_x - 1) // 3)
                selected_indices = [0, step, 2*step, n_x-1]
                tick_labels = [f"{xvals[i]:.2g}" if i in selected_indices else '' for i in range(n_x)]
            else:
                tick_labels = [f"{x:.2g}" for x in xvals]
            ax.set_xticklabels(tick_labels)
        else:
            ax.set_xlabel('')
            ax.set_xticklabels([])
        
        # Set title with column labels
        if row_idx == 0:
            # Calculate exploration advantage programmatically
            advantage_ratio = K_val / K_human_base
            col_label = f'Machine exploration advantage {int(advantage_ratio)}x'
            ax.set_title(col_label)

# Add colorbar in the rightmost column
cbar = hm.figure.colorbar(hm.collections[0], cax=axs[0, n_cols])
cbar.outline.set_visible(False)
cbar.ax.set_ylabel('Machine Uplift', rotation=-90, va="bottom")
axs[0, n_cols].set_xticks([])

# Hide empty colorbar axes in other rows
for row_idx in range(1, n_rows):
    axs[row_idx, n_cols].set_visible(False)

plt.tight_layout()

In [None]:
# How does exploration horizon and bias discover rate of optimal solution

N_gen = 100  # Number of agents per generation
N_mach = 5  # Number of machines
G = 20
K_demo = 1

K_human = [10]
K_machine = [10000]

q_opt_human = [0.01]
q_opt_machine = [0.01]

epsilon = [1] # Success in demonstration learned solution
n_teacher = [1, 10] # Number of agents observed during social learning

d_myopic = [0.5]  # Discoverability rate for myopic solutions
d_optimal_log_ = np.linspace(0, -6, 10)

lambda_ = np.linspace(0.0, 0.5, 10)  # Social learning rate
conditions = ['human', 'human-machine', 'human-machine-all']


grid_d = {
    'condition': conditions,
    'd_myopic': d_myopic,
    'd_optimal_log': d_optimal_log_,
    'lambda': lambda_,
    'n_teacher': n_teacher,
    'replication': np.arange(R),
    'epsilon': epsilon,
    'q_opt_human': q_opt_human,
    'K_human': K_human,
    'K_machine': K_machine,
    'q_opt_machine': q_opt_machine,
}

d = grid_dict(grid_d)
if not SKIP_HM_N_TEACHER:
    df = run_simulation(d, G, N_gen, N_mach, K_demo)
    gen_max = df.reset_index()['gen'].max()
    df = df[df['gen'] == gen_max]
    df_agg = df.groupby(['rep','replication', 'n_teacher', 'discoverability difficulty', 'learnability difficulty', 'condition']).agg(
        {
            'discovered': 'max',
            'optimal': 'mean',
            'myopic': 'mean',
            'rewards': 'mean',
        }
    ).sort_index()
    df_agg.to_parquet('../data/abm_v2/human_machine_n_teacher.parquet')
else:
    df_agg = pd.read_parquet('../data/abm_v2/human_machine_n_teacher.parquet')

In [None]:
dimensions = ['n_teacher', 'discoverability difficulty', 'learnability difficulty']

df_delta = df_agg.reset_index().pivot_table(index=dimensions, columns='condition', values='rewards')
df_delta['temporal'] = df_delta['human-machine'] - df_delta['human']
df_delta['permanent'] = df_delta['human-machine-all'] - df_delta['human']

df_delta = df_delta[['temporal', 'permanent']].stack(0)

new_names = list(df_delta.index.names)
new_names[-1] = "machine persistence"
df_delta.index = df_delta.index.set_names(new_names)
df_delta.name = "machine-uplift"

In [None]:
# Machine uplift visualization for n_teacher
# X-axis: learnability difficulty
# Y-axis: discoverability difficulty
# Columns: n_teacher
# Rows: machine persistence

# Mapping dictionary for n_teacher labels
n_teacher_label_map = {
    1: 'Random Social Learning',
    10: 'Selective Social Learning',
}

# Get unique values for rows and columns
df_uplift_viz = df_delta.reset_index()
machine_persistence_vals = sorted(df_uplift_viz['machine persistence'].unique())
n_teacher_vals = sorted(df_uplift_viz['n_teacher'].unique())

# Create a grid: rows = machine persistence, columns = n_teacher
n_rows = len(machine_persistence_vals)
n_cols = len(n_teacher_vals)

# Create figure with subplots (add one column for colorbar)
fig, axs = plt.subplots(n_rows, n_cols + 1, figsize=(6*(n_cols+1), 5*n_rows),
                        gridspec_kw={"width_ratios": [1]*n_cols + [0.08]})

# Handle 1D case
if n_rows == 1:
    axs = axs.reshape(1, -1)

# Find global min/max for consistent color scale
# After reset_index, the Series values become a column (usually named 0)
value_col = df_uplift_viz.columns[-1]  # Get the last column which should be the values
uplift_values = df_uplift_viz[value_col].values
vmin = 0
vmax = uplift_values.max()

# Create heatmaps for each combination
for row_idx, persistence_val in enumerate(machine_persistence_vals):
    for col_idx, n_teacher_val in enumerate(n_teacher_vals):
        # Filter data for this combination
        data_subset = df_uplift_viz[
            (df_uplift_viz['machine persistence'] == persistence_val) & 
            (df_uplift_viz['n_teacher'] == n_teacher_val)
        ]
        
        # Create pivot table for heatmap
        heatmap_data = data_subset.pivot_table(
            index='discoverability difficulty',
            columns='learnability difficulty',
            values=value_col
        )
        
        # Sort indices for proper display
        heatmap_data = heatmap_data.sort_index(axis=0, ascending=False)
        heatmap_data = heatmap_data.sort_index(axis=1, ascending=False)
        
        ax = axs[row_idx, col_idx]
        
        # Create heatmap with data labels (like the earlier plot)
        hm = sns.heatmap(heatmap_data, ax=ax, cbar=False, cmap='RdBu_r', 
                        center=0, vmin=vmin, vmax=vmax,
                        xticklabels=heatmap_data.columns, yticklabels=heatmap_data.index)
        
        # Invert x-axis
        ax.invert_xaxis()
        
        # Set labels
        if col_idx == 0:
            ax.set_ylabel(f'{persistence_val} machine presence\nDiscoverability Difficulty')
            # Format the y-ticks - only show 0, 2, 4, 6
            yvals = heatmap_data.index.values
            mask = np.isin(yvals, [0, 2, 4, 6])
            ax.set_yticklabels([f"{int(v)}" if m else '' for v, m in zip(yvals, mask)])
        else:
            ax.set_ylabel('')
            ax.set_yticks([])
        
        if row_idx == n_rows - 1:
            ax.set_xlabel('Learnability Difficulty')
            # Format the x-ticks - select 4 equally spaced with 2 significant digits
            xvals = heatmap_data.columns.values
            n_x = len(xvals)
            if n_x > 4:
                # Select 4 equally spaced indices
                step = max(1, (n_x - 1) // 3)
                selected_indices = [0, step, 2*step, n_x-1]
                tick_labels = [f"{xvals[i]:.2g}" if i in selected_indices else '' for i in range(n_x)]
            else:
                tick_labels = [f"{x:.2g}" for x in xvals]
            ax.set_xticklabels(tick_labels)
        else:
            ax.set_xlabel('')
            ax.set_xticklabels([])
        
        # Set title with column labels
        if row_idx == 0:
            col_label = n_teacher_label_map.get(n_teacher_val, f'n_teacher = {n_teacher_val}')
            ax.set_title(col_label)

# Add colorbar in the rightmost column
cbar = hm.figure.colorbar(hm.collections[0], cax=axs[0, n_cols])
cbar.outline.set_visible(False)
cbar.ax.set_ylabel('Machine Uplift', rotation=-90, va="bottom")
axs[0, n_cols].set_xticks([])

# Hide empty colorbar axes in other rows
for row_idx in range(1, n_rows):
    axs[row_idx, n_cols].set_visible(False)

plt.tight_layout()