In [1]:
import logging

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

In [2]:
logging.getLogger("tensorflow").setLevel(logging.ERROR)

In [3]:
dtype = tf.dtypes.float32

In [4]:
N = 100000
p = 1000

sigma_squared = tf.constant(1.0)

X_dist = tfd.Uniform(low=-1.0, high=1.0)
beta_dist = tfd.Uniform(-10, 10)
error_dist = tfd.Normal(loc=0, scale=sigma_squared ** 0.5)

# sample parameters
tf.random.set_seed(0)
beta = beta_dist.sample((p, 1), seed=0, name="beta")
X = X_dist.sample((N, p), seed=0, name="X")
error = error_dist.sample((N, 1), seed=0, name="error")

# generate predictions, etc.
y = X @ beta + error
beta_hat_ols = tf.matmul(tf.linalg.inv(tf.matmul(X, X, transpose_a=True)), tf.matmul(X, y, transpose_a=True))
y_hat_ols = X @ beta_hat_ols


In [5]:
print((N, p))
# print(beta)
# print(sigma_squared)
# print(X)
# print(error)
# print(y)

(100000, 1000)


In [6]:
def f():
    y_dist = tfd.MultivariateNormalDiag(y_hat_ols, scale_diag=tf.ones((N, 1)) * sigma_squared**0.5)
    log_probs = y_dist.log_prob(y)
    l = -1 * tf.reduce_sum(log_probs)
    return l


%timeit f()

9.2 ms ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
def f():
    y_dist = tfd.Normal(y_hat_ols, sigma_squared**0.5)
    log_probs = y_dist.log_prob(y)
    l = -1 * tf.reduce_sum(log_probs)
    return l


%timeit f()

1.18 ms ± 112 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [8]:
def f():
    y_dist = tfd.Normal(y_hat_ols[..., 0], sigma_squared**0.5)
    log_probs = y_dist.log_prob(y[..., 0])
    l = -1 * tf.reduce_sum(log_probs)
    return l


%timeit f()

1.43 ms ± 186 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
y_hat_ols_vector = y_hat_ols[..., 0]
y_vector = y[..., 0]

In [10]:
def f():
    y_dist = tfd.Normal(y_hat_ols_vector, sigma_squared**0.5)
    log_probs = y_dist.log_prob(y_vector)
    l = -1 * tf.reduce_sum(log_probs)
    return l


%timeit f()

1.1 ms ± 54.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
