Skip to content

Commit

Permalink
minor fixes to plotting and logging
Browse files Browse the repository at this point in the history
  • Loading branch information
priyald17 committed Nov 29, 2020
1 parent 76d19e4 commit 2bf4384
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions power_sched/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from constants import *



def task_loss(Y_sched, Y_actual, params):
return (params["gamma_under"] * torch.clamp(Y_actual - Y_sched, min=0) +
params["gamma_over"] * torch.clamp(Y_sched - Y_actual, min=0) +
Expand Down Expand Up @@ -219,11 +220,11 @@ def eval_net(which, model, variables, params, save_folder):
hold_loss_task = task_loss(
Y_sched_hold.float(), variables['Y_hold_'], params)

torch.save(train_loss_task.data,
torch.save(train_loss_task.detach().cpu().numpy(),
os.path.join(save_folder, '{}_train_task'.format(which)))
torch.save(test_loss_task.data,
torch.save(test_loss_task.detach().cpu().numpy(),
os.path.join(save_folder, '{}_test_task'.format(which)))

if (which == "task_net"):
torch.save(hold_loss_task.data,
torch.save(hold_loss_task.detach().cpu().numpy(),
os.path.join(save_folder, '{}_hold_task'.format(which)))
4 changes: 2 additions & 2 deletions power_sched/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def plot_results(load_folders, save_folder):
colors = [sns.color_palette()[i] for i in [1,2]] + ['gray']

ax = axes[0]
ax.set_axis_bgcolor('none')
# ax.set_axis_bgcolor('none')
for col, style, color in zip(rmse_mean.columns, styles, colors):
rmse_mean[col].plot(
ax=ax, lw=2, fmt=style, color=color, yerr=rmse_stds[col])
ax.set_ylabel('RMSE')

ax2 = axes[1]
ax2.set_axis_bgcolor('none')
# ax2.set_axis_bgcolor('none')
for col, style, color in zip(task_mean.columns, styles, colors):
if col == 'Cost-weighted RMSE':
task_mean[col].plot(
Expand Down

0 comments on commit 2bf4384

Please sign in to comment.