In [None]:
import wandb 
import pandas as pd 
import matplotlib.pyplot as plt
import numpy as np

# get metrics from run online 
api = wandb.Api()
run = api.run("tlips/lerobot/rynf40po")
run = api.run("tlips/lerobot/7dhx23nf")

# get the data from the run

data = run.history(samples=1000)
data.columns

In [None]:
# count numer of non-nan entries
data.count()

In [None]:
train_loss = data["train/loss"]
val_loss = data["validation/val_loss"]
train_steps = data["train/step"]
val_steps = data["validation/step"]
eval_steps = data["eval/step"]
eval_success_rate = data["eval/pc_success"]

# filter out nan values
train_loss = train_loss.dropna()
val_loss = val_loss.dropna()
train_steps = train_steps.dropna()
val_steps = val_steps.dropna()
eval_steps = eval_steps.dropna()
eval_success_rate = eval_success_rate.dropna()

print(eval_steps.count())
print(eval_success_rate.count())


In [None]:
# plot train and validation loss on left y axis and success rate on right y axis
fig, ax1 = plt.subplots()

ax1.set_xlabel('steps')
ax1.set_ylabel('loss')
ax1.set_ylim(0, 1)
# format x axis ticks as N k 
ax1.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{int(x/1000):,}k'))

ax1.plot(train_steps, train_loss, color="blue", label="train loss")
ax1.plot(val_steps, val_loss, color="red", label="validation loss")
ax1.legend(loc="upper left")
ax2 = ax1.twinx()
ax2.set_ylabel('success rate')
ax2.set_ylim(0,100)
ax2.plot(eval_steps,eval_success_rate, color="green", label="success rate")
ax2.legend(loc="upper right")
plt.title(f"val loss - {run.tags}")
plt.show()


In [None]:
# get correlation between the val loss and the success rate
correlation = val_loss.corr(eval_success_rate)
print(correlation)

time_correlation = val_steps.corr(eval_success_rate)
print(time_correlation)


In [None]:
# get a table of the val loss compared to the success rate
import pandas as pd
table = pd.concat([val_loss, eval_success_rate], axis=1)

# drop entries with nan values
table = table.dropna()
table

# for each validaiton loss, get the rank of the success rate
table["success_rank"] = table["eval/pc_success"].rank(ascending=False)
table = table.sort_values(by="validation/val_loss")
table