Copyright 2021 DeepMind Technologies Limited

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Video compression
Simplified demo showing how to:
1. Load a compressed network and apply it to a video
2. Load an augmentation and apply it to a video

# Setup

In [1]:
# @title Imports
import numpy as np
import os
import seaborn as sns
import haiku as hk
import jax
import jax.numpy as jnp
from IPython.display import Image as display_image
from PIL import Image
import sys
sys.path.append(os.path.abspath('../../'))

from compressed_vision.models import equivariant_networks
from compressed_vision.models import encoder_decoder_unet
from compressed_vision.utils import checkpoint_loader
from compressed_vision.utils import data_utils
from compressed_vision.utils import metric_utils
from compressed_vision.utils import video_utils


sns.set_style("whitegrid")

ModuleNotFoundError: No module named 'compressed_vision'

# Run compression



In [None]:
# @title Load models.
# @markdown Investigate different augmentations or compression.
# @markdown Because the models are larger when using augmentation, the parameters are different.
test_augmentations = True # @param {type: 'boolean'}
_BASE_PATH = 'https://storage.googleapis.com/dm_compressed_vision/models/' # @param {type: 'string'}
_SAVE_PATH = '/tmp/' # @param {type: 'string'}

if not test_augmentations:
  # @markdown Controllable params in non augmentation mode.
  augmentation_type = None
  compression_rate = 192 # @param {type: 'integer'}
  if compression_rate == 192:
    model_path = f'{_BASE_PATH}compression/41861759_1_cr%3D192.pkl'
  if compression_rate == 236:
    model_path = f'{_BASE_PATH}compression/42071788_1_cr%3D236.pkl'
  if compression_rate == 384:
    model_path = f'{_BASE_PATH}compression/41877908_2_cr%3D384.pkl'
  if compression_rate == 384:
    model_path = f'{_BASE_PATH}compression/41748435_4_cr%3D786.pkl'
  NUM_FRAMES = 32 # @param
  BATCH_SIZE = 4
elif test_augmentations:
  # @markdown Controllable params in augmentation mode.
  augmentation_type = 'flip' # @param {type: 'string'} ['flip']
  if augmentation_type == 'flip':
    model_path = f'{_BASE_PATH}augmentations/35225852_7_augm%3Dflip.pkl'
  else:
    raise ValueError(f"Unexpected augmentation {augmentation_type}.")

!wget $model_path -O '/tmp/model_path.pkl'

with open('/tmp/model_path.pkl', 'rb') as f:
  all_params = checkpoint_loader.load_params_state(f)

augm_params = all_params['augm_params']
augm_state = all_params['augm_state']
augm_config = all_params['augm_config']
params = all_params['params']
state = all_params['state']
config = all_params['config']

if augm_config is not None:
  augm_config = augm_config.experiment_kwargs.config
  NUM_FRAMES = augm_config.data.num_frames
  BATCH_SIZE = 1

exp_config = config.experiment_kwargs.config
exp_config.data.train_batch_size = BATCH_SIZE
exp_config.data.eval_batch_size = BATCH_SIZE
exp_config.data.num_frames = NUM_FRAMES

In [None]:
# @title Load a video to test on.
VIDEO_TO_TEST = f'https://storage.googleapis.com/dm_compressed_vision/data/video8.gif'
!wget $VIDEO_TO_TEST -O '/tmp/video.gif'

# Load video.
with open('/tmp/video.gif', 'rb') as f:
  sample_video = Image.open(f)
  sample_video.seek(0)

  images = []
  try:
      while True:
          images.append(np.asarray(sample_video.convert()))
          sample_video.seek(sample_video.tell()+1)
  except EOFError:
      pass
print(f'Length of video: {len(images)} frames.')
sample_video = np.array(images)[:NUM_FRAMES][None, :] / 255.
v = video_utils.video_reshaper(sample_video)
video_utils.save_video((v * 255).astype(np.uint8), '/tmp/display.gif')
display_image(filename='/tmp/display.gif', embed=True)

In [None]:
# @title Encode-decode functions.
def forward_codec_fn():
  codec_model = encoder_decoder_unet.CompressionConvEncoderDecoder(
      num_channels=3,
      **exp_config.model_kwargs,
  )
  return codec_model, {
      'encoder': codec_model.encode,
      'decoder': codec_model.decode,
  }

_, codec_apply_fns = (
    hk.multi_transform_with_state(forward_codec_fn)
)
codec_encoder = codec_apply_fns['encoder']
codec_decoder = codec_apply_fns['decoder']

encode_decode_jitted = jax.jit(
    lambda x: data_utils.encode_decode(
        codec_encoder=codec_encoder,
        codec_decoder=codec_decoder,
        codec_params=params,
        codec_state=state,
        inputs=x,
    )
)

In [None]:
# @title Run compression.
decompressed_video, codes = encode_decode_jitted(sample_video)

cpr = metric_utils.get_compression_rate(
    sample_video,
    codes,
    bits_per_element=exp_config.model_kwargs.vq_num_embeddings,
)
print(f'Compression rate is {cpr}')
print(f'Codes W x H is {codes[0].shape[2]} x {codes[0].shape[3]}')
print(f'Codes channels is {codes[0].shape[-1]}')
print(f'Codes time is {codes[0].shape[1]}')
print(f'Decompressed space is {decompressed_video.shape}')

In [None]:
# @title Visualise reconstructed videos.
e = video_utils.video_reshaper(decompressed_video)

video_utils.save_video(np.clip(e * 255,0,255).astype(np.uint8), '/tmp/display.gif')

err = jnp.mean((sample_video - decompressed_video)**2)
print(f'Mean L2-norm error is {err}')
display_image(filename='/tmp/display.gif', embed=True)

In [None]:
# @title Save video.
ds_name = VIDEO_TO_TEST.split('.')[0]

MAIN_PATH = f'{_SAVE_PATH}{ds_name}/neural_codec-unet/'
if not os.path.exists(MAIN_PATH):
  os.makedirs(MAIN_PATH)
name = os.path.basename(model_path).split('.')[0]

path = f'{MAIN_PATH}/compression-{name}-combined.gif'
video_utils.save_video((e * 255).astype(np.uint8), path)

MAIN_PATH = f'{_SAVE_PATH}{ds_name}/raw_videos/'
if not os.path.exists(MAIN_PATH):
  os.makedirs(MAIN_PATH)
path = f'{MAIN_PATH}/original-{name}.gif'
video_utils.save_video(np.clip(v * 255,0,255).astype(np.uint8), path)

# Run augmentations.

In [None]:
# @title Setup

if augmentation_type == 'flip':
  def get_augmentation(video):
    bs, _, _, _, _ = video.shape
    yes_flip = jnp.ones(shape=(bs, 1, 1))
    return yes_flip
  num_frames = 32
  pixel_width = 256
else:
  raise ValueError(
      "Augmentation type must be set above to run this part of the CoLAB.")


In [None]:
# @title Functions

def forward_codec_fn():
  codec_model = encoder_decoder_unet.CompressionConvEncoderDecoder(
      num_channels=3,
      **exp_config.model_kwargs,
  )
  return codec_model, {
      'encoder': codec_model.encode,
      'decoder': codec_model.decode,
  }

@hk.transform_with_state
def get_codes(inputs):
  sample_video = inputs['image']
  _, codec_apply_fns = (
      hk.multi_transform_with_state(forward_codec_fn)
  )
  codec_encoder = codec_apply_fns['encoder']
  codec_decoder = codec_apply_fns['decoder']

  _, quantized = data_utils.convert_im_to_codes(
      codec_encoder=codec_encoder,
      codec_params=params,
      codec_state=state,
      images=sample_video,
      is_return_quantized=True,
  )

  transformation = inputs['transformation']
  equivariant_model = equivariant_networks.get_equivariant_network(
      augm_config.augmentation.network)
  quantized_original = quantized
  quantized = quantized[:, None].repeat(transformation.shape[1], 1)
  transformation = transformation.reshape((-1, transformation.shape[2]))
  quantized = quantized.reshape((transformation.shape[0],) +
                                quantized.shape[2:])
  transform_quantized = equivariant_model(
      **augm_config.augmentation.kwargs)(quantized, transformation)

  transform_reconstruction = data_utils.convert_codes_to_im(
      codec_decoder=codec_decoder,
      codec_params=params,
      codec_state=state,
      codes=transform_quantized,
      is_quantized=True,
      )

  return transform_reconstruction

In [None]:
# @title Run augmentation
def _bcast_local_devices(array):
    array = jax.tree_map(
        lambda a: a[None, :].repeat(jax.local_device_count(), 0), array)
    return array

pmapped_codes = jax.pmap(jax.jit(get_codes.apply))
bcast_augm_params = _bcast_local_devices(augm_params)
bcast_augm_state = _bcast_local_devices(augm_state)
bcast_rng = _bcast_local_devices(jax.random.PRNGKey(0))
sample_inputs = _bcast_local_devices(sample_video)

transformation = jax.vmap(get_augmentation)(sample_inputs)

inputs = {
    'image': (
        sample_inputs[:,:,:num_frames,:pixel_width, :pixel_width]), 
    'transformation': transformation
}

(transform_reconstruction), _ = pmapped_codes(
    bcast_augm_params, bcast_augm_state, bcast_rng, inputs)

In [None]:
# @title Visualise augmented videos with original video.
video = np.concatenate((inputs['image'], transform_reconstruction), 1)
v = video_utils.video_reshaper(video[0])

video_utils.save_video(np.clip(v * 255,0,255).astype(np.uint8), '/tmp/display.gif')
display_image(filename='/tmp/display.gif', embed=True)

In [None]:
# @title Save augmented video.
ds_name = VIDEO_TO_TEST.split('.')[0]
MAIN_PATH = f'{_SAVE_PATH}{ds_name}/neural_codec-unet/augmented/'
if not os.path.exists(MAIN_PATH):
  os.makedirs(MAIN_PATH)

path = f'{MAIN_PATH}/compression-{name}-combined.gif'
video_utils.save_video(np.clip(v * 255,0,255).astype(np.uint8), path)