Navigation Menu

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NormalWishart-Normal model #471

Closed
bertini36 opened this issue Feb 21, 2017 · 16 comments
Closed

NormalWishart-Normal model #471

bertini36 opened this issue Feb 21, 2017 · 16 comments

Comments

@bertini36
Copy link
Contributor

bertini36 commented Feb 21, 2017

Hi there!

I'm trying to implement a NormalWishart-Normal model with Edward. I think the model representation is OK but, what do you think?. Here is the code:

# -*- coding: UTF-8 -*-

"""
NormalWishart-Normal Model
Posterior inference with Edward BBVI
[DOING]
"""

import edward as ed
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from edward.models import MultivariateNormalFull, WishartCholesky
from scipy.stats import invwishart

N = 1000
D = 2

# Data generation
# NIW Inverse Wishart hyperparameters
v = 3.
W = np.array([[20., 30.], [25., 40.]])
sigma = invwishart.rvs(v, W)
# NIW Normal hyperparameters
m = np.array([1., 1.])
k = 0.8
mu = np.random.multivariate_normal(m, sigma / k)
xn_data = np.random.multivariate_normal(mu, sigma, N)
plt.scatter(xn_data[:, 0], xn_data[:, 1], cmap=cm.gist_rainbow, s=5)
plt.show()
print('mu={}'.format(mu))
print('sigma={}'.format(sigma))

# Prior definition
v_prior = tf.Variable(3., dtype=tf.float64, trainable=False)
W_prior = tf.Variable(np.array([[1., 0.], [0., 1.]]),
                      dtype=tf.float64, trainable=False)
m_prior = tf.Variable(np.array([0.5, 0.5]), dtype=tf.float64, trainable=False)
k_prior = tf.Variable(0.6, dtype=tf.float64, trainable=False)

print('***** PRIORS *****')
print('v_prior: {}'.format(v_prior))
print('W_prior: {}'.format(W_prior))
print('m_prior: {}'.format(m_prior))
print('k_prior: {}'.format(k_prior))

# Posterior inference
# Probabilistic model
sigma = WishartCholesky(df=v_prior, scale=W_prior)
mu = MultivariateNormalFull(m_prior, k_prior * sigma)
xn = MultivariateNormalFull(tf.reshape(tf.tile(mu, [N]), [N, D]),
                            tf.reshape(tf.tile(sigma, [N, 1]), [N, 2, 2]))

print('***** PROBABILISTIC MODEL *****')
print('mu: {}'.format(mu))
print('sigma: {}'.format(sigma))
print('xn: {}'.format(xn))

# Variational model
qmu = MultivariateNormalFull(
    tf.Variable(tf.random_normal([D], dtype=tf.float64), name='v1'),
    tf.nn.softplus(
        tf.Variable(tf.random_normal([D, D], dtype=tf.float64), name='v2')))
qsigma = WishartCholesky(
    df=tf.nn.softplus(
        tf.Variable(tf.random_normal([], dtype=tf.float64), name='v3')),
    scale=tf.nn.softplus(
        tf.Variable(tf.random_normal([D, D], dtype=tf.float64), name='v4')))

print('***** VARIATIONAL MODEL *****')
print('qmu: {}'.format(qmu))
print('qsigma: {}'.format(qsigma))

# Inference
print('xn_data: {}'.format(xn_data.dtype))
inference = ed.KLqp({mu: qmu, sigma: qsigma}, data={xn: xn_data})
inference.run(n_iter=2000, n_samples=20)

But it seems there is a type error:

File "NW_normal_edward.py", line 78, in <module>
    inference.run(n_iter=2000, n_samples=20)
  File "/home/alberto/.virtualenvs/GMM/local/lib/python2.7/site-packages/edward/inferences/inference.py", line 218, in run
    self.initialize(*args, **kwargs)
  File "/home/alberto/.virtualenvs/GMM/local/lib/python2.7/site-packages/edward/inferences/klqp.py", line 66, in initialize
    return super(KLqp, self).initialize(*args, **kwargs)
  File "/home/alberto/.virtualenvs/GMM/local/lib/python2.7/site-packages/edward/inferences/variational_inference.py", line 70, in initialize
    self.loss, grads_and_vars = self.build_loss_and_gradients(var_list)
  File "/home/alberto/.virtualenvs/GMM/local/lib/python2.7/site-packages/edward/inferences/klqp.py", line 108, in build_loss_and_gradients
    return build_reparam_loss_and_gradients(self, var_list)
  File "/home/alberto/.virtualenvs/GMM/local/lib/python2.7/site-packages/edward/inferences/klqp.py", line 343, in build_reparam_loss_and_gradients
    z_copy = copy(z, dict_swap, scope=scope)
  File "/home/alberto/.virtualenvs/GMM/local/lib/python2.7/site-packages/edward/util/random_variables.py", line 176, in copy
    new_rv = rv.__class__(*args, **kwargs)
  File "/home/alberto/.virtualenvs/GMM/local/lib/python2.7/site-packages/edward/models/random_variable.py", line 62, in __init__
    super(RandomVariable, self).__init__(*args, **kwargs)
  File "/home/alberto/.virtualenvs/GMM/local/lib/python2.7/site-packages/tensorflow/contrib/distributions/python/ops/wishart.py", line 521, in __init__
    name=ns)
  File "/home/alberto/.virtualenvs/GMM/local/lib/python2.7/site-packages/tensorflow/contrib/distributions/python/ops/wishart.py", line 125, in __init__
    dtype=self._scale_operator_pd.dtype, name="dimension")
  File "/home/alberto/.virtualenvs/GMM/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 651, in convert_to_tensor
    as_ref=False)
  File "/home/alberto/.virtualenvs/GMM/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 730, in internal_convert_to_tensor
    dtype.name, ret.dtype.name))
RuntimeError: dimension: Conversion function <function _constant_tensor_conversion_function at 0x7f35210922a8> for type <type 'object'> returned incompatible dtype: requested = float64_ref, actual = float64

Do you think it is a Tensorflow's WishartCholesky problem? Do you have some model example using Wishart distribution in Edward?

@AlexLewandowski
Copy link
Contributor

I was working on a similar problem, but used only tf.float32 and it ran fine. I understand this may not be relevant to you anymore, but it may help someone else!

@gundeep59
Copy link

@bertini36 Were you able to resolve this? I am facing the same problem. @AlexLewandowski Any links to your code?

@bertini36
Copy link
Contributor Author

Nop.. sorry @gundeep59 I tried it using float32 but it didn't work for me. Same error.

@bertini36
Copy link
Contributor Author

@AlexLewandowski Can you tell us what Edward and Tensorflow versions you use?

@gundeep59
Copy link

@AlexLewandowski Also did you use WishartCholesky or WishartFull from the models? I tried both, but WishartFull too returned an error.

@dustinvtran
Copy link
Member

dustinvtran commented Apr 4, 2017

Can you report your versions of Edward and TensorFlow?

@gundeep59
Copy link

Edward : 1.2.4 Tensorflow : 1.0.1

@bertini36
Copy link
Contributor Author

Latest versions:
edward==1.2.4
tensorflow==1.0.1

@dustinvtran
Copy link
Member

Looks like a bug with the copy function when using a tf.Variable with trainable=False. I'll see if I can find the issue.

For now, if you replace the prior hyperparameters with

v_prior = tf.constant(3., dtype=tf.float64)
W_prior = tf.constant(np.array([[1., 0.], [0., 1.]]),
                      dtype=tf.float64)
m_prior = tf.constant(np.array([0.5, 0.5]), dtype=tf.float64)
k_prior = tf.constant(0.6, dtype=tf.float64)

this will run. (Note an additional error appears, as the multivariate normal's covariance matrix is overparameterized and does not guarantee positive semi-definiteness; optimizing it freely would be easier if you use a multivariate normal cholesky.)

@gundeep59
Copy link

@dustinvtran Thanks for the help! I am tried using MultivariateNormalCholesky for mu as well as xn and in the variational model for qmu but facing the error for cholesky decomposition.

@dustinvtran
Copy link
Member

dustinvtran commented Apr 4, 2017

The Cholesky decomposition only raises an error if it's trying to decompose a matrix that isn't positive definite. This shouldn't be an issue if you're directly optimizing with respect to a lower triangular matrix (by construction its outer product with itself will always be positive definite). Can you provide more details on the error and modified code?

@gundeep59
Copy link

Sure. The code for the probabilistic model and the inference is as below.

# Probabilistic model
sigma = WishartCholesky(df=v_prior, scale=W_prior)
mu = MultivariateNormalCholesky(m_prior, k_prior * sigma)
xn = MultivariateNormalCholesky(tf.reshape(tf.tile(mu, [N]), [N, D]),
                            tf.reshape(tf.tile(sigma, [N, 1]), [N, D, D]))

print('***** PROBABILISTIC MODEL *****')
print('mu: {}'.format(mu))
print('sigma: {}'.format(sigma))
print('xn: {}'.format(xn))

# Variational model
qmu = MultivariateNormalCholesky(
    tf.Variable(tf.random_normal([D], dtype=tf.float64), name='v1'),
    tf.nn.softplus(
        tf.Variable(tf.random_normal([D, D], dtype=tf.float64), name='v2')))
qsigma = WishartCholesky(
    tf.nn.softplus(
        tf.Variable(tf.random_normal([], dtype=tf.float64), name='v3')),
    tf.nn.softplus(
        tf.Variable(tf.random_normal([D, D], dtype=tf.float64), name='v4')))

print('***** VARIATIONAL MODEL *****')
print('qmu: {}'.format(qmu))
print('qsigma: {}'.format(qsigma))

inference = ed.KLqp({mu: qmu, sigma: qsigma}, data={xn: xn_data})
inference.run(n_iter=2000, n_samples=20,n_print=100)

And the error is

InvalidArgumentError (see above for traceback): LLT decomposition was not successful. The input might not be valid.
[[Node: inference_4665971536/8/WishartCholesky_6/log_prob/Cholesky = CholeskyT=DT_DOUBLE, _device="/job:localhost/replica:0/task:0/cpu:0"]]

This "WishartCholesky_6" is for qsigma.

@bertini36
Copy link
Contributor Author

I was trying with positive definite matrices generation in this way:

random_matrix_1 = tf.Variable(tf.random_normal([D, D], dtype=tf.float64))
qmu = MultivariateNormalCholesky(
    tf.Variable(tf.random_normal([D], dtype=tf.float64)),
    tf.matmul(random_matrix_1, tf.transpose(random_matrix_1))
    + D * tf.eye(D, dtype=tf.float64))
random_matrix_2 = tf.Variable(tf.random_normal([D, D], dtype=tf.float64))
qsigma = WishartCholesky(
    df=tf.nn.softplus(tf.Variable(tf.random_normal([], dtype=tf.float64))),
    scale=tf.matmul(random_matrix_2, tf.transpose(random_matrix_2)) +
          D * tf.eye(D, dtype=tf.float64))

But same error

@dustinvtran
Copy link
Member

dustinvtran commented Apr 6, 2017

@gundeep59 | you overparameterized qmu and qsigma. There's no guarantee their scale parameters are lower triangular, which might be an issue.

@bertini36 | I think this runs into the same issue?

In general, it's worth debugging where things go wrong. For exapmle, first you should try doing the optimization with a fixed scale parameter. If that works, then the second step is to find the specific value of the variables parameterizing scale which break the Cholesky decomposition. This will help diagnose if there's an error with how TensorFlow implemented the Cholesky operator.

@deoxyribose
Copy link

deoxyribose commented May 24, 2017

Struggling with the same errors, I resolved them by constraining qsigma like this:

L = tf.Variable(tf.random_normal([D, D], dtype=tf.float32))
qsigma = WishartCholesky(
    tf.nn.softplus(
        tf.Variable(tf.random_normal([], dtype=tf.float32))+D+1),
    LinearOperatorTriL(L).to_dense())

The problem with the code posted by @bertini36 and @gundeep59 is probably that the df parameter is not constrained to be greater or equal to the dimensionality D, and that the scale is not lower triangular.

@dustinvtran
Copy link
Member

Thanks @deoxyribose. Closing issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants