## Imports and settings

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import itertools
from functools import partial
from tqdm import tqdm
from datetime import datetime
import os

In [None]:
from generate_games import generate_delegation_games_with_alignment_bounds
from evaluate_games import get_stat_nash, get_stat_general
import regret_bounds

## Plotting helpers

In [None]:
def render_simple_bound_surface(welfare_regret:pd.Series, total_misalignment:pd.Series, bound_fn:regret_bounds.SimpleBound):
    x = np.linspace(0, welfare_regret.max(), num=50)
    y = np.linspace(0, total_misalignment.max(), num=50)
    xx, yy = np.meshgrid(x, y)
    zz = bound_fn(xx, yy)
    return xx, yy, zz

### Pyplot

In [None]:
plt.rcParams['figure.dpi'] = 200 # 100

In [None]:
def plot_stats(stats, bound_fn, view=(10, 30), ax=None, label=None, cmap=None):
    ax = plt.subplot(projection='3d') if ax is None else ax
    ax.view_init(*view)

    err = -stats.principals_welfare_regret if bound_fn is None else bound_fn(stats.welfare_regret, stats.total_misalignment) - stats.principals_welfare_regret

    ax.scatter(
        stats.welfare_regret,
        stats.total_misalignment,
        stats.principals_welfare_regret,
        c=err,
        cmap='Greens_r' if cmap is None else cmap,
        s=0.5,
        label=label)

    if bound_fn is not None:
        xx, yy, zz = render_simple_bound_surface(stats.welfare_regret, stats.total_misalignment, bound_fn)
        ax.plot_surface(xx, yy, zz, alpha=0.8)

    ax.set_xlabel('$\mathrm{WR}$')
    ax.set_ylabel('$\sum d_A$')
    ax.set_zlabel('$\hat\mathrm{WR}$')

In [None]:
def plot_stats_multiview(stats, bound_fn):
    axl = plt.subplot(1, 2, 1, projection='3d')
    axr = plt.subplot(1, 2, 2, projection='3d')
    plot_stats(stats, bound_fn=bound_fn, view=(20, 10), ax=axl) # (20, 10) bound surface top and behind
    plot_stats(stats, bound_fn=bound_fn, view=(20, -40), ax=axr) # (20, -40) bound surface at front and top

### Plotly

In [None]:
MUTED_BLUE='#1f77b4'
def fixed_colorscale(c): return [(0, c), (1, c)]

In [None]:
layout_spec = dict(showlegend=False, margin=dict(l=0, r=0, b=0, t=0))
axis_spec = dict(title_font_size=10, dtick=1, tickfont_size=10)
axes_spec = dict(
    xaxis=axis_spec,
    yaxis=axis_spec,
    zaxis=axis_spec,
    xaxis_title_text='Cooperation failure',
    yaxis_title_text='Total agent misalignment',
    zaxis_title_text='Principals\' welfare regret',
)
aspect_spec = dict(x=1, y=1, z=0.8)
scene_spec = dict(
    aspectratio=aspect_spec,
    # camera_projection_type='orthographic', # optional render choice
    **axes_spec,
)
lighting_spec = dict(
    diffuse=1,
    fresnel=5,
    ambient=0.5,
    roughness=1.,
    specular=0.5,
)

In [None]:
def make_simple_figure(stats, bound_fn):
    err = -stats.principals_welfare_regret if bound_fn is None else bound_fn(stats.welfare_regret, stats.total_misalignment) - stats.principals_welfare_regret

    fig = make_subplots(specs=[[{'is_3d':True}]])

    scatter = go.Scatter3d(
        x=stats.welfare_regret,
        y=stats.total_misalignment,
        z=stats.principals_welfare_regret,
        mode='markers',
        marker=dict(
            color=err,
            colorscale='Greens_r',
            size=1.5
        ))
    
    fig.append_trace(scatter, row=1, col=1)

    if bound_fn is not None:
        xx, yy, zz = render_simple_bound_surface(stats.welfare_regret, stats.total_misalignment, bound_fn)
        surface = go.Surface(x=xx, y=yy, z=zz, opacity=0.8, colorscale=fixed_colorscale(MUTED_BLUE), showscale=False, lighting=lighting_spec)
        fig.append_trace(surface, row=1, col=1)

    fig.update_layout(scene=scene_spec, **layout_spec)
    
    return fig

In [None]:
def make_multiview_simple_figure(stats, bound_fn):
    fig1 = make_simple_figure(stats, bound_fn)
    fig2 = go.Figure(fig1)
    fig = make_subplots(1, 2, specs=[[{'is_3d':True}, {'is_3d':True}]], horizontal_spacing=0)
    fig.add_traces(data=fig1.data, rows=1, cols=1)
    fig.add_traces(data=fig2.data, rows=1, cols=2)
    fig.update_layout(
        scene=scene_spec,
        scene2=scene_spec,
        scene_camera_eye=dict(x=1.8, y=-1.5, z=.8),
        scene2_camera_eye=dict(x=2.2, y=.75, z=.8),
        **layout_spec,
        )
    return fig

## Data generation

Set up to generate configurable number of games and calculate stats, alternatively to read in stats from previous run

In [None]:
READ_DATA = False # Falsey value means generate afresh and save; YYYY-mm-dd string means reread from that date
todaystr = READ_DATA if READ_DATA else datetime.today().strftime('%Y-%m-%d')

In [None]:
K_m = np.sqrt(2) # a 2-standardised utility has range at most this

In [None]:
n_games = 20000
m = [2, 1, 1, 1] # [1 for _ in range(n_players)]
n_players = len(m)
d_u = 4 # 2**n_players

max_epic = 1

max_welfare_regret = n_players * K_m # None means use Nash; only working for 2-agent right now

In [None]:
get_stat = get_stat_nash if max_welfare_regret is None \
    else partial(get_stat_general, max_welfare_regret=max_welfare_regret, use_agents=False)

In [None]:
GENERAL_IDENT = f'general_{n_players}p_{"x".join(str(m_) for m_ in m if m_ != 1)}_{d_u}u_{n_games//1000}k'
GENERAL_LOG_PATH = f'logs/{todaystr}/{GENERAL_IDENT}/'
print(GENERAL_LOG_PATH)

In [None]:
if READ_DATA:
    stats = pd.read_csv(GENERAL_LOG_PATH + 'data.csv', index_col=0)
else:
    stats = pd.DataFrame(tqdm(
        itertools.islice(map(
                get_stat,
                generate_delegation_games_with_alignment_bounds(
                    n_players=n_players,
                    n_outcomes=d_u,
                    m=m,
                    max_epic=max_epic)),
            n_games),
        total=n_games))
    os.makedirs(GENERAL_LOG_PATH)
    stats.to_csv(GENERAL_LOG_PATH + 'data.csv')
stats

## Produce plots

Plotting principals' welfare regret against cooperation failure (agent WR) and total misalignment (agent d_A from principals).

Also plotting theoretical bounds.

In [None]:
default_simple_bound = regret_bounds.bound_principals_welfare_regret_simple(cap=max_welfare_regret)
miscal_bound = regret_bounds.bound_principals_welfare_regret_miscalibrated(ms=m, K_m=K_m, cap=max_welfare_regret)

### Pyplot

In [None]:
plot_stats_multiview(stats, default_simple_bound)
plt.savefig(GENERAL_LOG_PATH + f'{GENERAL_IDENT}_render_with_simple_bound.png', bbox_inches='tight')

In [None]:
plot_stats_multiview(stats, miscal_bound)
plt.savefig(GENERAL_LOG_PATH + f'{GENERAL_IDENT}_render_with_bound.png', bbox_inches='tight')

### Plotly

In [None]:
# miscalibration-sensitive bound
fig = make_multiview_simple_figure(stats, miscal_bound)
fig.show()

### Explicitly checking bounds

In [None]:
# if stats is produced from miscalibrated games, we should expect (at least some) failures of the simple bound
stats[stats.principals_welfare_regret > default_simple_bound(stats.welfare_regret, stats.total_misalignment)].principals_welfare_regret.count()

In [None]:
# the miscalibration-sensitive bound should always be satisfied
stats[stats.principals_welfare_regret > miscal_bound(stats.welfare_regret, stats.total_misalignment)].principals_welfare_regret.count()