# A colab for "What Matters In On-Policy Reinforcement Learning? A Large-Scale Empirical Study" paper 
https://arxiv.org/abs/2006.05990


Experiment results are stored in a pickled file available in Google Cloud Storage bucket.
This colab reads the file and generates the plots.

In [None]:
#@title Authentication to access a datadump from Google Cloud Storage bucket.
from google.colab import auth
auth.authenticate_user()

In [None]:
#@title Copy the data
!gsutil cp gs://seed_rl_external_data_release/mujoco123_data.json .

Copying gs://seed_rl_external_data_release/mujoco123_data.dump...
\ [1 files][370.2 MiB/370.2 MiB]                                                
Operation completed over 1 objects/370.2 MiB.                                    


In [None]:
#@title Load the data
import json
import pandas as pd

with open('mujoco123_data.json' , 'r') as f:
  tmp = json.load(f)

loaded_data = {}

def convert_data(x):
  x = dict([(int(k), v) for k, v in x.items()])
  return pd.Series(x)

for k, v in tmp.items():
  loaded_data[int(k)] = v
  df = pd.read_json(v[0])
  print(f'starting to process {k} {len(df)}')
  df['_eval_'] = df['_eval_'].map(convert_data)
  df['_eval_original_avg'] = df['_eval_original_avg'].map(convert_data)  
  loaded_data[int(k)][0] = df

In [None]:
#@title Utils

#@title Overwrite dict

overwrite_dict = {
    #Experiment names
    'final_losses': '\\texttt{Policy Losses}',
    'final_time': '\\texttt{Time}',
    'final_optimize': '\\texttt{Optimizers}',
    'final_arch': '\\texttt{Networks architecture}',
    'final_arch2': '\\texttt{Networks architecture}',
    'final_stability': '\\texttt{Normalization and clipping}',
    'final_advantages': '\\texttt{Advantage Estimation}',
    'final_setup': '\\texttt{Training setup}',
    'final_regularizer': '\\texttt{Regularizers}',
    # Batch mode things.
    '_gin.study_design.choice_value_batch_mode': '\\choicet{batchhandling}',
    'repeat': 'Fixed trajectories',
    'shuffle': 'Shuffle trajectories',
    'split': 'Shuffle transitions',
    'split_with_advantage_recomputation': 'Shuffle transitions (recompute advantages)',

    '_gin.study_design.choice_value_epochs_per_step': '\\choicet{numepochsperstep}',
    '_gin.study_design.choice_value_num_actors_in_learner': '\\choicet{numenvs}',
    '_gin.study_design.choice_value_batch_size_transitions': '\\choicet{batchsize}',
    '_gin.study_design.choice_value_step_size_transitions': '\\choicet{stepsize}',

    '_gin.study_design.choice_value_value_loss': '\\choicet{valueloss}',

    '_gin.study_design.choice_value_ppo_style_value_clipping_epsilon': '\\choicet{ppovalueclip}',

    '_gin.study_design.choice_value_sub_advantage_estimator_gae_gae_lambda': '\\choicet{gaelambda}',
    '_gin.study_design.choice_value_advantage_estimator': '\\choicet{advantageestimator}',
    '_gin.study_design.choice_value_sub_advantage_estimator_n_step_n': '\\choicet{nstep}',
    '_gin.study_design.choice_value_sub_value_loss_huber_delta': '\\choicet{huberdelta}',
    '_gin.study_design.choice_value_sub_advantage_estimator_v_trace_lambda': '\\choicet{vtraceaelambda}',
    '_gin.study_design.choice_value_sub_advantage_estimator_v_trace_max_importance_weight_in_advantage_estimation': '\\choicet{vtraceaecrho}',


    '_gin.study_design.choice_value_discount_factor': '\\choicet{discount}',

    '_gin.study_design.choice_value_frame_skip': '\\choicet{frameskip}',

    '_gin.study_design.choice_value_handle_abandoned_episodes_properly': '\\choicet{handleabandon}',

    '_gin.study_design.choice_value_learning_rate': '\\choicet{adamlr}',
    # Losses.

    '_gin.study_design.choice_value_standard_policy_losses': '\\choicet{policyloss}',
    'repeat_positive_advantages':'RPA',
    '_gin.study_design.choice_value_sub_standard_policy_losses_v_trace_vtrace_max_importance_weight':'\\choicet{vtracelossrho}',
    '_gin.study_design.choice_value_sub_standard_policy_losses_ppo_ppo_epsilon':'\\choicet{ppoepsilon}',
    '_gin.study_design.choice_value_sub_standard_policy_losses_v_mpo_vmpo_e_n':'\\choicet{vmpoeps}',
    '_gin.study_design.choice_value_sub_standard_policy_losses_awr_awr_beta':'\\choicet{awrbeta}',
    '_gin.study_design.choice_value_sub_standard_policy_losses_awr_awr_w_max':'\\choicet{awrw}',

    # Optimizer.
    '_gin.study_design.choice_value_learning_rate_decay': '\\choicet{lrdecay}',
    '_gin.study_design.choice_value_optimizer':'\\choicet{optimizer}',
    '_gin.study_design.choice_value_sub_optimizer_adam_momentum':'\\choicet{adammom}',
    '_gin.study_design.choice_value_sub_optimizer_adam_epsilon':'\\choicet{adameps}',
    '_gin.study_design.choice_value_sub_optimizer_adam_learning_rate':'\\choicet{adamlr}',
    '_gin.study_design.choice_value_sub_optimizer_rmsprop_centered':'\\choicet{rmscent}',
    '_gin.study_design.choice_value_sub_optimizer_rmsprop_momentum': '\\choicet{rmsmom}',
    '_gin.study_design.choice_value_sub_optimizer_rmsprop_epsilon': '\\choicet{rmseps}',
    '_gin.study_design.choice_value_sub_optimizer_rmsprop_learning_rate': '\\choicet{rmslr}',

    # Architecture.
    '_gin.study_design.choice_value_action_postprocessing': '\\choicet{actionpost}',
    '_gin.study_design.choice_value_last_kernel_init_value_scaling': '\\choicet{valueinit}',
    '_gin.study_design.choice_value_std_independent_of_input': '\\choicet{stdind}',
    '_gin.study_design.choice_value_last_kernel_init_policy_scaling': '\\choicet{policyinit}',
    '_gin.study_design.choice_value_scale_function': '\\choicet{stdtransform}',
    '@safe_exp': 'exp',
    '@tf.nn.softplus': 'softplus',
    '_gin.study_design.choice_value_initial_action_std': '\\choicet{initialstd}',
    '_gin.study_design.choice_value_initializer': '\\choicet{init}',
    '@GlorotNormal()': 'Glorot normal',
    '@GlorotUniform()': 'Glorot uniform',
    '@Orthogonal()': 'Orthogonal',
    '@he_normal()': 'He normal',
    '@he_uniform()': 'He uniform',
    '@lecun_normal()': 'LeCun normal',
    '@lecun_uniform()': 'LeCun uniform',
    '@orthogonal_gain_sqrt2()': 'Orthogonal(gain=1.41)',
    '_gin.study_design.choice_value_policy_and_value_function_network': '\\choicet{mlpshared}',
    '_gin.study_design.choice_value_sub_policy_and_value_function_network_shared_policy_mlp_width': '\\choicet{sharedwidth}',
    '_gin.study_design.choice_value_sub_policy_and_value_function_network_shared_policy_mlp_depth': '\\choicet{shareddepth}',
    '_gin.study_design.choice_value_sub_policy_and_value_function_network_shared_baseline_cost': '\\choicet{baselinecost}',
    '_gin.study_design.choice_value_minimum_action_std': '\\choicet{minstd}',
    '_gin.study_design.choice_value_activation': '\\choicet{activation}',
    '@swish': 'Swish',
    '@tf.nn.elu': 'ELU',
    '@tf.nn.leaky_relu': 'Leaky ReLU',
    '@tf.nn.relu': 'ReLU',
    '@tf.nn.sigmoid': 'Sigmoid',
    '@tf.nn.tanh': 'Tanh',
    '_gin.study_design.choice_value_sub_policy_and_value_function_network_separate_policy_mlp_width': '\\choicet{policywidth}',
    '_gin.study_design.choice_value_sub_policy_and_value_function_network_separate_policy_mlp_depth':'\\choicet{policydepth}',
    '_gin.study_design.choice_value_sub_policy_and_value_function_network_separate_value_mlp_width':'\\choicet{valuewidth}',
    '_gin.study_design.choice_value_sub_policy_and_value_function_network_separate_value_mlp_depth':'\\choicet{valuedepth}',
    '_gin.study_design.choice_value_policy_mlp_width':'\\choicet{policywidth}',
    '_gin.study_design.choice_value_value_mlp_depth':'\\choicet{valuedepth}',
    '_gin.study_design.choice_value_value_mlp_width':'\\choicet{valuewidth}',
    '_gin.study_design.choice_value_policy_mlp_depth':'\\choicet{policydepth}',

    '_gin.study_design.choice_value_ppo_epsilon': '\\choicet{ppoepsilon}',
    '_gin.study_design.choice_value_input_normalization': '\\choicet{norminput}',
    'Avg (comp=False)': 'Average',
    '_gin.study_design.choice_value_gradient_clipping': '\\choicet{clipgrad}',
    '_gin.study_design.choice_value_normalize_advantages': '\\choicet{normadv}',
    '_gin.study_design.choice_value_reward_normalization': '\\choicet{normreward}',
    'Popart Avg (comp=False)': 'Average',
    '_gin.study_design.choice_value_sub_input_normalization_avg_compfalse_input_clipping': '\\choicet{clipinput}',

    # Regularization.
    '_gin.study_design.choice_value_policy_regularization': '\\choicet{regularizationtype}',
    '_gin.study_design.choice_value_sub_policy_regularization_constraint_regularization_constraint': '\\choicet{regularizerconstraint}',
    '_gin.study_design.choice_value_sub_policy_regularization_penalty_regularization_penalty': '\\choicet{regularizerpenalty}',
    
    '_gin.study_design.choice_value_sub_regularization_constraint_klmupi_kl_mu_pi_threshold': '\\choicet{regularizerconstraintklmupi}',
    '_gin.study_design.choice_value_sub_regularization_constraint_klpimu_kl_pi_mu_threshold': '\\choicet{regularizerconstraintklpimu}',
    '_gin.study_design.choice_value_sub_regularization_constraint_klrefpi_kl_ref_pi_threshold': '\\choicet{regularizerconstraintklrefpi}',
    '_gin.study_design.choice_value_sub_regularization_constraint_decoupled_klmupi_kl_mu_pi_mean_threshold': '\\choicet{regularizerconstraintklmupimean}',
    '_gin.study_design.choice_value_sub_regularization_constraint_decoupled_klmupi_kl_mu_pi_std_threshold': '\\choicet{regularizerconstraintklmupistd}',
    '_gin.study_design.choice_value_sub_regularization_constraint_entropy_entropy_threshold': '\\choicet{regularizerconstraintentropy}',

    '_gin.study_design.choice_value_sub_regularization_penalty_klmupi_coefficient': '\\choicet{regularizerpenaltyklmupi}',
    '_gin.study_design.choice_value_sub_regularization_penalty_klpimu_coefficient': '\\choicet{regularizerpenaltyklpimu}',
    '_gin.study_design.choice_value_sub_regularization_penalty_klrefpi_coefficient': '\\choicet{regularizerpenaltyklrefpi}',
    '_gin.study_design.choice_value_sub_regularization_penalty_decoupled_klmupi_mean_coefficient': '\\choicet{regularizerpenaltyklmupimean}',
    '_gin.study_design.choice_value_sub_regularization_penalty_decoupled_klmupi_std_coefficient': '\\choicet{regularizerpenaltyklmupistd}',
    '_gin.study_design.choice_value_sub_regularization_penalty_entropy_coefficient': '\\choicet{regularizerpenaltyentropy}',

    'constraint': 'Constraint',
    'penalty': 'Penalty',
    'no regularization': 'No regularization',
  }



import sys
if sys.version_info[0] < 3:
  raise Exception("Must be using Python 3")


from concurrent import futures
import hashlib
import os
import copy
import time
import dill
import numpy as np
import pandas as pd
import math
import six.moves.cPickle


start_time = time.time()

import matplotlib.pyplot as plt
from matplotlib import lines

import copy
from io import BytesIO

from scipy.stats import binom

import base64

# These are the "Tableau 20" colors as RGB.
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]

# Scale the RGB values to the [0, 1] range, which is the format matplotlib
# accepts.
for i in range(len(tableau20)):
  r, g, b = tableau20[i]
  tableau20[i] = (r / 255., g / 255., b / 255.)



def encode_fig_pdf(plt, name):
  figfile = BytesIO()
  plt.savefig(figfile, format='pdf', bbox_inches="tight", transparent=True)
  figfile.seek(0)  # rewind to beginning of file 
  
  return ((name.replace(".", "_") + '.pdf', figfile.read()))

def percentile_ci(vector, p, alpha=0.95):
  """Computes the pth percentile and a alpha CI based on binomial coverage.
  
  See https://staff.math.su.se/hoehle/blog/2016/10/23/quantileCI.html. For
  efficiency the confidence interval is based on the 
  [(1-alpha)/2, 1-(1-alpha)/2] interval of a binomial distribution.

  Args:
    vector: Numpy array with values.
    p: Float with percentile (e.g. 0.95).
    alpha: Float with coverage of the CI.

  Returns:
    Percentile as well as a lower and upper bound.
  """
  vector = np.array(vector)
  assert vector.ndim == 1
  n = vector.shape[0]
  low_index, high_index = binom.interval(alpha, n, p, loc=0)
  high_index = int(high_index)
  f=False
  if high_index > n-1:
    high_index = n-1
    f=True
  
  percentile = np.percentile(vector, p*100.)
  low = vector[int(low_index)]
  high = vector[int(high_index)]
  return percentile, low, high, f

def sanitize(name, replace=True):
  if name is None:
    return 'None', 0
  if len(name) > 0 and name[0] == '\'':
    name = name[1:]
  if len(name) > 0 and name[-1] == '\'':
    name = name[:-1]
  try:
    x = int(name)
    return '', (overwrite_dict.get(x, x) if replace else x)
  except ValueError:
    try:
      x = float(name)
      return '', (overwrite_dict.get(x, x) if replace else x)
    except ValueError:
      return (overwrite_dict.get(name, name) if replace else name), 0
  return (overwrite_dict.get(name, name) if replace else name), 0


def primary_plot95(data,percentile):
  fig, axs = plt.subplots(1, 1,sharex=True, figsize=(15, 6),gridspec_kw={'wspace': 0.5})
  fig.suptitle('%d%% performance' % percentile)
  res=""
  groups = data.groupby(['env_name'])['_eval_']
  groups_wid = data.groupby(['env_name'])['wid']
  w = []

  summary = {}
  
  for (name, group_values), (_, group_wid_values) in zip(groups, groups_wid):
    summary[name] = {}
    x = [(x.iloc[-1], wid) for x ,wid in zip(group_values.values, group_wid_values.values)]
    x.sort()
    best_wids = [b for a,b in x][-5:]
    x=[a for a,b in x]
    summary[name]["90th percentile"] = percentile_ci(x, 0.9)[0]
    summary[name]["95th percentile"] = percentile_ci(x, 0.95)[0]
    summary[name]["99th percentile"] = percentile_ci(x, 0.99)[0]
    summary[name]["Max"] = max(x)
    summary[name]['best wu'] = '/'.join([str(x) for x in best_wids])
    p, l, h, f = percentile_ci(x, percentile / 100.0)
    if f:
      res += "<br/><b>CAREFUL. Not enough data to compute CI.</b><br/>"
      #print ('CAREFUL. Not enough data to compute CI.')
    w.append((sanitize(name), p, p-l, h-p))

  summary = pd.DataFrame(summary)

  w.sort()
  nl = [str(val) if name == '' else name for (name, val), _ , _ , _ in w]
  pl = [p for _,  p , _ , _ in w]
  ll = [l for _, _ , l , _ in w]
  hl = [h for _, _ , _ , h in w]
  axs.bar(nl,pl,yerr=[ll, hl], color=tableau20[:len(hl)])#,rotation=90)
  axs.set_xticklabels(nl, rotation=90)
  axs.set_title('Overall')
  
  reslatex = ""
  plt.close()
  

  res += '<br/>' + summary.to_html()

  summary = summary.drop("best wu")
  reslatex += "\\begin{table}[ht]\n\\begin{center}\n\\caption{Performance quantiles across choice configurations.}\n\\label{tab:EXPNAME_overview}\n"
  reslatex += summary.to_latex(column_format="lrrrrr",float_format=(lambda x: '%.0f' % x)) + '\n'
  reslatex += "\\end{center}\n\\end{table}"
  return res, reslatex

def split_data(nl, pl, ll, hl):
  nnl = []
  npl = []
  nll = []
  nhl = []
  c=0
  for n, p, l, h in zip(nl, pl, ll, hl):
    if nnl and nnl[-1][0:3] != n[0:3]:
      nnl.append(" "*c)
      c+=1
      npl.append(0)
      nll.append(0)
      nhl.append(0)
    nnl.append(n)
    npl.append(p)
    nll.append(l)
    nhl.append(h)
  return nnl, npl, nll, nhl

def plot95(data, param,percentile,lines=1,height=5):
  tenvs = list(set(data['env_name'].values))
  envs_list = [tenvs]
  if lines == 2:
    envs_list = [tenvs[:(len(tenvs)+1)//2], tenvs[(len(tenvs)+1)//2:]]
  res = ""
  pdfs = []
  for ei, envs in enumerate(envs_list):
    fig, axs = plt.subplots(1, len(envs),sharex=True, figsize=(15 // len(envs_list[0]) * len(envs) * lines , height),gridspec_kw={'wspace': 0.5 / lines})
    #fig.suptitle('%dth percentile of performance conditioned on %s \'%s\'' % (percentile, 'sub-choice' if '_sub_' in param else 'choice', overwrite_dict.get(param,param)), fontsize=20)
    for i, env in enumerate(envs):
      edata = data[data['env_name'] == env]
      groups = edata.groupby([param])['_eval_']
      
      w = []
      
      for name, group_values in groups:
        x = [x.iloc[-1] for x in group_values.values]
        x.sort()
        p, l, h, f = percentile_ci(x, percentile / 100.0)
        if f:
          res += "<br/><b>CAREFUL. Not enough data to compute CI for %s/%s, len %d.</b><br/>" % (env,name,len(x))
          #print ('CAREFUL. Not enough data to compute CI.')
        w.append((sanitize(name), p, p-l, h-p))

      w.sort()
      nl = [str(val) if name == '' else name for (name, val), _ , _ , _ in w]
      pl = [p for _,  p , _ , _ in w]
      ll = [l for _, _ , l , _ in w]
      hl = [h for _, _ , _ , h in w]
      # Hack
      if param == '_gin.study_design.choice_value_sub_regularization_constraint_entropy_entropy_threshold':
        nl = ["-"+x if x!="0.0" else x for x in nl]
      if 'custom' in param:
        nl, pl, ll, hl = split_data(nl, pl, ll, hl)
      axs[i].bar(nl,pl,yerr=[ll, hl], color=tableau20[:len(hl)])#,rotation=90)
      axs[i].set_xticklabels(nl, rotation=90)
      axs[i].set_title(env, fontsize=20)
      

    suffix = ''
    if lines > 1:
      suffix = '_' + str(ei)
    if height != 5:
      suffix += '_height_'+str(height)
    pdfs.append(encode_fig_pdf(plt, 'perf_' + param + suffix))

    #reslatex = pdf_image("perf_" + param.replace(".", "_") + ".pdf")
    plt.close()#show()
  
  return res, pdfs
  

def plot_top_90(data, param,percentile):
  if 'custom' in param:
    return "", ""
  envs = set(data['env_name'].values)
  pv = set(data[param].values)
  if None in pv:
    pv.remove(None)
  inv_env = {}
  for i, e in enumerate(envs):
    inv_env[e] = i
  inv_env["all"] = len(envs)
  envs.add("all")
  fig, axs = plt.subplots(1, len(envs),sharex=True, figsize=(15, 5),gridspec_kw={'wspace': 0.5})
  #fig.suptitle('Distribution of %s \'%s\' for top %d%% configurations' % ('sub-choice' if '_sub_' in param else 'choice', overwrite_dict.get(param, param), 100-percentile), fontsize=20)
  res=""

  tdata = data[data[param].notnull()]
  groups = tdata.groupby(['env_name'])
  data = {name: (group.name,
                  [x for x in group.values],
                  [x for x in group_v.values],
                    )
    for (name, group), (_, group_v) in zip(groups[param],
                                           groups['_eval_'])}
  allv=[]
  for a, (b, c, d) in data.items():
    p90 = [(x.iloc[-1], i) for i, x in enumerate(d)]
    p90 = sorted(p90)
    p90 = p90[-((len(p90) - 1) // (100 // (100 - percentile)) + 1):]
    v=[data[a][1][x] for _, x in p90]
    allv.extend(v)
    data[a] = (data[a][0], v)
  data["all"] = ("all", allv)
                                           
  s = {}
  tt = {}
  for (a), (_, c) in data.items():
    if a not in tt:
      tt[a]=0
    if a not in s: 
      s[a] = {}
      for ww in pv:
        s[a][ww] = 0
    for x in c:
      if x not in pv:
        continue
      tt[a] = tt.get(a,0) + 1
      s[a][x] = s[a].get(x, 0) + 1
  for i, (a, b) in enumerate(s.items()):
    w = []
    for x, y in b.items():
      w.append((sanitize(x), y/(tt[a]+1e-6),0,0))
    
    w.sort()
    nl = [str(val) if name == '' else name for (name, val), _ , _ , _ in w]
    pl = [p for _,  p , _ , _ in w]
    ll = [l for _, _ , l , _ in w]
    hl = [h for _, _ , _ , h in w]
    # Hack
    if param == '_gin.study_design.choice_value_sub_regularization_constraint_entropy_entropy_threshold':
      nl = ["-"+x if x!="0.0" else x for x in nl]        
    axs[ inv_env[a] ].bar(nl,pl,yerr=[ll, hl], color=tableau20[:len(hl)])#,rotation=90)
    axs[ inv_env[a] ].set_xticklabels(nl, rotation=90)
    axs[ inv_env[a] ].set_title(a, fontsize=20)



  pdf = encode_fig_pdf(plt, 'frequency_' + param)
  plt.close()#show()
  
  return res, pdf

step_limit_easy = 1000000
step_limit_hard = 2000000
use_average = True

percentile = 95


RANDOM_POLICY = {
    'HalfCheetah-v1': -290,
    'Hopper-v1': 18,
    'Walker2d-v1': 2,
    'Ant-v1': 50,
    'Humanoid-v1': 125,
}
easy_envs = set(['HalfCheetah-v1', 'Hopper-v1', 'Walker2d-v1'])


def plot95correlation(data, param1, param2, ename, percentile):
  t = time.time()
  envs = set(data['env_name'].values)
  p1 = [sanitize(x) for x in list(set(data[param1].values)) if x is not None]
  p1.sort()
  fig, axs = plt.subplots(len(p1), len(envs),sharex=True, figsize=(15, 4 * len(p1)),gridspec_kw={'wspace': 0.5})
  #fig.suptitle('%d%% performance' % percentile)
  res=""
  #print ('Estimated 1 ', time.time() - t )
  for i, env in enumerate(envs):
    #print ('EstimatedA  ',i, env, time.time() - t )

    edata = data[data['env_name'] == env]
    groups = edata.groupby([param1, param2])['_eval_']
    
    for ii, t in enumerate(p1):
      w = []
      for name, group_values in groups:
        if sanitize(name[0]) != t:
          continue
        x = [x.iloc[-1] for x in group_values.values]
        x.sort()
        p, l, h, f = percentile_ci(x, percentile / 100.0)
        if f:
          res += "<br/><b>CAREFUL. Not enough data to compute CI.</b><br/>"
          #print ('CAREFUL. Not enough data to compute CI.')
        #print (env, name, x)
        w.append((sanitize(name[1]), p, p-l, h-p))

      w.sort()
      nl = [str(val) if name == '' else name for (name, val), _ , _ , _ in w]
      pl = [p for _,  p , _ , _ in w]
      ll = [l for _, _ , l , _ in w]
      hl = [h for _, _ , _ , h in w]
      axs[ii, i].bar(nl,pl,yerr=[ll, hl], color=tableau20[:len(hl)])#,rotation=90)
      axs[ii, i].set_xticklabels(nl, rotation=90)
      axs[ii, i].set_title(env)
      if i == 0:
        axs[ii, i].set(ylabel=(str(t[1]) if t[0] == '' else t[0]))
  
  pdf = encode_fig_pdf(plt, 'correlation_%s_%s_%s' % (ename, param1.replace(".", "_").replace("_gin_study_design_choice", ""), param2.replace(".", "_").replace("_gin_study_design_choice", "")))

  plt.close()
  
  return res, pdf

def add_line(str, level=0):
  return (" " * level) + str + "\n"

def add_sub_params(param,paramshort, params, frame, vals, level):
  #print ('analyze', param)
  latex = ""
  sparam = paramshort[len('_gin.study_design.choice_value_'):] 
  vfirst = True
  pp = []
  #print (param, vals)
  for val in vals:
    first = True
    tsparam = sparam + '_' + str(val).lower().replace('-', '_').replace(' ','_').replace('(','').replace(')','').replace('=','').replace('|', '')
    #print (tsparam)
    for p in params:
      if p.startswith("_gin.study_design.choice_value_sub_" + tsparam):
        pp.append(p)
        if first:
          if vfirst:
            vfirst = False
            latex += add_line("\\begin{itemize}", level-4)
          tv = sanitize(val)
          ttv0 = tv[0]
          ttv0 = ttv0.replace("KL(mu||pi)", "$\\kl(\\mu||\\pi)$").replace("KL(pi||mu)", "$\\kl(\\pi||\\mu)$").replace("KL(ref||pi)", "$\\kl(\\texttt{ref}||\\pi)$")
          tv = ttv0, tv[1]
          latex += add_line("\\item For the case ``%s = %s'', we further sampled the sub-choices:" % (overwrite_dict.get(param, param).replace("_", "\\_"), str(tv[1]) if tv[0]=='' else tv[0]), level)
          latex += add_line("\\begin{itemize}", level)
        first = False

        vz = []
        vzorig = []
        t = []
        torig = []
        for x in frame[p].value_counts().keys():
          t.append(sanitize(x))
        for x in frame[p].value_counts().keys():
          torig.append(sanitize(x, replace=False))
        t.sort()
        torig.sort()
        for a, b in t:
          if a == '':
            # Hack
            if p == '_gin.study_design.choice_value_sub_regularization_constraint_entropy_entropy_threshold':
              if b > 0:
                vz.append(str(-b))
              else:
                vz.append(str(b))
            else:
              vz.append(str(b))
          else:
            if p == '_gin.study_design.choice_value_batch_mode':
              vz.append('\\texttt{'+a.replace("_", "\\_")+'}')
            else:
              vz.append(a.replace("_", "\\_"))
            vz[-1] = vz[-1].replace("KL(mu||pi)", "$\\kl(\\mu||\\pi)$").replace("KL(pi||mu)", "$\\kl(\\pi||\\mu)$").replace("KL(ref||pi)", "$\\kl(\\texttt{ref}||\\pi)$")
          
        for a, b in torig:
          if a == '':
            vzorig.append(str(b))
          else:
            vzorig.append(a)
        latex += add_line("\\item %s: \\{%s\\}" % (overwrite_dict.get(p, p).replace("_", "\\_"), ', '.join(vz)), level+4)
        #print (p, 'vs', '_gin.study_design.choice_value'+p[len("_gin.study_design.choice_value_sub_" + tsparam):])
        l, tpp = add_sub_params(
            p,'_gin.study_design.choice_value'+p[len("_gin.study_design.choice_value_sub_" + tsparam):],
            params, frame, vzorig, level=level+8)
        latex += l
        pp.extend(tpp)
    if not first:
      latex += add_line("\\end{itemize}", level)
  if not vfirst:
    latex += add_line("\\end{itemize}", level-4)
  return latex, pp

def average(data):
  if not data:
    return [], []

  x = [s for s, _ in data[0].items()]
  y = [0] * len(x)

  for series in data:
    assert len(series) == len(x)
    for i, (a, b) in enumerate(series.items()):
      assert x[i] == a
      y[i] += b / len(data)

  return x, y

def get_last_score(data):
  ind = -1
  sc = 0
  for a, b in data.items():
    if a > ind:
      ind = a
      sc = b
  return b

def plot_training_curves(data):
  envs = set(data['env_name'].values)
  e_inv = {}
  for i, env in enumerate(envs):
    e_inv[env] = i
  fig, axs = plt.subplots(1, len(envs),sharex=False, figsize=(15 * 3, 6),gridspec_kw={'wspace': 0.5 /3})
  res=""
  for name, group in data.groupby(['env_name']):
    ind = e_inv[name]
    x = group['_eval_original_avg'].values
    y = [(get_last_score(a), i) for i, a in enumerate(x)]
    y.sort()
    x = [x[b] for a, b in y]
    data0 = average(x)
    data90 = average(x[-len(x)//10:])
    data95 = average(x[-len(x)//20:])
    data99 = average(x[-len(x)//100:])

    axs[ind].set_title(name, fontsize=20)
    axs[ind].plot(data0[0], data0[1], color=tableau20[0], lw=3)
    axs[ind].plot(data90[0], data90[1], color=tableau20[1], lw=3)
    axs[ind].plot(data95[0], data95[1], color=tableau20[2], lw=3)
    axs[ind].plot(data99[0], data99[1], color=tableau20[3], lw=3)
    axs[ind].legend(['Overall mean', 'top 10% mean', 'top 5% mean', 'top 1% mean'], fontsize=20)
  
  pdf = encode_fig_pdf(plt, 'training_curves')
  plt.close()
  return res, pdf

def publish_report(html, pdf_plots, ename, path):
  nd = path
  d = nd + ename

  os.makedirs(d, exist_ok=True)


  def save_file(inp):
    a, b = inp
    path = d + '/' + a
    with open(path, 'w' + ('b' if '.pdf' in a else '')) as f:
      f.write(b)

    print ('%s available at %s' % (a, path))
  
  with futures.ThreadPoolExecutor(max_workers=100) as executor:
    executor.map(save_file, pdf_plots)


def generate_config(params, frame):
  processed_params = set()
  latex = ""
  latex += add_line("\\begin{itemize}")
  for param in params:
    if not param.startswith('_gin.study_design.choice_value'):
      processed_params.add(param)
      continue
    if param.startswith('_gin.study_design.choice_value') and not param.startswith('_gin.study_design.choice_value_sub_'):
      processed_params.add(param)
      vals = []
      valsorig = []
      t = []
      torig = []
      for x in frame[param].value_counts().keys():
        t.append(sanitize(x))
        torig.append(sanitize(x, replace=False))
      t.sort()
      torig.sort()
      for a, b in t:
        if a == '':
          vals.append(str(b))
        else:
          if param == '_gin.study_design.choice_value_batch_mode':
            vals.append('\\texttt{'+a.replace("_", "\\_")+'}')
          else:
            vals.append(a.replace("_", "\\_"))
      for a, b in torig:
        if a == '':
          valsorig.append(str(b))
        else:
          valsorig.append(a)
      latex += add_line("    \\item %s: \\{%s\\}" % (overwrite_dict.get(param, param).replace("_", "\\_").replace("#", "\\#"), ', '.join(vals)))
      l, pp = add_sub_params(param, param, params, frame, valsorig, level=8)
      latex += l
      for x in pp:
        processed_params.add(x)
      
  latex += add_line("\\end{itemize}")
  return latex, processed_params

def ecdf(data):
  """ Compute ECDF """
  x = np.sort(data)
  n = x.size
  y = np.arange(1, n+1) / n
  return(x,y)

def plot_ecdf_curves(data, param):
  envs = set(data['env_name'].values)
  e_inv = {}
  for i, env in enumerate(envs):
    e_inv[env] = i
  fig, axs = plt.subplots(1, len(envs),sharex=False, figsize=(15 * 3, 6),gridspec_kw={'wspace': 0.5 /3})
  res=""

  for e in envs:
    sdata = data[data['env_name'] == e]
    ind = e_inv[e]
    max_val = max([x.iloc[-1] for x in sdata['_eval_']])
    min_val = RANDOM_POLICY[e]
    leg = []
    for name, group in sdata.groupby([param]):
      leg.append(sanitize(name)[0])
      x = [(x.iloc[-1] - min_val) / (max_val - min_val) for x in group['_eval_'].values]
      x, y = ecdf(x)
      axs[ind].step(x=x, y=y, where='post')#, color=acolor[name])
      axs[ind].set_title('%s (max=%.0f)' % (e, max_val), fontsize=20)
    axs[ind].legend(leg, fontsize=20)
    
  pdf = encode_fig_pdf(plt, 'ecdf_'+param.replace(".","_"))
  plt.close()
  return res, pdf


def generate_report(loaded_data, xid, path, BREAK=False):
  frame, params, experiment_name, num_seeds, num_wus = loaded_data[xid]

  html = ""
  latex = ""
  pdf_plots = []


  ename = experiment_name.split(' ')[1]
  print ('Starting to generate report for %s' % ename)

  latex += add_line("\\clearpage")
  latex += add_line("\\section{Experiment %s}" % overwrite_dict.get(ename, ename).replace("_", "\\_"))
  latex += add_line("\\label{exp_EXPNAME}")
  latex += add_line("\\subsection{Design}")
  latex += add_line("\\label{exp_design_EXPNAME}")
  latex += add_line("For each of the 5 environments, we sampled %d choice configurations where we sampled the following choices independently and uniformly from the following ranges:" % (num_wus // 5))
  if ename == "final_arch2":
    tframe, tparams, _, _, _ = loaded_data[13759846] 
    t, processed_params = generate_config(tparams, tframe)
  else:
    t, processed_params = generate_config(params, frame)
  latex += t

  #return
  #print(latex)
  #return 

  latex += add_line("All the other choices were set to the default values as described in Appendix~\\ref{sec:default_settings}.")
  latex += add_line("")
  latex += add_line("For each of the sampled choice configurations, we train %d agents with different random seeds and compute the performance metric as described in Section~\\ref{sec:performance}." % num_seeds)

  if ename == 'final_arch2':
    latex += add_line("")
    latex += add_line("After running the experiment described above we noticed (Fig.~\\ref{fig:final_arch__mlpshared}) that separate policy and value function networks (\\choicep{mlpshared}) perform better and we have rerun the experiment with only this variant present.")
    latex += add_line("")
    latex += add_line("\\begin{figure}[ht]")
    latex += add_line("\\begin{center}")
    latex += add_line("\\centerline{\\includegraphics[width=0.45\\textwidth]{final_arch/perf__gin_study_design_choice_value_policy_and_value_function_network.pdf}\\hspace{1cm}\\includegraphics[width=0.45\\textwidth]{final_arch/frequency__gin_study_design_choice_value_policy_and_value_function_network.pdf}}")
    latex += add_line("\\caption{Analysis of choice \\choicet{mlpshared}: "+str(percentile)+ "th percentile of performance scores conditioned on choice (left) and distribution of choices in top "+str(100-percentile)+"\\% of configurations (right).}")
    latex += add_line("\\label{fig:final_arch__mlpshared}")
    latex += add_line("\\end{center}")
    latex += add_line("\\end{figure}")


  latex += add_line("\\subsection{Results}")
  latex += add_line("\\label{exp_results_EXPNAME}")
  latex += add_line("We report aggregate statistics of the experiment in Table~\\ref{tab:EXPNAME_overview} as well as training curves in Figure~\\ref{fig:EXPNAME_training_curves}.")
  last = "fig:EXPNAME_"+params[-1].replace(".","_")
  if ename == 'final_setup':
    last = "fig:final_setup2__gin_study_design_choice_value_batch_mode"
  latex += add_line("For each of the investigated choices in this experiment, we further provide a per-choice analysis in Figures~\\ref{fig:EXPNAME_"+params[0].replace(".","_")+"}-\\ref{"+last+"}.")

  html += "<h1>Analysis of %s %s </h1>" % (xid, experiment_name)

  data = frame  
  if not BREAK:
    t, tl = primary_plot95(data,percentile)
    html += t
    latex += tl
    
    t, pdf = plot_training_curves(frame)
    html += t
    pdf_plots.append(pdf)

  
  latex += add_line("")
  latex += add_line("\\begin{figure}[ht]")
  latex += add_line("\\begin{center}")
  latex += add_line("\\centerline{\\includegraphics[width=1\\textwidth]{EXPNAME/training_curves.pdf}}")
  latex += add_line("\\caption{Training curves.}")
  latex += add_line("\\label{fig:EXPNAME_training_curves}")
  latex += add_line("\\end{center}")
  latex += add_line("\\end{figure}")

  if ename == "final_losses":
    param='_gin.study_design.choice_value_standard_policy_losses'
    chtml, pdf = plot_ecdf_curves(frame, param)
    pdf_plots.append(pdf)
    html += chtml
    latex += add_line("")
    latex += add_line("\\begin{figure}[ht]")
    latex += add_line("\\begin{center}")
    latex += add_line("\\centerline{\\includegraphics[width=1\\textwidth]{"+ename+"/"+pdf[0]+"}}")
    latex += add_line("\\caption{Empirical cumulative density functions of agent performance conditioned on different values of "+overwrite_dict.get(param,param)+". The x axis denotes performance rescaled so that 0 corresponds to a random policy and 1 to the best found configuration, and the y axis denotes the quantile.}")
    latex += add_line("\\label{fig:"+ename+"__ecdf_standard_policy_losses}")
    latex += add_line("\\end{center}")
    latex += add_line("\\end{figure}")
  
  
  for param in params:
    if BREAK:
      break
    html += "<h2>Analysis of %s</h2>" % overwrite_dict.get(param, param)
    pname = param.replace(".", "_") 
      
    if 'custom' in param:
      t1, pdfs = plot95(data, param,percentile,lines=2)
      html += t1
      pdf_plots.extend(pdfs)
      
      latex += add_line("")
      latex += add_line("\\begin{figure}[ht]")
      latex += add_line("\\begin{center}")
      latex += add_line("\\centerline{\\includegraphics[width=1\\textwidth]{EXPNAME/perf_"+pname+"_0.pdf}}")
      latex += add_line("\\centerline{\\includegraphics[width=0.65\\textwidth]{EXPNAME/perf_"+pname+"_1.pdf}}")
      if ename == "final_losses":
        latex += add_line("\\caption{Comparison of "+str(percentile)+ "th percentile of the performance of different policy losses conditioned on their hyperparameters.}")
      elif ename == "final_advantages":
        latex += add_line("\\caption{Comparison of "+str(percentile)+ "th percentile of the performance of different advantage estimators conditioned on their hyperparameters.}")
      elif ename == "final_regularizer":
        latex += add_line("\\caption{Comparison of "+str(percentile)+ "th percentile of the performance of different regularization approaches conditioned on their type.}")
      else:
        latex += add_line("\\caption{Analysis of choice "+overwrite_dict.get(param,param).replace("_", "\\_")+": "+str(percentile)+ "th percentile of performance scores conditioned on choice.}")
      latex += add_line("\\label{fig:EXPNAME_"+pname+"}")
      latex += add_line("\\end{center}")
      latex += add_line("\\end{figure}")
    else:
      if ename == 'final_losses' and param == "_gin.study_design.choice_value_standard_policy_losses":
        t1, pdfs = plot95(data, param,percentile,height=3)
        pdf_plots.extend(pdfs)
      if (ename == 'final_arch2' or ename == 'final_arch') and param == "_gin.study_design.choice_value_initial_action_std":
        t1, pdfs = plot95(data, param,percentile,height=3)
        pdf_plots.extend(pdfs)
      t1, pdfs = plot95(data, param,percentile)
      pdf_plots.extend(pdfs)
      html += t1
      t2, pdf = plot_top_90(data, param,percentile)
      pdf_plots.append(pdf)
      html += t2
      sub = 'sub-' if '_sub_' in param else ''
      latex += add_line("")
      latex += add_line("\\begin{figure}[ht]")
      latex += add_line("\\begin{center}")
      latex += add_line("\\centerline{\\includegraphics[width=0.45\\textwidth]{EXPNAME/perf_"+pname+".pdf}\\hspace{1cm}\\includegraphics[width=0.45\\textwidth]{EXPNAME/frequency_"+pname+".pdf}}")
      latex += add_line("\\caption{Analysis of choice "+overwrite_dict.get(param,param).replace("_", "\\_").replace("#", "\\#")+": "+str(percentile)+ "th percentile of performance scores conditioned on "+sub+"choice (left) and distribution of "+sub+"choices in top "+str(100-percentile)+"\\% of configurations (right).}")
      latex += add_line("\\label{fig:EXPNAME_"+pname+"}")
      latex += add_line("\\end{center}")
      latex += add_line("\\end{figure}")

    if BREAK:
      break
    #if param.startswith('_gin'):
    #  break

  if not BREAK:
    if ename == "final_losses":
      param2 = "_gin.study_design.choice_value_epochs_per_step"
      param1 = "_gin.study_design.choice_value_standard_policy_losses"

      chtml, pdf = plot95correlation(frame, param1, param2, ename, percentile)
      pdf_plots.append(pdf)
      html += chtml
      latex += add_line("")
      latex += add_line("\\begin{figure}[ht]")
      latex += add_line("\\begin{center}")
      latex += add_line("\\centerline{\\includegraphics[width=1\\textwidth]{"+ename+"/"+pdf[0]+"}}")
      latex += add_line("\\caption{"+str(percentile)+"th percentile of performance scores conditioned on "+overwrite_dict.get(param1,param1)+"(rows) and "+overwrite_dict.get(param2,param2)+"(bars).}")
      latex += add_line("\\label{fig:"+ename+"__correlation_epochs_per_step_vs_losses}")
      latex += add_line("\\end{center}")
      latex += add_line("\\end{figure}")


    if ename == "final_optimize":
      param1="_gin.study_design.choice_value_sub_optimizer_rmsprop_momentum"
      param2="_gin.study_design.choice_value_sub_optimizer_rmsprop_learning_rate"

      chtml, pdf = plot95correlation(frame, param1, param2, ename, percentile)
      pdf_plots.append(pdf)
      html += chtml
      latex += add_line("")
      latex += add_line("\\begin{figure}[ht]")
      latex += add_line("\\begin{center}")
      latex += add_line("\\centerline{\\includegraphics[width=1\\textwidth]{"+ename+"/"+pdf[0]+"}}")
      latex += add_line("\\caption{"+str(percentile)+"th percentile of performance scores conditioned on \\choicet{rmsmom}(rows) and \\choicet{rmslr}(bars).}")
      latex += add_line("\\label{fig:"+ename+"__correlation_rmsprop_momentum_vs_lr}")
      latex += add_line("\\end{center}")
      latex += add_line("\\end{figure}")

    if False and ename == "final_setup":  # we do not want this graph anymore
      param1="_gin.study_design.choice_value_batch_mode"
      param2="_gin.study_design.choice_value_num_actors_in_learner"

      chtml, pdf = plot95correlation(frame, param1, param2, ename, percentile)
      pdf_plots.append(pdf)
      html += chtml
      latex += add_line("")
      latex += add_line("\\begin{figure}[ht]")
      latex += add_line("\\begin{center}")
      latex += add_line("\\centerline{\\includegraphics[width=1\\textwidth]{"+ename+"/"+pdf[0]+"}}")
      latex += add_line("\\caption{\todo{Marcin, please change me!!!}}")
      latex += add_line("\\label{fig:"+ename+"__correlation_batch_mode_vs_num_actors}")
      latex += add_line("\\end{center}")
      latex += add_line("\\end{figure}")

    if ename == "final_setup":
      latex += add_line("")
      latex += add_line("\\begin{figure}[ht]")
      latex += add_line("\\begin{center}")
      latex += add_line("\\centerline{\\includegraphics[width=0.45\\textwidth]{final_setup2/perf__gin_study_design_choice_value_batch_mode.pdf}\\hspace{1cm}\\includegraphics[width=0.45\\textwidth]{final_setup2/frequency__gin_study_design_choice_value_batch_mode.pdf}}")
      latex += add_line("\\caption{Analysis of choice \\choicet{batchhandling}: "+str(percentile)+"th percentile of performance scores conditioned on choice (left) and distribution of choices in top "+str(100-percentile)+"\\% of configurations(right). In order to obtain narrower confidence intervals in this experiment we only sweep \\choicet{batchhandling}, \\choicet{numenvs}, \\choicet{adamlr}.}")
      latex += add_line("\\label{fig:final_setup2__gin_study_design_choice_value_batch_mode}")
      latex += add_line("\\end{center}")
      latex += add_line("\\end{figure}")

  html = html.replace("width=\"\"", "width=700")

  latex += add_line("\\clearpage")
  latex = latex.replace("EXPNAMEREADABLE", "experiment " + overwrite_dict.get(ename,ename).replace("_", "\\_"))
  latex = latex.replace("EXPNAME", ename)

  pdf_plots.append(('main.tex', latex))

  print('Report generation finished for %s' % ename)
  publish_report(html, pdf_plots, ename, path)
  print ('FINISHED publishing with %d/%s' % (xid, experiment_name))
  print('')
  print('')


In [None]:
#@title Generate reports

for xid in loaded_data.keys():
  generate_report(loaded_data, xid, '', BREAK=False)

Starting to generate report for final_advantages
Report generation finished for final_advantages
training_curves.pdf available at final_advantages/training_curves.pdf
perf__gin_study_design_choice_value_num_actors_in_learner.pdf available at final_advantages/perf__gin_study_design_choice_value_num_actors_in_learner.pdfperf_custom_advantage_estimator_1.pdf available at final_advantages/perf_custom_advantage_estimator_1.pdf

perf_custom_advantage_estimator_0.pdf available at final_advantages/perf_custom_advantage_estimator_0.pdf
frequency__gin_study_design_choice_value_value_loss.pdf available at final_advantages/frequency__gin_study_design_choice_value_value_loss.pdf
frequency__gin_study_design_choice_value_num_actors_in_learner.pdf available at final_advantages/frequency__gin_study_design_choice_value_num_actors_in_learner.pdf
perf__gin_study_design_choice_value_value_loss.pdf available at final_advantages/perf__gin_study_design_choice_value_value_loss.pdf
frequency__gin_study_design_c

In [None]:
!ls -l final_advantages

total 388
-rw-r--r-- 1 root root 15006 Jan  5 15:43 frequency__gin_study_design_choice_value_advantage_estimator.pdf
-rw-r--r-- 1 root root 15261 Jan  5 15:43 frequency__gin_study_design_choice_value_learning_rate.pdf
-rw-r--r-- 1 root root 14476 Jan  5 15:43 frequency__gin_study_design_choice_value_num_actors_in_learner.pdf
-rw-r--r-- 1 root root 14120 Jan  5 15:43 frequency__gin_study_design_choice_value_ppo_style_value_clipping_epsilon.pdf
-rw-r--r-- 1 root root 15252 Jan  5 15:43 frequency__gin_study_design_choice_value_sub_advantage_estimator_gae_gae_lambda.pdf
-rw-r--r-- 1 root root 14370 Jan  5 15:43 frequency__gin_study_design_choice_value_sub_advantage_estimator_n_step_n.pdf
-rw-r--r-- 1 root root 15022 Jan  5 15:43 frequency__gin_study_design_choice_value_sub_advantage_estimator_v_trace_lambda.pdf
-rw-r--r-- 1 root root 13458 Jan  5 15:43 frequency__gin_study_design_choice_value_sub_advantage_estimator_v_trace_max_importance_weight_in_advantage_estimation.pdf
-rw-r--r-- 1 roo