Skip to content

Commit

Permalink
Rejig ELBO so that baselines don't have to output scaled costs.
Browse files Browse the repository at this point in the history
This means that the baseline objective doesn't change with the batch
size, and it also makes our baseline net step sizes more likely to be
comparable to those used outside of Pyro.
  • Loading branch information
null-a committed Nov 10, 2017
1 parent d4975ae commit b84cc99
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pyro/infer/tracegraph_elbo.py
Expand Up @@ -235,7 +235,7 @@ def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
for node in non_reparam_nodes:
guide_site = guide_trace.nodes[node]
log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf'
downstream_cost = downstream_costs[node]
downstream_cost = downstream_costs[node] / guide_site["scale"] # not scaled by subsampling
baseline = 0.0
(nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta,
baseline_value) = _get_baseline_options(guide_site)
Expand All @@ -260,7 +260,7 @@ def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
# accumulate baseline loss
baseline_loss += torch.pow(downstream_cost.detach() - baseline, 2.0).sum()

guide_log_pdf = guide_site[log_pdf_key] / guide_site["scale"] # not scaled by subsampling
guide_log_pdf = guide_site[log_pdf_key]
if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value:
if downstream_cost.size() != baseline.size():
raise ValueError("Expected baseline at site {} to be {} instead got {}".format(
Expand Down

0 comments on commit b84cc99

Please sign in to comment.