Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
handling plot category (#36)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #36

Add support for trace plot and autocorrelation plots for sampled variables . User can override these functions or add new plot functionalities by registering these new methods via plot_fn(...)

Reviewed By: nimar

Differential Revision: D17911632

fbshipit-source-id: f5db51ed4bb8e20d6be2b0845d42d20e47e3792b
  • Loading branch information
Narges Torabi authored and facebook-github-bot committed Dec 12, 2019
1 parent 9262745 commit 98f8b7a
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 33 deletions.
95 changes: 95 additions & 0 deletions beanmachine/ppl/diagnostics/common_plots.py
@@ -0,0 +1,95 @@
from typing import Callable, List, NamedTuple, Tuple

import numpy as np
import plotly.graph_objs as go
import torch
from torch import Tensor


class SamplesSummary(NamedTuple):
num_chain: int
num_samples: int
single_sample_sz: Tensor


def _samples_info(query_samples: Tensor):
return SamplesSummary(
num_chain=query_samples.size(0),
num_samples=query_samples.size(1),
single_sample_sz=query_samples.size()[2:],
)


def trace_helper(
x: List[List[List[int]]], y: List[List[List[float]]], labels: List[str]
) -> Tuple[List[go.Scatter], List[str]]:
"""
this function gets results prepared by a plot-related function and
outputs a tuple including plotly object and its corresponding legend.
"""
all_traces = []
num_chains = len(x)
num_indices = len(x[0])
for index in range(num_indices):
trace = []
for chain in range(num_chains):
trace.append(
go.Scatter(
x=x[chain][index],
y=y[chain][index],
mode="lines",
name="chain" + str(chain),
)
)
all_traces.append(trace)
return (all_traces, labels)


def plot_helper(
query_samples: Tensor, func: Callable
) -> Tuple[List[go.Scatter], List[str]]:
"""
this function executes a plot-related function, passed as input parameter func, and
outputs a tuple including plotly object and its corresponding legend.
"""
num_chain, num_samples, single_sample_sz = _samples_info(query_samples)

x_axis, y_axis, all_labels = [], [], []
for chain in range(num_chain):
flattened_data = query_samples[chain].reshape(num_samples, -1)
numel = flattened_data[0].numel()
x_axis_data, y_axis_data, labels = [], [], []
for i in range(numel):
index = np.unravel_index(i, single_sample_sz)
data = flattened_data[:, i]
partial_label = f" for {[j for j in index]}"

x_data, y_data = func(data.detach())
x_axis_data.append(x_data)
y_axis_data.append(y_data)
labels.append(partial_label)
x_axis.append(x_axis_data)
y_axis.append(y_axis_data)
all_labels.append(labels)
return trace_helper(x_axis, y_axis, all_labels[0])


def autocorr(x: Tensor) -> Tuple[List[int], List[float]]:
def autocorr_calculation(x: Tensor, lag: int) -> Tensor:
y1 = x[: (len(x) - lag)]
y2 = x[lag:]

sum_product = (
(y1 - (x.mean(dim=0).expand(y1.size())))
* (y2 - (x.mean(dim=0).expand(y2.size())))
).sum(0)
return sum_product / ((len(x) - lag) * torch.var(x, dim=0))

max_lag = x.size(0)
y_axis_data = [autocorr_calculation(x, lag).item() for lag in range(max_lag)]
x_axis_data = [k for k in range(max_lag)]
return (x_axis_data, y_axis_data)


def trace_plot(x: Tensor) -> Tuple[List[int], Tensor]:
return ([k for k in range(x.size(0))], x)
228 changes: 198 additions & 30 deletions beanmachine/ppl/diagnostics/diagnostics.py
@@ -1,30 +1,52 @@
# Copyright (c) Facebook, Inc. and its affiliates.

from typing import List, Optional
import functools
import math
from typing import Callable, Dict, List, Optional, Tuple

import beanmachine.ppl.diagnostics.common_plots as common_plots
import beanmachine.ppl.diagnostics.common_statistics as common_stats
import numpy as np
import pandas as pd
import plotly
from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples
from beanmachine.ppl.model.utils import RVIdentifier
from plotly import tools
from torch import Tensor


class BaseDiagnostics:
def __init__(self, samples: MonteCarloSamples):
self.samples = samples
self.statistics_dict = {}
self.plots_dict = {}

def summaryfn(self, func, display_names: List[str], statistics_name: str = None):
def _prepare_query_list(
self, query_list: Optional[List[RVIdentifier]] = None
) -> List[RVIdentifier]:
if query_list is None:
return list(self.samples.get_rv_names())
for query in query_list:
if not (query in self.samples.get_rv_names()):
raise ValueError(f"query {self._stringify_query(query)} does not exist")
return query_list

def summaryfn(self, func: Callable, display_names: List[str]) -> Callable:
"""
this function keeps a directory of all summary-related functions,
so it could handle the overridden functions and new ones that user defines
:param func: method which is going to be executed when summary() is called.
:param display_name: the name appears in the summary() output dataframe
:returns: user-visible function that can be called over a list of queries
"""
if not statistics_name:
statistics_name = func.__name__
statistics_name = func.__name__
self.statistics_dict[statistics_name] = (func, display_names)
return self._standalone_summary_stat_function(statistics_name, func)

def _prepare_input(self, query: RVIdentifier, chain: Optional[int] = None):
def _prepare_summary_stat_input(
self, query: RVIdentifier, chain: Optional[int] = None
):
query_samples = self.samples[query]
if chain is not None:
query_samples = query_samples[chain].unsqueeze(0)
Expand All @@ -33,6 +55,9 @@ def _prepare_input(self, query: RVIdentifier, chain: Optional[int] = None):
def _create_table(
self, query: RVIdentifier, results: List[Tensor], func_list: List[str]
) -> pd.DataFrame:
"""
this function turns output of each summary stat function to a dataframe
"""
out_pd = pd.DataFrame()
single_result_set = results[0]
for flattened_index in range(single_result_set[0].numel()):
Expand All @@ -56,6 +81,32 @@ def _create_table(
def _stringify_query(self, query: RVIdentifier) -> str:
return f"{query.function.__name__}{query.arguments}"

def _execute_summary_stat_funcs(
self,
query: RVIdentifier,
func_dict: Dict[str, Tuple[Callable, str]],
chain: Optional[int] = None,
):
frames = pd.DataFrame()
query_results = []
func_list = []
queried_samples = self._prepare_summary_stat_input(query, chain)
for _k, (func, display_names) in func_dict.items():
result = func(queried_samples)
if result is None:
continue
# the first dimension is equivalant to the size of the display_names
if len(display_names) <= 1:
result = result.unsqueeze(0)
query_results.append(result)
func_list.extend(display_names)
out_df = self._create_table(query, query_results, func_list)
if frames.empty:
frames = out_df
else:
frames = pd.concat([frames, out_df])
return frames

def summary(
self,
query_list: Optional[List[RVIdentifier]] = None,
Expand All @@ -67,41 +118,158 @@ def summary(
if chain is None, results correspond to the aggreagated chains
"""
frames = pd.DataFrame()
if query_list is None:
query_list = list(self.samples.get_rv_names())
query_list = self._prepare_query_list(query_list)
for query in query_list:
if not (query in self.samples.get_rv_names()):
raise ValueError(f"query {self._stringify_query(query)} does not exist")
query_results = []
func_list = []
queried_samples = self._prepare_input(query, chain)
for _k, (func, display_names) in self.statistics_dict.items():
result = func(queried_samples)
if result is not None:
# the first dimension is equivalant to the size of the display_names
if len(display_names) <= 1:
result = result.unsqueeze(0)
query_results.append(result)
func_list.extend(display_names)
out_df = self._create_table(query, query_results, func_list)
if frames.empty:
frames = out_df
else:
frames = pd.concat([frames, out_df])
out_df = self._execute_summary_stat_funcs(
query, self.statistics_dict, chain
)
frames = pd.concat([frames, out_df])
frames.sort_index(inplace=True)
return frames

def _prepare_plots_input(
self, query: RVIdentifier, chain: Optional[int] = None
) -> Tensor:
"""
:param query: the query for which registered plot functions are called
:param chain: the chain that query samples are extracted from
:returns: tensor of query samples
"""
query_samples = self.samples[query]
if chain is not None:
return query_samples[chain].unsqueeze(0)
return query_samples

def plotfn(self, func: Callable, display_name: str) -> Callable:
"""
this function keeps a directory of all plot-related functions
so it could handle the overridden functions and new ones that user defines
:param func: method which is going to be executed when plot() is called.
:param display_name: appears as part of the plot title for func
:returns: user-visible function that can be called over a list of queries
"""
self.plots_dict[func.__name__] = (func, display_name)
return self._standalone_plot_function(func.__name__, func)

def _execute_plot_funcs(
self,
query: RVIdentifier,
func_dict: Dict[str, Tuple[Callable, str]],
chain: Optional[int] = None,
display: Optional[bool] = False,
): # task T57168727 to add type
figs = []
queried_samples = self._prepare_plots_input(query, chain)
for _k, (func, display_name) in func_dict.items():
trace, labels = common_plots.plot_helper(queried_samples, func)
title = f"{self._stringify_query(query)} {display_name}"
fig = self._display_results(
trace, [title + label for label in labels], display
)
figs.append(fig)
return figs

def plot(
self,
query_list: Optional[List[RVIdentifier]] = None,
display: Optional[bool] = False,
chain: Optional[int] = None,
): # task T57168727 to add type
"""
this function outputs plots related to registered functions in
self.plots_dict for requested queries in query_list
:param query_list: list of queries for which plot functions will be called
:param chain: the chain that query samples are extracted from
:returns: plotly object holding the results from registered plot functions
"""
figs = []
query_list = self._prepare_query_list(query_list)
for query in query_list:
fig = self._execute_plot_funcs(query, self.plots_dict, chain, display)
figs.extend(fig)
return figs

def _display_results(
self, traces, labels: List[str], display: bool
): # task T57168727 to add type
"""
:param traces: a list of plotly objects
:param labels: plot labels
:returns: a plotly subplot object
"""
fig = tools.make_subplots(
rows=math.ceil(len(traces) / 2), cols=2, subplot_titles=tuple(labels)
)

r = 1
for trace in traces:
for data in trace:
fig.append_trace(data, row=math.ceil(r / 2), col=((r - 1) % 2) + 1)
r += 1
if display:
plotly.offline.iplot(fig)
return fig

def _standalone_plot_function(self, func_name: str, func: Callable) -> Callable:
"""
this function makes each registered plot function directly callable by the user
"""

@functools.wraps(func)
def _wrapper(
query_list: List[RVIdentifier],
chain: Optional[int] = None,
display: Optional[bool] = False,
):
figs = []
query_list = self._prepare_query_list(query_list)
for query in query_list:
fig = self._execute_plot_funcs(
query, {func_name: self.plots_dict[func_name]}, chain, display
)
figs.extend(fig)
return figs

return _wrapper

def _standalone_summary_stat_function(
self, func_name: str, func: Callable
) -> Callable:
"""
this function makes each registered summary-stat related function directly callable by the user
"""

@functools.wraps(func)
def _wrapper(query_list: List[RVIdentifier], chain: Optional[int] = None):
frames = pd.DataFrame()
query_list = self._prepare_query_list(query_list)
for query in query_list:
out_df = self._execute_summary_stat_funcs(
query, {func_name: self.statistics_dict[func_name]}, chain
)
frames = pd.concat([frames, out_df])
return frames

return _wrapper


class Diagnostics(BaseDiagnostics):
def __init__(self, samples: MonteCarloSamples):
super().__init__(samples)
"""
every function related to summary stat should be registered in the constructor
"""
self.summaryfn(common_stats.mean, display_names=["avg"])
self.summaryfn(common_stats.std, display_names=["std"])
self.summaryfn(
self.mean = self.summaryfn(common_stats.mean, display_names=["avg"])
self.std = self.summaryfn(common_stats.std, display_names=["std"])
self.confidence_interval = self.summaryfn(
common_stats.confidence_interval, display_names=["2.5%", "50%", "97.5%"]
)
self.summaryfn(common_stats.split_r_hat, display_names=["r_hat"])
self.summaryfn(common_stats.effective_sample_size, display_names=["n_eff"])
self.split_r_hat = self.summaryfn(
common_stats.split_r_hat, display_names=["r_hat"]
)
self.effective_sample_size = self.summaryfn(
common_stats.effective_sample_size, display_names=["n_eff"]
)
self.trace = self.plotfn(common_plots.trace_plot, display_name="trace")
self.autocorr = self.plotfn(common_plots.autocorr, display_name="autocorr")

0 comments on commit 98f8b7a

Please sign in to comment.