Skip to content

Commit

Permalink
Merge pull request #148 from blei-lab/doc/details
Browse files Browse the repository at this point in the history
Doc/details
  • Loading branch information
akucukelbir committed Jul 3, 2016
2 parents 42034ec + cfb00f7 commit 29a1623
Show file tree
Hide file tree
Showing 7 changed files with 1,595 additions and 348 deletions.
118 changes: 68 additions & 50 deletions edward/criticisms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,28 @@
from edward.util import logit, get_session

def evaluate(metrics, model, variational, data):
"""
Evaluate fitted model using a set of metrics.
"""Evaluate fitted model using a set of metrics.
Parameters
----------
metric : list or str
metrics : list or str
List of metrics or a single metric.
model : ed.Model
Probability model p(x, z)
variational : ed.Variational
Variational approximation to the posterior p(z | x)
data : ed.Data
Data to evaluate the model at
Returns
-------
list or float
A list of evaluations or a single evaluation.
Raises
------
NotImplementedError
If an input metric does not match an implemented metric in Edward.
"""
sess = get_session()
# Monte Carlo estimate the mean of the posterior predictive:
Expand Down Expand Up @@ -85,35 +95,39 @@ def evaluate(metrics, model, variational, data):
return evaluations

def ppc(model, variational=None, data=Data(), T=None, size=100):
"""
Posterior predictive check.
"""Posterior predictive check.
(Rubin, 1984; Meng, 1994; Gelman, Meng, and Stern, 1996)
If variational is not specified, it defaults to a prior predictive
check (Box, 1980).
If no posterior approximation is provided through ``variational``,
then we default to a prior predictive check (Box, 1980).
PPC's form an empirical distribution for the predictive discrepancy,
p(T) = \int p(T(yrep) | z) p(z | y) dz
by drawing replicated data sets yrep and calculating T(yrep) for
each data set. Then it compares it to T(y).
.. math::
p(T) = \int p(T(yrep) | z) p(z | y) dz
by drawing replicated data sets yrep and calculating
:math:`T(yrep)` for each data set. Then it compares it to
:math:`T(y)`.
Parameters
----------
model : Model
class object with a 'sample_likelihood' method
variational : Variational, optional
model : ed.Model
class object that implements the ``sample_likelihood`` method
variational : ed.Variational, optional
latent variable distribution q(z) to sample from. It is an
approximation to the posterior, e.g., a variational
approximation or an empirical distribution from MCMC samples.
If not specified, samples will be obtained from model
with a 'sample_prior' method.
data : Data, optional
If not specified, samples will be obtained from the model
through the ``sample_prior`` method.
data : ed.Data, optional
Observed data to compare to. If not specified, will return
only the reference distribution with an assumed replicated
data set size of 1.
T : function, optional
Discrepancy function written in TensorFlow. Default is
identity. It is a function taking in a data set
y and optionally a set of latent variables z as input.
Discrepancy function taking tf.Tensor inputs and returning
a tf.Tensor output. Default is the identity function.
In general this is a function taking in a data set ``y``
and optionally a set of latent variables ``z`` as input.
size : int, optional
number of replicated data sets
Expand All @@ -122,10 +136,15 @@ class object with a 'sample_likelihood' method
list
List containing the reference distribution, which is a Numpy
vector of size elements,
(T(yrep^{1}, z^{1}), ..., T(yrep^{size}, z^{size}));
.. math::
(T(yrep^{1}, z^{1}), ..., T(yrep^{size}, z^{size}))
and the realized discrepancy, which is a NumPy vector of size
elements,
(T(y, z^{1}), ..., T(y, z^{size})).
.. math::
(T(y, z^{1}), ..., T(y, z^{size})).
"""
sess = get_session()
y = data.data
Expand Down Expand Up @@ -173,8 +192,7 @@ class object with a 'sample_likelihood' method
# Classification metrics

def binary_accuracy(y_true, y_pred):
"""
Binary prediction accuracy, also known as 0/1-loss.
"""Binary prediction accuracy, also known as 0/1-loss.
Parameters
----------
Expand All @@ -188,17 +206,15 @@ def binary_accuracy(y_true, y_pred):
return tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32))

def categorical_accuracy(y_true, y_pred):
"""
Multi-class prediction accuracy. One-hot representation for
y_true.
"""Multi-class prediction accuracy. One-hot representation for ``y_true``.
Parameters
----------
y_true : tf.Tensor
Tensor of 0s and 1s, where the outermost dimension of size K
Tensor of 0s and 1s, where the outermost dimension of size ``K``
has only one 1 per row.
y_pred : tf.Tensor
Tensor of probabilities, with same shape as y_true.
Tensor of probabilities, with same shape as ``y_true``.
The outermost dimension denote the categorical probabilities for
that data point per row.
"""
Expand All @@ -207,16 +223,15 @@ def categorical_accuracy(y_true, y_pred):
return tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32))

def sparse_categorical_accuracy(y_true, y_pred):
"""
Multi-class prediction accuracy. Label {0, 1, .., K-1}
representation for y_true.
"""Multi-class prediction accuracy. Label {0, 1, .., K-1}
representation for ``y_true``.
Parameters
----------
y_true : tf.Tensor
Tensor of integers {0, 1, ..., K-1}.
y_pred : tf.Tensor
Tensor of probabilities, with shape (y_true.get_shape(), K).
Tensor of probabilities, with shape ``(y_true.get_shape(), K)``.
The outermost dimension are the categorical probabilities for
that data point.
"""
Expand All @@ -225,7 +240,8 @@ def sparse_categorical_accuracy(y_true, y_pred):
return tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32))

def binary_crossentropy(y_true, y_pred):
"""
"""Binary cross-entropy.
Parameters
----------
y_true : tf.Tensor
Expand All @@ -238,8 +254,7 @@ def binary_crossentropy(y_true, y_pred):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(y_pred, y_true))

def categorical_crossentropy(y_true, y_pred):
"""
Multi-class cross entropy. One-hot representation for y_true.
"""Multi-class cross entropy. One-hot representation for ``y_true``.
Parameters
----------
Expand All @@ -256,16 +271,15 @@ def categorical_crossentropy(y_true, y_pred):
return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_pred, y_true))

def sparse_categorical_crossentropy(y_true, y_pred):
"""
Multi-class cross entropy. Label {0, 1, .., K-1} representation
for y_true.
"""Multi-class cross entropy. Label {0, 1, .., K-1} representation
for ``y_true.``
Parameters
----------
y_true : tf.Tensor
Tensor of integers {0, 1, ..., K-1}.
y_pred : tf.Tensor
Tensor of probabilities, with shape (y_true.get_shape(), K).
Tensor of probabilities, with shape ``(y_true.get_shape(), K)``.
The outermost dimension are the categorical probabilities for
that data point.
"""
Expand All @@ -274,7 +288,8 @@ def sparse_categorical_crossentropy(y_true, y_pred):
return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(y_pred, y_true))

def hinge(y_true, y_pred):
"""
"""Hinge loss.
Parameters
----------
y_true : tf.Tensor
Expand All @@ -287,7 +302,8 @@ def hinge(y_true, y_pred):
return tf.reduce_mean(tf.maximum(1.0 - y_true * y_pred, 0.0))

def squared_hinge(y_true, y_pred):
"""
"""Squared hinge loss.
Parameters
----------
y_true : tf.Tensor
Expand All @@ -302,7 +318,8 @@ def squared_hinge(y_true, y_pred):
# Regression metrics

def mean_squared_error(y_true, y_pred):
"""
"""Mean squared error loss.
Parameters
----------
y_true : tf.Tensor
Expand All @@ -312,7 +329,8 @@ def mean_squared_error(y_true, y_pred):
return tf.reduce_mean(tf.square(y_pred - y_true))

def mean_absolute_error(y_true, y_pred):
"""
"""Mean absolute error loss.
Parameters
----------
y_true : tf.Tensor
Expand All @@ -322,7 +340,8 @@ def mean_absolute_error(y_true, y_pred):
return tf.reduce_mean(tf.abs(y_pred - y_true))

def mean_absolute_percentage_error(y_true, y_pred):
"""
"""Mean absolute percentage error loss.
Parameters
----------
y_true : tf.Tensor
Expand All @@ -333,7 +352,8 @@ def mean_absolute_percentage_error(y_true, y_pred):
return 100.0 * tf.reduce_mean(diff)

def mean_squared_logarithmic_error(y_true, y_pred):
"""
"""Mean squared logarithmic error loss.
Parameters
----------
y_true : tf.Tensor
Expand All @@ -345,9 +365,8 @@ def mean_squared_logarithmic_error(y_true, y_pred):
return tf.reduce_mean(tf.square(first_log - second_log))

def poisson(y_true, y_pred):
"""
Negative Poisson log-likelihood of data y_true given predictions
y_pred (up to proportion).
"""Negative Poisson log-likelihood of data ``y_true`` given predictions
``y_pred`` (up to proportion).
Parameters
----------
Expand All @@ -358,8 +377,7 @@ def poisson(y_true, y_pred):
return tf.reduce_sum(y_pred - y_true * tf.log(y_pred + 1e-8))

def cosine_proximity(y_true, y_pred):
"""
Cosine similarity of two vectors.
"""Cosine similarity of two vectors.
Parameters
----------
Expand Down
60 changes: 43 additions & 17 deletions edward/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,44 @@
import tensorflow as tf

class Data:
"""
Base class for data.
"""Base class for Edward data objects.
By default, it assumes the data is an array (or list of arrays).
If requested will perform data subsampling according to slices of
the first index (e.g., elements in a vector, rows in a matrix,
y-by-z matrices in a x-by-y-by-z tensor). Use one of the derived
classes for subsampling more complex data structures.
Arguments
----------
data: tf.tensor, np.ndarray, list, dict, optional
Data whose type depends on the type of model it is fed into.
If TensorFlow, must be tf.tensor or list (see notes).
If Stan, must be dict.
If PyMC3, must be np.ndarray.
If NumPy/SciPy, must be np.ndarray or list of np.ndarrays.
shuffled: bool, optional
Whether the data is shuffled.
Notes
-----
For TensorFlow models, data argument can be list of placeholders
or list of np.ndarrays. If np.ndarrays, it will use mini-batches
of the np.arrays during computation. If placeholders, user must
manually control mini-batches and feed in the placeholders.
Data subsampling is not currently available for Stan models.
Internally, self.counter stores the last accessed data index. It
Internally, ``self.counter`` stores the last accessed data index. It
is used to obtain the next batch of data starting from
self.counter to the size of the data set.
``self.counter`` to the size of the data set.
"""
def __init__(self, data=None, shuffled=True):
"""Initialization.
Parameters
----------
data : tf.tensor, np.ndarray, list, dict, optional
Data whose type depends on the type of model it is fed into.
If TensorFlow, must be ``tf.tensor`` or ``list``.
If Stan, must be ``dict``.
If PyMC3, must be ``np.ndarray``.
If NumPy/SciPy, must be ``np.ndarray`` or ``list`` of ``np.ndarrays``.
shuffled: bool, optional
Whether the data is shuffled when sampling.
"""
self.data = data
if not shuffled:
# TODO
Expand Down Expand Up @@ -63,6 +67,28 @@ def __init__(self, data=None, shuffled=True):
raise NotImplementedError()

def sample(self, n_data=None):
"""Data sampling method.
At any given point, the internal counter ``self.counter`` tracks the
last datapoint returned by ``sample``.
If the requested number of datapoints ``n_data`` goes beyond the size
of the dataset, the internal counter wraps around the size of the
dataset. The returned minibatch, thus, may include datapoints from the
beginning of the dataset.
Parameters
----------
n_data : int, optional
Number of datapoints to sample
Defaults to total number of datapoints in ``Data`` object.
Returns
-------
minibatch : tf.Tensor
a tensor with first dimension size = ``n_data``
"""
# TODO
# In general, there should be a scale factor due to data
# subsampling, so that
Expand Down

0 comments on commit 29a1623

Please sign in to comment.