In [None]:
import jax
from jax import random, lax, jit, vmap, value_and_grad
from jax.tree_util import tree_map
from jax.nn import initializers
import jax.numpy as np
import numpy as onp
from flax import linen as nn

import optax
from flax.training import checkpoints
from flax.training.early_stopping import EarlyStopping

import time
from copy import copy

In [2]:
class CNN(nn.Module):
    n_loops_top_layer: int
    x_dim_top_layer: int

    @nn.compact
    def __call__(self, x):
        
        # CNN based on End-to-End Training of Deep Visuomotor Policies
        x = nn.Conv(features = 64, kernel_size = (7, 7))(x)
        x = nn.relu(x)
        x = nn.Conv(features = 32, kernel_size = (5, 5))(x)
        x = nn.relu(x)
        x = nn.Conv(features = 32, kernel_size = (5, 5))(x)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1)) # flatten
        x = nn.Dense(features = 64)(x)
        x = nn.relu(x)
        x = nn.Dense(features = 40)(x)
        x = nn.relu(x)
        x = nn.Dense(features = 40)(x)
        x = nn.relu(x)
        x = nn.Dense(features = (self.n_loops_top_layer + self.x_dim_top_layer) * 2)(x)
        
        # # CNN based on deepmimic
        # x = nn.Conv(features = 16, kernel_size = (8, 8))(x)
        # x = nn.relu(x)
        # x = nn.Conv(features = 32, kernel_size = (4, 4))(x)
        # x = nn.relu(x)
        # x = nn.Conv(features = 32, kernel_size = (4, 4))(x)
        # x = nn.relu(x)
        # x = x.reshape((x.shape[0], -1)) # flatten
        # x = nn.Dense(features = 64)(x)
        # x = nn.relu(x)
        # x = nn.Dense(features = 1024)(x)
        # x = nn.relu(x)
        # x = nn.Dense(features = 512)(x)
        # x = nn.relu(x)
        # x = nn.Dense(features = (self.n_loops_top_layer + self.x_dim_top_layer) * 2)(x)
        
        # mean and log variances of Gaussian distribution over latents
        z_mean, z_log_var = np.split(x, 2, axis = 1)
        
        return z_mean, z_log_var
    
class pen_actions_readout(nn.Module):

    @nn.compact
    def __call__(self, x):
        
        # linear layer
        outputs = nn.Dense(3)(x)

        # pen velocity in x and y directions
        d_xy = outputs[:2]

        # log probability that the pen is down (actively drawing)
        pen_down_log_p = nn.log_sigmoid(outputs[2])
    
        return d_xy, pen_down_log_p
    
class sample_latents():
    
    def __init__(self, n_loops_top_layer):
        self.n_loops_top_layer = n_loops_top_layer
    
    def __call__(self, results, params, hyperparams, key):

        # sample the latents
        z = diag_Gaussian_sample(results['z_mean'], results['z_log_var'], hyperparams, key)

        # split into latents that determine the top-layer alphas (z1) and initial state (z2)
        z1, z2 = np.split(z, [self.n_loops_top_layer], axis = 1)

        return nn.activation.softmax(z1 / params['dynamics']['t'][0], axis = 1), z2
    
def update_pen_position(pen_xy, d_xy, hyperparams):
    
    # candidate new pen position
    pen_xy = pen_xy + d_xy
    
    # align pen position relative to centre of canvas
    pen_xy = pen_xy - hyperparams['image_dim'] / 2
    
    # transform canvas boundaries to -/+ 5
    pen_xy = pen_xy * 2 / hyperparams['image_dim'] * 5
    
    # squash pen position to be within canvas boundaries
    pen_xy = nn.sigmoid(pen_xy)
    
    # transform canvas boundaries back to their original values
    pen_xy_new = pen_xy * hyperparams['image_dim']
    
    return pen_xy_new

def compute_alphas(W, x, b, t):
    
    return nn.activation.softmax( (W @ x + b) / t, axis = 0)

def compute_inputs(W, x):
    
    return W @ x

def update_state(A, x, alphas, u, dt):
    
    return x + (np.sum(alphas[:, None, None] * A, axis = 0) @ x + u) * dt

def decode_one_step(carry, inputs):
    
    params, hyperparams, A, x, pen_xy = carry
    top_layer_alphas = inputs
    
    # compute the alphas
    alphas = jax.tree_map(compute_alphas, params['dynamics']['W_a'], x[:2], params['dynamics']['b'], 
                          params['dynamics']['t'][1:])

    # prepend the top-layer alphas
    alphas.insert(0, top_layer_alphas)
    
    # compute the additive inputs
    u = jax.tree_map(compute_inputs, params['dynamics']['W_u'], x[:2])
    
    # prepend the top-layer additive inputs
    u.insert(0, x[0] * 0)

    # update the states
    x_new = jax.tree_map(update_state, A, x, alphas, u, hyperparams['dt'])

    # readout the pen actions (pen velocity and pen down probability) from the state at the bottom layer
    d_xy, pen_down_log_p = readout.apply(params['readout'], x = x_new[-1])
    
    # calculate the per-pixel bernoulli parameter
    p_xy = per_pixel_bernoulli_parameter(params, hyperparams, pen_xy, pen_down_log_p)
    
    # update the pen position based on the pen velocity
    pen_xy_new = update_pen_position(pen_xy, d_xy, hyperparams)

    carry = params, hyperparams, A, x_new, pen_xy_new
    outputs = alphas, x_new, pen_xy_new, p_xy, pen_down_log_p
    
    return carry, outputs

def decode(params, hyperparams, models, A, x0, T, z1, z2):

    x0[0] = z2
    pen_xy0 = hyperparams['image_dim'] / 2 # initialise pen in centre of canvas

    # decoder = lambda state, inputs: decode_one_step(params, hyperparams, models, A, state = state, inputs = inputs)
    # _, (alphas, x, pen_xy, p_xy_t, pen_down_log_p) = lax.scan(decoder, (x0, pen_xy0), np.repeat(z1[None,:], T, axis = 0))
    
    # _, _, readout = models
    carry = params, hyperparams, A, x0, pen_xy0
    inputs = np.repeat(z1[None,:], T, axis = 0)
    
    _, (alphas, x, pen_xy, p_xy_t, pen_down_log_p) = lax.scan(decode_one_step, carry, inputs)
    
    return {'alphas': alphas,
            'x0': x0,
            'x': x,
            'pen_xy0': pen_xy0,
            'pen_xy': pen_xy,
            'p_xy_t': p_xy_t,
            'pen_down_log_p': pen_down_log_p}

batch_decode = vmap(decode, in_axes = (None, None, None, None, None, None, 0, 0))

def encode(models, params, data):

    encoder, _, _ = models
    z_mean, z_log_var = encoder.apply(params['encoder'], x = data)

    return {'z_mean': z_mean, 
            'z_log_var': z_log_var}
    
def losses(params, hyperparams, models, data, x0, T, kl_weight, key):
    
    # pass the data through the encoder
    results_encode = encode(models, params, data)
    
    # sample the latent variables from the approximate posterior
    _, z_sampler, _ = models
    z1, z2 = z_sampler(results_encode, params, hyperparams, key)

    # pass the latent variables through the decoder
    A = construct_dynamics_matrix(params, hyperparams)
    results_decode = batch_decode(params, hyperparams, models, A, x0, T, z1, z2)
    
    results = results_encode | results_decode

    # cross entropy given latent variables
    cross_entropy = batch_cross_entropy_loss(params, hyperparams, results['p_xy_t'], data).mean()
    
    # KL divergence between the approximate posterior and prior over the latents
    mu_0 = 0
    log_var_0 = params['prior_z_log_var']
    mu_1 = results['z_mean']
    log_var_1 = results['z_log_var']
    kl_loss_prescale = batch_KL_diagonal_Gaussians(mu_0, log_var_0, mu_1, log_var_1, hyperparams).mean()
    kl_loss = kl_weight * kl_loss_prescale
    
    loss = cross_entropy + kl_loss
    
    all_losses = {'total': loss, 'cross_entropy': cross_entropy, 'kl': kl_loss, 'kl_prescale': kl_loss_prescale}
    
    return loss, all_losses

losses_jit = jit(losses, static_argnums = (2, 5))
training_loss_grad = value_and_grad(losses_jit, has_aux = True)

def construct_dynamics_matrix(params, hyperparams):
    
    def construct_P(P):
        return P @ P.T
    
    def construct_S(U, V):
        return U @ np.transpose(V, (0, 2, 1)) - V @ np.transpose(U, (0, 2, 1))
    
    def construct_A(L, P, S):
        return (-L @ np.transpose(L, (0, 2, 1)) + S) @ P
    
    # positive semi-definite matrix P
    P = jax.tree_map(construct_P, params['dynamics']['P'])

    # skew symmetric matrix S
    S = jax.tree_map(construct_S, params['dynamics']['S_U'], params['dynamics']['S_V'])

    # dynamics matrix A (loops organised along axis 0)
    A = jax.tree_map(construct_A, params['dynamics']['L'], S, P)

    return A

def stabilise_varance(log_var, hyperparams):
    """
    var_min is added to the variances for numerical stability
    """
    return np.log(np.exp(log_var) + hyperparams['var_min'])

def diag_Gaussian_sample(mean, log_var, hyperparams, key):
    """
    sample from a diagonal Gaussian distribution
    """
    log_var = stabilise_varance(log_var, hyperparams)
    
    return mean + np.exp(0.5 * log_var) * random.normal(key, mean.shape)

def log_Gaussian_kernel(x, mu, log_var, hyperparams):
    """
    calculate the log likelihood of x under a diagonal Gaussian distribution
    """
    log_var = stabilise_varance(log_var, hyperparams)
    
    return -0.5 * (x - mu)**2 / np.exp(log_var)

def log_likelihood_diagonal_Gaussian(x, mu, log_var, hyperparams):
    """
    calculate the log likelihood of x under a diagonal Gaussian distribution
    """
    log_var = stabilise_varance(log_var, hyperparams)
    
    return np.sum(-0.5 * (log_var + np.log(2 * np.pi) + (x - mu)**2 / np.exp(log_var)))

def per_pixel_bernoulli_parameter(params, hyperparams, pen_xy, pen_down_log_p):
    
    ll_p_x = log_Gaussian_kernel(pen_xy[0], hyperparams['x_pixels'], params['pen_log_var'], hyperparams)
    ll_p_y = log_Gaussian_kernel(pen_xy[1], hyperparams['y_pixels'], params['pen_log_var'], hyperparams)
    
    p_xy_t = np.exp(ll_p_x[None,:] + ll_p_y[:,None] + pen_down_log_p)
    
    return p_xy_t

def cross_entropy_loss(params, hyperparams, p_xy_t, data):
    
    # compute the smooth maximum of the per-pixel bernoulli parameter across time steps
    p_xy = np.sum(p_xy_t * nn.activation.softmax(p_xy_t * hyperparams['smooth_max_parameter'], axis = 0), axis = 0)

    # compute the logit for each pixel
    logits = np.log(p_xy / (1 - p_xy))
    
    # compute the average cross entropy across pixels
    cross_entropy = np.mean(optax.sigmoid_binary_cross_entropy(logits, data))

    return cross_entropy

batch_cross_entropy_loss = vmap(cross_entropy_loss, in_axes = (None, None, 0, 0))

def KL_diagonal_Gaussians(mu_0, log_var_0, mu_1, log_var_1, hyperparams):
    """
    KL(q||p), where q is posterior and p is prior
    mu_0, log_var_0 is the mean and log variances of the prior
    mu_1, log_var_1 is the mean and log variances of the posterior
    var_min is added to the variances for numerical stability
    """
    log_var_1 = stabilise_varance(log_var_1, hyperparams)
    
    return np.sum(0.5 * (log_var_0 - log_var_1 + np.exp(log_var_1 - log_var_0) 
                         - 1.0 + (mu_1 - mu_0)**2 / np.exp(log_var_0)))

batch_KL_diagonal_Gaussians = vmap(KL_diagonal_Gaussians, in_axes = (None, None, 0, 0, None))

In [3]:
def initialise_dynamics_parameters(hyperparams,  key):
    
    P = []
    S_U = []
    S_V = []
    L = []
    W_u = []
    W_a = []
    b = []
    t = []
    
    n_layers = len(hyperparams['x_dim'])
    for layer in range(n_layers):
        
        key, *subkeys = random.split(key, num = 7)
        
        n_loops = hyperparams['n_loops'][layer]
        x_dim = hyperparams['x_dim'][layer]
        
        # parameters of layer-specific P
        p = random.normal(subkeys[0], (x_dim, x_dim))
        
        # set trace of P @ P.T to x_dim
        P.append(p * np.sqrt(x_dim / np.trace(p @ p.T)))
        
        # parameters of layer- and loop-specific S
        u = random.normal(subkeys[1], (n_loops, x_dim, int(x_dim / n_loops)))
        v = random.normal(subkeys[2], (n_loops, x_dim, int(x_dim / n_loops)))
        
        # set variance of elements of S to 1/(n_loops * x_dim)
        s = u @ np.transpose(v, (0, 2, 1)) - v @ np.transpose(u, (0, 2, 1))
        # f = 1 / np.linalg.norm(s, axis = (1, 2))[:, None, None] / n_loops # frobenius norm of each loop 1/n_loops
        f = 1 / np.std(s, axis = (1, 2))[:, None, None] / np.sqrt(n_loops * x_dim)
        S_U.append(u * np.sqrt(f))
        S_V.append(v * np.sqrt(f))

        # parameters of layer- and loop-specific L
        Q, _ = np.linalg.qr(random.normal(subkeys[3], (x_dim, x_dim)))
        L_i = np.split(Q * np.sqrt(n_loops), n_loops, axis = 1)
        L.append(np.stack(L_i, axis = 0))

        # parameters of the mapping from hidden states to alphas, alphas = softmax(W @ x + b, temperature)
        if layer != 0:
            
            # weights for additive inputs
            std_W = 1 / np.sqrt(hyperparams['x_dim'][layer - 1])
            W_u.append(random.normal(subkeys[4], (hyperparams['x_dim'][layer], hyperparams['x_dim'][layer - 1])) * std_W)
            
            # weights for modulatory factors
            W_a.append(random.normal(subkeys[5], (n_loops, hyperparams['x_dim'][layer - 1])) * std_W)

            # bias for modulatory factors
            b.append(np.zeros((n_loops)))
            
        # temperature of layer-specific softmax function
        t.append(1.0)

    return {'P': P, 
            'S_U': S_U,
            'S_V': S_V, 
            'L': L, 
            'W_u': W_u,
            'W_a': W_a, 
            'b': b,
            't': t}

def initialise_hidden_states(hyperparams):
              
    # initialise the hidden states of the decoder model to zero (not learned)
    n_layers = len(hyperparams['x_dim'])
    init_states = []
    for layer in range(n_layers):
        init_states.append(np.zeros(hyperparams['x_dim'][layer]))
        
    return init_states

In [4]:
import os
import random as rdm
from sys import platform as sys_pf
import matplotlib
if sys_pf == 'darwin':
    matplotlib.use("TkAgg")
from matplotlib import pyplot as plt

%matplotlib inline

def num2str(idx):
    if idx < 10:
        return '0'+str(idx)
    return str(idx)

def load_img(fn):
    I = plt.imread(fn)
    I = np.array(I,dtype=bool)
    return I

def load_motor(fn):
    motor = []
    with open(fn,'r') as fid:
        lines = fid.readlines()
    lines = [l.strip() for l in lines]
    for myline in lines:
        if myline =='START': # beginning of character
            stk = []
        elif myline =='BREAK': # break between strokes
            stk = np.array(stk)
            motor.append(stk) # add to list of strokes
            stk = [] 
        else:
            arr = np.fromstring(myline,dtype=float,sep=',')
            stk.append(arr)
    return motor

img_dir = '/Users/James/Dropbox/James MacBook/Guillaume/omniglot/omniglot/python/images_background'
stroke_dir = '/Users/James/Dropbox/James MacBook/Guillaume/omniglot/omniglot/python/strokes_background'
pth_i = os.path.join(img_dir, 'Sanskrit')
pth_s = os.path.join(stroke_dir, 'Sanskrit')
n_characters = len([s for s in os.listdir(pth_i) if 'character' in s])
n_reps_per_character = 20
training_data = onp.empty((n_characters * n_reps_per_character, 105, 105))
for character in range(1):
    
    # get directories for this character
    img_char_dir = os.path.join(pth_i, 'character' + num2str(character + 1))
    stroke_char_dir = os.path.join(pth_s, 'character' + num2str(character + 1))
    
    # get base file name for this character
    # print(os.listdir(img_char_dir)[0])
    fn_example = os.listdir(img_char_dir)[0]
    fn_base = fn_example[:fn_example.find('_')] 
    
    for rep in range(1):
        
        fn_img = img_char_dir + '/' + fn_base + '_' + num2str(rep + 1) + '.png'
        I = load_img(fn_img) == False # ensures letter pixels are 1
        training_data[character * n_reps_per_character + rep,:,:] = I
        
        fn_stk = stroke_char_dir + '/' + fn_base + '_' + num2str(rep + 1) + '.txt'
        # motor = load_motor(fn_stk)

training_data = np.array(training_data)

# # plot images and strokes
# plt.imshow(training_data[0,:,:], cmap = 'binary')
# plt.show()
# for ex in range(1):
#     for m in motor:
#         plt.scatter(m[:,0],m[:,1], c = 'k')
# plt.ylim(-105, 0)
# plt.xlim(0, 105)
# plt.show()

In [5]:
# explicitly generate a PRNG key
seed = 0
key = random.PRNGKey(seed)

# generate the required number of subkeys
key, *subkeys = random.split(key, num = 5)

image_height = 105
image_width = 105
validation_data = training_data[0:20,:,:]

# primary hyperparameters
hyperparams = {'image_dim': np.array(training_data.shape[1:]),
               'x_dim': [20, 50, 200],
               'alpha_fraction': 0.1,
               'dt': [0.01, 0.01, 0.01],
               'time_steps': 100,
               'var_min': 1e-16,
               'smooth_max_parameter': 1e3}

# secondary hyperparameters (derived from primary hyperparameters)
hyperparams['n_loops'] = [int(np.ceil(i * hyperparams['alpha_fraction'])) for i in hyperparams['x_dim']]
hyperparams['x_pixels'] = np.linspace(0.5, hyperparams['image_dim'][1] - 0.5, hyperparams['image_dim'][1])
hyperparams['y_pixels'] = np.linspace(0.5, hyperparams['image_dim'][0] - 0.5, hyperparams['image_dim'][0])

encoder = CNN(n_loops_top_layer = hyperparams['n_loops'][0], x_dim_top_layer = hyperparams['x_dim'][0])
z_sampler = sample_latents(n_loops_top_layer = hyperparams['n_loops'][0])
readout = pen_actions_readout()
models = (encoder, z_sampler, readout)

init_params = {}
init_params['encoder'] = encoder.init(x = np.ones((1, training_data.shape[1], training_data.shape[2])), 
                                      rngs = {'params': subkeys[0]})
init_params['prior_z_log_var'] = np.log(0.1)
init_params['dynamics'] = initialise_dynamics_parameters(hyperparams, subkeys[1])
init_params['readout'] = readout.init(x = np.ones((hyperparams['x_dim'][-1])), rngs = {'params': subkeys[2]})
init_params['pen_log_var'] = 10.0

x0 = initialise_hidden_states(hyperparams)

optimisation_hyperparams = {'kl_warmup_start': 500,
                            'kl_warmup_end': 1000,
                            'kl_min': 0.01,
                            'kl_max': 1,
                            'adam_b1': 0.9,
                            'adam_b2': 0.999,
                            'adam_eps': 1e-8,
                            'weight_decay': 0.0001,
                            'max_grad_norm': 10,
                            'step_size': 0.001,
                            'decay_steps': 1,
                            'decay_factor': 0.9999,
                            'batch_size': 5,
                            'print_every': 1,
                            'n_epochs': 10,
                            'min_delta': 1e-3,
                            'patience': 2}
optimisation_hyperparams['n_batches'] = int(training_data.shape[0] / optimisation_hyperparams['batch_size'])

In [6]:
def kl_scheduler(optimisation_hyperparams):
    
    kl_warmup_start = optimisation_hyperparams['kl_warmup_start']
    kl_warmup_end = optimisation_hyperparams['kl_warmup_end']
    kl_min = optimisation_hyperparams['kl_min']
    kl_max = optimisation_hyperparams['kl_max']
    n_batches = optimisation_hyperparams['n_batches']
    
    kl_schedule = []
    for i_batch in range(n_batches):
        
        warm_up_fraction = min(max((i_batch - kl_warmup_start) / (kl_warmup_end - kl_warmup_start),0),1)
        
        kl_schedule.append(kl_min + warm_up_fraction * (kl_max - kl_min))

    return np.array(kl_schedule)

def reshape_training_data(training_data):
    
    return np.reshape(training_data, (optimisation_hyperparams['n_batches'],
                                      optimisation_hyperparams['batch_size'],
                                      training_data.shape[1],
                                      training_data.shape[2]))

def optimize_dynamical_VAE_core(params, x0, hyperparams, models, training_data, validation_data, optimizer, optimizer_state, 
                                optimisation_hyperparams, kl_schedule, print_every, epoch, scheduler, key):
    
    n_batches = optimisation_hyperparams['n_batches']
    
    # generate subkeys
    key, *training_subkeys = random.split(key, num = n_batches + 1)
    key, *validation_subkeys = random.split(key, num = int(n_batches / print_every) + 1)
    
    # initialise the losses and the timer
    training_losses = {'total': 0, 'cross_entropy': 0, 'kl': 0, 'kl_prescale': 0}
    start_time = time.time()

    # loop over batches
    for i in range(n_batches):
        
        i_batch = i + epoch * n_batches

        kl_weight = kl_schedule[i_batch]
        
        (loss, all_losses), grad = training_loss_grad(params, hyperparams, models, training_data[i],
                                                      x0, hyperparams['time_steps'], kl_weight, training_subkeys[i]) 

        updates, optimizer_state = optimizer.update(grad, optimizer_state, params)

        params = optax.apply_updates(params, updates)
        
        # training losses (average of 'print_every' batches)
        training_losses = tree_map(lambda x, y: x + y / print_every, training_losses, all_losses)
        
        if (i + 1) % print_every == 0:
        
            # calculate loss on validation data
            _, validation_losses = losses_jit(params, hyperparams, models, validation_data, x0, 
                                              hyperparams['time_steps'], kl_weight, validation_subkeys[int((i + 1) / print_every)])
            
            # end batches timer
            batches_duration = time.time() - start_time

            # print and store losses
            s1 = '\033[1m' + "Batches {}-{} in {:.2f} seconds, step size: {:.5f}" + '\033[0m'
            s2 = """  Training losses {:.4f} = cross entropy {:.4f} + KL {:.4f} ({:.4f})"""
            s3 = """  Validation losses {:.4f} = cross entropy {:.4f} + KL {:.4f} ({:.4f})"""
            print(s1.format(i + 1 - print_every + 1, i + 1, batches_duration, scheduler(i_batch)))
            print(s2.format(training_losses['total'], training_losses['cross_entropy'],
                            training_losses['kl'], training_losses['kl_prescale']))
            print(s3.format(validation_losses['total'], validation_losses['cross_entropy'],
                            validation_losses['kl'], validation_losses['kl_prescale']))

            if (i + 1) == print_every:
                tlosses_thru_training = copy(training_losses)
                vlosses_thru_training = copy(validation_losses)
            else:
                tlosses_thru_training = tree_map(lambda x, y: np.append(x, y), tlosses_thru_training, training_losses)
                vlosses_thru_training = tree_map(lambda x, y: np.append(x, y), vlosses_thru_training, validation_losses)
            
            # re-initialise the losses and timer
            training_losses = {'total': 0, 'cross_entropy': 0, 'kl': 0, 'kl_prescale': 0}
            start_time = time.time()
            
    losses = {'tlosses' : tlosses_thru_training, 'vlosses' : vlosses_thru_training}

    return params, optimizer_state, losses

def optimize_dynamical_VAE(params, x0, hyperparams, models, training_data, validation_data, 
                           optimisation_hyperparams, key, ckpt_dir, writer):

    kl_schedule = kl_scheduler(optimisation_hyperparams)
    
    scheduler = optax.exponential_decay(optimisation_hyperparams['step_size'], 
                                        optimisation_hyperparams['decay_steps'], 
                                        optimisation_hyperparams['decay_factor'])

    optimizer = optax.chain(optax.adamw(learning_rate = scheduler, 
                            b1 = optimisation_hyperparams['adam_b1'],
                            b2 = optimisation_hyperparams['adam_b2'],
                            eps = optimisation_hyperparams['adam_eps'],
                            weight_decay = optimisation_hyperparams['weight_decay']),
                            optax.clip_by_global_norm(optimisation_hyperparams['max_grad_norm']))
    
    optimizer_state = optimizer.init(params)
    
    # reshape training data
    training_data = reshape_training_data(training_data)
    
    # set early stopping criteria
    early_stop = EarlyStopping(min_delta = optimisation_hyperparams['min_delta'], 
                               patience = optimisation_hyperparams['patience'])
    
    # loop over epochs
    n_epochs = optimisation_hyperparams['n_epochs']
    print_every = optimisation_hyperparams['print_every']
    losses = {}
    for epoch in range(n_epochs):
        
        # start epoch timer
        epoch_start_time = time.time()
        
        # generate subkeys
        key, *subkeys = random.split(key, num = 3)
        
        # shuffle the batches every epoch
        data = random.permutation(subkeys[0], training_data, axis = 0)
        
        # perform optimisation
        params, optimizer_state, losses['epoch ' + str(epoch)] = \
        optimize_dynamical_VAE_core(params, x0, hyperparams, models, data, validation_data, optimizer, optimizer_state, 
                                    optimisation_hyperparams, kl_schedule, print_every, epoch, scheduler, subkeys[1])
        
        # end epoch timer
        epoch_duration = time.time() - epoch_start_time
        
        # print metrics (mean over all batches in epoch) 
        s1 = '\033[1m' + "Epoch {} in {:.1f} minutes" + '\033[0m'
        s2 = """  Training losses {:.4f} = cross entropy {:.4f} + KL {:.4f} ({:.4f})"""
        s3 = """  Validation losses {:.4f} = cross entropy {:.4f} + KL {:.4f}, ({:.4f})\n"""
        print(s1.format(epoch + 1, epoch_duration / 60))
        tlosses = losses['epoch ' + str(epoch)]['tlosses']
        print(s2.format(tlosses['total'].mean(), tlosses['cross_entropy'].mean(),
                        tlosses['kl'].mean(), tlosses['kl_prescale'].mean()))
        vlosses = losses['epoch ' + str(epoch)]['vlosses']
        print(s3.format(vlosses['total'].mean(), vlosses['cross_entropy'].mean(),
                        vlosses['kl'].mean(), vlosses['kl_prescale'].mean()))

        # write metrics to tensorboard
        writer.scalar('loss (train)', tlosses['total'].mean(), epoch)
        writer.scalar('cross entropy (train)', tlosses['cross_entropy'].mean(), epoch)
        writer.scalar('KL (train)', tlosses['kl'].mean(), epoch)
        writer.scalar('KL prescale (train)', tlosses['kl_prescale'].mean(), epoch)
        writer.scalar('loss (validation)', vlosses['total'].mean(), epoch)
        writer.scalar('cross entropy (validation)', vlosses['cross_entropy'].mean(), epoch)
        writer.scalar('KL (validation)', vlosses['kl'].mean(), epoch)
        writer.scalar('KL prescale (validation)', vlosses['kl_prescale'].mean(), epoch)
        writer.flush()
        
        # save checkpoint
        ckpt = {'params': params, 'optimizer_state': optimizer_state, 'losses': losses}
        checkpoints.save_checkpoint(ckpt_dir = ckpt_dir, target = ckpt, step = epoch)
        
        # if early stopping criteria met, break
        _, early_stop = early_stop.update(vlosses['total'].mean())
        if early_stop.should_stop:
            
            print('Early stopping criteria met, breaking...')
            
            break
            
    return params, optimizer_state, losses

In [None]:
from jax.config import config
config.update("jax_debug_nans", True)
config.update("jax_disable_jit", False)
# use xeus-python kernel -- Python 3.9 (XPython) -- for debugging
# typing help at a breakpoint() gives you list of available commands

from flax.metrics import tensorboard
log_folder = "runs/exp9/profile"
writer = tensorboard.SummaryWriter(log_folder)
%load_ext tensorboard
%tensorboard --logdir=runs/exp9

ckpt_dir = 'tmp/flax-checkpointing'

import shutil
if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir) # remove any existing checkpoints from the last notebook run

trained_params, optimizer_state, losses = \
optimize_dynamical_VAE(init_params, x0, hyperparams, models, training_data, validation_data, 
                       optimisation_hyperparams, key, ckpt_dir, writer)

# # restore checkpoint
# ckpt = {'params': trained_params, 'optimizer_state': optimizer_state, 'losses': losses}
# restored_state = checkpoints.restore_checkpoint(ckpt_dir = ckpt_dir, target = ckpt)

Launching TensorBoard...
Please visit http://localhost:6006 in a web browser.
[1mBatches 1-1 in 20.76 seconds, step size: 0.00100[0m
  Training losses 1.4011 = cross entropy 0.6644 + KL 0.7367 (73.6715)
  Validation losses 1.4568 = cross entropy 0.7225 + KL 0.7344 (73.4391)
[1mBatches 2-2 in 15.77 seconds, step size: 0.00100[0m
  Training losses 1.4255 = cross entropy 0.6908 + KL 0.7347 (73.4739)
  Validation losses 1.4666 = cross entropy 0.7342 + KL 0.7324 (73.2424)
[1mBatches 3-3 in 7.78 seconds, step size: 0.00100[0m
  Training losses 1.4159 = cross entropy 0.6832 + KL 0.7328 (73.2768)
  Validation losses 1.4048 = cross entropy 0.6744 + KL 0.7305 (73.0463)
[1mBatches 4-4 in 7.53 seconds, step size: 0.00100[0m
  Training losses 1.4306 = cross entropy 0.6998 + KL 0.7308 (73.0804)
  Validation losses 1.4584 = cross entropy 0.7299 + KL 0.7285 (72.8508)
[1mBatches 5-5 in 7.19 seconds, step size: 0.00100[0m
  Training losses 1.4384 = cross entropy 0.7096 + KL 0.7288 (72.8847)
  

In [None]:
# # run model and plot
# T = 100
# data = training_data[0,:,:]
# hyperparams['dt'] = 0.01
# hyperparams['time_steps'] = 1000

# params = init_params
# results_encode = encode(models, params, data[None,:,:])
# _, z_sampler, _ = models
# z1, z2 = z_sampler(results_encode, params, hyperparams, key)
# A = construct_dynamics_matrix(params, hyperparams)
# results_decode = batch_decode(params, hyperparams, models, A, x0, T, z1, z2)
# results = results_encode | results_decode

# for ex in range(1):
#     plt.scatter(results['pen_xy'][ex][:,0],results['pen_xy'][ex][:,1], c = 'k', alpha =  np.exp(results['pen_down_log_p'][ex,:]))
# plt.ylim(0,105)
# plt.xlim(0,105)
# plt.show()

In [None]:
# most of time when running model forward is decoder and x-entropy functions

# replace cnn with capsule?
# do i need to standardize images (and how) or is normalised (as is) fine?
# ultimately replace pen with myosuite finger to make it a finger painting task - could have flexion/extension determine finger up/down

# pen_xy0 options: always start in center of canvas (105/2, 105/2), randomise and use extra feature dimension in image, let the CNN choose
# consider learning initial neural states
# add control costs? e.g. squared velocity costs?
# add batch norm and other tricks to CNN, dropout?

# if there are l layers (e.g. 3), the top-layer alphas on the last l-1 time steps (e.g. 2) don't influence the state of the lowest layer and hence the objective
# when n_layer = 3, top-layer alphas at time 1 inlfuence the state of the lowest layer and hence the objective at time 3, and so on
# this leads to zeros in the biases and columns of weight matrix in last dense layer

# relu activation function, and to some extent binary image data, can lead to zero gradients scattered throughout CNN
# these gradients go to zero if you change relu to tanh and make data continuous on [0, 1], so not pathological, i don't think

# print values without tracer information
# jax.debug.print("{z_mean}", z_mean = z_mean)
# jax.debug.breakpoint() - didn't work for me