In [None]:
import numpy as np
import gym
import matplotlib.pyplot as plt
import pickle
import glob

import multiworld
multiworld.register_pygame_envs()

%load_ext autoreload
%autoreload 2

In [None]:
import os
import csv
import seaborn as sns
import glob
import itertools

from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
# rc('text', usetex=True)

#sns.set_style("whitegrid")
sns.set_style("darkgrid", {"axes.facecolor": "0.9"})
sns.despine(left=True)

def smooth(scalars, weight):  # Weight between 0 and 1
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)                        # Save it
        last = smoothed_val                                  # Anchor the last smoothed value
    return smoothed

def plot(xs, ys, legend_names, file_name, xlim=None, ylim=(-0.01, 1.01),
         xlabel="Iterations", ylabel="Success", title="", smoothing_weight=0.9):
    colors = sns.color_palette()
    colors.pop(-2)
    colors.pop(1)
    colors.insert(-1, (242/255, 121/255, 53/255))
    sns.palplot(colors)
    palette = itertools.cycle(colors)
    fontsize = 20
    
    plt.figure(figsize=(8, 6), dpi=160)
    for i, (x, y) in enumerate(zip(xs, ys)):
        c = next(palette)
        lw = 2
        if i <= 1:
            lw = 2
        else:
            lw = 1.5
#         plt.plot(x, y, linewidth=lw, color=c)        
        mean, std = np.mean(y, axis=1), np.std(y, axis=1)
        below, above = mean - std, mean + std

        plt.plot(x, smooth(mean, smoothing_weight), linewidth=lw, color=c)
        plt.fill_between(x, smooth(below, smoothing_weight), smooth(above, smoothing_weight), color=c, alpha=0.1)

    plt.title(title, fontsize=fontsize)
    plt.xlabel(xlabel, fontsize=fontsize)
    plt.ylabel(ylabel, fontsize=fontsize)
    plt.xlim(xlim)
    plt.ylim(ylim)
    if legend_names:
        legend = plt.legend(legend_names,
                   frameon=True, facecolor='w', framealpha=0.9, shadow=True, fontsize=fontsize-5,
    #                loc='upper center', ncol=4) # right legend
    #                loc='center left', bbox_to_anchor=(1, 0.5), ncol=1) # right legend
                   loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3) # bottom legend
        legend.get_frame().set_linewidth(1.5)
        legend.get_frame().set_edgecolor("black")
#     plt.savefig("./test.pdf", bbox_inches='tight')
    plt.show()

In [None]:
import collections
def get_keys_data(csv_filename, data_keys):
    data = collections.defaultdict(list)
    if not os.path.exists(csv_filename):
        return data
    with open(prog_csv, "r") as csvfile:
        reader = csv.reader(csvfile)
        key_to_idx = {}
        for i, row in enumerate(reader):
            if i == 0:
                for key in data_keys:
                    key_to_idx[key] = row.index(key)
            else:
                for key in data_keys:
                    if key == "epoch":
                        data[key].append(int(row[key_to_idx[key]]))
                    else:
                        data[key].append(float(row[key_to_idx[key]]))
    return {
        key: np.array(val)
        for key, val in data.items()
    }

In [None]:
# exp_dirs = [
#     "/home/justinvyu/doodad-logs/20-12-16-gc-explore-SAC-env=Point2DRooms-v0-goal-generator=EmpowermentGoalGenerator-schedule=linear-50000-0.1-1",
#     "/home/justinvyu/doodad-logs/20-12-16-gc-explore-SAC-env=Point2DRooms-v0-goal-generator=OracleSkewFitGoalGenerator-schedule=linear-50000-0.1-1",
#     "/home/justinvyu/doodad-logs/20-12-16-gc-explore-SAC-env=Point2DRooms-v0-goal-generator=RandomGoalGenerator-schedule=linear-50000-0.1-1",
#     "/home/justinvyu/doodad-logs/20-12-16-sac-env=Point2DRooms-v0-count-bonus=True",
#     "/home/justinvyu/doodad-logs/20-12-14-sac-env=Point2DRooms-v0-count-bonus=False",
# ]

# exp_dirs = [
#     "/home/justinvyu/doodad-logs/20-12-16-gc-explore-SAC-env=Point2DRoomsLarge-v0-goal-generator=EmpowermentGoalGenerator-schedule=linear-50000-0.1-1",
#     "/home/justinvyu/doodad-logs/20-12-16-gc-explore-SAC-env=Point2DRoomsLarge-v0-goal-generator=OracleSkewFitGoalGenerator-schedule=linear-50000-0.1-1",
#     "/home/justinvyu/doodad-logs/20-12-16-gc-explore-SAC-env=Point2DRoomsLarge-v0-goal-generator=RandomGoalGenerator-schedule=linear-50000-0.1-1",
#     "/home/justinvyu/doodad-logs/20-12-16-sac-env=Point2DRoomsLarge-v0-count-bonus=True",
#     "/home/justinvyu/doodad-logs/20-12-16-sac-env=Point2DRoomsLarge-v0-count-bonus=False", 
# ]

exp_dirs = [
    "/home/justinvyu/doodad-logs/teleport_sac/dir1",
    "/home/justinvyu/doodad-logs/teleport_sac/dir2",
    "/home/justinvyu/doodad-logs/teleport_sac/dir3"
]

x_stream = []
y_stream = []

for exp_dir in exp_dirs:
    seed_dirs = list(glob.iglob(os.path.join(exp_dir, "*")))

    keys = [
        "epoch",
        "eval/Final distance_to_target Mean",
    ]

    xs = []
    ys = []
    for seed_dir in seed_dirs:
        prog_csv = os.path.join(seed_dir, "progress.csv")
        keys_data = get_keys_data(prog_csv, keys)
        if keys_data:
            xs.append(keys_data["epoch"])
            ys.append(keys_data["eval/Final distance_to_target Mean"])
    
    min_len = min([len(x) for x in xs])
    epochs = min(xs, key=lambda l: len(l))
    dist = np.concatenate([
        y[:min_len].reshape(-1, 1) for y in ys
    ], axis=1)
    x_stream.append(epochs)
    y_stream.append(dist)

In [None]:
plot(
    x_stream,
    y_stream,
    ["Empowerment", "SkewFit", "Random"],
    "test",
    ylim=(-0.2, 10),
    xlim=(-5, 500),
    xlabel="Iteration (1000 steps each)",
    ylabel="Final Distance to Goal",
    title="S-Maze"
)