In [16]:
import pandas as pd
import json
import os
import os.path as osp
import numpy as np
from collections import OrderedDict
import matplotlib.pyplot as plt

In [17]:
# Global vars for tracking and labeling data at load time.
DIV_LINE_WIDTH = 50
ROW_ORDER = ["PPOL-vanilla", "PPOL-random", "SA-PPOL", "SA-PPOL(MC)", "SA-PPOL(MR)", "ADV-PPOL(MC)", "ADV-PPOL(MR)"]
REPLACE_NAME = {"vanilla": "PPOL-vanilla", 
                "uniform": "PPOL-random",
                "kl": "SA-PPOL",
                "klmc": "SA-PPOL(MC)",
                "klmr": "SA-PPOL(MR)",
                "cost": "ADV-PPOL(MC)",
                "reward": "ADV-PPOL(MR)",
                }
COL_ORDER = ["NaturalReward", "NaturalCost",
             "AdvUniformReward", "AdvUniformCost",
             "AdvMadReward", "AdvMadCost", 
             "AdvAmadReward", "AdvAmadCost", 
             "AdvMaxCostReward", "AdvMaxCostCost", 
             "AdvMaxRewardReward", "AdvMaxRewardCost", 
             "AverageReward", "AverageCost"]
ENV_NAMES = {"SafetyCarCircle-v0": "Car-Circle",
             "SafetyAntRun-v0": "Ant-Run",
             "SafetyAntCircle-v0": "Ant-Circle",
             "SafetyCarRun-v0": "Car-Run",
             "SafetyDroneCircle-v0": "Drone-Circle",
             "SafetyDroneRun-v0": "Drone-Run"}

# change this
NAME = "ppo"
ENV = "DroneRun"
TABLE_NAME = "eval"

def get_datasets(logdir, data):
    """
    Recursively look through logdir for output files produced by
    spinup.logx.Logger. 

    Assumes that any file "progress.txt" is a valid hit. 
    """
    for root, _, files in os.walk(logdir):
        if 'progress.txt' in files:
            exp_name = None
            env = None
            try:
                config_path = open(os.path.join(root, 'config.json'))
                config = json.load(config_path)
                if TABLE_NAME not in config["data_dir"]:
                    continue
                if NAME not in config["data_dir"]:
                    continue
                if ENV not in config["data_dir"]:
                    continue
                env = config["env_cfg"]["env_name"]
                exp_name = config["exp_name"]
                exp_name = exp_name.split('_')[-1]
                if exp_name not in REPLACE_NAME.keys():
                    continue
                exp_name = REPLACE_NAME[exp_name]

                if env not in list(data.keys()):
                    data[env] = OrderedDict()
                if exp_name not in list(data[env].keys()):
                    data[env][exp_name] = OrderedDict()
                print(root)
            except:
                print('No file named config.json')
            try:
                exp_data = pd.read_table(os.path.join(root, 'progress.txt'))
                exp_data = exp_data.rename(columns=lambda x: x.split("/")[-1])
            except:
                print('Could not read from %s' % os.path.join(root, 'progress.txt'))
                continue
            # Score, NoiseScale, Time
            for (column_name, column_data) in exp_data.items():
                if "Reward" in column_name or "Cost" in column_name:
                    if column_name not in data[env][exp_name].keys():
                        data[env][exp_name][column_name] = []
                    data[env][exp_name][column_name] += list(column_data.values)
            data[env][exp_name] = OrderedDict(sorted(data[env][exp_name].items(), key=lambda i:COL_ORDER.index(i[0])))
    for env in data.keys():
        for exp in data[env].keys():
            for r in data[env][exp].keys():
                mean = round(np.mean(data[env][exp][r]), 2)
                std = round(np.std(data[env][exp][r]), 2)
                data[env][exp][r] = str(mean) + ", " + str(std)
    return data


def get_all_datasets(all_logdirs, exp_data):
    """
    For every entry in all_logdirs,
        1) check if the entry is a real directory and if it is, 
           pull data from it; 

        2) if not, check to see if the entry is a prefix for a 
           real directory, and pull data from that.
    """
    logdirs = []
    for logdir in all_logdirs:
        if osp.isdir(logdir) and logdir[-1] == os.sep:
            logdirs += [logdir]
        else:
            basedir = osp.dirname(logdir)
            fulldir = lambda x: osp.join(basedir, x)
            prefix = logdir.split(os.sep)[-1]
            listdir = os.listdir(basedir)
            logdirs += sorted([fulldir(x) for x in listdir if prefix in x])

    # Verify logdirs
    print('Plotting from...\n' + '=' * DIV_LINE_WIDTH + '\n')
    for logdir in logdirs:
        print(logdir)
    print('\n' + '=' * DIV_LINE_WIDTH)

    # Load data from logdirs
    for log in logdirs:
        exp_data = get_datasets(log, exp_data)
    return exp_data


In [None]:
logdir = [
    "data/"
    ]

print("="*DIV_LINE_WIDTH)
print("processing csv file")
exp_data = OrderedDict()
get_all_datasets(logdir, exp_data)

env_names = exp_data.keys()
for env in env_names:
    save_dir = osp.join(logdir[0], str(env)+"_"+TABLE_NAME+".csv")
    exp_data[env] = OrderedDict(sorted(exp_data[env].items(), key=lambda i:ROW_ORDER.index(i[0])))
    row_names = list(exp_data[env].keys())
    column_names = list(exp_data[env][row_names[0]].keys())
    exp_data_pd = pd.DataFrame.from_dict(exp_data[env], orient='index', columns=column_names)
    exp_data_pd.to_csv(save_dir)