To run this colab you can use your own colab setup or try
[Sandwich Video Compression Lowres Codec](https://colab.research.google.com/github/google/sandwiched_compression/blob/main/sandwich_video_compression_lowres_codec.ipynb).


In [None]:
!pip install -q mediapy tensorflow-datasets==4.9.4

In [None]:
!if [ ! -f compress_intra_model.py ]; then \
  git clone https://github.com/google/sandwiched_compression; \
  mv sandwiched_compression/* .; \
fi

In [None]:
import functools
import logging
import numpy as np
import tensorflow as tf

import mediapy as media
import compress_intra_model
import compress_video_model
import datasets
from distortion import distortion_fns

# Setup

In [None]:
# You will need to download the sandwich_video_research_444 dataset.
# The dataset is in TFRecord format. For dataset keys see datasets._video_data
# Please see the README for further details.
dataset_path = '<Please insert the path to the downloaded dataset.>'


def dataset_fn(
    batch_size: int, is_training: bool, take_count: int = 100
) -> tf.data.Dataset:
  return datasets.load_video_dataset(
      path=dataset_path,
      batch_size=batch_size,
      is_training=is_training,
  ).take(take_count)

In [None]:
# It is recommended to train a loop filter with the checkpoint saved to the
# loop_filter_folder. See compress_intra_model.create_loop_filter_model()
def create_lowres_codec_model(gamma: float) -> tf.keras.Model:
  return compress_video_model.create_basic_model(
      model_keys=('clip',),
      bottleneck_channels=3,  # 3-channel bottleneck.
      output_channels=3,
      num_mlp_layers=2,
      use_video_codec_rate_model=True,
      downsample_factor=2,  # Lowres codec at half-resolution.
      gamma=gamma,
      loop_filter_folder=None,
      video_is_420=False,
      codec_proxy_is_420=False,
  )

In [None]:
def create_loss_fn(gamma: float) -> tf.keras.losses.Loss:
  distortion_fn = functools.partial(
      distortion_fns.distortion_l2norm, image_key='clip', scaler=1
  )

  return compress_intra_model.create_basic_loss(
      gamma=gamma,
      distortion_fn=distortion_fn,
      add_valid_bottleneck_pixels_penalty=False,
  )

In [None]:
train_batch_size = 3  # Try upping this to 8 if your gpu allows.
train_dataset = dataset_fn(train_batch_size, True)  # Pull from train split.
eval_batch_size = 1
eval_dataset = dataset_fn(eval_batch_size, False)  # Pull from eval split.

In [None]:
gamma = 50  # Lagrange multiplier
base_model = create_lowres_codec_model(gamma)

learning_rate = 1e-4
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

check_point_path = '/tmp/sandwich_video_compression_lowres_codec/checkpoints'
checkpoint = tf.train.Checkpoint(
    model=base_model,
    optimizer=optimizer,
    step=optimizer.iterations,
    epoch=tf.Variable(0, dtype=tf.int64, trainable=False),
    loss=tf.Variable(tf.float32.max, dtype=tf.float32, trainable=False),
    best_loss=tf.Variable(tf.float32.max, dtype=tf.float32, trainable=False),
    training_finished=tf.Variable(False, dtype=tf.bool, trainable=False),
)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, directory=check_point_path, max_to_keep=10
)

epoch = 0
if checkpoint_manager.latest_checkpoint:
  # Restore the model and optimizer state from earlier training.
  logging.info('Restoring from %s', check_point_path)
  checkpoint.restore(checkpoint_manager.latest_checkpoint)
  # Adjust next epoch to continue training from.
  epoch = checkpoint.epoch.numpy() + 1

# Train

In [None]:
# Simple trainer. It is recommended to use a custom trainer and train to
# convergence.
num_epochs = 1000
epoch_stat = tf.keras.metrics.Mean()
loss_fn = create_loss_fn(gamma=gamma)

while epoch < num_epochs:
  for train_batch in train_dataset:
    with tf.GradientTape() as tape:
      out = base_model(train_batch)
      loss = loss_fn(train_batch, out)

      gradients = tape.gradient(loss, base_model.trainable_variables)
      optimizer.apply_gradients(zip(gradients, base_model.trainable_variables))
      epoch_stat(loss)

  # Note each epoch is over a varying set of take_count x batch_size clips.
  # Calculate a median or change the dataset loader to always use the same set
  # if you prefer.
  print(f'Epoch {epoch:=4d}/{num_epochs:=4d} Loss: {epoch_stat.result():=4.4f}')

  if epoch % 100 == 0:
    checkpoint.epoch.assign(epoch)
    checkpoint.loss.assign(epoch_stat.result())
    checkpoint.best_loss.assign(epoch_stat.result())
    checkpoint.training_finished.assign(epoch == num_epochs)
    checkpoint_manager.save(checkpoint_number=optimizer.iterations.numpy())

  epoch += 1
  epoch_stat.reset_state()

# Eval

In [None]:
def convert_to_rgb(yuv: np.ndarray) -> np.ndarray:
  channels = yuv.shape[-1]
  assert channels <= 3
  # Start with an all gray clip.
  rgb = 0.5 * np.ones([*yuv.shape[:-1], 3])
  rgb[..., :channels] = np.clip(yuv[..., :channels], 0, 1)
  return media.rgb_from_yuv(rgb[..., :3])


def modify_title(title: str) -> str:
  # We are showing the bottlenecks 2x zoomed.
  if title in ['bottleneck', 'recons_bottleneck']:
    return title + ' (2x)'
  return title


def simple_linear_path(inputs: np.ndarray) -> np.ndarray:
  """Emulates lienar down-up without compression."""
  sample = tf.convert_to_tensor(inputs)
  assert tf.rank(sample) >= 4
  factor = base_model.downsample_factor
  sample_rank_four = tf.reshape(sample, [-1, *sample.shape[-3:]])
  low_res = tf.image.resize(
      sample_rank_four,
      size=[sample.shape[-3] // factor, sample.shape[-2] // factor],
      method=tf.image.ResizeMethod.BICUBIC,
  )
  return tf.reshape(
      tf.image.resize(
          low_res,
          size=[sample.shape[-3], sample.shape[-2]],
          method=tf.image.ResizeMethod.LANCZOS3,
      ),
      sample.shape,
  ).numpy()

In [None]:
# It is recommended to generate R-D curves by training multiple models for
# multiple gammas, then evaluate each model for multiple qsteps, and construct
# the Pareto frontier. Please see the paper for details:
# https://arxiv.org/abs/2402.05887

# Discussion on the results shown below:
#
# For the low-res codec scenario pay attention to areas where the simple linear
# has lost detail, e.g., text and textures, has aliasing, e.g., merging
# lines/edges, etc. Notice how much better the model predictions are and also
# notice what compressed-bottlenecks transport in these areas. While
# compressed-bottlenecks are typically aliased and pixellated with
# modulation-like patterns the post-procesor manages to demodulate these into
# clear video. Running post-processing-only models will typically generate wrong
# results in these areas. Please see the paper for examples.
#
# One can design models to hallucinate detail but it is important to understand
# that hallucination is not accurate transport. When one watches a movie one
# wants to see it as the director, cinematographer, etc., have intended it. One
# does not want to see some model's hallucinated reinterpretation of the
# art/reality.

# Clips to show. Can also look at the proxy rate through 'rate', calculate
# distortion or whatever else you would like.
show_keys = ['prediction', 'recons_bottleneck']
show_count = 10

for idx, sample in enumerate(eval_dataset.as_numpy_iterator()):
  if idx >= show_count:
    break

  # Path 1: Simple demo:
  # Run the pre-processor, codec-proxy, and the post-procesor.
  output = base_model(sample)

  # Path 2: Actual performance with your codec:
  # Run the pre-processor, your codec, then post-processor.
  #
  # base_model.set_bit_depth(tf.cast(sample['bit_depth'], tf.float32))
  # bottlenecks = []
  # recons_bottlenecks = []
  # predictions = []
  # for i in range(sample['clip'].shape[0]):  # Run over a batch.
  #   bottleneck = base_model.run_preprocessor(sample['clip'][i], training=False)
  #   bottlenecks.append(bottleneck)
  # recons_bottlenecks = insert_your_video_codec_binary(bottlenecks)
  # for i in range(sample['clip'].shape[0]):
  #   predictions.append(
  #       base_model.run_postprocessor(recons_bottlenecks[i], training=False)
  #   )
  #
  # output = {
  #     'prediction': predictions,
  #     'bottleneck': bottlenecks,
  #     'recons_bottleneck': recons_bottlenecks,
  # }

  bit_depth = sample['bit_depth'][0][0]
  max_pixel = 2**bit_depth - 1
  clips = {
      f'original ({bit_depth:=2d}-bit)': convert_to_rgb(
          sample['clip'][0] / max_pixel
      )
  }
  clips['simple_linear'] = convert_to_rgb(
      simple_linear_path(sample['clip'][0] / max_pixel)
  )
  clips.update({
      modify_title(key): convert_to_rgb(value[0] / max_pixel)
      for key, value in output.items()
      if key in show_keys
  })
  media.show_videos(clips, fps=5, height=400)  # bottlenecks zoomed 2x.

0,1,2,3
original ( 8-bit)  This browser does not support the video tag.,simple_linear  This browser does not support the video tag.,prediction  This browser does not support the video tag.,recons_bottleneck (2x)  This browser does not support the video tag.


0,1,2,3
original ( 8-bit)  This browser does not support the video tag.,simple_linear  This browser does not support the video tag.,prediction  This browser does not support the video tag.,recons_bottleneck (2x)  This browser does not support the video tag.


0,1,2,3
original ( 8-bit)  This browser does not support the video tag.,simple_linear  This browser does not support the video tag.,prediction  This browser does not support the video tag.,recons_bottleneck (2x)  This browser does not support the video tag.


0,1,2,3
original ( 8-bit)  This browser does not support the video tag.,simple_linear  This browser does not support the video tag.,prediction  This browser does not support the video tag.,recons_bottleneck (2x)  This browser does not support the video tag.


0,1,2,3
original ( 8-bit)  This browser does not support the video tag.,simple_linear  This browser does not support the video tag.,prediction  This browser does not support the video tag.,recons_bottleneck (2x)  This browser does not support the video tag.


0,1,2,3
original ( 8-bit)  This browser does not support the video tag.,simple_linear  This browser does not support the video tag.,prediction  This browser does not support the video tag.,recons_bottleneck (2x)  This browser does not support the video tag.


0,1,2,3
original ( 8-bit)  This browser does not support the video tag.,simple_linear  This browser does not support the video tag.,prediction  This browser does not support the video tag.,recons_bottleneck (2x)  This browser does not support the video tag.


0,1,2,3
original ( 8-bit)  This browser does not support the video tag.,simple_linear  This browser does not support the video tag.,prediction  This browser does not support the video tag.,recons_bottleneck (2x)  This browser does not support the video tag.


0,1,2,3
original ( 8-bit)  This browser does not support the video tag.,simple_linear  This browser does not support the video tag.,prediction  This browser does not support the video tag.,recons_bottleneck (2x)  This browser does not support the video tag.


0,1,2,3
original ( 8-bit)  This browser does not support the video tag.,simple_linear  This browser does not support the video tag.,prediction  This browser does not support the video tag.,recons_bottleneck (2x)  This browser does not support the video tag.
