# Unamortised VAE implementation

In [1]:
import os

os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "1,1,1"
os.environ["TPU_HOST_BOUNDS"] = "1,1,1"
os.environ["TPU_VISIBLE_DEVICES"] = "0"

In [65]:
import jax
import hlax
import jax.numpy as jnp
import flax.linen as nn

In [66]:
%load_ext autoreload
%autoreload 2

In [3]:
class Decoder(nn.Module):
    """
    For the generative model
    p(x,z) = p(x|z) * p(z)
    """
    dim_full: int
    dim_latent: int = 20
    
    def setup(self):
        self.mean = nn.Dense(self.dim_full, use_bias=True, name="mean")
        self.logvar = nn.Dense(self.dim_full, use_bias=False, name="logvar")
    
    @nn.compact
    def __call__(self, z):
        x = nn.Dense(20)(z)
        x = nn.elu(x)
        mean_x = self.mean(x)
        logvar_x = self.logvar(x)
        return mean_x, logvar_x 


class Encoder(nn.Module):
    """
    two-layered encoder
    """
    latent_dim: int
    n_hidden: int = 100
    
    @nn.compact
    def __call__(self, x):
        z = nn.Dense(self.n_hidden)(x)
        z = nn.elu(z)
        z = nn.Dense(self.n_hidden)(z)
        z = nn.elu(z)
        mean_z = nn.Dense(self.latent_dim)(z)
        logvar_z = nn.Dense(self.latent_dim)(z)
        return mean_z, logvar_z

## Initialisiation

In [17]:
key = jax.random.PRNGKey(314)

In [10]:
warmup, test = hlax.datasets.load_fashion_mnist(n_train=1000, n_test=100)
X_warmup = warmup[0]

In [12]:
lossfn_vae = hlax.losses.iwae
lossfn_hardem = hlax.losses.loss_hard_nmll

In [75]:
dim_latent = 100
num_obs, *dim_obs = X_warmup.shape

In [92]:
model_vae = hlax.models.UnamortisedVAEBern(dim_latent, dim_obs[0], Encoder, Decoder)

In [140]:
batch_size = 200
key_params_init, key_eps_init = jax.random.split(key)

# Initialise model consider total number of observations.
# We will then slice over the batches
batch_init = jnp.ones((num_obs, *dim_obs))
params = model_vae.init(key_params_init, batch_init, key_eps_init, num_samples=4)

jax.tree_map(jnp.shape, params)

FrozenDict({
    params: {
        decoder: {
            Dense_0: {
                bias: (20,),
                kernel: (100, 20),
            },
            logvar: {
                kernel: (20, 784),
            },
            mean: {
                bias: (784,),
                kernel: (20, 784),
            },
        },
        encoder: {
            Dense_0: {
                bias: (1000, 100),
                kernel: (1000, 784, 100),
            },
            Dense_1: {
                bias: (1000, 100),
                kernel: (1000, 100, 100),
            },
            Dense_2: {
                bias: (1000, 100),
                kernel: (1000, 100, 100),
            },
            Dense_3: {
                bias: (1000, 100),
                kernel: (1000, 100, 100),
            },
        },
    },
})

In [106]:
key_batch, keys_vae = jax.random.split(key)
batch_ixs = hlax.training.get_batch_train_ixs(key_batch, num_obs, batch_size)
num_batches = len(batch_ixs)
keys_vae = jax.random.split(keys_vae, num_batches)

In [107]:
batch_ix = batch_ixs[0]

## Param surgery

In [146]:
import optax
from functools import partial
from flax.core import freeze, unfreeze
from flax.training.train_state import TrainState

In [150]:
num_samples = 10
tx = optax.adam(1e-5)

state = TrainState.create(
    apply_fn=partial(model_vae.apply, num_samples=num_samples),
    params=params,
    tx=tx
)

In [168]:
state.params

FrozenDict({
    params: {
        encoder: {
            Dense_0: {
                bias: DeviceArray([[0., 0., 0., ..., 0., 0., 0.],
                             [0., 0., 0., ..., 0., 0., 0.],
                             [0., 0., 0., ..., 0., 0., 0.],
                             ...,
                             [0., 0., 0., ..., 0., 0., 0.],
                             [0., 0., 0., ..., 0., 0., 0.],
                             [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
                kernel: DeviceArray([[[-5.91159391e-04, -5.99345751e-02, -1.70421340e-02, ...,
                               -1.11749899e-02, -4.11583930e-02,  6.92131172e-04],
                              [-8.39921486e-05, -9.13265161e-03, -2.38039382e-02, ...,
                               -3.54486592e-02,  3.54229026e-02,  7.44586764e-03],
                              [-2.64137685e-02,  8.02890062e-02,  1.09117301e-02, ...,
                               -1.07408445e-02, -5.21490984e-02,  5.91373891e-0

In [164]:
state.replace(params={"3":3})

TrainState(step=0, apply_fn=functools.partial(<bound method Module.apply of UnamortisedVAEBern(
    # attributes
    latent_dim = 100
    obs_dim = 784
    Encoder = Encoder
    Decoder = Decoder
)>, num_samples=10), params={'3': 3}, tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x7f6a1e16add0>, update=<function chain.<locals>.update_fn at 0x7f6a1e16a710>), opt_state=(ScaleByAdamState(count=DeviceArray(0, dtype=int32), mu=FrozenDict({
    params: {
        decoder: {
            Dense_0: {
                bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                             0., 0., 0., 0., 0.], dtype=float32),
                kernel: DeviceArray([[0., 0., 0., ..., 0., 0., 0.],
                             [0., 0., 0., ..., 0., 0., 0.],
                             [0., 0., 0., ..., 0., 0., 0.],
                             ...,
                             [0., 0., 0., ..., 0., 0., 0.],
                             [0., 0., 0.,

In [153]:
jax.tree_map(jnp.shape, state)

TrainState(step=(), apply_fn=functools.partial(<bound method Module.apply of UnamortisedVAEBern(
    # attributes
    latent_dim = 100
    obs_dim = 784
    Encoder = Encoder
    Decoder = Decoder
)>, num_samples=10), params=FrozenDict({
    params: {
        decoder: {
            Dense_0: {
                bias: (20,),
                kernel: (100, 20),
            },
            logvar: {
                kernel: (20, 784),
            },
            mean: {
                bias: (784,),
                kernel: (20, 784),
            },
        },
        encoder: {
            Dense_0: {
                bias: (1000, 100),
                kernel: (1000, 784, 100),
            },
            Dense_1: {
                bias: (1000, 100),
                kernel: (1000, 100, 100),
            },
            Dense_2: {
                bias: (1000, 100),
                kernel: (1000, 100, 100),
            },
            Dense_3: {
                bias: (1000, 100),
                kernel

In [126]:
params_batch = unfreeze(params)["params"]
params_batch_encoder = jax.tree_map(lambda x: x[batch_ix], params_batch["encoder"])
jax.tree_map(jnp.shape, params_batch_encoder)

params_batch = freeze({
    "params": {
        "encoder": params_batch_encoder,
        "decoder": params_batch["decoder"]
    }
})

In [128]:
jax.tree_map(jnp.shape, params_batch)

FrozenDict({
    params: {
        decoder: {
            Dense_0: {
                bias: (20,),
                kernel: (100, 20),
            },
            logvar: {
                kernel: (20, 784),
            },
            mean: {
                bias: (784,),
                kernel: (20, 784),
            },
        },
        encoder: {
            Dense_0: {
                bias: (200, 100),
                kernel: (200, 784, 100),
            },
            Dense_1: {
                bias: (200, 100),
                kernel: (200, 100, 100),
            },
            Dense_2: {
                bias: (200, 100),
                kernel: (200, 100, 100),
            },
            Dense_3: {
                bias: (200, 100),
                kernel: (200, 100, 100),
            },
        },
    },
})