# Stochastic variational inference

In [2]:
from tqdm import tqdm_notebook as tqdm
import autograd.numpy as np

# https://github.com/bamos/opt.py/blob/master/bamos_opt/pgd.py
def psgd(x0, g, proj, eps, maxit, lam, callback=None, **kwargs):
    """TODO."""
    results = {
        'feval': 0,
        'geval': 0
    }

    x = proj(np.copy(x0))

    for k in tqdm(range(maxit)):
        g_x = g(x, **kwargs)
        lam_val = lam(k)
        x = proj(x - lam_val * g_x)
        
        if callback:
            callback(k, x, g_x, lam_val)

    return x

def lam_inverse(init_lam):
    return lambda k: init_lam/k

def sort_rows(z):
    perm = anp.argsort(z, axis=1)
    rows = anp.arange(z.shape[0], dtype=int).reshape(-1, 1)
    return z[rows, perm]

def sparsemax_rows(z):
    ncols = z.shape[1]
    u = sort_rows(z)[:, ::-1]
    cssv = u.cumsum(axis=1) - 1
    ind = anp.arange(ncols) + 1
    cond = u - cssv / ind > 0
    argmin = anp.argmin(cond, axis=1)
    rho = np.where(argmin > 0, argmin, ncols)
    theta = cssv[anp.arange(z.shape[0]) , rho - 1]/rho
    return np.clip(z - theta.reshape(-1, 1), 0, None)

Always assume a dimension for samples

In [109]:
from autograd.scipy.stats import norm

n_clusters = 3
n_data = 4
mu = np.arange(n_clusters)
p = np.array([0.2, 0.3, 0.5])
indices = np.arange(n_clusters)
assignments = np.random.choice(indices, size=n_data, p=p)
assignments0 = np.random.choice(indices, size=n_data)
x = np.random.normal(loc=mu[assignments])
log_p = lambda assignments: np.sum(norm.logpdf(x[:, np.newaxis], loc=mu[assignments]) + np.log(p[assignments]), axis=0)
log_p(assignments[:, np.newaxis]), log_p(assignments0[:, np.newaxis])

(array([-9.99541264]), array([-13.54310484]))

In [128]:
rows = np.arange(n_data)
log_q = lambda assignments, lam: np.log(lam[rows[:, np.newaxis], assignments]).sum(axis=0)
lam0 = np.ones((n_data, n_clusters))/n_clusters
log_q(assignments[:, np.newaxis], lam0), log_q(assignments0[:, np.newaxis], lam0)

(array([-4.39444915]), array([-4.39444915]))

In [131]:
kl_point = lambda assignments, lam: log_q(assignments, lam) - log_p(assignments)
kl_point(assignments[:, np.newaxis], lam0)

array([5.60096349])

In [148]:
def random_choice_row(p, n_samples=1):
    u = np.random.uniform(size=(p.shape[0], n_samples)) #[index, sample]
    pf = p.cumsum(axis=1) #[index, row]
    return np.argmax(u[:, np.newaxis, :] < pf[:, :, np.newaxis], axis=1)

random_choice_row(lam0), random_choice_row(lam0, n_samples=10)

(array([[0],
        [2],
        [1],
        [1]]), array([[0, 0, 1, 1, 2, 0, 2, 0, 0, 0],
        [2, 1, 2, 1, 1, 1, 1, 2, 1, 0],
        [1, 1, 1, 0, 1, 0, 0, 1, 1, 2],
        [0, 1, 2, 0, 0, 1, 1, 1, 0, 1]]))

In [167]:
def kl_mc(lam, n_samples=1):
    assignments = random_choice_row(lam, n_samples=n_samples)
    return np.mean(kl_point(assignments, lam))

kl_mc(lam0, n_samples=100000)

7.340918732922332

In [161]:
all_assignments = np.repeat(np.arange(n_clusters)[np.newaxis, :], n_data, axis=0)

def kl(lam):
    log_p = norm.logpdf(x[:, np.newaxis], loc=mu[all_assignments]) + np.log(p[all_assignments])
    return np.sum(lam*np.log(lam)) - np.sum(log_p * lam)

kl(lam0)

7.338103815469151

In [164]:
import autograd 
kl_grad = autograd.grad(kl)
kl_grad(lam0)

array([[2.85179924, 2.02760122, 2.09804268],
       [2.44275679, 2.69849116, 3.84886501],
       [4.32626996, 2.47323929, 1.51484811],
       [5.20904469, 2.9459195 , 1.57743379]])

In [None]:
true_posterior = 

In [265]:
kl_point(random_choice_row(lam0, n_samples=13), lam0).shape

(13,)

In [263]:
log_q_grad

<function autograd.wrap_util.unary_to_nary.<locals>.nary_operator.<locals>.nary_f(*args, **kwargs)>

In [264]:
np.log(1/3)

-1.0986122886681098

In [261]:
log_q_grad = autograd.jacobian(log_q, argnum=1)

def kl_grad_mc(lam, n_samples=1):
    assignments = random_choice_row(lam, n_samples=n_samples)
    kl_val = kl_point(assignments, lam)
    log_q_grad_val = log_q_grad(assignments, lam)
    return np.mean(kl_val[:, np.newaxis, np.newaxis] * log_q_grad_val, axis=0)

kl_grad_mc(lam0, n_samples=100)

array([[7.68947118, 5.45555353, 8.46922381],
       [7.14464885, 6.7696205 , 7.69997918],
       [7.60757154, 6.81820734, 7.18846964],
       [8.72498135, 8.02323798, 4.86602919]])

http://pyro.ai/examples/svi_part_iii.html