In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_probability as tfp
import matplotlib.pyplot as plt

from time import time
from tqdm import tqdm

import os

In [None]:
class H_a(tf.keras.layers.Layer):
    """Encoder network for the Hyperprior."""
    
    def __init__(self, N):
        """Initializes the encoder."""
        
        super(H_a, self).__init__()
        self.N      = N

        self.conv1  = tf.keras.layers.Conv2D(self.N, 3, strides=1, activation='relu')
        self.conv2  = tf.keras.layers.Conv2D(self.N, 5, strides=2, activation='relu')
        self.conv3  = tf.keras.layers.Conv2D(self.N, 5, strides=2)

    def call(self, inputs):
        """Forward pass of the encoder."""
        x = tf.abs(inputs)
        x = self.conv1(x)
        x = self.conv2(x)
        z = self.conv3(x)
        return z

In [None]:
class H_s(tf.keras.layers.Layer):
    """Decocer network for the Hyperprior."""
    
    def __init__(self, N, M):
        """Initializes the decoder."""
        
        super(H_s, self).__init__()
        self.N      = N
        self.M      = M
        
        self.conv1  = tf.keras.layers.Conv2DTranspose(self.N, 5, strides=2, activation='relu')
        self.conv2  = tf.keras.layers.Conv2DTranspose(self.N, 5, strides=2, activation='relu')
        self.conv3  = tf.keras.layers.Conv2DTranspose(self.M, 3, strides=1, activation='relu')

    def call(self, inputs):
        """Forward pass of the decoder."""
        x = self.conv1(inputs)
        x = self.conv2(x)
        z = self.conv3(x)
        return z