In [1]:
import jax
from jax import random, lax, jit, vmap, value_and_grad
from jax.tree_util import tree_map
from jax.nn import initializers
import jax.numpy as np
import numpy as onp
from flax import linen as nn
from flax.core.frozen_dict import freeze, unfreeze

import tensorflow as tf
tf.config.set_visible_devices([], device_type = 'GPU')
import tensorflow_datasets as tfds

import optax
from flax.training import checkpoints, train_state
from flax.training.early_stopping import EarlyStopping

import time
from copy import copy

In [2]:
class encoder(nn.Module):
    n_loops_top_layer: int
    x_dim_top_layer: int

    @nn.compact
    def __call__(self, x):
        
        # CNN based on End-to-End Training of Deep Visuomotor Policies
        x = nn.Conv(features = 64, kernel_size = (7, 7))(x)
        x = nn.relu(x)
        x = nn.Conv(features = 32, kernel_size = (5, 5))(x)
        x = nn.relu(x)
        x = nn.Conv(features = 32, kernel_size = (5, 5))(x)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1)) # flatten
        x = nn.Dense(features = 64)(x)
        x = nn.relu(x)
        x = nn.Dense(features = 40)(x)
        x = nn.relu(x)
        x = nn.Dense(features = 40)(x)
        x = nn.relu(x)
        x = nn.Dense(features = (self.n_loops_top_layer + self.x_dim_top_layer) * 2)(x)
        
        # mean and log variances of Gaussian distribution over latents
        z_mean, z_log_var = np.split(x, 2, axis = 1)
        
        return {'z_mean': z_mean, 
                'z_log_var': z_log_var}
    
class sampler(nn.Module):
    n_loops_top_layer: int
    
    @nn.compact
    def __call__(self, results, params, hyperparams, key):
        
        def diag_Gaussian_sample(mean, log_var, hyperparams, key):
            """
            sample from a diagonal Gaussian distribution
            """
            log_var = stabilise_varance(log_var, hyperparams)

            return mean + np.exp(0.5 * log_var) * random.normal(key, mean.shape)

        # sample the latents
        z = diag_Gaussian_sample(results['z_mean'], results['z_log_var'], hyperparams, key)

        # split the latents into top-layer alphas, softmax(z1), and initial state, z2
        z1, z2 = np.split(z, [self.n_loops_top_layer], axis = 1)

        return nn.activation.softmax(z1 / params['t'][0], axis = 1), np.squeeze(z2)

class decoder(nn.Module):
    T: int
    
    @nn.compact
    def __call__(self, params, A, hyperparams, x0, z1, z2):        

        def decode_one_step(carry, inputs):
            
            def compute_alphas(W, x, b, t):

                return nn.activation.softmax( (W @ x + b) / t, axis = 0)

            def compute_inputs(W, x):

                return W @ x

            def update_state(A, x, alphas, u, dt):

                return x + (np.sum(alphas[:, None, None] * A, axis = 0) @ x + u) * dt

            def compute_pen_actions(W, x, b):

                return W @ x + b
            
            def per_pixel_bernoulli_parameter(params, hyperparams, pen_xy, pen_down_log_p):
    
                def log_Gaussian_kernel(x, mu, log_var, hyperparams):
                    """
                    calculate the log likelihood of x under a diagonal Gaussian distribution
                    """
                    log_var = stabilise_varance(log_var, hyperparams)

                    return -0.5 * (x - mu)**2 / np.exp(log_var)

                ll_p_x = log_Gaussian_kernel(pen_xy[0], hyperparams['x_pixels'], params['pen_log_var'], hyperparams)
                ll_p_y = log_Gaussian_kernel(pen_xy[1], hyperparams['y_pixels'], params['pen_log_var'], hyperparams)

                p_xy_t = np.exp(ll_p_x[None,:] + ll_p_y[:,None] + pen_down_log_p)

                return p_xy_t
            
            def update_pen_position(pen_xy, d_xy, hyperparams):
    
                # candidate new pen position
                pen_xy = pen_xy + d_xy

                # align pen position relative to centre of canvas
                pen_xy = pen_xy - hyperparams['image_dim'] / 2

                # transform canvas boundaries to -/+ 5
                pen_xy = pen_xy * 2 / hyperparams['image_dim'] * 5

                # squash pen position to be within canvas boundaries
                pen_xy = nn.sigmoid(pen_xy)

                # transform canvas boundaries back to their original values
                pen_xy_new = pen_xy * hyperparams['image_dim']

                return pen_xy_new

            x, pen_xy = carry
            top_layer_alphas = inputs

            # compute the alphas
            alphas = jax.tree_map(compute_alphas, params['W_a'], x[:2], params['b_a'], params['t'][1:])

            # prepend the top-layer alphas
            alphas.insert(0, np.squeeze(top_layer_alphas))

            # compute the additive inputs
            u = jax.tree_map(compute_inputs, params['W_u'], x[:2])

            # prepend the top-layer additive inputs
            u.insert(0, x[0] * 0)

            # update the states
            x_new = jax.tree_map(update_state, A, x, alphas, u, hyperparams['dt'])

            # linear readout from the state at the bottom layer
            pen_actions = compute_pen_actions(params['W_p'], x_new[-1], params['b_p'])

            # pen velocities in x and y directions
            d_xy = pen_actions[:2]

            # log probability that the pen is down (actively drawing)
            pen_down_log_p = nn.log_sigmoid(pen_actions[2])

            # calculate the per-pixel bernoulli parameter
            p_xy = per_pixel_bernoulli_parameter(params, hyperparams, pen_xy, pen_down_log_p)

            # update the pen position based on the pen velocity
            pen_xy_new = update_pen_position(pen_xy, d_xy, hyperparams)

            carry = x_new, pen_xy_new
            outputs = alphas, x_new, pen_xy_new, p_xy, pen_down_log_p

            return carry, outputs

        x0[0] = z2[:]

        pen_xy0 = hyperparams['image_dim'] / 2 # initialise pen in centre of canvas

        carry = x0, pen_xy0
        inputs = np.repeat(z1[None,:], self.T, axis = 0)

        _, (alphas, x, pen_xy, p_xy_t, pen_down_log_p) = lax.scan(decode_one_step, carry, inputs)
    
        return {'alphas': alphas,
                'x0': x0,
                'x': x,
                'pen_xy0': pen_xy0,
                'pen_xy': pen_xy,
                'p_xy_t': p_xy_t,
                'pen_down_log_p': pen_down_log_p}

def stabilise_varance(log_var, hyperparams):
    """
    var_min is added to the variances for numerical stability
    """
    return np.log(np.exp(log_var) + hyperparams['var_min'])

In [3]:
class VAE(nn.Module):
    n_loops_top_layer: int
    x_dim_top_layer: int
    T: int

    def setup(self):
        
        self.encoder = encoder(self.n_loops_top_layer, self.x_dim_top_layer)
        self.sampler = sampler(self.n_loops_top_layer)
        self.decoder = decoder(self.T)

    def __call__(self, data, params, hyperparams, key, A, x0):
        
        results_encode = self.encoder(data[None,:,:,None])
        z1, z2 = self.sampler(results_encode, params, hyperparams, key)
        results_decode = self.decoder(params, A, hyperparams, x0, z1, z2)
        
        return results_encode | results_decode

In [4]:
def initialise_decoder_parameters(hyperparams,  key):
    
    P = []
    S_U = []
    S_V = []
    L = []
    W_u = []
    W_a = []
    b_a = []
    t = []
    
    n_layers = len(hyperparams['x_dim'])
    for layer in range(n_layers):
        
        key, *subkeys = random.split(key, num = 8)
        
        n_loops = hyperparams['n_loops'][layer]
        x_dim = hyperparams['x_dim'][layer]
        
        # parameters of layer-specific P
        p = random.normal(subkeys[0], (x_dim, x_dim))
        
        # set trace of P @ P.T to x_dim
        P.append(p * np.sqrt(x_dim / np.trace(p @ p.T)))
        
        # parameters of layer- and loop-specific S
        u = random.normal(subkeys[1], (n_loops, x_dim, int(x_dim / n_loops)))
        v = random.normal(subkeys[2], (n_loops, x_dim, int(x_dim / n_loops)))
        
        # set variance of elements of S to 1/(n_loops * x_dim)
        s = u @ np.transpose(v, (0, 2, 1)) - v @ np.transpose(u, (0, 2, 1))
        # f = 1 / np.linalg.norm(s, axis = (1, 2))[:, None, None] / n_loops # frobenius norm of each loop 1/n_loops
        f = 1 / np.std(s, axis = (1, 2))[:, None, None] / np.sqrt(n_loops * x_dim)
        S_U.append(u * np.sqrt(f))
        S_V.append(v * np.sqrt(f))

        # parameters of layer- and loop-specific L
        Q, _ = np.linalg.qr(random.normal(subkeys[3], (x_dim, x_dim)))
        L_i = np.split(Q * np.sqrt(n_loops), n_loops, axis = 1)
        L.append(np.stack(L_i, axis = 0))

        # parameters of the mapping from hidden states to alphas, alphas = softmax(W @ x + b, temperature)
        if layer != 0:
            
            # weights for additive inputs
            std_W = 1 / np.sqrt(hyperparams['x_dim'][layer - 1])
            W_u.append(random.normal(subkeys[4], (hyperparams['x_dim'][layer], hyperparams['x_dim'][layer - 1])) * std_W)
            
            # weights for modulatory factors
            W_a.append(random.normal(subkeys[5], (n_loops, hyperparams['x_dim'][layer - 1])) * std_W)

            # bias for modulatory factors
            b_a.append(np.zeros((n_loops)))
            
        if layer == n_layers - 1:
            
            # weights for pen actions
            std_W = 1 / np.sqrt(hyperparams['x_dim'][layer])
            W_p = random.normal(subkeys[6], (3, hyperparams['x_dim'][layer])) * std_W
            
            # bias for pen actions
            b_p = np.zeros((3))
            
        # temperature of layer-specific softmax function
        t.append(1.0)

    return {'P': P, 
            'S_U': S_U,
            'S_V': S_V, 
            'L': L, 
            'W_u': W_u,
            'W_a': W_a, 
            'b_a': b_a,
            't': t,
            'W_p': W_p,
            'b_p': b_p,
            'pen_log_var': hyperparams['init_pen_log_var']}

def construct_dynamics_matrix(params, hyperparams):

    def construct_P(P):
        return P @ P.T

    def construct_S(U, V):
        return U @ np.transpose(V, (0, 2, 1)) - V @ np.transpose(U, (0, 2, 1))

    def construct_A(L, P, S):
        return (-L @ np.transpose(L, (0, 2, 1)) + S) @ P

    # positive semi-definite matrix P
    P = jax.tree_map(construct_P, params['P'])

    # skew symmetric matrix S
    S = jax.tree_map(construct_S, params['S_U'], params['S_V'])

    # dynamics matrix A (loops organised along axis 0)
    A = jax.tree_map(construct_A, params['L'], S, P)

    return A

def initialise_LDS_states(hyperparams):
              
    # initialise the states of the LDS in the decoder to zero (not learned)
    # the state of the top layer will be inferred later by the encoder and so the value here will be overwritten
    n_layers = len(hyperparams['x_dim'])
    init_states = []
    for layer in range(n_layers):
        
        init_states.append(np.zeros(hyperparams['x_dim'][layer]))

    return init_states

In [5]:
# https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html
# https://www.tensorflow.org/datasets/catalog/omniglot
# https://www.tensorflow.org/datasets/api_docs/python/tfds/load

(full_train_set, test_dataset), ds_info = \
tfds.load('Omniglot', split = ['train', 'test'], shuffle_files = True, as_supervised = False, with_info = True)

def prepare_image(dictionary):
    
    # remove redundant channels
    dictionary['image'] = dictionary['image'][:,:,0]
    
    # invert image so drawn pixels are 1
    dictionary['image'] = tf.cast(dictionary['image'] == 0, tf.float32)

    return dictionary

def transform_dataset(dataset, batch_size, tfds_seed):
    
    dataset = dataset.cache()
    dataset = dataset.shuffle(tf.data.experimental.cardinality(dataset).numpy(), seed = tfds_seed, reshuffle_each_iteration = True)
    dataset = dataset.batch(batch_size, drop_remainder = True).prefetch(1)
    
    return dataset

full_train_set = full_train_set.map(prepare_image, num_parallel_calls = tf.data.AUTOTUNE)

validation_split = 0.2
num_data = tf.data.experimental.cardinality(full_train_set).numpy()
train_dataset = full_train_set.take(num_data * (1 - validation_split))
val_dataset = full_train_set.take(num_data * (validation_split))

tfds_seed = 0
batch_size = 5
train_dataset = transform_dataset(train_dataset, batch_size, tfds_seed)

In [6]:
# primary hyperparameters
hyperparams = {'jax_seed': 0,
               'tfds_seed': tfds_seed,
               'x_dim': [20, 50, 200],
               'alpha_fraction': 0.1,
               'dt': [0.01, 0.01, 0.01],
               'time_steps': 100,
               'var_min': 1e-16,
               'smooth_max_parameter': 1e3,
               'init_pen_log_var': 10.0,
               'image_dim': np.array(ds_info.features['image'].shape[:2])}

# secondary hyperparameters (derived from primary hyperparameters)
hyperparams['n_loops'] = [int(np.ceil(i * hyperparams['alpha_fraction'])) for i in hyperparams['x_dim']]
hyperparams['x_pixels'] = np.linspace(0.5, hyperparams['image_dim'][1] - 0.5, hyperparams['image_dim'][1])
hyperparams['y_pixels'] = np.linspace(0.5, hyperparams['image_dim'][0] - 0.5, hyperparams['image_dim'][0])

optimisation_hyperparams = {'kl_warmup_start': 500,
                            'kl_warmup_end': 1000,
                            'kl_min': 0.01,
                            'kl_max': 1,
                            'adam_b1': 0.9,
                            'adam_b2': 0.999,
                            'adam_eps': 1e-8,
                            'weight_decay': 0.0001,
                            'max_grad_norm': 10,
                            'step_size': 0.001,
                            'decay_steps': 1,
                            'decay_factor': 0.9999,
                            'batch_size': batch_size,
                            'print_every': 1,
                            'n_epochs': 10,
                            'n_batches': len(train_dataset),
                            'min_delta': 1e-3,
                            'patience': 2}

In [7]:
# explicitly generate a PRNG key
key = random.PRNGKey(hyperparams['jax_seed'])

# generate the required number of subkeys
key, *subkeys = random.split(key, num = 3)

# initialise model parameters and LDS states
x0 = initialise_LDS_states(hyperparams)
model = VAE(n_loops_top_layer = hyperparams['n_loops'][0], x_dim_top_layer = hyperparams['x_dim'][0], T = hyperparams['time_steps'])
params = {}
params['prior_z_log_var'] = np.log(0.1)
params['decoder'] = initialise_decoder_parameters(hyperparams, subkeys[0])
init_params = model.init(data = np.ones((1, hyperparams['image_dim'][0], hyperparams['image_dim'][1], 1)), 
                         params = params['decoder'], hyperparams = hyperparams, key = subkeys[1], 
                         A = construct_dynamics_matrix(params['decoder'], hyperparams), x0 = x0, 
                         rngs = {'params': random.PRNGKey(0)})['params']

# concatenate all params into one dictionary
init_params = unfreeze(init_params)
init_params = init_params | params
init_params = freeze(init_params)

In [8]:
# # run the model
# A = construct_dynamics_matrix(init_params, hyperparams)
# results = model.apply({'params': init_params}, training_data[0,:,:], params, hyperparams, key, A, x0)

In [9]:
def kl_scheduler(optimisation_hyperparams):
    
    kl_warmup_start = optimisation_hyperparams['kl_warmup_start']
    kl_warmup_end = optimisation_hyperparams['kl_warmup_end']
    kl_min = optimisation_hyperparams['kl_min']
    kl_max = optimisation_hyperparams['kl_max']
    n_batches = optimisation_hyperparams['n_batches']
    
    kl_schedule = []
    for i_batch in range(n_batches):
        
        warm_up_fraction = min(max((i_batch - kl_warmup_start) / (kl_warmup_end - kl_warmup_start), 0), 1)
        
        kl_schedule.append(kl_min + warm_up_fraction * (kl_max - kl_min))

    return iter(kl_schedule)

def create_train_state(model, params, optimisation_hyperparams):
    
    lr_scheduler = optax.exponential_decay(optimisation_hyperparams['step_size'], 
                                        optimisation_hyperparams['decay_steps'], 
                                        optimisation_hyperparams['decay_factor'])

    optimiser = optax.chain(optax.adamw(learning_rate = lr_scheduler, 
                            b1 = optimisation_hyperparams['adam_b1'],
                            b2 = optimisation_hyperparams['adam_b2'],
                            eps = optimisation_hyperparams['adam_eps'],
                            weight_decay = optimisation_hyperparams['weight_decay']),
                            optax.clip_by_global_norm(optimisation_hyperparams['max_grad_norm']))
    
    state = train_state.TrainState.create(apply_fn = model.apply, params = init_params, tx = optimiser)

    return state, lr_scheduler

def apply_model(state, data, hyperparams, key, A, x0):

    return state.apply_fn({'params': {'encoder': state.params['encoder']}}, data, state.params['decoder'], hyperparams, key, A, x0)
    # return state.apply_fn({'params': state.params}, data, state.params['decoder'], hyperparams, key, A, x0)

batch_apply_model = vmap(apply_model, in_axes = (None, 0, None, 0, None, None))
    
def loss_fn(params, state, hyperparams, data, x0, kl_weight, key):

    def cross_entropy_loss(hyperparams, p_xy_t, data):
    
        # compute the smooth maximum of the per-pixel bernoulli parameter across time steps
        p_xy = np.sum(p_xy_t * nn.activation.softmax(p_xy_t * hyperparams['smooth_max_parameter'], axis = 0), axis = 0)

        # compute the logit for each pixel
        logits = np.log(p_xy / (1 - p_xy))

        # compute the average cross entropy across pixels
        cross_entropy = np.mean(optax.sigmoid_binary_cross_entropy(logits, data))

        return cross_entropy

    def KL_diagonal_Gaussians(mu_0, log_var_0, mu_1, log_var_1, hyperparams):
        """
        KL(q||p), where q is posterior and p is prior
        mu_0, log_var_0 is the mean and log variances of the prior
        mu_1, log_var_1 is the mean and log variances of the posterior
        var_min is added to the variances for numerical stability
        """
        log_var_1 = stabilise_varance(log_var_1, hyperparams)

        return np.sum(0.5 * (log_var_0 - log_var_1 + np.exp(log_var_1 - log_var_0) 
                             - 1.0 + (mu_1 - mu_0)**2 / np.exp(log_var_0)))

    batch_cross_entropy_loss = vmap(cross_entropy_loss, in_axes = (None, 0, 0))
    batch_KL_diagonal_Gaussians = vmap(KL_diagonal_Gaussians, in_axes = (None, None, 0, 0, None))

    A = construct_dynamics_matrix(params['decoder'], hyperparams)
    
    # apply the model
    batch_size = data.shape[0]
    subkeys = random.split(key, batch_size)
    results = batch_apply_model(state, data, hyperparams, subkeys, A, x0)

    # calculate cross entropy
    cross_entropy = batch_cross_entropy_loss(hyperparams, results['p_xy_t'], data).mean()

    # calculate KL divergence between the approximate posterior and prior over the latents
    mu_0 = 0
    log_var_0 = params['prior_z_log_var']
    mu_1 = results['z_mean']
    log_var_1 = results['z_log_var']
    kl_loss_prescale = batch_KL_diagonal_Gaussians(mu_0, log_var_0, mu_1, log_var_1, hyperparams).mean()
    kl_loss = kl_weight * kl_loss_prescale

    loss = cross_entropy + kl_loss

    all_losses = {'total': loss, 'cross_entropy': cross_entropy, 'kl': kl_loss, 'kl_prescale': kl_loss_prescale}

    return loss, all_losses

loss_grad = value_and_grad(loss_fn, has_aux = True)
eval_step_jit = jit(loss_fn)

def train_step(state, hyperparams, training_data, x0, kl_weight, key):
    
    (loss, all_losses), grads = loss_grad(state.params, state, hyperparams, training_data, x0, kl_weight, key)

    state = state.apply_gradients(grads = grads)
    
    gradient_fn = jax.value_and_grad(loss_fn, has_aux = True)
    
    return state, loss, all_losses

train_step_jit = jit(train_step)

def print_metrics(phase, duration, t_losses, v_losses, batch_range = [], lr = [], epoch = []):
    
    if phase == "batch":
        
        s1 = '\033[1m' + "Batches {}-{} in {:.2f} seconds, learning rate: {:.5f}" + '\033[0m'
        print(s1.format(batch_range[0], batch_range[1], duration, lr))
        
    elif phase == "epoch":
        
        s1 = '\033[1m' + "Epoch {} in {:.1f} minutes" + '\033[0m'
        print(s1.format(epoch, duration / 60))
        
    s2 = """  Training losses {:.4f} = cross entropy {:.4f} + KL {:.4f} ({:.4f})"""
    s3 = """  Validation losses {:.4f} = cross entropy {:.4f} + KL {:.4f} ({:.4f})"""
    s3 = """  Validation losses {:.4f} = cross entropy {:.4f} + KL {:.4f} ({:.4f})"""
    print(s2.format(t_losses['total'].mean(), t_losses['cross_entropy'].mean(),
                    t_losses['kl'].mean(), t_losses['kl_prescale'].mean()))
    print(s3.format(v_losses['total'].mean(), v_losses['cross_entropy'].mean(),
                    v_losses['kl'].mean(), v_losses['kl_prescale'].mean()))
    
    if phase == "epoch":
        print("""\n""")
        
def write_to_tensorboard(writer, t_losses, v_losses, epoch):

    writer.scalar('loss (train)', t_losses['total'].mean(), epoch)
    writer.scalar('cross entropy (train)', t_losses['cross_entropy'].mean(), epoch)
    writer.scalar('KL (train)', t_losses['kl'].mean(), epoch)
    writer.scalar('KL prescale (train)', t_losses['kl_prescale'].mean(), epoch)
    writer.scalar('loss (validation)', v_losses['total'].mean(), epoch)
    writer.scalar('cross entropy (validation)', v_losses['cross_entropy'].mean(), epoch)
    writer.scalar('KL (validation)', v_losses['kl'].mean(), epoch)
    writer.scalar('KL prescale (validation)', v_losses['kl_prescale'].mean(), epoch)
    writer.flush()

def optimise_VAE(init_params, x0, hyperparams, model, training_data, validation_data, 
                           optimisation_hyperparams, key, ckpt_dir, writer):

    kl_schedule = kl_scheduler(optimisation_hyperparams)
    
    state, lr_scheduler = create_train_state(model, params, optimisation_hyperparams)
    
    # set early stopping criteria
    early_stop = EarlyStopping(min_delta = optimisation_hyperparams['min_delta'], 
                               patience = optimisation_hyperparams['patience'])
    
    # loop over epochs
    n_epochs = optimisation_hyperparams['n_epochs']
    print_every = optimisation_hyperparams['print_every']
    n_batches = optimisation_hyperparams['n_batches']
    losses = {}
    for epoch in range(n_epochs):
        
        # start epoch timer
        epoch_start_time = time.time()
        
        # convert the tf.data.Dataset training_data into an iterable
        # this iterable is shuffled differently each epoch
        train_datagen = iter(tfds.as_numpy(training_data))

        # generate subkeys
        key, *training_subkeys = random.split(key, num = n_batches + 1)
        key, *validation_subkeys = random.split(key, num = int(n_batches / print_every) + 1)

        # initialise the losses and the timer
        training_losses = {'total': 0, 'cross_entropy': 0, 'kl': 0, 'kl_prescale': 0}
        batch_start_time = time.time()

        # loop over batches
        for i in range(n_batches):

            state, loss, all_losses = train_step_jit(state, hyperparams, np.array(next(train_datagen)['image']), 
                                                     x0, np.array(next(kl_schedule)), training_subkeys[i])

            # training losses (average of 'print_every' batches)
            training_losses = tree_map(lambda x, y: x + y / print_every, training_losses, all_losses)

            if (i + 1) % print_every == 0:

                # calculate loss on validation data
                # _, validation_losses = eval_step_jit(state.params, state, hyperparams, validation_data, x0, kl_weight, 
                                                     # validation_subkeys[int((i + 1) / print_every) - 1])
                validation_losses = training_losses
                    
                # end batches timer
                batches_duration = time.time() - batch_start_time

                # print metrics
                print_metrics("batch", batches_duration, training_losses, training_losses, 
                              batch_range = [i + 1 - print_every + 1, i + 1], lr = lr_scheduler(i + epoch * n_batches))

                # store losses
                if (i + 1) == print_every:
                    
                    t_losses_thru_training = copy(training_losses)
                    v_losses_thru_training = copy(validation_losses)
                    
                else:
                    
                    t_losses_thru_training = tree_map(lambda x, y: np.append(x, y), t_losses_thru_training, training_losses)
                    v_losses_thru_training = tree_map(lambda x, y: np.append(x, y), v_losses_thru_training, validation_losses)

                # re-initialise the losses and timer
                training_losses = {'total': 0, 'cross_entropy': 0, 'kl': 0, 'kl_prescale': 0}
                batch_start_time = time.time()

        losses['epoch ' + str(epoch)] = {'t_losses' : t_losses_thru_training, 'v_losses' : v_losses_thru_training}
        
        # end epoch timer
        epoch_duration = time.time() - epoch_start_time
        
        # print losses (mean over all batches in epoch)
        print_metrics("epoch", epoch_duration, t_losses_thru_training, v_losses_thru_training, epoch = epoch + 1)

        # write metrics to tensorboard
        write_to_tensorboard(writer, t_losses_thru_training, v_losses_thru_training, epoch)
        
        # save checkpoint
        ckpt = {'train_state': state, 'losses': losses, 'hyperparams': hyperparams, 'optimisation_hyperparams': optimisation_hyperparams}
        checkpoints.save_checkpoint(ckpt_dir = ckpt_dir, target = ckpt, step = epoch)
        
        # if early stopping criteria met, break
        _, early_stop = early_stop.update(v_losses_thru_training['total'].mean())
        if early_stop.should_stop:
            
            print('Early stopping criteria met, breaking...')
            
            break
            
    return state, losses

In [None]:
from jax.config import config
config.update("jax_debug_nans", False)
config.update("jax_disable_jit", False)
# use xeus-python kernel -- Python 3.9 (XPython) -- for debugging
# typing help at a breakpoint() gives you list of available commands

from flax.metrics import tensorboard
log_folder = "runs/exp9/profile"
writer = tensorboard.SummaryWriter(log_folder)
%load_ext tensorboard
%tensorboard --logdir=runs/exp9

ckpt_dir = 'tmp/flax-checkpointing'

import shutil, os
if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir) # remove any existing checkpoints from the last notebook run

state, losses = \
optimise_VAE(init_params, x0, hyperparams, model, train_dataset, val_dataset, 
             optimisation_hyperparams, key, ckpt_dir, writer)

# # restore checkpoint
# ckpt = {'train_state': state, 'losses': losses, 'hyperparams': hyperparams, 'optimisation_hyperparams': optimisation_hyperparams}
# restored_state = checkpoints.restore_checkpoint(ckpt_dir = ckpt_dir, target = ckpt)

Launching TensorBoard...
Please visit http://localhost:6006 in a web browser.
[1mBatches 1-1 in 16.42 seconds, learning rate: 0.00100[0m
  Training losses 1.4123 = cross entropy 0.6687 + KL 0.7436 (74.3640)
  Validation losses 1.4123 = cross entropy 0.6687 + KL 0.7436 (74.3640)
[1mBatches 2-2 in 14.45 seconds, learning rate: 0.00100[0m
  Training losses 1.4387 = cross entropy 0.6997 + KL 0.7390 (73.9018)
  Validation losses 1.4387 = cross entropy 0.6997 + KL 0.7390 (73.9018)
[1mBatches 3-3 in 11.61 seconds, learning rate: 0.00100[0m
  Training losses 1.4506 = cross entropy 0.7069 + KL 0.7437 (74.3652)
  Validation losses 1.4506 = cross entropy 0.7069 + KL 0.7437 (74.3652)
[1mBatches 4-4 in 2.41 seconds, learning rate: 0.00100[0m
  Training losses 1.4021 = cross entropy 0.6639 + KL 0.7382 (73.8201)
  Validation losses 1.4021 = cross entropy 0.6639 + KL 0.7382 (73.8201)
[1mBatches 5-5 in 2.43 seconds, learning rate: 0.00100[0m
  Training losses 1.4944 = cross entropy 0.7566 + K

In [None]:
# # run model and plot
# T = 100
# data = training_data[0,:,:]
# hyperparams['dt'] = 0.01
# hyperparams['time_steps'] = 1000

# params = init_params
# results_encode = encode(models, params, data[None,:,:])
# _, z_sampler, _ = models
# z1, z2 = z_sampler(results_encode, params, hyperparams, key)
# A = construct_dynamics_matrix(params, hyperparams)
# results_decode = batch_decode(params, hyperparams, models, A, x0, T, z1, z2)
# results = results_encode | results_decode

# for ex in range(1):
#     plt.scatter(results['pen_xy'][ex][:,0],results['pen_xy'][ex][:,1], c = 'k', alpha =  np.exp(results['pen_down_log_p'][ex,:]))
# plt.ylim(0,105)
# plt.xlim(0,105)
# plt.show()

In [None]:
# most of time when running model forward is decoder and x-entropy functions

# replace cnn with capsule?
# do i need to standardize images (and how) or is normalised (as is) fine?
# ultimately replace pen with myosuite finger to make it a finger painting task - could have flexion/extension determine finger up/down

# pen_xy0 options: always start in center of canvas (105/2, 105/2), randomise and use extra feature dimension in image, let the CNN choose
# consider learning initial neural states
# add control costs? e.g. squared velocity costs?
# add batch norm and other tricks to CNN, dropout?

# if there are l layers (e.g. 3), the top-layer alphas on the last l-1 time steps (e.g. 2) don't influence the state of the lowest layer and hence the objective
# when n_layer = 3, top-layer alphas at time 1 inlfuence the state of the lowest layer and hence the objective at time 3, and so on
# this leads to zeros in the biases and columns of weight matrix in last dense layer

# relu activation function, and to some extent binary image data, can lead to zero gradients scattered throughout CNN
# these gradients go to zero if you change relu to tanh and make data continuous on [0, 1], so not pathological, i don't think

# print values without tracer information
# jax.debug.print("{z_mean}", z_mean = z_mean)
# jax.debug.breakpoint() - didn't work for me