diff --git a/src/beanmachine/ppl/inference/monte_carlo_samples.py b/src/beanmachine/ppl/inference/monte_carlo_samples.py index 8380ced021..2b605f505e 100644 --- a/src/beanmachine/ppl/inference/monte_carlo_samples.py +++ b/src/beanmachine/ppl/inference/monte_carlo_samples.py @@ -45,7 +45,10 @@ def __init__( self.samples[rv] = val[:, num_adaptive_samples:] if logll_results is not None: - logll = merge_dicts(logll_results, 0, stack_not_cat) + if isinstance(logll_results, list): + logll = merge_dicts(logll_results, 0, stack_not_cat) + else: + logll = logll_results self.log_likelihoods = {} self.adaptive_log_likelihoods = {} for rv, val in logll.items(): @@ -92,7 +95,22 @@ def get_chain(self, chain: int = 0) -> "MonteCarloSamples": raise IndexError("Please specify a valid chain") samples = {rv: self.get_variable(rv, True)[[chain]] for rv in self} - new_mcs = MonteCarloSamples(samples, self.num_adaptive_samples) + + if self.log_likelihoods is None: + logll = None + else: + logll = { + rv: self.get_log_likelihoods(rv, True)[[chain]] + for rv in self.log_likelihoods + } + + new_mcs = MonteCarloSamples( + samples, + self.num_adaptive_samples, + True, + logll, + self.observations, + ) new_mcs.single_chain_view = True return new_mcs @@ -135,6 +153,30 @@ def get_variable( samples = samples.squeeze(0) return samples + def get_log_likelihoods( + self, + rv: RVIdentifier, + include_adapt_steps: bool = False, + ) -> torch.Tensor: + """ + :returns: log_likelihoods computed during inference for the specified variable + """ + + if not isinstance(rv, RVIdentifier): + raise TypeError( + "The key is required to be a random variable " + + f"but is of type {type(rv).__name__}." + ) + + logll = self.log_likelihoods[rv] + + if include_adapt_steps: + logll = torch.cat([self.adaptive_log_likelihoods[rv], logll], dim=1) + + if self.single_chain_view: + logll = logll.squeeze(0) + return logll + def get( self, rv: RVIdentifier,