Skip to content

Commit

Permalink
feat(nbeats): add parameter coefficient to backcast loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Feb 22, 2021
1 parent 6597b53 commit 35b3c31
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
8 changes: 7 additions & 1 deletion src/backends/torch/native/native_factory.cc
Expand Up @@ -50,12 +50,18 @@ namespace dd
if (tdef.find("nbeats") != std::string::npos)
{
std::vector<std::string> p;
double bc_loss_coef = 1;
if (template_params.has("stackdef"))
{
p = template_params.get("stackdef")
.get<std::vector<std::string>>();
}
return new NBeats(inputc, p);
if (template_params.has("backcast_loss_coef"))
{
bc_loss_coef
= template_params.get("backcast_loss_coef").get<double>();
}
return new NBeats(inputc, p, bc_loss_coef);
}
else if (tdef.find("ttransformer") != std::string::npos)
return new TTransformer(inputc, template_params, logger);
Expand Down
13 changes: 9 additions & 4 deletions src/backends/torch/native/templates/nbeats.h
Expand Up @@ -297,7 +297,7 @@ namespace dd
}

NBeats(const CSVTSTorchInputFileConn &inputc,
std::vector<std::string> stackdef,
std::vector<std::string> stackdef, double backcast_loss_coef = 1,
std::vector<BlockType> stackTypes = NBEATS_DEFAULT_STACK_TYPES,
int nb_blocks_per_stack = NBEATS_DEFAULT_NB_BLOCKS,
int data_size = NBEATS_DEFAULT_DATA_SIZE,
Expand All @@ -311,7 +311,8 @@ namespace dd
_hidden_layer_units(hidden_layer_units),
_nb_blocks_per_stack(nb_blocks_per_stack),
_share_weights_in_stack(share_weights_in_stack),
_stack_types(stackTypes), _thetas_dims(thetas_dims)
_stack_types(stackTypes), _thetas_dims(thetas_dims),
_backcast_loss_coef(backcast_loss_coef)
{
parse_stackdef(stackdef);
update_params(inputc);
Expand Down Expand Up @@ -413,12 +414,14 @@ namespace dd
torch::Tensor y_pred = torch::slice(output, 1, _backcast_length,
_backcast_length + _forecast_length);
torch::Tensor input_zeros = torch::zeros_like(input_real);

if (loss.empty() || loss == "L1" || loss == "l1")
return torch::l1_loss(y_pred, target)
+ torch::l1_loss(x_pred, input_zeros);
+ torch::l1_loss(x_pred, input_zeros) * _backcast_loss_coef;
if (loss == "L2" || loss == "l2" || loss == "eucl")
return torch::mse_loss(y_pred, target)
+ torch::mse_loss(x_pred, input_zeros);
+ torch::mse_loss(x_pred, input_zeros) * _backcast_loss_coef;

throw MLLibBadParamException("unknown loss " + loss);
}

Expand All @@ -435,6 +438,8 @@ namespace dd
bool _share_weights_in_stack = NBEATS_DEFAULT_SHARE_WEIGHTS;
std::vector<BlockType> _stack_types = NBEATS_DEFAULT_STACK_TYPES;
std::vector<int> _thetas_dims = NBEATS_DEFAULT_THETAS;
double _backcast_loss_coef
= 1; /** < Coefficient applied to backcast loss */

std::vector<Stack> _stacks;
torch::nn::Linear _fcn{ nullptr };
Expand Down

0 comments on commit 35b3c31

Please sign in to comment.