In [None]:
import pandas as pd
import json
import matplotlib.pyplot as plt


CB_color_cycle = ['#377eb8', '#ff7f00', '#4daf4a',
                  '#f781bf', '#a65628', '#984ea3',
                  '#999999', '#e41a1c', '#dede00']


def load_unique_data(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line.strip()))  
    df = pd.DataFrame(data).sort_values(by="episode")  # Sort by episode
    df = df.drop_duplicates(subset="episode", keep="last")  # Keep last occurrence if duplicates
    return df

# TODO
# File paths for the SRL methods
files = {
    "CPC": "cpc.txt",
    "DBC": "dbc.txt",
    "DeepMDP": "deepmdp.txt",
    "RAD": "rad.txt",
    # "CURL": "curl.txt"
}

# Load data for each method with unique episodes
srl_data = {key: load_unique_data(path) for key, path in files.items()}

# Plotting the data
fig, ax = plt.subplots(figsize=(10, 6))
for (key, df), color in zip(srl_data.items(), CB_color_cycle):
    ax.plot(df["episode"], df["episode_reward"], label=key, color=color)

# Set plot labels and title
ax.set_xlabel("Episode")
ax.set_ylabel("Reward")
ax.set_title("Cheetah Performance on Various SRL Methods")
ax.legend()
plt.grid(True)

# Display the plot
plt.show()