Let $s$ be a random variable with discrete probability distribution dependent on $\theta$. Let $z$ be a Gaussian random variable with means and variances dependent on $\theta$ and $s$. Say we have an objective function $L_{\theta}(z)$. We want to compute the gradient of $L$ w.r.t $\theta$. We would like to use Monte-Carlo, but note that the straight forward sampling of $(s,z)$ depend on $\theta$. To compute the gradient using sampling, we will use three different methods: elemntary, Gumbel-softmax, log trick, and the compare them

In [66]:
import numpy as np
from matplotlib import pyplot as plt
from scipy import stats
import tensorflow as tf

dim = 10
buckets = 5
log_P = tf.Variable(np.random.normal(0,1, (1,buckets)).astype(np.float32))
P = tf.keras.activations.softmax(log_P)
A = tf.Variable(np.random.normal(0, 1, (dim + buckets,1)).astype(np.float32), 'A')
B = tf.Variable(np.random.normal(0, 1, 1).astype(np.float32), 'B')
C = tf.Variable(np.random.normal(0, 1, (dim + buckets,1)).astype(np.float32), 'C')
D = tf.Variable(np.random.normal(0, 1, 1).astype(np.float32), 'D')
X = tf.Variable(np.random.normal(0, 1, (1, dim)).astype(np.float32), 'X')

def get_mu_log_sigma(ss_oh):
    X_s = tf.concat([tf.broadcast_to(X, (len(ss_oh), dim)),ss_oh], axis=1)
    mu = tf.sigmoid(tf.matmul(X_s,A) + B)
    log_sigma = tf.matmul(X_s,C) + D
    return mu, log_sigma
    
L1=100
L2=1000
seed=1

def elem_mean():
    ss = np.arange(buckets)
    ss_oh = tf.keras.backend.one_hot(ss, buckets)
    mu, log_sigma = get_mu_log_sigma(ss_oh)
    sigma = tf.exp(log_sigma)
    return tf.matmul(P,mu)

def gumbel_mean(tau=0.01):
    gs = np.random.uniform(0, 1, (L1, buckets))
    gs = -np.log(-np.log(gs))
    ss_oh = tf.keras.activations.softmax((log_P + gs)/tau)
    mu, log_sigma = get_mu_log_sigma(ss_oh)
    sigma = tf.exp(log_sigma)
    return tf.reduce_mean(mu)


def elem_loss():
    ss = np.arange(buckets)
    ss_oh = tf.keras.backend.one_hot(ss, buckets)
    mu, log_sigma = get_mu_log_sigma(ss_oh)
    sigma = tf.exp(log_sigma)
    np.random.seed(seed)
    zs = np.random.normal(0,1,(len(P), L2)) * sigma + mu
    return tf.reduce_mean(tf.matmul(P, zs**2))

def gumbel_loss(tau=0.01):
    gs = np.random.uniform(0, 1, (L1, buckets))
    gs = -np.log(-np.log(gs))
    ss_oh = tf.keras.activations.softmax((log_P + gs)/tau)
    mu, log_sigma = get_mu_log_sigma(ss_oh)
    sigma = tf.exp(log_sigma)
    np.random.seed(seed)
    zs = np.random.normal(0,1,(1, L2)) * sigma + mu
    return tf.reduce_mean(zs**2)

elem_mean(), gumbel_mean()

(<tf.Tensor: id=3383, shape=(1, 1), dtype=float32, numpy=array([[0.9990226]], dtype=float32)>,
 <tf.Tensor: id=3406, shape=(), dtype=float32, numpy=0.99904543>)