In [1]:
import pandas as pd
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import seaborn as sns
import os

plt.rcParams['text.usetex'] = True

In [2]:
def read_tsv(env_name, method, type_, ext="csv"):
    """
    e.g.
    env_name = "cartpole"
    method = "koopman", "rnn", "reflex"
    type_ = "train", "val1", "val2", "test1", "test2"
    """
    result = {
        "epochs": np.arange(26) * 10,
        "im_loss": [],
        "id_loss": []
    }
    
    for dir_ in glob(f"losses/{method}-{env_name}-*/"):
        fnames = glob(f"{dir_}/{type_}.{ext}")
        if len(fnames) > 0:
            for fname in fnames:
                df = pd.read_csv(fname, sep="\t")
                result["im_loss"].append(list(df["im_loss"].to_numpy()))
                result["id_loss"].append(list(df["id_loss"].to_numpy()))
        else:
            ext = "tsv"
            for fname in glob(f"{dir_}/{type_}.{ext}"):
                df = pd.read_csv(fname, sep="\t")
                result["im_loss"].append(list(df["im_loss"].to_numpy()))
                result["id_loss"].append(list(df["id_loss"].to_numpy()))
    
    result["im_loss"] = np.array(result["im_loss"])
    result["id_loss"] = np.array(result["id_loss"])
    
    return result

In [4]:
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

def plot(env_name, ylim, title=None):    
    method_names = ["KCPO", "RNN", "ReflexNet", "RiccatiNet"]

    type_ = "test2"
    results = []
    for method in ["koopman", "rnn", "reflex", "cyin"]:
        print(env_name, type_, method)
        result = read_tsv(env_name, method, type_)
        results.append(result["im_loss"][:, -1])
    in_df = pd.DataFrame(np.array(results).T, columns=method_names)

    type_ = "test1"
    results = []
    for method in ["koopman", "rnn", "reflex", "cyin"]:
        print(env_name, type_, method)
        result = read_tsv(env_name, method, type_)
        results.append(result["im_loss"][:, -1])
    out_df = pd.DataFrame(np.array(results).T, columns=method_names)

    # Create the figure
    fig, (left, right) = plt.subplots(1, 2, figsize=(8, 5), dpi=300)

    if title:
        fig.suptitle(title, fontsize=30)

    left.set_title("In-Distribution Constraints", size=17)
    left.set_yscale("log")
    left.set_ylim(ylim)
    pal = sns.color_palette("Set1", desat=0.4)
    bp = sns.stripplot(data=in_df, size=6, ax=left, palette=pal)
    left.tick_params(axis='y', labelsize=13)
#     left.tick_params(axis='x', labelsize=20)
#     left.set_xticklabels(method_names, rotation=45, ha="right")
    bp.set(yscale='log')
    sns.boxplot(showmeans=True,
    meanline=True,
    meanprops={'color': 'k', 'ls': '-', 'lw': 2},
    medianprops={'visible': False},
    whiskerprops={'visible': False},
    zorder=10,
    data=in_df,
    showfliers=False,
    showbox=False,
    showcaps=False,
    ax=bp)
    left.set_xticks([])
    left.set_xticklabels([])

#     right.set_title("OOD Constraints", size=17, weight="bold")
    right.set_title("OOD Constraints", size=17)
    right.set_yscale("log")
    right.set_ylim(ylim)

    pal = sns.color_palette("Set1")
    bp = sns.stripplot(data=out_df, size=6, ax=right, palette=pal)
    right.tick_params(axis='y', labelsize=13)
#     right.tick_params(axis='x', labelsize=20)
#     right.set_xticklabels(method_names, rotation=45, ha="right")
    bp.set(yscale='log')
    sns.boxplot(showmeans=True,
    meanline=True,
    meanprops={'color': 'k', 'ls': '-', 'lw': 2},
    medianprops={'visible': False},
    whiskerprops={'visible': False},
    zorder=10,
    data=out_df,
    showfliers=False,
    showbox=False,
    showcaps=False,
    ax=bp)
    right.set_xticks([])
    right.set_xticklabels([])

    plt.tight_layout()

    plt.savefig(f"./loss_figures/{env_name}_imitation_loss.pdf", transparent=True, bbox_inches='tight', pad_inches=0)

    plt.show()

In [8]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.offline as pyo

def plotly_plot(env_name, ylim, title=None):    
    method_names = ["KCPO", "RNN", "ReflexNet", "RiccatiNet"]
#     colors = ['blue', 'green', 'red', 'purple']  # Choose your own colors
    colors = px.colors.qualitative.Set1

    type_ = "test2"
    results = []
    for method in ["koopman", "rnn", "reflex", "cyin"]:
        print(env_name, type_, method)
        result = read_tsv(env_name, method, type_)
        results.append(result["im_loss"][:, -1])
    in_df = pd.DataFrame(np.array(results).T, columns=method_names)

    type_ = "test1"
    results = []
    for method in ["koopman", "rnn", "reflex", "cyin"]:
        print(env_name, type_, method)
        result = read_tsv(env_name, method, type_)
        results.append(result["im_loss"][:, -1])
    out_df = pd.DataFrame(np.array(results).T, columns=method_names)

    # Create the figure
    fig = make_subplots(rows=1, cols=2, subplot_titles=('In-Distribution Constraints', 'OOD Constraints'))

    for i, method in enumerate(method_names):
        # Add strip plot (swarm plot) to the figure
        fig.add_trace(go.Box(y=in_df[method], 
                             name=method, 
                             boxpoints='all',
                             jitter=0.3, 
                             pointpos=-1.8, 
                             marker_color=colors[i], 
                             line_color=colors[i], 
                             showlegend=False), 
                      row=1, col=1)
        
        fig.add_trace(go.Box(y=out_df[method], 
                             name=method, 
                             boxpoints='all', 
                             jitter=0.3, 
                             pointpos=-1.8, 
                             marker_color=colors[i], 
                             line_color=colors[i], 
                             showlegend=True), 
                      row=1, col=2)

#         # Add box plot to the figure
#         fig.add_trace(go.Box(y=in_df[method], 
#                              name=method, 
#                              marker_color=colors[i], 
#                              line_color=colors[i], 
#                              boxpoints=False,
#                              showlegend=True), 
#                       row=1, col=1)
        
#         fig.add_trace(go.Box(y=out_df[method], 
#                              name=method, 
#                              marker_color=colors[i], 
#                              line_color=colors[i], 
#                              boxpoints=False,
#                              showlegend=False), 
#                       row=1, col=2)

    # Update yaxis properties
    fig.update_yaxes(type="log", range=[np.log10(ylim[0]), np.log10(ylim[1])], row=1, col=1)
    fig.update_yaxes(type="log", range=[np.log10(ylim[0]), np.log10(ylim[1])], row=1, col=2)
    
    # Update xaxis properties
    fig.update_xaxes(showticklabels=False, row=1, col=1)
    fig.update_xaxes(showticklabels=False, row=1, col=2)

    # Update layout
    fig.update_layout(height=500, width=700, title={'text': title, 'x': 0.5, 'y': 0.9, 'font': {'size': 24}})

    # Show figure
    fig.show()
    
    div_str = pyo.plot(fig, output_type="div", include_plotlyjs=False)
    print(div_str)

In [9]:
env_name = "diffdrive"
ylim = [10**-2,10**3]
# ylim = [10**-4,10**3]

# plot(env_name, ylim, title="Differential Drive")
plotly_plot(env_name, ylim, title="Differential Drive")

diffdrive test2 koopman
diffdrive test2 rnn
diffdrive test2 reflex
diffdrive test2 cyin
diffdrive test1 koopman
diffdrive test1 rnn
diffdrive test1 reflex
diffdrive test1 cyin


<div>                            <div id="23639592-bb5a-4d2b-8697-0c34d5c0f802" class="plotly-graph-div" style="height:500px; width:700px;"></div>            <script type="text/javascript">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById("23639592-bb5a-4d2b-8697-0c34d5c0f802")) {                    Plotly.newPlot(                        "23639592-bb5a-4d2b-8697-0c34d5c0f802",                        [{"boxpoints":"all","jitter":0.3,"line":{"color":"rgb(228,26,28)"},"marker":{"color":"rgb(228,26,28)"},"name":"KCPO","pointpos":-1.8,"showlegend":false,"xaxis":"x","y":[5.1824985,27.51522,4.574925,2.0105896,11.269353,35.38388,74.78879,2.9682274,7.461791,2.9172785],"yaxis":"y","type":"box"},{"boxpoints":"all","jitter":0.3,"line":{"color":"rgb(228,26,28)"},"marker":{"color":"rgb(228,26,28)"},"name":"KCPO","pointpos":-1.8,"showlegend":true,"xaxis":"x2","y":[5.1824985,27.51522,4.574925,2.0105896,11.269353,

In [11]:
env_name = "pendulum"
ylim = [10**-4,10**-1]
# ylim = [10**-4,10**3]
# plot(env_name, ylim, title="Simple Pendulum
plotly_plot(env_name, ylim, title="Simple Pendulum")

pendulum test2 koopman
pendulum test2 rnn
pendulum test2 reflex
pendulum test2 cyin
pendulum test1 koopman
pendulum test1 rnn
pendulum test1 reflex
pendulum test1 cyin


<div>                            <div id="d2db3f5a-2b7a-49f3-9f4e-58a485665a12" class="plotly-graph-div" style="height:500px; width:700px;"></div>            <script type="text/javascript">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById("d2db3f5a-2b7a-49f3-9f4e-58a485665a12")) {                    Plotly.newPlot(                        "d2db3f5a-2b7a-49f3-9f4e-58a485665a12",                        [{"boxpoints":"all","jitter":0.3,"line":{"color":"rgb(228,26,28)"},"marker":{"color":"rgb(228,26,28)"},"name":"KCPO","pointpos":-1.8,"showlegend":false,"xaxis":"x","y":[0.00091132795,0.0006001711,0.0008467387,0.0004767382,0.00075143966,0.00028767914,0.0020105678,0.0005978285,0.0009446145,0.0004741274],"yaxis":"y","type":"box"},{"boxpoints":"all","jitter":0.3,"line":{"color":"rgb(228,26,28)"},"marker":{"color":"rgb(228,26,28)"},"name":"KCPO","pointpos":-1.8,"showlegend":true,"xaxis":"x2","y":[0.01191186

In [12]:
env_name = "cartpole"
ylim = [10**-2,10**1]
# ylim = [10**-4,10**3]
# plot(env_name, ylim, title="Cartpole Swing-Up")
plotly_plot(env_name, ylim, title="Cartpole Swing-Up")

cartpole test2 koopman
cartpole test2 rnn
cartpole test2 reflex
cartpole test2 cyin
cartpole test1 koopman
cartpole test1 rnn
cartpole test1 reflex
cartpole test1 cyin


<div>                            <div id="02f3c4bf-f767-4fbf-b767-b371f06e91e2" class="plotly-graph-div" style="height:500px; width:700px;"></div>            <script type="text/javascript">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById("02f3c4bf-f767-4fbf-b767-b371f06e91e2")) {                    Plotly.newPlot(                        "02f3c4bf-f767-4fbf-b767-b371f06e91e2",                        [{"boxpoints":"all","jitter":0.3,"line":{"color":"rgb(228,26,28)"},"marker":{"color":"rgb(228,26,28)"},"name":"KCPO","pointpos":-1.8,"showlegend":false,"xaxis":"x","y":[0.042085662,0.053336166,0.03941295,0.026288336,0.05894171,0.039778654,0.04677989,0.07750407,0.03207238,0.038071956],"yaxis":"y","type":"box"},{"boxpoints":"all","jitter":0.3,"line":{"color":"rgb(228,26,28)"},"marker":{"color":"rgb(228,26,28)"},"name":"KCPO","pointpos":-1.8,"showlegend":true,"xaxis":"x2","y":[1.3447948,1.1782117,1.481545

In [14]:
env_name = "reacher"
ylim = [10**-3,10**-1]
# ylim = [10**-4,10**3]
# plot(env_name, ylim, title="Reacher")
plotly_plot(env_name, ylim, title="Reacher")

reacher test2 koopman
reacher test2 rnn
reacher test2 reflex
reacher test2 cyin
reacher test1 koopman
reacher test1 rnn
reacher test1 reflex
reacher test1 cyin


<div>                            <div id="8d1c6d01-0322-4e79-b210-18c1393bf7d3" class="plotly-graph-div" style="height:500px; width:700px;"></div>            <script type="text/javascript">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById("8d1c6d01-0322-4e79-b210-18c1393bf7d3")) {                    Plotly.newPlot(                        "8d1c6d01-0322-4e79-b210-18c1393bf7d3",                        [{"boxpoints":"all","jitter":0.3,"line":{"color":"rgb(228,26,28)"},"marker":{"color":"rgb(228,26,28)"},"name":"KCPO","pointpos":-1.8,"showlegend":false,"xaxis":"x","y":[0.012164485,0.00879578,0.010828233,0.010574724,0.012094905,0.005030518,0.010765644,0.010893947,0.022893587,0.010805615],"yaxis":"y","type":"box"},{"boxpoints":"all","jitter":0.3,"line":{"color":"rgb(228,26,28)"},"marker":{"color":"rgb(228,26,28)"},"name":"KCPO","pointpos":-1.8,"showlegend":true,"xaxis":"x2","y":[0.0056188335,0.005644042