In [31]:
from jax.random import PRNGKey, randint, uniform, permutation, split
import jax.numpy as jnp
import jax.dlpack
from jax.lax import scan
from jax import vmap, jit, value_and_grad
from jax.nn import softmax
from jax.scipy.special import expit, logit
from jax.example_libraries import optimizers

import os
import matplotlib.pyplot as plt
import numpy as np
import struct
import tensorflow as tf
import tensorflow_probability as tfp
import distrax
from distrax._src.utils import jittable

import itertools
import haiku as hk

tfd = tfp.distributions

In [56]:
# %pip install -U tensorflow_datasets
import tensorflow_datasets as tfds
from typing import Any, Iterator, Mapping, NamedTuple, Sequence, Tuple
Batch = Mapping[str, np.ndarray]

def load_dataset(split: str, batch_size: int) -> Iterator[Batch]:
  ds = tfds.load("binarized_mnist", split=split, shuffle_files=True)
  ds = ds.shuffle(buffer_size=10 * batch_size)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(buffer_size=5)
  ds = ds.repeat()
  data = tfds.as_numpy(ds)
  return iter(tfds.as_numpy(ds))

In [34]:
from typing import Tuple

class Encoder(hk.Module):
  """Encoder model."""

  def __init__(self, hidden_size: int = 512, latent_size: int = 10):
    super().__init__()
    self._hidden_size = hidden_size
    self._latent_size = latent_size
    

  def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    x = hk.Flatten()(x)
    x = hk.Linear(self._hidden_size)(x)
    x = jax.nn.relu(x)
    x = hk.Linear(self._latent_size)(x)
    x = jax.nn.sigmoid(x)
    mean = hk.Linear(self._latent_size)(x)
    log_stddev = hk.Linear(self._latent_size)(x)
    stddev = jnp.exp(log_stddev)
    return mean, stddev

In [16]:
import optax

OptState = Any
no_of_steps = 5000
batch_size = 128
no_of_batches = 10
latent_size = 10
batch_size = 10

model = hk.transform(
    lambda x: Encoder(latent_size)(x),  # pylint: disable=unnecessary-lambda
    apply_rng=True)
optimizer = optax.adam(0.001)

@jax.jit
def update(
    params: hk.Params,
    rng_key: PRNGKey,
    opt_state: OptState,
    batch: Batch,
) -> Tuple[hk.Params, OptState]:
    """Single SGD update step."""
    grads = jax.grad(loss_fn)(params, rng_key, batch)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

@jax.jit
def loss_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray:
    """Loss = -ELBO, where ELBO = E_q[log p(x|z)] - KL(q(z|x) || p(z))."""
    mean,std = model.apply(params, rng_key, batch["image"])
    likelihood_distrib = distrax.Independent(
        distrax.Bernoulli(logits=10))
    
    z = likelihood_distrib.sample(seed=hk.next_rng_key())
    print('z : ',z.shape)
    # p(z) = N(0, I)
    # prior_z = distrax.MultivariateNormalDiag(
    #     loc=jnp.zeros((latent_size,)),
    #     scale_diag=jnp.ones((latent_size,)))

    log_likelihood = likelihood_distrib.log_prob(batch["image"])
    return -jnp.mean(log_likelihood)
rng_seq = hk.PRNGSequence(42)
params = model.init(next(rng_seq), np.zeros((1, *(1,28,28))))
opt_state = optimizer.init(params)
train_data = load_dataset(tfds.Split.TRAIN, batch_size)
test_data = load_dataset(tfds.Split.TEST, batch_size)


for step in range(no_of_steps):
    params, opt_state = update(params, next(rng_seq), opt_state, next(train_data))

    if step % 100 == 0:
        loss = loss_fn(params,next(rng_seq),next(test_data))
        print(f' Step: {step}; Validation loss : {-loss}')


  unscaled = jax.random.truncated_normal(
  param = init(shape, dtype)
2023-12-19 19:40:28.714288: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
  return getattr(self.aval, name).fun(self, *args, **kwargs)


ValueError: `hk.next_rng_key` must be used as part of an `hk.transform`