Copyright 2024 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
#@title Imports

import io
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotnine as gg
import requests

import warnings


In [None]:
#@title Utility to load data from Cloud bucket.

def import_data(path: str) -> pd.DataFrame:
  """Load data from Cloud bucket."""
  base_path = 'https://storage.googleapis.com/statistical_discrimination/data'
  full_path = f'{base_path}/{path}'
  response = requests.get(full_path)
  with io.BytesIO(response.content) as f:
    return pd.read_feather(f)


In [None]:
#@title Names for conditions.

names = [
    "No memory", "Baseline", "$\\beta$ = 0.02", "$\\beta$ = 0.005", "$\\beta$ = 0.002", "$\\beta$ = 0.0",
]


In [None]:
summary_df = import_data('summary_df.feather')
analytic_df = import_data('analytic_df.feather')

# Figure 3

In [None]:
summary_df

In [None]:
# @title Utils for plots.

empirical_colors = {
    'Random': '#333333',
    'No memory': '#b6cbdd',
    'Baseline': '#497dab',
    '$\\beta$ = 0.02': 'tab:brown',
    '$\\beta$ = 0.005': 'tab:orange',
    '$\\beta$ = 0.002': 'tab:red',
    '$\\beta$ = 0.0': 'tab:pink',
}
race_colors = {'2': '#49ab7d', '8': '#ab497d'}
# The palette to show different conditions as different colors.
palette = {
    'Condition': empirical_colors,
    '# Races': race_colors,
}


def make_discrimination_plot(display_name: str, df: pd.DataFrame):
  """Make a discrimination vs community bias plot.

  Args:
    display_name: The name of the plot, this must be one of the values in
      `names` which also correspond to the `arch` column in `df`.
    df: The dataframe containing data to plot.
  """
  # Branch copy of summary dataframe.
  plot_df = df.loc[df.arch == display_name].copy()
  # Normalize participation for use with alpha aesthetic.
  plot_df['norm_part'] = plot_df['part'] / plot_df['part'].max()
  # Format labels for `race` and `k` variables
  plot_df['race_label'] = plot_df['race'].map('{} race'.format).str.capitalize()
  plot_df['k_label'] = plot_df['k'].map('{} races'.format).str.capitalize()
  # Visualize.
  plot = (
      gg.ggplot(data = plot_df)
      + gg.geom_abline(intercept=0, slope=0, colour='#555555', size=1)
      + gg.aes(x = 'corr',
               y = 'discr')
      + gg.geom_jitter(colour=empirical_colors[display_name],
                       width = 0.05,
                       height = 0.05,
                       show_legend = False)
      + gg.geom_smooth(method = 'lm',
                       color = 'red')
      + gg.facet_grid('k_label ~ race_label')
      + gg.labs(x = f'Color-strategy correlation',
                y = 'Discrimination index')
      + gg.ylim(-0.6, 0.6)
      + gg.theme_bw()
      + gg.theme(figure_size = [3, 1.5 * len(df.k.unique())],
                axis_title = gg.element_text(size = 10),
                axis_text = gg.element_text(size = 8),
                strip_background = gg.element_rect(fill = 'white'),
                strip_text = gg.element_text(size = 8)))
  print(plot)
  # Output svg file.
  dataname = f'{display_name.replace(" ", "_").lower()}'
  with warnings.catch_warnings():
    warnings.filterwarnings('ignore', category = UserWarning)
    plot_name = f'{dataname}.svg'
    plot.save(plot_name, format = 'svg')


def get_discrimination_index_for_run(array: np.ndarray) -> np.ndarray:
  """Get the discrimination index for a single run (set of agents).

  Args:
    array: The array to get the discrimination index from. The array must be
      indexed by the 4 possible types of partners: purple cooperator, purple
      defector, teal cooperator, and teal defector. A second, optional, index
      can index over agents in the run.

  Returns:
    The discrimination index for the run. If multiple agents were given, the
    returned value is indexed by agent, otherwise is a scalar.
  """
  v_discr = np.abs(array[0] - array[1]) + np.abs(array[2] - array[3])
  b_discr = np.abs(array[0] - array[2]) + np.abs(array[1] - array[3])
  denominator = np.where(array.sum() == 0, 1, array.sum())
  return (b_discr - v_discr) / denominator


def bootstrap_95(series: pd.Series,
                 bootstrap_count: int = 5000,
                 verbose: bool = False) -> tuple[float, float]:
  """Bootstraps the discrimination index for a series.

  This function is used to calculate 95% confidence intervals via bootstrap on
  the participation counts in a run. The provided series must have the
  participation counts (i.e. containing number of times paired with purple
  cooperator, purple defector, teal cooperator, and teal defector). Bootstraping
  is computed by sampling from the participation counts using a multinomial
  distribution.

  Args:
    series: The series to bootstrap. Must contain participation counts.
    bootstrap_count: The number of bootstrap iterations to run.
    verbose: Whether to print the bootstrap confidence intervals and other
      useful information.

  Returns:
    The bootstrapped discrimination index for the series as a pair of values
    corresponding to the 2.5% and 97.5% percentiles of the discrimination index.
  """
  bootstrapped = []
  for _ in range(bootstrap_count):
    sampled = series.apply(
        lambda x: get_discrimination_index_for_run(
            np.random.multinomial(x.sum(), x/x.sum())))
    bootstrapped.append(np.mean(sampled))
  srt = sorted(bootstrapped)
  srt_btstrp_2_5 = srt[int(0.025*len(bootstrapped))]
  srt_btstrp_97_5 = srt[int(0.975*len(bootstrapped))]
  if verbose:
    print('mean', np.mean(series.values), 'std.dev', np.std(series.values),
          'std.err', np.std(series.values) / np.sqrt(len(series.values)),
          'mean + std.err', np.mean(series.values) + np.std(series.values) / np.sqrt(len(series.values)),
          'bootstrapped 2.5%-ile', srt_btstrp_2_5,
          'bootstrapped 97.5%-ile', srt_btstrp_97_5, "series", list(series.values))
  return srt_btstrp_2_5, srt_btstrp_97_5


def average_discrimination_over_bias(df: pd.DataFrame,
                                     fill: str = 'Condition',
                                     bootstrap: bool = True):
  """Makes a plot for the average discrimination across community bias.

  Args:
    df: The dataframe with summary data.
    fill: The fill variable to use for filling the bars of the plot.
    bootstrap: Whether to bootstrap the discrimination index.
  """
  grouped = df.groupby(['arch', 'k', 'race'])
  plot_df = grouped.mean()
  if bootstrap:
    tmp = grouped['raw_discr_counts'].aggregate(bootstrap_95)
    plot_df['discr_2_5ci'] = tmp.apply(lambda x: x[0])
    plot_df['discr_97_5ci'] = tmp.apply(lambda x: x[1])
  else:
    tmp = grouped['discr'].aggregate(
        lambda x: np.std(x.values) / np.sqrt(len(x.values))
    )
    plot_df['discr_2_5ci'] = plot_df['discr'] - tmp
    plot_df['discr_97_5ci'] = plot_df['discr'] + tmp
  plot_df = plot_df.reset_index()
  # Format labels for `race` and `k` variables
  plot_df['race_label'] = plot_df['race'].map('{} race'.format).str.capitalize()
  plot_df['k_label'] = plot_df['k'].map('{} races'.format)
  plot_df['Condition'] = plot_df['arch']
  plot_df['Condition'] = pd.Categorical(
      plot_df['Condition'],
      categories=names,
      ordered=True)
  # Visualize.
  plot = (
      gg.ggplot(data = plot_df)
      + gg.aes(x = 'Condition',
               y = 'discr', fill=fill)
      + gg.geom_abline(intercept=0, slope=0, colour='#555555', size=1)
      + gg.geom_bar(stat='identity')
      + gg.geom_errorbar(gg.aes(ymin='discr_2_5ci',
                                ymax='discr_97_5ci'),
      )
      + gg.facet_grid('k_label ~ race_label')
      + gg.labs(x = 'Condition',
                y = 'Avg. discr. index')
      + gg.theme_bw()
      + gg.theme(
          figure_size = [2.2 * len(plot_df.race.unique()),
                         2.5 * len(df.k.unique())],
          axis_title = gg.element_text(size=10),
          axis_text = gg.element_text(size=8),
          strip_background = gg.element_rect(fill = 'white'),
          strip_text = gg.element_text(size=9),
          axis_text_x = gg.element_text(angle=45, vjust=1, hjust=1),
      )
      + gg.scale_fill_manual(palette[fill])
  )

  print(plot)
  # Output SVG file.
  dataname = 'avg_discr_vs_percept'
  with warnings.catch_warnings():
    warnings.filterwarnings('ignore', category = UserWarning)
    plot_name = f'{dataname}.svg'
    plot.save(plot_name, format = 'svg')


def make_reward_plot(df: pd.DataFrame, fill: str = 'Condition'):
  """Makes a plot showing the rewards for all conditions.

  Args:
    df: The dataframe with summary data.
    fill: The fill variable to use for filling the bars of the plot.
  """
  grouped = df.loc[df['race'] == 'last'].groupby(
      ['arch', 'k', 'race'])
  plot_df = grouped.mean()
  plot_df['rwd_std'] = grouped['rwd'].std() / np.sqrt(grouped['rwd'].count())
  plot_df = plot_df.reset_index()
  # Format labels for `race` and `k` variables
  plot_df['k_label'] = plot_df['k'].map('{} races'.format)
  plot_df['# Races'] = plot_df['k'].map('{}'.format)
  plot_df['Condition'] = plot_df['arch']
  plot_df['Condition'] = pd.Categorical(
      plot_df['Condition'],
      categories=names,
      ordered=True)
  # Visualize.
  plot = (
      gg.ggplot(data = plot_df)
      + gg.aes(x = 'Condition',
               y = 'rwd', fill=fill)
      + gg.geom_col(stat='identity', position='dodge',
                    )
      + gg.geom_errorbar(gg.aes(ymin='rwd - rwd_std',
                                ymax='rwd + rwd_std'),
                         )
      + gg.labs(x = f'Condition',
                y = 'Episode reward')
      + gg.theme_bw()
      + gg.theme(
          figure_size = [3.8, 3],
          axis_title = gg.element_text(size=10),
          axis_text = gg.element_text(size=8),
          strip_background = gg.element_rect(fill = 'white'),
          strip_text = gg.element_text(size=9),
          axis_text_x = gg.element_text(angle=45, vjust=1, hjust=1),
      )
      + gg.scale_fill_manual(palette[fill])
  )

  print(plot)
  # Output SVG file.
  dataname = 'avg_rwd_vs_percept'
  with warnings.catch_warnings():
    warnings.filterwarnings('ignore', category = UserWarning)
    plot_name = f'{dataname}.svg'
    plot.save(plot_name, format = 'svg')


def make_attribute_plot_all(attribute: str, df: pd.DataFrame, fill='Condition'):
  """Makes a plot of awareness or stickiness for all conditions."""
  label = {
      'stickiness_prob': 'Stickiness',
      'awareness_prob': 'Awareness',
  }
  random = {
      'stickiness_prob': 0.2,
      'awareness_prob': 0.5,
  }
  # Branch copy of summary dataframe.
  plot_df = df.copy()
  del plot_df['raw_discr_counts']
  # Stickiness & Awareness only make sense for computation at the last race.
  plot_df = plot_df[plot_df['race']=='last']
  del plot_df['race']
  grouped = plot_df.groupby(['arch', 'k'])
  plot_df = grouped.mean()
  plot_df[f'{attribute}_std'] = grouped[attribute].std() / np.sqrt(grouped[attribute].count())
  plot_df = plot_df.reset_index()
  # Normalize participation for use with alpha aesthetic.
  plot_df['norm_part'] = plot_df['part'] / plot_df['part'].max()
  # Format labels for `k` variable
  plot_df['# Races'] = plot_df['k'].map('{}'.format)
  plot_df['Condition'] = plot_df['arch']
  plot_df['Condition'] = pd.Categorical(
      plot_df['Condition'],
      categories=names,
      ordered=True)
  # Visualize.
  plot = (
      gg.ggplot(data = plot_df)
      + gg.aes(x='Condition', y=attribute, fill=fill)
      + gg.geom_col(stat='identity', position='dodge')
      + gg.geom_errorbar(gg.aes(ymin=f'{attribute}-{attribute}_std',
                                ymax=f'{attribute}+{attribute}_std'),
                         position=gg.position_dodge(0.9))
      + gg.labs(x = 'Condition',
                y = label[attribute])
      + gg.theme_bw()
      + gg.theme(figure_size = [3.8, 3],
                axis_title = gg.element_text(size = 10),
                axis_text = gg.element_text(size = 8),
                axis_text_x = gg.element_text(angle=45, vjust=1, hjust=1),
                strip_background = gg.element_rect(fill='white'),
                strip_text = gg.element_text(size = 10))
      + gg.geom_abline(intercept=random[attribute], slope=0, size=1, colour=empirical_colors['Random'], show_legend=True)
      + gg.labs(colour="")
      + gg.scale_fill_manual(palette[fill])
  )

  print(plot)

  dataname = f'{attribute}'
  plot_name = f'{dataname}.svg'
  plot.save(plot_name, format = 'svg')


In [None]:
make_discrimination_plot('Baseline', summary_df.loc[summary_df.k == 8])

In [None]:
make_discrimination_plot('$\\beta$ = 0.0', summary_df.loc[summary_df.k == 8])

In [None]:
average_discrimination_over_bias(summary_df.loc[summary_df.k == 8], bootstrap=True)

In [None]:
average_discrimination_over_bias(summary_df.loc[summary_df.k == 8], bootstrap=False)

# Analytical model & Fig 4

In [None]:
make_attribute_plot_all('awareness_prob', summary_df.loc[summary_df.k == 8])

In [None]:
make_attribute_plot_all('stickiness_prob', summary_df.loc[summary_df.k == 8])

In [None]:
make_reward_plot(summary_df.loc[summary_df.k == 8])

In [None]:
#@title Create empirical CSV to overlay

plot_df = summary_df.copy()
# Stickiness & Awareness only make sense for computation at the last race.
plot_df = plot_df[plot_df['race']=='last']
del plot_df['race']
grouped = plot_df.groupby(['arch', 'k'])
plot_df = grouped.mean()
plot_df = plot_df.reset_index()
plot_df['arch_label'] = plot_df['arch']  # .map(lambda x: x.split(' ')[-1] if x != 'A3C LargeNet' else 'Baseline')
plot_df['arch_label'] = pd.Series(plot_df['arch_label'], dtype="category")
empirical_csv = plot_df[['arch_label', 'awareness_prob', 'stickiness_prob', 'k']].to_csv()
empirical_csv += '12,Random,0.5,0.2,2\n13,Random,0.5,0.2,8\n'

In [None]:
#@title Plot Analytical figure

plot_df = analytic_df.copy()
del plot_df['race_counts']  # Not needed, prevents aggregation
grouped_df = plot_df.groupby(['k', 'awareness', 'stickiness'])
grouped_df = grouped_df.mean()
grouped_df = grouped_df.reset_index()
grouped_df = grouped_df.loc[(grouped_df.k == 2) | (grouped_df.k == 8)]
ks = sorted(set(grouped_df.k))

Xs = sorted(set(grouped_df.stickiness))[1:5]
Ys = sorted(set(grouped_df.awareness))[4:8]
Xv, Yv = np.meshgrid(Xs, Ys)

sticks = Xs

fig, axes = plt.subplots(1, len(ks), tight_layout=True)
fig.set_size_inches(3*len(ks) + 1.5, 3)

axes[0].set_ylabel('Awareness $\omega$')

csv_io = io.StringIO(empirical_csv)
empirical_df = pd.read_csv(csv_io).reset_index(drop=True)

for j, k in enumerate(ks):
  last_fn = np.vectorize(
      lambda x, y: grouped_df[(grouped_df.awareness == y) & (grouped_df.stickiness == x) &
                        (grouped_df.k == k)].last_race_discr.mean())

  empirical_subset = empirical_df.loc[empirical_df.k == k]
  last_v = -last_fn(Xv, Yv)

  axes[j].set_title(f'${k}$ races')

  axes[j].set_xlabel('Stickiness $s$')

  cs = axes[j].contourf(Xv, Yv, last_v,
                        extend='both')
  cbar = fig.colorbar(cs, ax=axes[j], format='%.1f')

  axes[j].scatter(empirical_subset.stickiness_prob,
                  empirical_subset.awareness_prob,
                  c=empirical_subset.arch_label.map(empirical_colors),
                  edgecolors='#555555')
  handles = [
      plt.Line2D([0], [0], marker='o', color='w', markeredgecolor='#555555',
                 markerfacecolor=v, label=k, markersize=8)
      for k, v in empirical_colors.items()
  ]

axes[-1].legend(title='Condition', handles=handles,
                bbox_to_anchor=(1.4, 1),
                loc='upper left')

dataname = 'analytic_w_overlay'
plot_name = f'{dataname}.svg'
fig.savefig(plot_name, format = 'svg')


# Supplementary Material Figures

In [None]:
make_reward_plot(summary_df, fill='# Races')

In [None]:
average_discrimination_over_bias(summary_df, bootstrap=False)

In [None]:
make_discrimination_plot('Baseline', summary_df)

In [None]:
make_discrimination_plot('No memory', summary_df)

In [None]:
make_discrimination_plot('$\\beta$ = 0.0', summary_df)

In [None]:
make_discrimination_plot('$\\beta$ = 0.002', summary_df)

In [None]:
make_discrimination_plot('$\\beta$ = 0.005', summary_df)

In [None]:
make_discrimination_plot('$\\beta$ = 0.02', summary_df)

In [None]:
# @title Run regressions
import statsmodels.api as sm


# Initialize consistuent lists for dataframe.
agent_architectures = []
overall_num_races = []
race_num = []
inters = []
archs = []
all_coeffs = []
all_ci_95_lower = []
all_ci_95_upper = []
# Group dataframe by number of races and current race number.
summary_df_groups = summary_df.groupby(['k', 'race', 'arch'])
# Iterate through groups.
for i, ((k, race, arch), grouped_df) in enumerate(summary_df_groups):
  # Set up predictors and outcome variable.
  X = grouped_df['corr']
  X = sm.add_constant(X)
  Y = grouped_df['discr']
  # Fit model.
  model = sm.OLS(Y, X)
  res = model.fit()
  # Extract coefficient and confidence intervals.
  coeff = res.params['corr']
  inter = res.params['const'] if 'const' in res.params else 0.0
  ci_95_lower, ci_95_upper = res.conf_int().loc['corr']
  # Record this group's params and results.
  overall_num_races.append(k)
  race_num.append(race)
  archs.append(arch)
  all_coeffs.append(coeff)
  all_ci_95_lower.append(ci_95_lower)
  all_ci_95_upper.append(ci_95_upper)
  inters.append(inter)
discr_results_df = pd.DataFrame({
    'Description': archs,
    'Race_Num': race_num,
    'Total_Num_Races': overall_num_races,
    'Coeff': all_coeffs,
    'CI_95%_low': all_ci_95_lower,
    'CI_95%_upp': all_ci_95_upper,
    'Intercept': inters,
}).sort_values(['Description', 'Total_Num_Races']).reset_index(drop = True)
discr_results_df

In [None]:
#@title Crown decay plot {run: 'auto'}

# Good values are:
# * 0.0 (no decay)
# * 0.002 (mid association decay)
# * 0.005 (mid reward decay)
# * 0.02 (early reward decay)

alpha = 0.2
min_thresh = 0.3
max_thresh = 0.6

cooldown = 2
race_duration = 225
partner_duration = 75
duration = race_duration + partner_duration

beta = 0.002 #@param {type:"slider", min:0, max:0.1, step:0.001}

xs = list(range(duration))
rowing = ([1] + [0]*cooldown)*13
rowing += [0] * (duration - len(rowing))
ys = [0]
for r, x in zip(rowing, xs):
  y = ys[-1]
  if r:
    y = alpha * r + (1-alpha) * y
  y *= 1 - beta
  ys.append(y)

ys = ys[1:]

plt.plot(xs, ys)
plt.plot(xs, [min_thresh]*duration)
plt.plot(xs, [max_thresh]*duration);
plt.plot(xs, [0]*race_duration + [1]*partner_duration);