In [1]:
import ast
import plotly.graph_objects
import inspectify
import math
import os
import pandas
import sqlite3
import subprocess


def load_policy_performance_data(db_path, xaxis_choice, yaxis_choice):
  conn = sqlite3.connect(db_path)
  cursor = conn.cursor()

  # Execute SQL query based on x-axis choice
  cursor.execute(f'SELECT policy_id, {xaxis_choice} FROM CONSTRUCTED_POLICIES WHERE policy_id >= 1')
  training_data = cursor.fetchall()

  try:
    cursor.execute('SELECT policy_id, num_training_episodes, num_total_function_evaluations, num_total_timesteps FROM CONSTRUCTED_POLICIES WHERE policy_id >= 1')
    rows = cursor.fetchall()
    policy_id_to_x_values = {policy_id: {column_name: value for column_name, value in zip(['num_training_episodes', 'num_total_function_evaluations', 'num_total_timesteps'], row)}
                 for policy_id, *row in rows}
  except sqlite3.OperationalError:
    cursor.execute('SELECT policy_id, num_training_episodes FROM CONSTRUCTED_POLICIES WHERE policy_id >= 1')
    rows = cursor.fetchall()
    policy_id_to_x_values = {policy_id: {'num_training_episodes': num_training_episodes}
                 for policy_id, num_training_episodes in rows}

  avg_function_evaluations, std_dev_evaluations = [], []
  for policy_id, _ in training_data:
    cursor.execute(f'SELECT {yaxis_choice} FROM EVALUATION_EPISODES WHERE policy_id = ?', (policy_id,))
    evaluations = [e[0] for e in cursor.fetchall()]
    assert evaluations[0] is not None, f"The database {db_path} has a table EVALUATION_EPISODES with a cell of column {yaxis_choice} that is NULL!"
    num_evaluation_episodes = len(evaluations)  # Assuming number of episodes is length of evaluations

    avg_evaluations = sum(evaluations) / len(evaluations) if evaluations else None
    std_dev = math.sqrt(sum((e - avg_evaluations) ** 2 for e in evaluations) / len(evaluations)) if evaluations else 0
    avg_function_evaluations.append(avg_evaluations if avg_evaluations is not None else 0)
    std_dev_evaluations.append(std_dev)

  policy_ids, num_training_timesteps_or_num_training_fes = zip(*training_data) if training_data else ([], [])

  cursor.execute(f'SELECT {yaxis_choice} FROM EVALUATION_EPISODES WHERE policy_id = -1')
  baseline_evaluations = [e[0] for e in cursor.fetchall()]
  baseline_avg_length = sum(baseline_evaluations) / len(baseline_evaluations)
  baseline_variance = sum((e - baseline_avg_length) ** 2 for e in baseline_evaluations) / (len(baseline_evaluations) - 1) if len(baseline_evaluations) > 1 else 0
  baseline_std_dev = math.sqrt(baseline_variance)

  baseline_upper_bound = [baseline_avg_length + baseline_std_dev] * len(num_training_timesteps_or_num_training_fes)
  baseline_lower_bound = [baseline_avg_length - baseline_std_dev] * len(num_training_timesteps_or_num_training_fes)

  data = [
    plotly.graph_objects.Scatter(x=num_training_timesteps_or_num_training_fes, y=avg_function_evaluations, mode='lines+markers', name='#FEs until optimum', line=dict(color='blue', width=4)),
    plotly.graph_objects.Scatter(x=num_training_timesteps_or_num_training_fes, y=[avg + std for avg, std in zip(avg_function_evaluations, std_dev_evaluations)], mode='lines', line=dict(color='rgba(173,216,230,0.2)'), name='Upper Bound (Mean + Std. Dev.)'),
    plotly.graph_objects.Scatter(x=num_training_timesteps_or_num_training_fes, y=[avg - std for avg, std in zip(avg_function_evaluations, std_dev_evaluations)], mode='lines', fill='tonexty', line=dict(color='rgba(173,216,230,0.2)'), name='Lower Bound (Mean - Std. Dev.)'),
    plotly.graph_objects.Scatter(x=[min(num_training_timesteps_or_num_training_fes), max(num_training_timesteps_or_num_training_fes)] if num_training_timesteps_or_num_training_fes else [0], y=[baseline_avg_length, baseline_avg_length], mode='lines', name='Theory: √(𝑛/(𝑛 − 𝑓(𝑥)))', line=dict(color='orange', width=2, dash='dash')),
    plotly.graph_objects.Scatter(x=num_training_timesteps_or_num_training_fes, y=baseline_upper_bound, mode='lines', line=dict(color='rgba(255, 165, 0, 0.2)'), name='Upper Bound (Baseline Variance)'),
    plotly.graph_objects.Scatter(x=num_training_timesteps_or_num_training_fes, y=baseline_lower_bound, mode='lines', fill='tonexty', line=dict(color='rgba(255, 165, 0, 0.2)'), name='Lower Bound (Baseline Variance)'),
  ]

  conn.close()
  return data

def policy_performance(db_path, xaxis_choice, yaxis_choice):
  data = load_policy_performance_data(db_path, xaxis_choice, yaxis_choice)
  # Define the layout with larger dimensions and enhanced appearance
  layout = plotly.graph_objects.Layout(
    titlefont=dict(size=24),  # Bigger title font size
    xaxis=dict(
      title=xaxis_choice.replace('_', ' ').title(),
      titlefont=dict(size=18),  # Bigger axis title font size
      tickfont=dict(size=14),  # Bigger tick labels font size
      gridcolor='lightgrey',  # Grid color
      gridwidth=2,  # Grid line width
    ),
    yaxis=dict(
      title='#FEs until optimum',
      titlefont=dict(size=18),  # Bigger axis title font size
      tickfont=dict(size=14),  # Bigger tick labels font size
      gridcolor='lightgrey',  # Grid color
      gridwidth=2,  # Grid line width
    ),
    font=dict(family='Courier New, monospace', size=18, color='RebeccaPurple'),
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(245, 245, 245, 1)',
    width=1100,  # Width of the figure
    height=600,  # Height of the figure
    margin=dict(l=50, r=50, b=100, t=100, pad=4),  # Margins to prevent cutoff
    showlegend=False,  # This will remove the legend
  )

  fig = plotly.graph_objects.Figure(data=data, layout=layout)
  fig.show()

In [3]:
policy_performance("abdominal.db", "num_total_timesteps", "num_function_evaluations")

In [4]:
policy_performance("merged.db", "num_total_timesteps", "num_function_evaluations")

In [10]:
import plotly.graph_objects
import math
import sqlite3


def load_policy_performance_data(db_path, xaxis_choice, yaxis_choice):
  conn = sqlite3.connect(db_path)
  cursor = conn.cursor()

  # Execute SQL query based on x-axis choice
  cursor.execute(f'SELECT db_path, policy_id, {xaxis_choice} FROM CONSTRUCTED_POLICIES WHERE policy_id >= 1')
  training_data = cursor.fetchall()
  print(training_data)

  try:
    cursor.execute('SELECT policy_id, num_training_episodes, num_total_function_evaluations, num_total_timesteps FROM CONSTRUCTED_POLICIES WHERE policy_id >= 1')
    rows = cursor.fetchall()
    policy_id_to_x_values = {policy_id: {column_name: value for column_name, value in zip(['num_training_episodes', 'num_total_function_evaluations', 'num_total_timesteps'], row)}
                 for policy_id, *row in rows}
  except sqlite3.OperationalError:
    cursor.execute('SELECT policy_id, num_training_episodes FROM CONSTRUCTED_POLICIES WHERE policy_id >= 1')
    rows = cursor.fetchall()
    policy_id_to_x_values = {policy_id: {'num_training_episodes': num_training_episodes}
                 for policy_id, num_training_episodes in rows}

  avg_function_evaluations, std_dev_evaluations = [], []
  for policy_id, _ in training_data:
    cursor.execute(f'SELECT {yaxis_choice} FROM EVALUATION_EPISODES WHERE policy_id = ?', (policy_id,))
    evaluations = [e[0] for e in cursor.fetchall()]
    assert evaluations[0] is not None, f"The database {db_path} has a table EVALUATION_EPISODES with a cell of column {yaxis_choice} that is NULL!"
    num_evaluation_episodes = len(evaluations)  # Assuming number of episodes is length of evaluations

    avg_evaluations = sum(evaluations) / len(evaluations) if evaluations else None
    std_dev = math.sqrt(sum((e - avg_evaluations) ** 2 for e in evaluations) / len(evaluations)) if evaluations else 0
    avg_function_evaluations.append(avg_evaluations if avg_evaluations is not None else 0)
    std_dev_evaluations.append(std_dev)

  policy_ids, num_training_timesteps_or_num_training_fes = zip(*training_data) if training_data else ([], [])

  cursor.execute(f'SELECT {yaxis_choice} FROM EVALUATION_EPISODES WHERE policy_id = -1')
  baseline_evaluations = [e[0] for e in cursor.fetchall()]
  baseline_avg_length = sum(baseline_evaluations) / len(baseline_evaluations)
  baseline_variance = sum((e - baseline_avg_length) ** 2 for e in baseline_evaluations) / (len(baseline_evaluations) - 1) if len(baseline_evaluations) > 1 else 0
  baseline_std_dev = math.sqrt(baseline_variance)

  baseline_upper_bound = [baseline_avg_length + baseline_std_dev] * len(num_training_timesteps_or_num_training_fes)
  baseline_lower_bound = [baseline_avg_length - baseline_std_dev] * len(num_training_timesteps_or_num_training_fes)

  data = [
    plotly.graph_objects.Scatter(x=num_training_timesteps_or_num_training_fes, y=avg_function_evaluations, mode='lines+markers', name='#FEs until optimum', line=dict(color='blue', width=4)),
    plotly.graph_objects.Scatter(x=num_training_timesteps_or_num_training_fes, y=[avg + std for avg, std in zip(avg_function_evaluations, std_dev_evaluations)], mode='lines', line=dict(color='rgba(173,216,230,0.2)'), name='Upper Bound (Mean + Std. Dev.)'),
    plotly.graph_objects.Scatter(x=num_training_timesteps_or_num_training_fes, y=[avg - std for avg, std in zip(avg_function_evaluations, std_dev_evaluations)], mode='lines', fill='tonexty', line=dict(color='rgba(173,216,230,0.2)'), name='Lower Bound (Mean - Std. Dev.)'),
    plotly.graph_objects.Scatter(x=[min(num_training_timesteps_or_num_training_fes), max(num_training_timesteps_or_num_training_fes)] if num_training_timesteps_or_num_training_fes else [0], y=[baseline_avg_length, baseline_avg_length], mode='lines', name='Theory: √(𝑛/(𝑛 − 𝑓(𝑥)))', line=dict(color='orange', width=2, dash='dash')),
    plotly.graph_objects.Scatter(x=num_training_timesteps_or_num_training_fes, y=baseline_upper_bound, mode='lines', line=dict(color='rgba(255, 165, 0, 0.2)'), name='Upper Bound (Baseline Variance)'),
    plotly.graph_objects.Scatter(x=num_training_timesteps_or_num_training_fes, y=baseline_lower_bound, mode='lines', fill='tonexty', line=dict(color='rgba(255, 165, 0, 0.2)'), name='Lower Bound (Baseline Variance)'),
  ]

  conn.close()
  return data

def policy_performance(db_path, xaxis_choice, yaxis_choice):
  data = load_policy_performance_data(db_path, xaxis_choice, yaxis_choice)
  # Define the layout with larger dimensions and enhanced appearance
  layout = plotly.graph_objects.Layout(
    titlefont=dict(size=24),  # Bigger title font size
    xaxis=dict(
      title=xaxis_choice.replace('_', ' ').title(),
      titlefont=dict(size=18),  # Bigger axis title font size
      tickfont=dict(size=14),  # Bigger tick labels font size
      gridcolor='lightgrey',  # Grid color
      gridwidth=2,  # Grid line width
    ),
    yaxis=dict(
      title='#FEs until optimum',
      titlefont=dict(size=18),  # Bigger axis title font size
      tickfont=dict(size=14),  # Bigger tick labels font size
      gridcolor='lightgrey',  # Grid color
      gridwidth=2,  # Grid line width
    ),
    font=dict(family='Courier New, monospace', size=18, color='RebeccaPurple'),
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(245, 245, 245, 1)',
    width=1100,  # Width of the figure
    height=600,  # Height of the figure
    margin=dict(l=50, r=50, b=100, t=100, pad=4),  # Margins to prevent cutoff
    showlegend=False,  # This will remove the legend
  )

  fig = plotly.graph_objects.Figure(data=data, layout=layout)
  fig.show()



policy_performance("merged.db", "num_total_timesteps", "num_function_evaluations")

[(1, 0), (2, 4000), (3, 8000), (4, 12000), (5, 16000), (6, 20000), (7, 24000), (8, 28000), (9, 32000), (10, 36000), (11, 40000), (12, 44000), (13, 48000), (14, 52000), (15, 56000), (16, 60000), (17, 64000), (18, 68000), (19, 72000), (20, 76000), (21, 80000), (22, 84000), (23, 88000), (24, 92000), (25, 96000), (26, 100000), (27, 104000), (28, 108000), (29, 112000), (30, 116000), (31, 120000), (32, 124000), (33, 128000), (34, 132000), (35, 136000), (36, 140000), (37, 144000), (38, 148000), (39, 152000), (40, 156000), (41, 160000), (42, 164000), (43, 168000), (44, 172000), (45, 176000), (46, 180000), (47, 184000), (48, 188000), (49, 192000), (50, 196000), (51, 200000), (52, 204000), (53, 208000), (54, 212000), (55, 216000), (56, 220000), (57, 224000), (58, 228000), (59, 232000), (60, 236000), (61, 240000), (62, 244000), (63, 248000), (64, 252000), (65, 256000), (66, 260000), (67, 264000), (68, 268000), (69, 272000), (70, 276000), (71, 280000), (72, 284000), (73, 288000), (74, 292000), (75