Skip to content

Commit

Permalink
fix(nbeats): fix loss for backcast part
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and mergify[bot] committed Dec 3, 2020
1 parent 8eef648 commit 986728c
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/backends/torch/native/templates/nbeats.h
Expand Up @@ -366,17 +366,19 @@ namespace dd
_backcast_length + _forecast_length);
}

virtual torch::Tensor loss(std::string loss, torch::Tensor input,
virtual torch::Tensor loss(std::string loss, torch::Tensor input_real,
torch::Tensor output, torch::Tensor target)
{
torch::Tensor x_pred = torch::slice(output, 1, 0, _backcast_length);
torch::Tensor y_pred = torch::slice(output, 1, _backcast_length,
_backcast_length + _forecast_length);
torch::Tensor input_zeros = torch::zeros_like(input_real);
if (loss.empty() || loss == "L1" || loss == "l1")
return torch::l1_loss(y_pred, target) + torch::l1_loss(x_pred, input);
return torch::l1_loss(y_pred, target)
+ torch::l1_loss(x_pred, input_zeros);
if (loss == "L2" || loss == "l2" || loss == "eucl")
return torch::mse_loss(y_pred, target)
+ torch::mse_loss(x_pred, input);
+ torch::mse_loss(x_pred, input_zeros);
throw MLLibBadParamException("unknown loss " + loss);
}

Expand Down

0 comments on commit 986728c

Please sign in to comment.