In [1]:
import pandas as pd
import numpy as np
import pickle

In [2]:
with open("../../data/mat_grades.pickle", "rb") as fp:
    data = pickle.load(fp)

In [3]:
# Remove groups with no observations

In [4]:
data = list(filter(lambda x: len(x) > 0, data))
len(data)

1048

In [68]:
# pad with nans

maxlen = np.max([len(x) for x in data])
out = np.empty((len(data), maxlen))
out[:, :] = np.nan

for i in range(len(data)):
    k = len(data[i])
    out[i, :k] = data[i] + np.random.normal(size=len(data[i])) * 0.25

In [69]:
with open("math_grades.pickle", "wb") as fp:
    pickle.dump(out, fp)

In [21]:
from jax import jit
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions


@jit
def get_mean_and_var(alldata, clus_idx, c):
    card = jnp.count_nonzero(clus_idx == c)
    ybar = jnp.where(clus_idx == c, alldata, 0).sum() / card
    var = jnp.where(clus_idx == c, (alldata - ybar)**2, 0).sum() / card
    return ybar, var


@jit
def norm_lpdf(data, atoms):
    means = jnp.vstack([x[0] for x in atoms])
    sds = jnp.sqrt(jnp.vstack([x[1] for x in atoms]))
    
    return tfd.Normal(means[:, jnp.newaxis], sds[:, jnp.newaxis]).log_prob(data).T

In [22]:
clus = np.random.choice(10, size=out.shape)
clus[np.isnan(out)] = -10
c = 2

In [23]:
alldata = jnp.array(out)
clus = jnp.array(clus)

get_mean_and_var(alldata, clus, c)

(DeviceArray(6.019065, dtype=float32), DeviceArray(2.098379, dtype=float32))

In [52]:
from jax import random
from jax.ops import index_update

seed = 0
key = random.PRNGKey(seed)

atoms = jnp.array(np.random.uniform(size=(10, 2)))
likes = norm_lpdf(alldata, atoms)
probas = jnp.exp(likes)
probas /= jnp.nansum(probas, axis=1)[:, jnp.newaxis, :]
rng_key, subkey = random.split(key)
clus = tfd.Categorical(probs=probas).sample(seed=key).T
clus = index_update(clus, np.isnan(alldata), -10)

In [63]:
from jax.ops import index

array([[False, False, False, ...,  True,  True,  True],
       [False, False, False, ...,  True,  True,  True],
       [False, False, False, ...,  True,  True,  True],
       ...,
       [False, False, False, ...,  True,  True,  True],
       [False, False, False, ...,  True,  True,  True],
       [False, False, False, ...,  True,  True,  True]])

In [35]:
alldata.shape

(1048, 131)

In [30]:
likes.shape

(131, 1048, 10)