<a href="https://colab.research.google.com/github/ezgimez/neural-distributed-compressor-JSAIT2024/blob/main/JSAIT'24_source_code_Variational_Entropy_Constrained_Vector_Quantizers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Code definitions

In [None]:
import matplotlib.pyplot as plt
from flax import linen as nn
from flax.training import train_state
import jax.numpy as jnp
import jax
from jax import random
import numpy as np
import optax
from google import colab
import os
from argparse import Namespace
from collections import defaultdict
from typing import Callable
import numpy as np

from tensorflow_probability.substrates import jax as tfp
tfpd = tfp.distributions


In [None]:
class MLP(nn.Module):
  """Used for neural decoder  (see Eqs. (11) and (12) in the paper) function. """
  output_dims: int
  activation_fn: Callable

  @nn.compact
  def __call__(self, *x):
    x = jnp.concatenate(x, -1)
    x = nn.Dense(100, name="fc1")(x)
    x = self.activation_fn(x)
    x = nn.Dense(100, name="fc2")(x)
    x = self.activation_fn(x)
    x = nn.Dense(self.output_dims, name="fc3")(x)
    return x


In [None]:
def sample_xyn_conditional(num_samples, x, var_n, n_rng):
  dist_n = tfp.distributions.Normal(loc=0., scale=var_n ** .5)
  n = dist_n.sample(seed=n_rng, sample_shape=(num_samples, 1, 1))
  y = x + n # corr_pattern: x = y + n, quadratic-Gaussian setup.
  y  = y / (1+2*var_n)
  # conditional sampling at the encoder, see Eqs. (14) and (15) in the paper.
  return y

def sample_yxn_conditional(num_samples, x, var_n, n_rng):
  dist_n = tfp.distributions.Normal(loc=0., scale=var_n ** .5)
  n = dist_n.sample(seed=n_rng, sample_shape=(num_samples, 1, 1))
  y = x + n # corr_pattern: y = x + n, quadratic-Gaussian setup.
  # conditional sampling at the encoder, see Eqs. (14) and (15) in the paper.
  return y

def sample_laplacian_conditional(x):
  signed_y = jnp.sign(x) # x : Laplacian(0, 1) and y : sign(x)
  # conditional sampling at the encoder, see Eqs. (14) and (15) in the paper.
  return signed_y[None, :, :]

def sample_xyn(num_samples, var_n, xy_rng):
  dist = tfp.distributions.Normal(loc=0., scale=[1., var_n ** .5])
  yn = dist.sample(seed=xy_rng, sample_shape=(num_samples, 1))
  y = yn[..., 0]
  x = y + yn[..., 1] # x = y + n, quadratic-Gaussian setup
  # (see Section IV.A in the paper).
  return x, y

def sample_yxn(num_samples, var_n, xy_rng):
  dist = tfp.distributions.Normal(loc=0., scale=[1., var_n ** .5])
  xn = dist.sample(seed=xy_rng, sample_shape=(num_samples, 1))
  x = xn[..., 0]
  y = x + xn[..., 1] # y = x + n, quadratic-Gaussian setup.
  # (see Section IV.A in the paper)
  return x, y

def sample_signed_laplacian(num_samples, var_n, xy_rng):
  dist = tfp.distributions.Laplace(loc=0., scale=var_n)
  x = dist.sample(seed=xy_rng, sample_shape=(num_samples, 1))
  y = jnp.sign(x) # 'signed' Laplacian experimental setup.
  # (see Section IV.B in the paper)
  return x, y

def sample_xy(num_samples, xy_rng):
  if corr_pattern == "y=x+n":
    return sample_yxn(num_samples, var_n, xy_rng)
  elif corr_pattern == "x=y+n":
    return sample_xyn(num_samples, var_n, xy_rng)
  elif corr_pattern == "signed_laplacian":
    return sample_signed_laplacian(num_samples, var_n, xy_rng)
  assert False

def sample_conditional(repeat_num, init_x, sampling_rng):
  if corr_pattern == "y=x+n":
    # shape: (N, B, 1)
    return sample_yxn_conditional(repeat_num, init_x, var_n, sampling_rng)
  elif corr_pattern == "x=y+n":
    # shape: (N, B, 1)
    return sample_xyn_conditional(repeat_num, init_x, var_n, sampling_rng) #
  elif corr_pattern == "signed_laplacian":
    # shape: (1, B, 1)
    # basically, N=1 because there doesn't need to be sampling
    # since it is deterministic correlation pattern in this case.
    return sample_laplacian_conditional(init_x)
  assert False


In [None]:
class WynerZivModel(nn.Module):
  source_dims: int
  latent_dims: int
  var_n: float
  num_y_samples: int
  lmbda: float
  corr_pattern: str

  def setup(self):
      # this is 'marginal' entropy bottleneck, just learnable parameters.
    self.logits = self.param(
          "logits", jax.nn.initializers.uniform(1e-1), (self.latent_dims,))

    self.g = MLP(self.source_dims, nn.leaky_relu) # this is the decoder.


  def encode(self, x, encoder_rng):
    """Implements the marginal encoder function (see Eq. (18) in the paper.
    See Section IV.A in the paper for discussion on this compressor."""
    # shape: (B, 1)
    B = x.shape[0]
    K = self.latent_dims
    N = self.num_y_samples

    if self.corr_pattern == "y=x+n":
      # shape: (N, B, 1)
      y_given_x = sample_yxn_conditional(N, x, self.var_n, encoder_rng)

    elif self.corr_pattern == "x=y+n":
      # shape: (N, B, 1)
      y_given_x = sample_xyn_conditional(N, x, self.var_n, encoder_rng)

    elif self.corr_pattern == "signed_laplacian":
      # shape: (1, B, 1)
      y_given_x = sample_laplacian_conditional(x)

    else:
      raise ValueError("Invalid correlation type. Only `y=x+n`, `x=y+n` and "
                       "`signed_laplacian` correlation patterns are supported.")

    # shape: (K,)
    rates = jax.nn.log_softmax(self.logits) / -jnp.log(2.)

    # shape: (K, K)
    all_possible_indices = jnp.eye(K)
    # shape: (B*N, K, K)
    all_possible_indices_broadcasted = jnp.broadcast_to(
        all_possible_indices, (B*N, K, K))
    # shape: (B*N*K, K)
    all_possible_indices_reshaped = all_possible_indices_broadcasted.reshape(
        (B*N*K, K))

    # shape: (N, B, K, 1)
    y_given_x_broadcasted = jnp.broadcast_to(
        y_given_x[:, :, None, :], (N, B, K, 1))
    # shape: (B*N*K, 1)
    y_given_x_reshaped = y_given_x_broadcasted.reshape((B*N*K, 1))

    # shape: (B*N*K, 1)
    reconstructed_values = self.g(
        all_possible_indices_reshaped, y_given_x_reshaped)
    # shape: (N, B, K, 1)
    reconstructed_values_reshaped = reconstructed_values.reshape(
        (N, B, K, 1))
    # shape: (N, B, K, 1)
    squared_dist = jnp.square(x[:, None, :] - reconstructed_values_reshaped)
    # shape: (N, B, K)
    summed_dist = jnp.sum(squared_dist, axis=-1)
    # shape: (B, K)
    distortion_estimated = jnp.mean(summed_dist, axis=0)

    # shape: (B, K)
    all_rd = rates + self.lmbda * distortion_estimated
    # shape: (B,)
    z = jnp.argmin(all_rd, axis=-1)

    distortion_estimated_passed = jnp.choose(
        z, distortion_estimated.T, mode='wrap')

    return z, rates[z], distortion_estimated_passed

  def decode(self, z, y):
    """Implements the neural decoder function in the paper."""
    z = nn.one_hot(z, self.latent_dims, axis=-1)
    return self.g(z, y)

  def decode_visualization(self, z, y):
    return self.g(z, y)

  def __call__(self, x, y, encoder_rng, training):
    z, rates, distortion = self.encode(x, encoder_rng)

    rate = jnp.mean(rates)

    if not training:
      x_hat = self.decode(z, y)
      distortion = jnp.sum(jnp.square(x - x_hat), axis=-1)

    distortion = jnp.mean(distortion)

    return dict(distortion=distortion, rate=rate)


In [None]:
@jax.jit
def pretrain_step(state, init_x, var_n, sampling_rng, init_z):
  # init_y_noisy shape: (N, B, 1)
  init_y_noisy = sample_conditional(repeat_num, init_x, sampling_rng)
  # init_y_noisy shape: (B*N, 1)
  init_y_noisy = init_y_noisy.reshape((init_x.shape[0]*repeat_num, 1))

  def loss_fn(params):
    m = model()
    x_hat = m.apply({'params': params}, method=m.decode,
                    z=init_z, y=init_y_noisy)
    init_x_repeated = jnp.tile(init_x, (repeat_num, 1))
    loss = jnp.square(init_x_repeated - x_hat).sum(axis=-1).mean()
    return loss, dict(distortion=loss)

  grads, aux = jax.grad(loss_fn, has_aux=True)(state.params)
  return state.apply_gradients(grads=grads), aux

In [None]:
@jax.jit
def train_step(state, data_rng, encoder_rng):
  x, y = sample_xy(batch_size, data_rng)
  def loss_fn(params):
    m = model()
    result = m.apply(
        {'params': params}, x=x, y=y, encoder_rng=encoder_rng, training=True)
    r = Namespace(**result)
    loss = r.rate + m.lmbda * r.distortion
    result.update(loss=loss)
    return loss, result
  grads, aux = jax.grad(loss_fn, has_aux=True)(state.params)
  return state.apply_gradients(grads=grads), aux


In [None]:
@jax.jit
def eval_step(state, x, y):
  m = model()
  result = m.apply(
      {'params': state.params}, x=x, y=y, training=False,
      encoder_rng=random.PRNGKey(0))
  r = Namespace(**result)
  loss = r.rate + m.lmbda * r.distortion
  return dict(val_rate=r.rate, val_distortion=r.distortion, val_loss=loss)


In [None]:
@jax.jit
def encoder_behaviour(state, x, encoder_rng):
  return model().apply({'params': state.params}, method="encode", x=x, encoder_rng=encoder_rng)


@jax.jit
def decoder_behaviour(state, z, y):
  return model().apply({'params': state.params}, method="decode_visualization", z=z, y=y)


In [None]:
def plot_history(history):
  _, axs = plt.subplots(1, 4, figsize=(25, 5))

  for k in ["rate", "val_rate"]:
    axs[0].plot(history[k], label=k)
  axs[0].legend(loc="best")
  axs[0].grid()
  axs[0].set_title("entropy")
  axs[0].set_yscale("log")

  for k in ["distortion", "val_distortion"]:
    axs[1].plot(history[k], label=k)
  axs[1].legend(loc="best")
  axs[1].grid()
  axs[1].set_title("distortion")
  axs[1].set_yscale("log")

  for k in ["loss", "val_loss"]:
    axs[2].plot(history[k], label=k)
  axs[2].legend(loc="best")
  axs[2].grid()
  axs[2].set_title("loss")
  axs[2].set_yscale("log")

  for k in ["lr"]:
    axs[3].plot(history[k], label=k)
  axs[3].legend(loc="best")
  axs[3].grid()
  axs[3].set_yscale("log")


# Experiment

In [None]:
corr_pattern = "signed_laplacian"
# options for `corr_pattern` are : {y=x+n, x=y+n, signed_laplacian}
# (see Section V in the paper for discussion on different correlation patterns)

var_n = 1.0
num_latents = 32
lmbda = 42.
num_y_samples = 16

batch_size = 512
num_epochs =  100 * 2
steps_per_epoch = 1000
validation_size = 10 * 1024
test_steps = 1024

# repeat_num = 2 # recommended for quadratic-Gaussian case.
repeat_num = 1 # recommended for signed Laplacian case.

init_num_bins = repeat_num * num_latents

total_steps = num_epochs * steps_per_epoch
start_training_step = steps_per_epoch * 30
lr_schedule = optax.piecewise_constant_schedule(
    init_value=1e-3,
    boundaries_and_scales={
        int(7/10 * total_steps): 1e-1,
    },
)

state_dict = dict()
mod_epoch = 10

def model():
  return WynerZivModel(
      source_dims=1, latent_dims=num_latents, var_n=var_n,
      num_y_samples=num_y_samples, lmbda=lmbda, corr_pattern=corr_pattern)

seed, = np.frombuffer(os.getrandom(4), dtype=np.int32)
rng = random.PRNGKey(seed)

rng, init_rng, init_xy_rng, init_z_rng = random.split(rng, 4)
init_x, init_y = sample_xy(num_latents, init_xy_rng)
init_z = jax.random.randint(init_z_rng, (init_num_bins,), 0, num_latents)
state = train_state.TrainState.create(
    apply_fn=model().apply,
    params=model().init(
        init_rng, x=init_x, y=init_y, training=False,
        encoder_rng=random.PRNGKey(0))['params'],
    tx=optax.adam(lr_schedule),
)

rng, data_rng = random.split(rng)
validation_set = sample_xy(validation_size, data_rng)

history = defaultdict(lambda: np.full((num_epochs,), float("nan")))

for epoch in range(num_epochs):

  if epoch%mod_epoch==0:
    state_dict[epoch] = state

  results = defaultdict(lambda: 0, lr=lr_schedule(state.step))

  for _ in range(steps_per_epoch):
    rng, sampling_rng = random.split(rng, 2)
    if state.step < start_training_step:
      state, result = pretrain_step(state, init_x, var_n=var_n, sampling_rng=sampling_rng, init_z=init_z)
    else:
      rng, data_rng, encoder_rng = random.split(rng, 3)
      state, result = train_step(state, data_rng=data_rng, encoder_rng=encoder_rng)
    for k in result:
      results[k] += result[k]
  for k in results:
    results[k] /= steps_per_epoch

  results.update(eval_step(state, *validation_set))

  for k in results:
    history[k][epoch] = results[k]

  colab.output.clear(wait=True)
  plot_history(history)
  plt.show()
  print(f"epoch {epoch:4}")
  print(f"train      entropy {history['rate'][epoch]:6.4f}, distortion {history['distortion'][epoch]:7.4f} ({10*np.log10(history['distortion'][epoch]):6.3f} dB), loss {history['loss'][epoch]:7.4f}")
  print(f"validation entropy {history['val_rate'][epoch]:6.4f}, distortion {history['val_distortion'][epoch]:7.4f} ({10*np.log10(history['val_distortion'][epoch]):6.3f} dB), loss {history['val_loss'][epoch]:7.4f}", flush=True)


In [None]:
metrics = defaultdict(lambda: 0)
for _ in range(test_steps):
  rng, data_rng = random.split(rng)
  result = eval_step(state, *sample_xy(batch_size, data_rng))
  for k in result:
    metrics[k] += result[k]
for k in metrics:
  metrics[k] /= test_steps
print(f"test rate {metrics['val_rate']:6.4f}, distortion {metrics['val_distortion']:7.4f}"
      "({10*np.log10(metrics['val_distortion']):6.3f} dB), loss {metrics['val_loss']:6.3f}")

print("pgfplots:", metrics['val_rate'], 10*np.log10(metrics['val_distortion']))


# Visualizing the Neural Compressor

In [None]:
def plot_binning(state):
  plt.figure(figsize=(8, 6))
  x = np.linspace(-10, 10, 10000)[:, None]

  y = jnp.repeat(0, len(x)).reshape((len(x), 1))
  # above, repeating 0. to be the same dimension as `x`
  # in order to get obtain quantization boundaries
  z, _, _ = encoder_behaviour(state, x, encoder_rng=(0,0))

  boundaries = jnp.nonzero(z[1:] != z[:-1])[0]
  print(boundaries)

  categories_used = []
  for i in range(len(boundaries)):
    categories_used.append(z[boundaries[i]])
  categories_used.append(z[boundaries[-1] +1])
  categories_used = np.asarray(categories_used)

  unique_categories = jnp.unique(categories_used)

  if corr_pattern == "y=x+n" or corr_pattern == "x=y+n":
    y_at_decoder_varying = np.linspace(-10, 10, 10000)[:, None]

  elif corr_pattern == "signed_laplacian":
    y_at_decoder_varying = np.sign(x)

  else:
    raise ValueError("Invalid correlation type. Only `y=x+n`, `x=y+n` and "
                     "`signed_laplacian` correlation patterns are supported.")

  reconstructed = []
  for i in range(len(unique_categories)):
    ex_z = nn.one_hot(
        unique_categories[i], num_latents, axis=-1, dtype='f')[None, :]
    ex_z = np.tile(ex_z, (10000,1))
    intermed_reconstructed = decoder_behaviour(
        state, ex_z, y_at_decoder_varying)
    intermed_reconstructed = jnp.squeeze(np.array(intermed_reconstructed), -1)
    reconstructed.append((intermed_reconstructed, unique_categories[i]))

  return reconstructed, boundaries, categories_used, unique_categories


In [None]:
reconstructed, boundaries, categories_used, _ = plot_binning(state)
print(categories_used)
print(len(categories_used))
print(np.unique(categories_used))
print(len(np.unique(categories_used)))

In [None]:
plt.figure(figsize=(16, 14))

ys = np.linspace(-10, 10, 10000)

xs = np.linspace(-10, 10, 10000)

for ind_boundary in boundaries:
  # you can reach the latest boundary
  # since it doesn't change anyway when you vary 'y')
  plt.axhline(xs[ind_boundary], -10, 10, linestyle='--', linewidth=2.5)

plt.ylim(-7, 7)
plt.xlim(-10, 10)


# getting random colours
vals = np.linspace(0,1,num_latents)
np.random.shuffle(vals)
#cmap = plt.cm.colors.ListedColormap(plt.cm.gist_rainbow(vals))
cmap = plt.cm.colors.ListedColormap(plt.cm.nipy_spectral(vals))

points = []
for ind_reconstructed, ind_category in reconstructed:
  plt.plot(ys, ind_reconstructed, color=cmap(ind_category.item()), linewidth=3)

plt.ylabel("learned quantization boundaries and $\hat{x}$", fontsize=30)
plt.grid()

for i in range(len(categories_used)):
  if i==0:
    plt.fill_between(xs, np.repeat(xs[boundaries[i]], len(xs)), np.repeat(-10, len(xs)),alpha=0.15, color=cmap(categories_used[i]))
  elif i==len(categories_used)-1:
    plt.fill_between(xs, np.repeat(xs[boundaries[i]], len(xs)), np.repeat(10, len(xs)),alpha=0.15, color=cmap(categories_used[i]))
  else:
    plt.fill_between(xs, np.repeat(xs[boundaries[i-1]], len(xs)), np.repeat(xs[boundaries[i]], len(xs)),alpha=0.15, color=cmap(categories_used[i]))


plt.yticks(fontsize=20)
plt.xticks([-3, 3], ["$y=-1$", "$y=+1$"], fontsize=30) # valid for signed Laplacian setup.


#plt.savefig('binning_plot.pdf', bbox_inches='tight')