## Imports and settings

In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import itertools
from functools import partial
from tqdm import tqdm
from datetime import datetime
import os

In [None]:
from generate_games import generate_delegation_games_with_alignment_bounds
from evaluate_games import get_stat_nash, get_stat_general
import regret_bounds
from alignment import AlignmentMetric, EPIC, MAX

## Plotting helpers

In [None]:
def render_simple_bound_surface(welfare_regret:pd.Series, total_misalignment:pd.Series, bound_fn:regret_bounds.SimpleBound):
    x = np.linspace(0, welfare_regret.max(), num=50)
    y = np.linspace(0, total_misalignment.max(), num=50)
    xx, yy = np.meshgrid(x, y)
    zz = bound_fn(xx, yy)
    return xx, yy, zz

### Plotly

#### Specs

In [None]:
MUTED_BLUE='#1f77b4'
def fixed_colorscale(c): return [(0, c), (1, c)]

In [None]:
layout_spec = dict(showlegend=False, margin=dict(l=0, r=0, b=0, t=0))
axis_spec = dict(title_font_size=12, dtick=1, tickfont_size=10)
axes_spec = dict(
    xaxis=axis_spec,
    yaxis=axis_spec,
    zaxis=axis_spec,
    xaxis_title_text='Agents\' welfare regret',
    yaxis_title_text='Total agent misalignment',
    zaxis_title_text='Principals\' welfare regret',
)
aspect_spec = dict(x=1, y=1, z=0.8)
scene_spec = dict(
    aspectratio=aspect_spec,
    # camera_projection_type='orthographic', # optional render choice
    **axes_spec,
)
lighting_spec = dict(
    diffuse=1,
    fresnel=5,
    ambient=0.5,
    roughness=1.,
    specular=0.5,
)

cameras = [
    dict(x=1.8, y=-1.5, z=.8),
    dict(x=2.2, y=.75, z=.8),
]

#### Plotting

In [None]:
def make_simple_figure(stats, bound_fn, epsilon=0.02):
    err = -stats.principals_welfare_regret if bound_fn is None else bound_fn(stats.welfare_regret, stats.total_misalignment) - stats.principals_welfare_regret

    fig = make_subplots(specs=[[{'is_3d':True}]])

    scatter = go.Scatter3d(
        x=stats.welfare_regret,
        y=stats.total_misalignment,
        z=stats.principals_welfare_regret,
        mode='markers',
        marker=dict(
            color=err,
            colorscale='Greens_r',
            size=1.5
        ))
    
    fig.append_trace(scatter, row=1, col=1)

    if bound_fn is not None:
        xx, yy, zz = render_simple_bound_surface(stats.welfare_regret, stats.total_misalignment, bound_fn)
        # adjust surface up a little to accommodate scatter point radius
        surface = go.Surface(x=xx, y=yy, z=zz+epsilon, opacity=0.8, colorscale=fixed_colorscale(MUTED_BLUE), showscale=False, lighting=lighting_spec)
        fig.append_trace(surface, row=1, col=1)

    fig.update_layout(scene=scene_spec, scene_camera_eye=cameras[0], **layout_spec)
    
    return fig

In [None]:
def make_multiview_simple_figure(stats, bound_fn):
    fig1 = make_simple_figure(stats, bound_fn)
    fig2 = go.Figure(fig1)
    fig = make_subplots(1, 2, specs=[[{'is_3d':True}, {'is_3d':True}]], horizontal_spacing=0)
    fig.add_traces(data=fig1.data, rows=1, cols=1)
    fig.add_traces(data=fig2.data, rows=1, cols=2)
    fig.update_layout(
        scene=scene_spec,
        scene2=scene_spec,
        scene_camera_eye=cameras[0],
        scene2_camera_eye=cameras[1],
        **layout_spec,
        )
    return fig

## Data generation

Set up to generate configurable number of games and calculate stats, alternatively to read in stats from previous run

In [None]:
READ_DATA = False # Falsey value means generate afresh and save; YYYY-mm-dd string means reread from that date
todaystr = READ_DATA if READ_DATA else datetime.today().strftime('%Y-%m-%d')

In [None]:
alignment_metric:AlignmentMetric = MAX

In [None]:
n_games = 100000
m = [1, 1, 1, 1] # [1 for _ in range(n_players)]
n_players = len(m)
d_u = 4 # 2**n_players

max_epic = 1

max_welfare_regret = n_players * alignment_metric.constants.K_m # None means use Nash; only working for 2-agent right now

In [None]:
if max_welfare_regret is None:
    get_stat = partial(get_stat_nash, am=alignment_metric)
else:
    get_stat = partial(get_stat_general, am=alignment_metric, max_welfare_regret=max_welfare_regret, use_agents=False)

In [None]:
GENERAL_IDENT = f'{alignment_metric.name}_{n_players}p_{"x".join(str(m_) for m_ in m if m_ != 1)}_{d_u}u_{n_games//1000}k'
GENERAL_LOG_PATH = f'logs/{todaystr}/{GENERAL_IDENT}/'
print(GENERAL_LOG_PATH)

In [None]:
if READ_DATA:
    stats = pd.read_csv(GENERAL_LOG_PATH + 'data.csv', index_col=0)
else:
    os.makedirs(GENERAL_LOG_PATH)
    stats = pd.DataFrame(tqdm(
        itertools.islice(map(
                get_stat,
                generate_delegation_games_with_alignment_bounds(
                    n_players=n_players,
                    n_outcomes=d_u,
                    m=m,
                    max_epic=max_epic,
                    am=alignment_metric)),
            n_games),
        total=n_games))
    stats.to_csv(GENERAL_LOG_PATH + 'data.csv')
stats

## Produce plots

Plotting principals' welfare regret against cooperation failure (agent WR) and total misalignment (agent d_A from principals).

Also plotting theoretical bounds.

In [None]:
K_m = alignment_metric.constants.K_m
K_d = alignment_metric.constants.K_d

default_simple_bound = regret_bounds.bound_principals_welfare_regret_simple(cap=max_welfare_regret)
miscal_bound = regret_bounds.bound_principals_welfare_regret_miscalibrated(ms=m, K_m=K_m, K_d=K_d, cap=max_welfare_regret)

### Plotly

In [None]:
# miscalibration-sensitive bound
fig = make_simple_figure(stats, miscal_bound)
fig.show()

### Explicitly checking bounds

In [None]:
# if stats is produced from miscalibrated games, we should expect (at least some) failures of the simple bound
stats[stats.principals_welfare_regret > default_simple_bound(stats.welfare_regret, stats.total_misalignment)].principals_welfare_regret.count()

In [None]:
# the miscalibration-sensitive bound should always be satisfied
stats[stats.principals_welfare_regret > miscal_bound(stats.welfare_regret, stats.total_misalignment)].principals_welfare_regret.count()