Copyright 2021 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
#@title Imports
import os
import pickle

from google.colab import files
import numpy as np
import pandas as pd
import plotnine as gg
from scipy import stats

In [None]:
#@title Utility functions
def saveplot(plot, path):
  print('Saving figure to {}'.format(path))
  plot.save(path)
  files.download(path)


def chunks(lst, n):
  for i in range(0, len(lst), n):
    yield lst[i:i + n]


def bootstrap_metric(prediction, ground_truth, metric_func, num_repeats=1000):
  prediction = np.array(prediction)
  ground_truth = np.array(ground_truth)
  length = len(ground_truth)
  repeats = []
  for _ in range(num_repeats):
    sample = np.random.choice(range(length), size=length)
    prediction_ = prediction[sample]
    ground_truth_ = ground_truth[sample]
    metric = metric_func(prediction_, ground_truth_)
    repeats.append(metric)
  return repeats


def normalized_regret(prediction, ground_truth, top_k=1):
  max_val = np.max(ground_truth)
  ind = np.argsort(prediction)
  top_k_values = ground_truth[ind][-top_k:]
  regret = (max_val - np.max(top_k_values)) / max_val
  return regret


def absolute_error(prediction, ground_truth):
  return np.mean(np.abs(prediction - ground_truth))


def rank_correlation(prediction, ground_truth):
  return stats.spearmanr(ground_truth, prediction).correlation


def make_scatter_data(policy_dict, gt_data, baselines, whitelist=None):
  """Make dataframe for plotting scatter plots."""
  data = []
  tasks = list(policy_dict.keys())
  for task in tasks:
    policy_ids = policy_dict[task]
    # TODO(tpaine): remove this
    if whitelist is not None:
      policy_ids = [p for p in policy_ids if p in whitelist]

    worst = np.min([gt_data[p][0] for p in policy_ids])
    best = np.max([gt_data[p][0] for p in policy_ids])
    for baseline_name, baseline in baselines.items():
      for policy_id in policy_ids:
        gt_value = (gt_data[policy_id][0] - worst) / (best - worst)
        baseline_value = np.minimum(
            (baseline[policy_id][0] - worst) / (best - worst), 2)

        datum = (policy_id, task, baseline_name, gt_value, baseline_value)
        data.append(datum)
  df = pd.DataFrame(
      data,
      columns=[
          'policy_id', 'task_name', 'baseline', 'ground_truth', 'prediction'
      ])
  return df


def make_results(policy_dict, gt_data, baselines, metric_func, whitelist=None):
  """Make dataframe containing metric results for each bootstrap 'trial'."""
  results = []
  tasks = list(policy_dict.keys())
  for task in tasks:
    policy_ids = policy_dict[task]
    if whitelist is not None:
      policy_ids = [p for p in policy_ids if p in whitelist]

    worst = np.min([gt_data[p][0] for p in policy_ids])
    best = np.max([gt_data[p][0] for p in policy_ids])
    for baseline_name, baseline in baselines.items():
      gt_values = np.array([gt_data[pid][0] for pid in policy_ids])
      baseline_values = np.array([baseline[pid][0] for pid in policy_ids])
      gt_values = (gt_values - worst) / (best - worst)
      baseline_values = (baseline_values - worst) / (best - worst)
      baseline_values = np.minimum(baseline_values, 2)
      repeats = bootstrap_metric(
          prediction=baseline_values,
          ground_truth=gt_values,
          metric_func=metric_func)
      for i, repeat in enumerate(repeats):
        row = (baseline_name, task, i, repeat)
        results.append(row)
  result_df = pd.DataFrame(
      results, columns=['baseline', 'task_name', 'trial', 'metric'])
  return result_df


def make_per_task(results):
  """Make dataframe containing per task performance per baseline."""
  per_task = results.groupby(['baseline', 'task_name'],
                             as_index=False)['metric'].mean()
  temp = results.groupby(['baseline', 'task_name'],
                         as_index=False)['metric'].std()
  per_task['metric_dn'] = per_task['metric'] - temp['metric']
  per_task['metric_up'] = per_task['metric'] + temp['metric']
  return per_task


def make_d4rl_per_domain(results):
  results['task_domain'] = [t.split('-')[0] for t in results['task_name']]
  per_domain = results.groupby(['baseline', 'task_domain'],
                               as_index=False)['metric'].mean()
  temp = results.groupby(['baseline', 'task_domain', 'trial'],
                         as_index=False)['metric'].mean()
  temp = temp.groupby(['baseline', 'task_domain'],
                      as_index=False)['metric'].std()
  per_domain['metric_dn'] = per_domain['metric'] - temp['metric']
  per_domain['metric_up'] = per_domain['metric'] + temp['metric']
  return per_domain


def make_overall(results):
  """Make dataframe containing aggregate performance per baseline."""
  overall = results.groupby(['baseline'], as_index=False)['metric'].mean()
  temp = results.groupby(['baseline', 'trial'], as_index=False)['metric'].mean()
  temp = temp.groupby(['baseline'], as_index=False)['metric'].std()
  overall['metric_dn'] = overall['metric'] - temp['metric']
  overall['metric_up'] = overall['metric'] + temp['metric']
  return overall


def plot_per_task(per_task, y_axis_name, figure_size=(20, 5), order=None):
  baseline_sorted = pd.Categorical(per_task['baseline'], categories=order)
  per_task['baseline_sorted'] = baseline_sorted
  x_axis_name = 'Tasks'
  tasks = list(np.unique(per_task['task_name']))[::-1]
  p = (
      gg.ggplot(per_task) + gg.aes('task_name', y='metric') + gg.geom_bar(
          gg.aes(fill='baseline'),
          stat='identity',
          position=gg.position_dodge()) + gg.geom_errorbar(
              gg.aes(
                  x='task_name',
                  group='baseline',
                  ymin='metric_dn',
                  ymax='metric_up'),
              width=.2,
              position=gg.position_dodge(0.9)) +
      gg.labs(y=y_axis_name, x=x_axis_name) + gg.theme_minimal() +
      gg.facet_wrap(['baseline_sorted'], ncol=8) + gg.theme(
          legend_position='none',
          figure_size=figure_size,
          panel_spacing=.2,
          text=gg.element_text(size=16),
          axis_text=gg.element_text(size=14),
          axis_title=gg.element_text(size=16)) +
      gg.scale_x_discrete(limits=tasks) + gg.coord_flip())
  return p


def plot_overall(overall, y_axis_name, ascending=True, order=None):
  baseline_sorted = pd.Categorical(overall['baseline'], categories=order)
  overall['baseline_sorted'] = baseline_sorted
  p = (
      gg.ggplot(overall) + gg.aes('baseline_sorted', y='metric') + gg.geom_bar(
          gg.aes(fill='baseline'),
          stat='identity',
          position=gg.position_dodge()) + gg.geom_errorbar(
              gg.aes(x='baseline_sorted', ymin='metric_dn', ymax='metric_up'),
              width=.2,
              position=gg.position_dodge(0.9)) +
      gg.labs(y=y_axis_name, x='Baselines') + gg.theme_minimal() + gg.theme(
          legend_position='none',
          figure_size=(9, 3),
          panel_spacing=.5,
          text=gg.element_text(size=16),
          axis_text=gg.element_text(size=14),
          axis_title=gg.element_text(size=16)))
  return p


def plot_scatter(scatter_data, order=None):
  y_axis_name = 'Estimate'
  x_axis_name = 'Return (d=0.995)'
  scatter_data['task_name'] = [
      t.replace('_', '\n') for t in scatter_data['task_name']
  ]
  scatter_data['task_name'] = [
      t.replace('-', '\n') for t in scatter_data['task_name']
  ]
  if order is None:
    scatter_data['baseline_sorted'] = scatter_data['baseline']
  else:
    baseline_sorted = pd.Categorical(scatter_data['baseline'], categories=order)
    scatter_data['baseline_sorted'] = baseline_sorted
  p = (
      gg.ggplot(scatter_data) + gg.geom_abline(color='grey', alpha=0.8) +
      gg.aes(y='prediction', x='ground_truth', color='baseline') +
      gg.geom_point(alpha=0.8) +
      gg.facet_grid(['task_name', 'baseline_sorted']) + gg.coord_fixed() +
      gg.labs(y=y_axis_name, x=x_axis_name) + gg.theme_minimal() +
      gg.xlim(0, 2) + gg.ylim(0, 2) + gg.theme(
          legend_position='none',
          figure_size=(15, 25),
          panel_spacing=.2,
          text=gg.element_text(size=16),
          axis_text=gg.element_text(size=14),
          axis_title=gg.element_text(size=16)))

  return p


def plot_per_domain(per_domain, y_axis_name, order=None):
  x_axis_name = 'Domains'
  baseline_sorted = pd.Categorical(per_domain['baseline'], categories=order)
  per_domain['baseline_sorted'] = baseline_sorted
  domains = list(np.unique(per_domain['task_domain']))[::-1]
  p = (
      gg.ggplot(per_domain) + gg.aes('task_domain', y='metric') + gg.geom_bar(
          gg.aes(fill='baseline'),
          stat='identity',
          position=gg.position_dodge()) + gg.geom_errorbar(
              gg.aes(
                  x='task_domain',
                  group='baseline',
                  ymin='metric_dn',
                  ymax='metric_up'),
              width=.2,
              position=gg.position_dodge(0.9)) +
      gg.labs(y=y_axis_name, x=x_axis_name) + gg.theme_minimal() +
      gg.facet_wrap(['baseline_sorted'], ncol=8) + gg.theme(
          legend_position='none',
          figure_size=(20, 5),
          panel_spacing=.1,
          text=gg.element_text(size=16),
          axis_text=gg.element_text(size=14),
          axis_title=gg.element_text(size=16)) +
      gg.scale_x_discrete(limits=domains) + gg.coord_flip())
  return p


def make_all_metric_plots(benchmark,
                          metric_type,
                          policy_dict,
                          gt_data,
                          baselines,
                          whitelist=None):
  """Utility function to make and save all metric plots."""
  # setup
  if benchmark == 'rlunplugged':
    per_task_figure_size = (25, 5)
  elif benchmark == 'd4rl':
    per_task_figure_size = (25, 15)
  else:
    ValueError('Unknown benchmark')

  if metric_type == 'absolute_error':
    y_axis_name = '< Absolute Error'
    metric_func = absolute_error
    ascending = True
  elif metric_type == 'rank_correlation':
    y_axis_name = '> Rank Correlation'
    metric_func = rank_correlation
    ascending = False
  elif metric_type == 'normalized_regret':
    y_axis_name = '< Regret@1'
    metric_func = normalized_regret
    ascending = True
  else:
    ValueError('Unknown metric_type')

  # Prepare data for plots
  results = make_results(
      policy_dict,
      gt_data,
      baselines,
      metric_func=metric_func,
      whitelist=whitelist)
  per_task = make_per_task(results)
  overall = make_overall(results)
  if benchmark == 'd4rl':
    per_domain = make_d4rl_per_domain(results)

  # Order baselines by overall performance
  overall.sort_values('metric', inplace=True, ascending=ascending)
  order = list(overall['baseline'])

  # Make plots
  p = plot_per_task(
      per_task,
      y_axis_name=y_axis_name,
      figure_size=per_task_figure_size,
      order=order)
  print(p)
  output_path = output_pattern.format(benchmark,
                                      'per_task_{}'.format(metric_type))
  saveplot(p, output_path)

  overall = make_overall(results)
  p = plot_overall(
      overall, y_axis_name=y_axis_name, ascending=ascending, order=order)
  print(p)
  output_path = output_pattern.format(benchmark,
                                      'overall_{}'.format(metric_type))
  saveplot(p, output_path)

  if benchmark == 'd4rl':
    p = plot_per_domain(per_domain, y_axis_name, order=order)
    print(p)
    output_path = output_pattern.format(benchmark,
                                        'per_domain_{}'.format(metric_type))
    saveplot(p, output_path)

In [None]:
#@title Download benchmark data
!gsutil cp -r gs://gresearch/deep-ope/benchmark.zip /tmp/
!unzip /tmp/benchmark.zip -d /tmp/

In [None]:
#@title Read benchmark data
BENCHMARK_ROOT = '/tmp/dope/'
# BENCHMARK = 'rlunplugged'
BENCHMARK = 'd4rl'

if BENCHMARK == 'rlunplugged':
  gt_path = os.path.join(BENCHMARK_ROOT, 'rlunplugged_gt.pkl')
  policy_path = os.path.join(BENCHMARK_ROOT, 'rlunplugged_policys.pkl')
  baseline_paths = {}
  baseline_paths['FQE-D'] = os.path.join(BENCHMARK_ROOT, 'rlunplugged_fqed.pkl')
  baseline_paths['FQE-L2'] = os.path.join(BENCHMARK_ROOT,
                                          'rlunplugged_fqel2.pkl')
  baseline_paths['VPM'] = os.path.join(BENCHMARK_ROOT, 'rlunplugged_vpm.pkl')
  baseline_paths['DICE'] = os.path.join(BENCHMARK_ROOT, 'rlunplugged_dice.pkl')
  baseline_paths['MB-FF'] = os.path.join(BENCHMARK_ROOT,
                                         'rlunplugged_mb_ff.pkl')
  baseline_paths['MB-AR'] = os.path.join(BENCHMARK_ROOT,
                                         'rlunplugged_mb_ar.pkl')
  baseline_paths['DR'] = os.path.join(BENCHMARK_ROOT, 'rlunplugged_dr.pkl')
  baseline_paths['IS'] = os.path.join(BENCHMARK_ROOT, 'rlunplugged_is.pkl')
elif BENCHMARK == 'd4rl':
  gt_path = os.path.join(BENCHMARK_ROOT, 'd4rl_gt.pkl')
  policy_path = os.path.join(BENCHMARK_ROOT, 'd4rl_policys.pkl')
  baseline_paths = {}
  baseline_paths['FQE-L2'] = os.path.join(BENCHMARK_ROOT, 'd4rl_fqel2.pkl')
  baseline_paths['DICE'] = os.path.join(BENCHMARK_ROOT, 'd4rl_dice.pkl')
  baseline_paths['DR'] = os.path.join(BENCHMARK_ROOT, 'd4rl_dr.pkl')
  baseline_paths['IS'] = os.path.join(BENCHMARK_ROOT, 'd4rl_is.pkl')
  baseline_paths['VPM'] = os.path.join(BENCHMARK_ROOT, 'd4rl_vpm.pkl')

# Ground truth AND baseline pickle files are simple dicts with the format
# key: policy_id
# value: (Return average, Return standard deviation)
with open(gt_path, 'rb') as f:
  gt_data = pickle.load(f)

with open(policy_path, 'rb') as f:
  policy_dict = pickle.load(f)

baselines = {}
for baseline, path in baseline_paths.items():
  with open(path, 'rb') as f:
    baseline_data = pickle.load(f)
  baselines[baseline] = baseline_data

# Whitelist of policy ids.
# Used to analyze results on a subset of the policies.
# If None, analysis will be over all policies in policy_dict.
whitelist = None


## Make plots

In [None]:
OUTPUT_ROOT = '/tmp/figures'
!mkdir {OUTPUT_ROOT}
output_pattern = os.path.join(OUTPUT_ROOT, '{}_{}.pdf')

In [None]:
#@title Scatterplots
# setup
if BENCHMARK == 'rlunplugged':
  order = ['MB-AR', 'FQE-D', 'FQE-L2', 'MB-FF', 'DR', 'VPM', 'IS', 'DICE']
  tasks_per_page = 9
elif BENCHMARK == 'd4rl':
  order = ['IS', 'DR', 'FQE-L2', 'VPM', 'DICE']
  tasks_per_page = 7

scatter_data = make_scatter_data(policy_dict,
                                 gt_data,
                                 baselines,
                                 whitelist=whitelist)

# Make scatter plot with pagination
tasks = list(np.unique(scatter_data['task_name']))
task_pages = list(chunks(tasks, tasks_per_page))
for page_i, page in enumerate(task_pages):
  mask = [t in page for t in scatter_data['task_name']]
  scatter_copy = scatter_data[mask].copy()
  p = plot_scatter(scatter_copy, order=order)
  print(p)
  output_path = output_pattern.format(BENCHMARK, 'scatter_{:02}'.format(page_i))
  saveplot(p, output_path)

In [None]:
#@title Absolute Error
make_all_metric_plots(benchmark=BENCHMARK,
                      metric_type='absolute_error',
                      policy_dict=policy_dict,
                      gt_data=gt_data,
                      baselines=baselines,
                      whitelist=whitelist)

In [None]:
#@title Rank Correlation
make_all_metric_plots(benchmark=BENCHMARK,
                      metric_type='rank_correlation',
                      policy_dict=policy_dict,
                      gt_data=gt_data,
                      baselines=baselines,
                      whitelist=whitelist)

In [None]:
#@title Normalized Regret
make_all_metric_plots(benchmark=BENCHMARK,
                      metric_type='normalized_regret',
                      policy_dict=policy_dict,
                      gt_data=gt_data,
                      baselines=baselines,
                      whitelist=whitelist)