In [1]:
import glob
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np 
import argparse
from collections import namedtuple

sns.set()
colors = sns.color_palette() 
plt.ioff() # no interactive so we can savefig 

In [2]:
Args = namedtuple('Args', ['fig_name', 'ylabel', 'x_max']) # global for figure 
AlgArgs = namedtuple('AlgArgs', ['file_glob', 'label']) # for each algorithm 

### Define args for figure and algorithms here, then execute the cell below

In [3]:
def plot_save(args, alg_args, show=False):
    for i, alg_arg in enumerate(alg_args): 
        csvs = glob.glob(alg_arg.file_glob)
        dfs = [pd.read_csv(csv_name).Value for csv_name in csvs]

        X_MIN = 0
        X_MAX = args.x_max 

        concat_df = pd.concat(dfs, axis=1)
        agg_df = concat_df.agg([np.mean, np.std, np.quantile], axis=1)
        num_points = len(agg_df['mean'])
        xaxis = np.linspace(X_MIN, X_MAX, num_points)

        concat_df['steps'] = xaxis
        concat_df = concat_df.set_index('steps')
        concat_df = concat_df.rename(columns={'Value': alg_arg.label})
        axes = sns.lineplot(data=concat_df, palette=[colors[i]])

    plt.xlabel('Steps')
    plt.ylabel(args.ylabel)

    if args.fig_name is not None and show==False:
        plt.savefig(args.fig_name, bbox_inches='tight')
    else:
        plt.show() 

In [13]:
args = Args(fig_name="battle_won_mean.pdf", 
         ylabel='Battle won %', 
         x_max=int(1e6)) 
alg_args = [] 
alg_args.append(AlgArgs(file_glob="marl/battle_won_mean/*", label="RODE"))
alg_args.append(AlgArgs(file_glob="mardoc1/battle_won_mean/*", label="MARDOC1"))
alg_args.append(AlgArgs(file_glob="mardoc2/battle_won_mean/*", label="MARDOC2"))
alg_args.append(AlgArgs(file_glob="mardoc3/battle_won_mean/*", label="MARDOC3"))

plot_save(args, alg_args)
plt.clf() 

In [9]:
args = Args(fig_name="dead_allies_mean.pdf", 
         ylabel='# Dead Allies', 
         x_max=int(1e6)) 
alg_args = [] 
alg_args.append(AlgArgs(file_glob="marl/dead_allies_mean/*", label="RODE"))
alg_args.append(AlgArgs(file_glob="mardoc1/dead_allies_mean/*", label="MARDOC1"))
alg_args.append(AlgArgs(file_glob="mardoc2/dead_allies_mean/*", label="MARDOC2"))
alg_args.append(AlgArgs(file_glob="mardoc3/dead_allies_mean/*", label="MARDOC3"))

plot_save(args, alg_args)
plt.clf() 

In [10]:
args = Args(fig_name="dead_enemies_mean.pdf", 
         ylabel='# Dead Enemies', 
         x_max=int(1e6)) 
alg_args = [] 
alg_args.append(AlgArgs(file_glob="marl/dead_enemies_mean/*", label="RODE"))
alg_args.append(AlgArgs(file_glob="mardoc1/dead_enemies_mean/*", label="MARDOC1"))
alg_args.append(AlgArgs(file_glob="mardoc2/dead_enemies_mean/*", label="MARDOC2"))
alg_args.append(AlgArgs(file_glob="mardoc3/dead_enemies_mean/*", label="MARDOC3"))

plot_save(args, alg_args)
plt.clf() 

In [11]:
args = Args(fig_name="ep_length_mean.pdf", 
         ylabel='Episode Length', 
         x_max=int(1e6)) 
alg_args = [] 
alg_args.append(AlgArgs(file_glob="marl/ep_length_mean/*", label="RODE"))
alg_args.append(AlgArgs(file_glob="mardoc1/ep_length_mean/*", label="MARDOC1"))
alg_args.append(AlgArgs(file_glob="mardoc2/ep_length_mean/*", label="MARDOC2"))
alg_args.append(AlgArgs(file_glob="mardoc3/ep_length_mean/*", label="MARDOC3"))

plot_save(args, alg_args)
plt.clf() 