Construct a probabilistic mapping between one set of integers and another, randomly generate the mapping and then create a function which performs the mapping. Then learn the mapping.

In [104]:
%pip install optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [126]:
from collections import namedtuple
from functools import partial
import jax
import jax.numpy as jnp
from jax import random
from jax import jit, vmap, grad
from jax import nn
import optax
import numpy as np
rng = random.PRNGKey(42)

In [106]:
# how would we model a probability distribution over a set of choices?
# start with a set of unnormalized values over the choices then turn that into
# a probability distribution and sample from it

In [107]:
num_choices = 16
rng, r = random.split(rng)
logits = random.normal(r, shape=(num_choices,))
logits

DeviceArray([ 0.7679137 ,  0.46966743, -1.4884446 , -1.155719  ,
             -1.2574353 , -0.66447204, -1.3314192 ,  0.6470733 ,
             -0.55124223, -0.3213504 ,  1.1168208 , -0.4216216 ,
              0.32398054,  1.3500887 , -0.22909231,  0.24462827],            dtype=float32)

In [108]:
# now sample from the logits
num_samples = 2**14
rng, r = random.split(rng)
samples = random.categorical(r, logits, shape=(num_samples,))
samples.shape

(16384,)

In [109]:
# do the sample distributions line up with the
# probabilities?
probs = nn.softmax(logits)
probs

DeviceArray([0.10999867, 0.08163206, 0.01152029, 0.01606809, 0.01451408,
             0.02626094, 0.01347903, 0.09747812, 0.02940934, 0.0370106 ,
             0.15592505, 0.03347949, 0.07056507, 0.19688964, 0.04058759,
             0.06518198], dtype=float32)

In [110]:
# they do!
samples_oh = nn.one_hot(samples, num_choices)
samples_counts = jnp.sum(samples_oh, axis=0)
samples_dist = samples_counts / num_samples
np.abs(probs - samples_dist)

array([9.2884898e-04, 8.2150847e-04, 7.6388009e-05, 2.8932840e-04,
       4.1496195e-04, 2.0593740e-03, 1.9018287e-03, 1.4087856e-03,
       4.7867931e-04, 8.4325671e-05, 3.7644058e-03, 1.7988309e-03,
       8.4606558e-04, 1.4746189e-03, 2.4493411e-04, 7.8988820e-04],
      dtype=float32)

In [112]:
# so now we know how to generate a random distribution over
# categorical choices and sample from it, assume we want to model
# mapping from x_choices to y_choices, we will need one row per
# x_choice and one column per y_choice
x_choices = 16
y_choices = 8

In [113]:
# and to do a mapping, select the distribution and sample from it
# call the mapping w for weights, we will 
@jit
def predict(rng, x, w):
    logits = w[x]
    return random.categorical(rng, logits)

In [114]:
# now let's put it together and see if we can learn the weights
# generate the weights
rng, r = random.split(rng)
w = random.normal(r, (x_choices, y_choices))

In [115]:
# create a sampler to generate a training pair
@jit
def sample(rng, w):
    rx, ry = random.split(rng)
    x = random.randint(rx, minval=0, maxval=x_choices, shape=())
    y = predict(ry, x, w)
    return x, y

rng, r = random.split(rng)
sample(r, w)

(DeviceArray(3, dtype=int32), DeviceArray(1, dtype=int32))

In [117]:
# define the loss function, actually, this isn't going to work
# our loss function needs to be differentiable and the predict function
# isn't differentiable, what would the loss even be? when
# we call predict, we get a single categorical value back
# how would you compute the loss in a way that connects
# back to the weights?

# I think this is where reparamaterization trick(s) come in

In [118]:
# So let's take a step back and just try to learn a normal distribution
# mean and variance

In [120]:
mean = 1.
std = 2.5
rng, r = random.split(rng)
z = random.normal(r)
mean + std * z

DeviceArray(-5.5664186, dtype=float32)

In [124]:
# run this a bunch and see if we get back mean and std from the samples
num_samples = 10_000
rng, r = random.split(rng)
zs = random.normal(r, shape=(num_samples,))
ys = mean + std * zs
jnp.mean(ys), jnp.std(ys)

(DeviceArray(1.0074186, dtype=float32), DeviceArray(2.508156, dtype=float32))

In [125]:
# ok so the problem is that random.normal isn't differentiable
# since it's a random function. But, the call to random.normal itself
# doesn't actually depend on our weights (mean, std) so we can come up
# with a loss function that is differentiable and pulls the mean and
# std in the right direction

In [127]:
# let's clean this up for a second, our
# weights will be a mean and std
GaussianModel = namedtuple('GaussianModel', 'mean std')
w = GaussianModel(mean=1., std=2.5)

In [137]:
# our sample function will perform the standard sampling from the
# above, 
@jit
def sample(rng, w):
    z = random.normal(rng, shape=())
    return w.mean + w.std * z
rng, r = random.split(rng)
sample(r, w)

DeviceArray(1.8244866, dtype=float32)

In [143]:
# this is differentiable wrt the weights
rng, r = random.split(rng)
grad(partial(sample, r))(w)

GaussianModel(mean=DeviceArray(1., dtype=float32, weak_type=True), std=DeviceArray(0.29066396, dtype=float32, weak_type=True))

In [144]:
# now we need a loss that will pull the mean and std
# in the right direction, this is the actual trick in a sense, but it's not
# really a trick is it? Or is the trick the fact that we can call sample
# in a way that the random part is separate from the part that relies on the
# weights?

In [146]:
# so we need to implement some loss function, which is a bunch of math 
# from https://arxiv.org/pdf/1312.6114v10.pdf 

In [148]:
# Maybe as a first step, how would I learn this as a per-sample
# algorithm, first manually in a loop through samples and then
# expressed as a gradient, take a look at the formula here:
# https://stats.stackexchange.com/questions/365192/bayesian-update-for-a-univariate-normal-distribution-with-unknown-mean-and-varia