Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Only store Sigma values in BART samples instead of object (#1527)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1527

Background:
We are building Bayesian Additive Regression Trees (BART) as an experimental causal inference model in beanmachine. Details of the project can be found in https://docs.google.com/document/d/11nkB6UTGpvQBEC2yBjfgwAr8VabTlD7R9XufGQG0EvI/edit?usp=sharing and the proposed design can be found in the draft design document: https://docs.google.com/document/d/1o3J7yobDF0M9E27Y0tP2889fycmemXUZbHE5cebRqzs/edit?usp=sharing.

In this diff:
The noise standard deviation (sigma) parameter is never really used in the prediction tasks. While we would like to retain them for diagnostic purposes, there is no reason to store the NoiseStandardDeviation object in the sample trace. In this diff, we are modifying the BART class to only store float samples of the noise standard deviation.

Reviewed By: feynmanliang

Differential Revision: D37635208

fbshipit-source-id: be2f53b61b666fe9d50d2504a57351c91bd24915
  • Loading branch information
himaghna authored and facebook-github-bot committed Jul 6, 2022
1 parent 9ba385f commit 26529f8
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _step(self) -> Tuple[List, float]:
self.X
)
self._update_sigma(self.y - self._predict_step())
return self.all_trees, self.sigma
return self.all_trees, self.sigma.val

def _update_leaf_mean(self, tree: Tree, partial_residual: torch.Tensor):
"""
Expand Down

0 comments on commit 26529f8

Please sign in to comment.