In [None]:
'''
This demo code accompanies 
"The Distributed Information Bottleneck reveals the explanatory structure of complex systems"
Kieran A Murphy and Dani S Bassett

https://arxiv.org/abs/2204.07576
'''

In [None]:
!pip install scikit-image

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.ndimage as nim
from google.colab import files

tfkl = tf.keras.layers
default_mpl_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

# Boolean circuit

In [None]:
'''
In this first section, we approximate a 10-input, 1-output Boolean circuit.
Because the input components may take only two values, the encoders are simple.
They are just two trainable constants that provide the mean and variance of a 
distribution in representation space.
'''

In [None]:
gates = [np.logical_and, np.logical_or, np.logical_xor]

# This is the circuit from the paper, formatted such that each intermediate output is defined by the contents of the brackets: [gate_id, input1, input2]
circuit_specification = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, [1, 0, 1], [2, 8, 7], [0, 4, 3], [1, 11, 5], [2, 6, 12], [2, 13, 9], [1, 14, 10], [0, 15, 2], [0, 17, 16]]
number_input_gates = 10

def apply_gates(inputs):
  intermed = inputs
  for thing in circuit_specification[inputs.shape[-1]:]:
    intermed = np.concatenate([intermed, np.int32(gates[thing[0]](intermed[:, thing[1]], intermed[:, thing[2]]))[:, np.newaxis]], -1)
  return intermed

def compute_entropy(vals):
  probs = np.bincount(vals)/float(vals.shape[0])
  probs = probs[probs>0]
  return -np.sum(probs * np.log(probs))

# Evaluate the full truth table
possible_inputs = np.meshgrid(*[[0, 1]]*number_input_gates)
possible_inputs = np.stack(possible_inputs, -1)
possible_inputs = np.reshape(possible_inputs, [-1, number_input_gates])

truth_table = apply_gates(possible_inputs)

entropy_y = compute_entropy(truth_table[:, -1])
print(f'Entropy of Y: {entropy_y/np.log(2):.3f} bits.')

In [None]:
# During pretraining \beta is at its minimum value so that the bottleneck
# is too weak to prevent finding the max-fidelity relationship
number_pretraining_steps = 10**4
# Then \beta is increased logarithmically during the annealing stage
number_annealing_steps = 2*10**5
beta_start = 1e-4
beta_end = 1e0
batch_size = 512
lr = 3e-4
optimizer = tf.keras.optimizers.Adam(lr)
cross_entropy_series, kl_divergence_series = [[] for _ in range(2)]
beta_var = tf.Variable(beta_start, dtype=tf.float32, trainable=False)
##############################################################################
# Network creation
# The encoders are trainable constants, taking each binary 0/1 to a normal
# distribution in representation space with +-mu as the mean
input_mus = tf.Variable(tf.ones(number_input_gates), dtype=tf.float32, trainable=True)
input_logvars = tf.Variable(-3*tf.ones(number_input_gates), dtype=tf.float32, trainable=True)

combined_encoder = tf.keras.Sequential([tfkl.Input((number_input_gates,)),
                                        tfkl.Dense(256, 'relu'),
                                        tfkl.Dense(256, 'relu'),
                                        tfkl.Dense(256, 'relu'),
                                        tfkl.Dense(1)])
all_trainable_vars = combined_encoder.trainable_variables
all_trainable_vars += [input_mus, input_logvars]
    
##############################################################################
bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
@tf.function
def train_step():
  rand_samples = tf.random.categorical(tf.zeros((batch_size, truth_table.shape[0])), 1)
  x = tf.gather(truth_table, rand_samples, axis=0)[:, 0, :number_input_gates]
  y = tf.gather(truth_table, rand_samples, axis=0)[:, 0, -1]
  x = tf.cast(x, tf.float32)
  with tf.GradientTape() as tape:
    x_channeled = tf.random.normal(x.shape, mean=input_mus*(x*2-1.), stddev=tf.exp(input_logvars/2.))
    y_predicted = combined_encoder(x_channeled)
    kl_divergence_channels = 0.5 * (tf.square(input_mus) + tf.exp(input_logvars) - input_logvars - 1.) # shape [num_spins]
    cross_entropy = tf.reduce_mean(bce_loss(y, y_predicted))
    loss = cross_entropy + beta_var*tf.reduce_sum(kl_divergence_channels)
  grads = tape.gradient(loss, all_trainable_vars)
  optimizer.apply_gradients(zip(grads, all_trainable_vars))
  return cross_entropy, kl_divergence_channels

for step in range(number_pretraining_steps+number_annealing_steps):
  beta_var.assign(np.exp(np.log(beta_start)+float(max(step-number_pretraining_steps, 0))/number_annealing_steps*(np.log(beta_end)-np.log(beta_start))))
  cross_entropy, kl_divergence_channels = train_step()
  cross_entropy_series.append(cross_entropy.numpy())
  kl_divergence_series.append(kl_divergence_channels.numpy())

kl_divergence_series = np.stack(kl_divergence_series, 0)

## Plot the error and KL divergences during the run
beta_lims = [5e-3, 5e-1]
likelihood_lims = [1e-3, None]
kl_lims = [0, 1.75]
plt.figure(figsize=(8, 4))
ax = plt.gca()
smoothing_sigma = 50
betas = np.exp(np.log(beta_start)+np.linspace(0, 1, number_annealing_steps)*(np.log(beta_end)-np.log(beta_start)))
ax.plot(betas, nim.filters.gaussian_filter1d(cross_entropy_series[-number_annealing_steps:], smoothing_sigma)/np.log(2), lw=3, color='k')

ax2 = ax.twinx()
[ax2.plot(betas, nim.filters.gaussian_filter1d(kl_divergence_series[-number_annealing_steps:, i], smoothing_sigma)/np.log(2), lw=3.5) for i in range(number_input_gates)]
ax.plot([0, 1], [entropy_y/np.log(2)]*2, 'k:', lw=2)
ax.set_ylabel('Cross entropy (bits)', color='k', fontsize=14)
ax2.set_ylabel('KL Divergence (bits)', color='b', fontsize=14)
ax.set_xscale('log')
ax.set_xlabel('Bottleneck strength beta', fontsize=14)
ax.set_xlim(beta_lims)
ax.set_ylim(likelihood_lims)
ax2.set_ylim(kl_lims)

ax.set_zorder(ax2.get_zorder()+1)
ax.patch.set_visible(False)
ax.tick_params(which='both', width=2, length=8, direction='in')
ax2.tick_params(which='both', width=2, length=8, direction='in')
plt.show()

# Mona Lisa

In [None]:
'''
Images are rich relationships between position and color.
We use the Distributed IB to find an approximation scheme to Leonardo da Vinci's Mona Lisa.

Rather than predicting a distribution over color and evaluating the cross entropy
in color space, we operate in embedding space with the help of the InfoNCE loss.
To optimize the InfoNCE contribution, each position-color pair should be co-located in a 
shared embedding space, with color embedded in its entirety and position embedded 
by its horizontal and vertical components independently before passing through a 
combined encoder as with the Boolean circuit.

The degree to which the embedding space leads to consistent matches between a 
position and its correct color serves as a bound on the mutual information while 
letting us sidestep any manual discretization of color space.

We display approximate relationships throughout training to visualize the gradual
erosion of fidelity and the coarsening of block-like interactions between the horizontal
and vertical components of the pixel position.
'''

In [None]:
# Load from wikipedia
!wget https://upload.wikimedia.org/wikipedia/commons/6/6a/Mona_Lisa.jpg

In [None]:
from PIL import Image
from skimage.transform import resize
fname = './Mona_Lisa.jpg'
image = Image.open(fname)

image = np.float32(np.asarray(image))/255.
image = resize(image, (600, 400))
plt.figure(figsize=(5, 8))
plt.imshow(image)
plt.axis('off')
plt.show()

In [None]:
#@title Similarity functions for InfoNCE evaluation
# Borrowed from https://github.com/google-research/google-research/tree/master/isolating_factors
@tf.function
def pairwise_l2_distance(pts1, pts2):
  """Computes squared L2 distances between each element of each set of points.
  Args:
    pts1: [N, d] tensor of points.
    pts2: [M, d] tensor of points.
  Returns:
    distance_matrix: [N, M] tensor of distances.
  """
  norm1 = tf.reduce_sum(tf.square(pts1), axis=1, keepdims=True)
  norm2 = tf.reduce_sum(tf.square(pts2), axis=1)
  norm2 = tf.reshape(norm2, [1, -1])
  distance_matrix = tf.maximum(
      norm1 + norm2 - 2.0 * tf.matmul(pts1, pts2, transpose_b=True), 0.0)
  return distance_matrix


@tf.function
def pairwise_l1_distance(pts1, pts2):
  """Computes L1 distances between each element of each set of points.
  Args:
    pts1: [N, d] tensor of points.
    pts2: [M, d] tensor of points.
  Returns:
    distance_matrix: [N, M] tensor of distances.
  """
  stack_size2 = pts2.shape[0]
  pts1_tiled = tf.tile(tf.expand_dims(pts1, 1), [1, stack_size2, 1])
  distance_matrix = tf.reduce_sum(tf.abs(pts1_tiled-pts2), -1)
  return distance_matrix


@tf.function
def pairwise_linf_distance(pts1, pts2):
  """Computes Chebyshev distances between each element of each set of points.
  The Chebyshev/chessboard distance is the L_infinity distance between two
  points, the maximum difference between any of their dimensions.
  Args:
    pts1: [N, d] tensor of points.
    pts2: [M, d] tensor of points.
  Returns:
    distance_matrix: [N, M] tensor of distances.
  """
  stack_size2 = pts2.shape[0]
  pts1_tiled = tf.tile(tf.expand_dims(pts1, 1), [1, stack_size2, 1])
  distance_matrix = tf.reduce_max(tf.abs(pts1_tiled-pts2), -1)
  return distance_matrix


def get_scaled_similarity(embeddings1,
                          embeddings2,
                          similarity_type,
                          temperature):
  """Returns matrix of similarities between two sets of embeddings.
  Similarity is a scalar relating two embeddings, such that a more similar pair
  of embeddings has a higher value of similarity than a less similar pair.  This
  is intentionally vague to emphasize the freedom in defining measures of
  similarity. For the similarities defined, the distance-related ones range from
  -inf to 0 and cosine similarity ranges from -1 to 1.
  Args:
    embeddings1: [N, d] float tensor of embeddings.
    embeddings2: [M, d] float tensor of embeddings.
    similarity_type: String with the method of computing similarity between
      embeddings. Implemented:
        l2sq -- Squared L2 (Euclidean) distance
        l2 -- L2 (Euclidean) distance
        l1 -- L1 (Manhattan) distance
        linf -- L_inf (Chebyshev) distance
        cosine -- Cosine similarity, the inner product of the normalized vectors
    temperature: Float value which divides all similarity values, setting a
      scale for the similarity values.  Should be positive.
  Raises:
    ValueError: If the similarity type is not recognized.
  """
  eps = 1e-9
  if similarity_type == 'l2sq':
    similarity = -1.0 * pairwise_l2_distance(embeddings1, embeddings2)
  elif similarity_type == 'l2':
    # Add a small value eps in the square root so that the gradient is always
    # with respect to a nonzero value.
    similarity = -1.0 * tf.sqrt(
        pairwise_l2_distance(embeddings1, embeddings2) + eps)
  elif similarity_type == 'l1':
    similarity = -1.0 * pairwise_l1_distance(embeddings1, embeddings2)
  elif similarity_type == 'linf':
    similarity = -1.0 * pairwise_linf_distance(embeddings1, embeddings2)
  elif similarity_type == 'cosine':
    embeddings1, _ = tf.linalg.normalize(embeddings1, ord=2, axis=-1)
    embeddings2, _ = tf.linalg.normalize(embeddings2, ord=2, axis=-1)
    similarity = tf.matmul(embeddings1, embeddings2, transpose_b=True)
  else:
    raise ValueError('Similarity type not implemented: ', similarity_type)

  similarity /= temperature
  return similarity

In [None]:
# Positional encoding helps with high frequency information in the image; 
# see Appendix A in the manuscript for more discussion
number_positional_encoding_freqs = 10

input_embedding_dim = 32
shared_embedding_dim = 64
encoder_spec = [512]*3
decoder_spec = [128]*3
activation_fn = 'relu'
lr = 3e-4
beta_start = 1e-6
beta_end = 3e0
number_pretraining_steps = 2*10**4
number_annealing_steps = 10**5
plot_every_n_steps = 1000
batch_size = 2048

## For the InfoNCE loss; other valid parameter values are discussed in the cell above
similarity = 'l2'
temperature = 1.

In [None]:
## Set up the data and networks
image_x, image_y, _ = image.shape
aspect_ratio = float(image_y)/float(image_x)
xx, yy = np.meshgrid(np.linspace(-aspect_ratio, aspect_ratio, image_y),
                     np.linspace(-1, 1, image_x))

positional_encoding_freqs = np.power(2., np.arange(1, number_positional_encoding_freqs))*np.pi

inp_pts_train = np.stack([xx.reshape([-1]), yy.reshape([-1])], -1)
outp_pts_train = image.reshape([-1, 3])

number_input_channels = 2

inp_pts_train = np.stack([inp_pts_train] + [np.sin(inp_pts_train*freq) for freq in positional_encoding_freqs], axis=-1).astype(np.float32)
# 'Positionally' encode the RGB colors as well
outp_pts_train = np.concatenate([outp_pts_train] + [np.sin(outp_pts_train*freq) for freq in positional_encoding_freqs], axis=-1)

# This is to signify the different channels to the encoder so that we can get away with using a single encoder for all channels
one_hot_appendix = tf.tile(tf.expand_dims(tf.eye(number_input_channels), 0), [batch_size, 1, 1])

# To get a color out, take a random sampling of colors from the GT image and then paint by numbers
palette_size = 1024
color_palette_inds = np.random.choice(outp_pts_train.shape[0], size=palette_size, replace=False)
color_palette = image.reshape([-1, 3])[color_palette_inds]

input_encoder = tf.keras.Sequential([tfkl.Input(number_positional_encoding_freqs+number_input_channels)]+[tfkl.Dense(num_units, activation_fn) for num_units in encoder_spec]+[tfkl.Dense(input_embedding_dim)])
combined_encoder = tf.keras.Sequential([tfkl.Input(number_input_channels*(input_embedding_dim//2),)]+[tfkl.Dense(num_units, activation_fn) for num_units in decoder_spec]+[tfkl.Dense(shared_embedding_dim)])  
color_encoder = tf.keras.Sequential([tfkl.Input(3*number_positional_encoding_freqs,)]+[tfkl.Dense(num_units, activation_fn) for num_units in decoder_spec]+[tfkl.Dense(shared_embedding_dim)])
all_trainable_vars = input_encoder.trainable_variables
all_trainable_vars += combined_encoder.trainable_variables
all_trainable_vars += color_encoder.trainable_variables

beta_var = tf.Variable(beta_start, dtype=tf.float32, trainable=False)
optimizer = tf.keras.optimizers.Adam(lr)

In [None]:
xent_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(batch_inps, batch_outps):
  x_inp = tf.concat([batch_inps, one_hot_appendix], -1)
  x_inp = tf.reshape(x_inp, [batch_size*number_input_channels, -1])
  with tf.GradientTape() as tape:
    enc_mus, enc_logvars = tf.split(input_encoder(x_inp), 2, axis=-1)
    all_embs = tf.random.normal(enc_mus.shape, mean=enc_mus, stddev=tf.exp(enc_logvars/2.))
    all_embs = tf.reshape(all_embs, [batch_size, -1])  # this concatenates the two embeddings which were previously siloed

    x_embs = combined_encoder(tf.concat(all_embs, -1))
    # x_embs = tf.reshape(x_embs, [batch_size, -1])

    y_embs = color_encoder(batch_outps)

    similarity_matrix = get_scaled_similarity(x_embs, y_embs, similarity, temperature)
    loss_infonce = tf.reduce_mean(xent_loss(tf.range(batch_size), similarity_matrix))
    loss_infonce += tf.reduce_mean(xent_loss(tf.range(batch_size), tf.transpose(similarity_matrix)))

    kl_divergence_channels = tf.reduce_sum(0.5 * (tf.square(enc_mus) + tf.exp(enc_logvars) - enc_logvars - 1.), axis=-1)
    kl_divergence_channels = tf.reshape(kl_divergence_channels, [batch_size, -1])
    kl_divergence_channels = tf.reduce_mean(kl_divergence_channels, 0)  # average across the batch; all that's left is a different term per axis

    loss = loss_infonce + beta_var*tf.reduce_sum(kl_divergence_channels)
  grads = tape.gradient(loss, all_trainable_vars)
  optimizer.apply_gradients(zip(grads, all_trainable_vars))
  return loss_infonce, kl_divergence_channels

In [None]:
## Training
infonce_series, kl_divergence_series = [[] for _ in range(2)]
# This code displays outputs in groups of 10, one each <plot_every_n_steps> steps
plot_ind = 1
inches_per_subplot = 2
plt.figure(figsize=(inches_per_subplot*10, inches_per_subplot*1.5))
for step in range(number_pretraining_steps+number_annealing_steps):
  beta_var.assign(np.exp(np.log(beta_start)+float(max(step-number_pretraining_steps, 0))/number_annealing_steps*(np.log(beta_end)-np.log(beta_start))))
  batch_inds = np.random.choice(inp_pts_train.shape[0], size=batch_size)
  batch_inps = inp_pts_train[batch_inds]
  batch_outps = outp_pts_train[batch_inds]
  
  loss_infonce, kl_divergence_channels = train_step(batch_inps, batch_outps)

  infonce_series.append(loss_infonce.numpy())
  kl_divergence_series.append(kl_divergence_channels)
  if (step % plot_every_n_steps == 0) and (step > number_pretraining_steps):
    ## Display the current approximation
    y_embs = color_encoder(outp_pts_train)
    color_palette_embs = tf.gather(y_embs, color_palette_inds)

    reconstructed_image = []
    num_pix = inp_pts_train.shape[0]
    # Chunk up the operations for memory
    chunking = 8192
    for pix_ind_start in range(0, num_pix, chunking):
      chunk_size = min(num_pix, pix_ind_start+chunking) - pix_ind_start
      one_hot_appendix_disp = tf.tile(tf.expand_dims(tf.eye(number_input_channels), 0), [chunk_size, 1, 1])
      x_inp = tf.concat([inp_pts_train[pix_ind_start:pix_ind_start+chunk_size], one_hot_appendix_disp], -1)   ## [bs, num_axes, num_pos] + [bs, num_axes, num_axes]
      x_inp = tf.reshape(x_inp, [chunk_size*number_input_channels, -1])
      enc_mus, enc_logvars = tf.split(input_encoder(x_inp), 2, axis=-1)  ## This will be [bs*num_axes, dim]
      all_embs = enc_mus
      all_embs = tf.reshape(all_embs, [chunk_size, -1])

      x_embs = combined_encoder(tf.concat(all_embs, -1))
      x_embs = tf.reshape(x_embs, [chunk_size, -1])

      sim_mat = get_scaled_similarity(x_embs, color_palette_embs, similarity, temperature)
      color_numbers = np.argmax(sim_mat, axis=-1)
      reconstructed_image.append(color_palette[color_numbers])

    reconstructed_image = tf.concat(reconstructed_image, 0)
    plt.subplot(1, 10, plot_ind)
    plt.imshow(np.reshape(reconstructed_image, [image_x, image_y, 3]))
    if plot_ind in [1, 6]:
      plt.title(f'Beta={beta_var.value():5f}', fontsize=14)
    plt.axis('off')
    plot_ind+=1
    if plot_ind > 10:
      plt.show()
      plot_ind = 1
      plt.figure(figsize=(inches_per_subplot*10, inches_per_subplot*1.5))
plt.show() # in case the division got messed up
# Display the evolution of the loss terms over the \beta sweep
infonce_series = np.stack(infonce_series, 0)
kl_divergence_series = np.stack(kl_divergence_series, 0)
betas = np.exp(np.log(beta_start)+np.linspace(0, 1, number_annealing_steps)*(np.log(beta_end)-np.log(beta_start)))
beta_lims = [1e-3, 3]
kl_lims = [1e-2, 8e1]
smoothing_sigma = 100
plt.figure(figsize=(6, 4))
ax = plt.gca()
ax.plot(betas[-number_annealing_steps:], nim.filters.gaussian_filter1d(infonce_series[-number_annealing_steps:], smoothing_sigma)/np.log(2), lw=4, color='k')
ax2 = ax.twinx()
[ax2.plot(betas[-number_annealing_steps:], nim.filters.gaussian_filter1d(kl_divergence_series[-number_annealing_steps:, i], smoothing_sigma)/np.log(2), lw=4, label=['Horizontal', 'Vertical'][i]) for i in range(number_input_channels)]
ax2.legend(loc='upper left')
ax.set_ylabel('InfoNCE (bits)', color='k', fontsize=14)
ax2.set_ylabel('KL Divergence (bits)', color='b', fontsize=14)
ax2.set_yscale('log')
ax.set_xscale('log')
ax.set_xlabel('Bottleneck strength beta', fontsize=14)
plt.xlim(beta_lims)
ax2.set_ylim(kl_lims)
ax.tick_params(which='both', width=2, length=8, direction='in')
ax2.tick_params(which='both', width=2, length=8, direction='in')
ax.set_zorder(ax2.get_zorder()+1)
ax.patch.set_visible(False)
plt.show()


# Titanic dataset

In [None]:
'''
Finally, we seek insight into a noisy relationship with primarily 
categorical features: the Titanic dataset.
'''

In [None]:
dset = tfds.load('titanic', split='train')

In [None]:
## Since the dataset is so small, just load everything manually
features = []
survived = []
good_keys = ['pclass', 'sex', 'age', 'sibsp', 'parch', 'fare', 'embarked']
for passenger in dset:
  features.append([passenger[key].numpy() for key in good_keys])
  survived.append(passenger['survived'])

## Munge the data
pclasses = tf.stack([tf.one_hot(passenger[0], 3) for passenger in features], 0)
sexes = tf.stack([tf.one_hot(passenger[1], 2) for passenger in features], 0)
ages = tf.stack([(lambda x: np.float32([1., 0.])*(x<0) + np.float32([0., np.log(np.abs(x))])*(x>0))(passenger[2]) for passenger in features], 0)
sibsps = tf.stack([tf.one_hot(passenger[3], 7) for passenger in features], 0)
parches = tf.stack([tf.one_hot(passenger[4], 8) for passenger in features], 0)
fares = tf.stack([(lambda x: np.float32([1., 0., 0.])*(x<0) + np.float32([0., 1., 0.])*(x==0) + np.float32([0., 0., np.log(np.abs(x+1e-4))])*(x>0))(passenger[5]) for passenger in features], 0)
embarkeds = tf.stack([tf.one_hot(passenger[6]+1, 4) for passenger in features], 0)
surviveds = tf.stack(survived, 0)

training_fraction = 0.9
training_inds = np.random.choice(surviveds.shape[0], size=int(training_fraction*surviveds.shape[0]), replace=False)
validation_inds = [i for i in range(surviveds.shape[0]) if i not in training_inds]

In [None]:
## Create a separate encoder for each feature
feature_embedding_dim = 8
encoder_arch = [128, 128]
activation_fn = 'tanh'
input_dimensionalities = [3, 2, 2, 7, 8, 3, 4]  ## one for each feature being used
number_features = len(input_dimensionalities)

encoders = []
for input_dimensionality in input_dimensionalities:
  encoders.append(tf.keras.Sequential([tfkl.Input((input_dimensionality,))] + [tfkl.Dense(num_units, activation_fn) for num_units in encoder_arch] + [tfkl.Dense(2*feature_embedding_dim)]))

combined_encoder = tf.keras.Sequential([tfkl.Input((number_features*feature_embedding_dim,)),
                                         tfkl.Dense(128, 'tanh'),
                                         tfkl.Dense(128, 'tanh'),
                                         tfkl.Dense(1)])


all_trainable_vars = combined_encoder.trainable_variables
for network in encoders:
  all_trainable_vars += network.trainable_variables

In [None]:
number_pretraining_steps = 2*10**4
number_annealing_steps = 10**5
lr = 1e-3
beta_start = 5e-5
beta_end = 1e0
batch_size = 128
eval_every = 1000
beta_var = tf.Variable(beta_start, dtype=tf.float32, trainable=False)

bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(lr)
@tf.function
def train_step(features, target_var):
  kl_divergence_channels = []
  running_embs = []
  with tf.GradientTape() as tape:
    for i in range(number_features):
      emb_mus, emb_logvars = tf.split(encoders[i](features[i]), 2, axis=-1)
      emb_channeled = tf.random.normal(emb_mus.shape, mean=emb_mus, stddev=tf.exp(emb_logvars/2.))
      running_embs.append(emb_channeled)
      kl_divergence_channels.append(tf.reduce_mean(tf.reduce_sum(0.5 * (tf.square(emb_mus) + tf.exp(emb_logvars) - emb_logvars - 1.), axis=-1)))

    prediction = combined_encoder(tf.concat(running_embs, -1))
    cross_entropy = tf.reduce_mean(bce_loss(target_var, prediction))

    loss = cross_entropy + beta_var*(tf.reduce_sum(kl_divergence_channels))
  grads = tape.gradient(loss, all_trainable_vars)
  optimizer.apply_gradients(zip(grads, all_trainable_vars))

  return cross_entropy, kl_divergence_channels
@tf.function
def eval_validation(features, target_var):
  kl_loss_terms = []
  running_embs = []
  for i in range(number_features):
    emb_mus, _ = tf.split(encoders[i](features[i]), 2, axis=-1)
    running_embs.append(emb_mus)
  prediction = combined_encoder(tf.concat(running_embs, -1))
  cross_entropy = tf.reduce_sum(bce_loss(target_var, prediction))
  return cross_entropy

betas_time_series = []
cross_entropy_series, cross_entropy_series_validation, kl_divergence_series = [[] for _ in range(3)]

for step in range(number_pretraining_steps+number_annealing_steps):
  beta_var.assign(np.exp(np.log(beta_start)+float(max(step-number_pretraining_steps, 0))/number_annealing_steps*(np.log(beta_end)-np.log(beta_start))))

  batch_inds = np.random.choice(training_inds, size=batch_size, replace=False)
  cross_entropy, kl_divergence_channels = train_step([tf.gather(pclasses, batch_inds),
                              tf.gather(sexes, batch_inds),
                              tf.gather(ages, batch_inds),
                              tf.gather(sibsps, batch_inds),
                              tf.gather(parches, batch_inds),
                              tf.gather(fares, batch_inds),
                              tf.gather(embarkeds, batch_inds)],
                              tf.gather(surviveds, batch_inds))

  cross_entropy_series.append(cross_entropy.numpy())
  kl_divergence_series.append(kl_divergence_channels)
  if (step % eval_every) == 0 and (step >= number_pretraining_steps):
    cross_entropy_series_validation.append(eval_validation([tf.gather(pclasses, validation_inds),
                                            tf.gather(sexes, validation_inds),
                                            tf.gather(ages, validation_inds),
                                            tf.gather(sibsps, validation_inds),
                                            tf.gather(parches, validation_inds),
                                            tf.gather(fares, validation_inds),
                                            tf.gather(embarkeds, validation_inds)],
                                            tf.gather(surviveds, validation_inds)))

betas = np.exp(np.log(beta_start)+np.linspace(0, 1, number_annealing_steps)*(np.log(beta_end)-np.log(beta_start)))
kl_divergence_series = np.stack(kl_divergence_series, 1)
beta_lims = [5e-3, 0.5]
kl_lims = [0, 4e0]
xent_lims = [0, 1]
smoothing_sigma = 50
plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.plot(betas, nim.filters.gaussian_filter1d(cross_entropy_series[number_pretraining_steps:], smoothing_sigma), lw=4, label="Cross entropy, Train", color='k')
ax.plot(betas[::eval_every], cross_entropy_series_validation, lw=2, label="Cross entropy, Validation", color='k')
ax2 = ax.twinx()
[ax2.plot(betas, nim.filters.gaussian_filter1d(kl_term[number_pretraining_steps:], smoothing_sigma), lw=2.5) for kl_term in kl_divergence_series]
[ax2.plot(0, 0.5, lw=5, color=default_mpl_colors[i], label=label) for i, label in enumerate(['Passenger class', 'Sex', 'Age', 'Siblings/spouses', 'Parents/children', 'Fare', 'Location embarked'])]

ax.set_ylabel('Cross entropy (bits)', color='k', fontsize=14)
ax2.set_ylabel('KL Divergence (bits)', color='b', fontsize=14)
ax.set_xscale('log')
ax.set_xlabel('Bottleneck strength beta', fontsize=14)
ax2.set_ylim(kl_lims)
ax.set_xlim(beta_lims)
ax.set_ylim(xent_lims)
plt.legend()
plt.show()