In [None]:
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from seaborn import set_theme
set_theme()

In [None]:
# Uncomment to save fig

# matplotlib.use("pgf")
# matplotlib.rcParams.update({
#     "pgf.texsystem": "pdflatex",
#     'font.family': 'serif',
#     'text.usetex': True,
#     'pgf.rcfonts': False,
# })

In [None]:
MODEL_FREE_PATH = "model_free"
MODEL_BASED_PATH = "model_based"
MODEL_FREE_TYPES = ["reward", "steps", "actor", "critic"]
MODEL_BASED_TYPES = MODEL_FREE_TYPES + ["model_p", "model_t", "model_r", "model_total"]

In [None]:
mf_dfs = {
    type_: np.array([pd.read_csv(os.path.join(MODEL_FREE_PATH, "mf_{}_{}.csv".format(type_, i))).iloc[:,1:] for i in [1,2,3]]) for type_ in MODEL_FREE_TYPES
}
mb_dfs = {
    type_: np.array([pd.read_csv(os.path.join(MODEL_BASED_PATH, "mb_{}_{}.csv".format(type_, i))).iloc[:,1:] for i in [1,2,3]]) for type_ in MODEL_BASED_TYPES
}

In [None]:
LABELS = {
    "reward": ("Millon Steps", "Return", "Average Return"),
    "steps": ("Millon Steps", "Number of Steps", "Average Episode Duration"),
    "actor": ("Millon Steps", "Loss", "Actor Training Loss"),
    "critic": ("Millon Steps", "Loss", "Critic Training Loss"),
    "model_p": ("Millon Steps", "Loss", "Model Dynamics Training Loss"),
    "model_t": ("Millon Steps", "Loss", "Model Termination Training Loss"),
    "model_r": ("Millon Steps", "Loss", "Model Reward Training Loss"),
    "model_total": ("Millon Steps", "Loss", "Model Total Training Loss Function")
    }
NB_SIGMA = 1
DOWN_SAMPLING = 25
NB_RUNS = 3
FIG_SIZE = (10,5)
FONT_SIZE = 8
MAX_STEPS = 1000000


plt.figure(figsize=FIG_SIZE);
for i, type_ in enumerate(MODEL_FREE_TYPES):
    nb_sigma = NB_SIGMA
    down_sampling = DOWN_SAMPLING

    plt.subplot(2,4,i+1);

    mf_y_mean = mf_dfs[type_][:,:,1].mean(axis=0)
    mf_y_mean = mf_y_mean.reshape(-1, down_sampling).mean(axis=1)
    mf_y_std = mf_dfs[type_][:,:,1].transpose().reshape(-1,NB_RUNS*down_sampling).std(axis=1)
    mf_x = mf_dfs[type_][0,(down_sampling-1)//2::down_sampling,0]/MAX_STEPS
    plt.plot(mf_x, mf_y_mean, label="Model-Free");
    plt.fill_between(mf_x, mf_y_mean - nb_sigma*mf_y_std, mf_y_mean + nb_sigma*mf_y_std, alpha=0.5);

    mb_y_mean = mb_dfs[type_][:,:,1].mean(axis=0)
    mb_y_mean = mb_y_mean.reshape(-1, down_sampling).mean(axis=1)
    mb_y_std = mb_dfs[type_][:,:,1].transpose().reshape(-1,NB_RUNS*down_sampling).std(axis=1)

    mb_x = mb_dfs[type_][0,(down_sampling-1)//2::down_sampling,0]/MAX_STEPS
    plt.plot(mb_x, mb_y_mean, label="Model-Based");
    plt.fill_between(mb_x, mb_y_mean - nb_sigma*mb_y_std, mb_y_mean + nb_sigma*mb_y_std, alpha=0.5);

    xlab, ylab, title = LABELS[type_]
    plt.xlabel(xlab, fontsize=FONT_SIZE);
    plt.ylabel(ylab, fontsize=FONT_SIZE, labelpad=0);
    plt.title(title, fontsize=FONT_SIZE);
    plt.grid(False);
    if i == 0:
        plt.legend(fontsize=FONT_SIZE);
    plt.xticks(fontsize=FONT_SIZE-2);
    plt.yticks(fontsize=FONT_SIZE-2);

for i, type_ in enumerate([type_ for type_ in MODEL_BASED_TYPES if type_ not in MODEL_FREE_TYPES]):
    nb_sigma = NB_SIGMA
    down_sampling = DOWN_SAMPLING

    plt.subplot(2,4,i+5);

    mb_y_mean = mb_dfs[type_][:,:,1].mean(axis=0)
    mb_y_mean = mb_y_mean.reshape(-1, down_sampling).mean(axis=1)
    mb_y_std = mb_dfs[type_][:,:,1].transpose().reshape(-1,NB_RUNS*down_sampling).std(axis=1)
    mb_x = mb_dfs[type_][0,(down_sampling-1)//2::down_sampling,0]/MAX_STEPS
    plt.plot(mb_x, mb_y_mean, label="Model-Based");
    plt.fill_between(mb_x, mb_y_mean - nb_sigma*mb_y_std, mb_y_mean + nb_sigma*mb_y_std, alpha=0.3);

    xlab, ylab, title = LABELS[type_]
    plt.xlabel(xlab, fontsize=FONT_SIZE);
    plt.ylabel(ylab, fontsize=FONT_SIZE, labelpad=0);
    plt.title(title, fontsize=FONT_SIZE);
    if i==0:
        plt.legend(fontsize=FONT_SIZE, loc="upper left");
    plt.grid(False);
    plt.xticks(fontsize=FONT_SIZE-2);
    plt.yticks(fontsize=FONT_SIZE-2);
plt.subplots_adjust(hspace=0.5, wspace=0.25);
# plt.tight_layout(pad=0, h_pad=1.5, w_pad=0.5, rect=None);
plt.show();
# with plt.style.context(['classic', {'text.usetex': True}]):
#     plt.savefig("plot_ddpg_training.pgf");