In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

In [2]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

In [3]:
from mixtures import discretized_mix_logistic_loss, sample_from_discretized_mix_logistic

In [5]:
class Conv1D(keras.layers.Conv1D):
    def __init__(self, filters, kernel_size, strides=1, padding="causal", dilation_rate=1, use_bias=False, *args, **kwargs):
        super().__init__(filters, kernel_size=kernel_size, strides=strides, padding=padding, dilation_rate=dilation_rate)
        
        ## (issue) Set name other than k and d invoke error : TypeError: unsupported operand type(s) for +: 'int' and 'tuple'
        self.k = kernel_size                
        self.d = dilation_rate

        self.use_bias = use_bias

        if kernel_size > 1:
            self.current_receptive_field = kernel_size + (kernel_size - 1) * (dilation_rate - 1)       # == queue_len (tf2)
            self.residual_channels = residual_channels
            self.queue = tf.zeros([1, self.current_receptive_field, filters])

    def build(self, x_shape):
        super().build(x_shape)

        self.linearized_weights = tf.cast(tf.reshape(self.kernel, [-1, self.filters]), dtype=tf.float32)

    def call(self, x, training=False):
        if not training:
            return super().call(x)

        if self.kernel_size > 1:
            self.queue = self.queue[:, 1:, :]
            self.queue = tf.concat([self.queue, tf.expand_dims(x[:, -1, :], axis=1)], axis=1)

            if self.dilation_rate > 1:
                x = self.queue[:, 0::self.d, :]
            else:
                x = self.queue

            outputs = tf.matmul(tf.reshape(x, [1, -1]), self.linearized_weights)
            
            if self.use_bias:
                outputs = tf.nn.bias_add(outputs, self.bias)

            return tf.reshape(outputs, [-1, 1, self.filters])

    #def init_queue(self):
        

In [150]:
class ResidualBlock(keras.Model):
    def __init__(self, layer_index, dilation, filter_width, dilation_channels, residual_channels, skip_channels, use_biases, output_width):
        super().__init__()

        self.layer_index = layer_index
        self.dilation = dilation
        self.filter_width = filter_width
        self.dilation_channels = dilation_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.use_biases = use_biases
        self.output_width = output_width        # = Receptive Field

    def build(self, input_shape):
        with tf.name_scope("residual_block_{}".format(self.layer_index)):
            self.conv_filter = keras.layers.Conv1D(
                filters=self.dilation_channels,
                kernel_size=self.filter_width,
                dilation_rate=self.dilation,
                padding='valid',
                use_bias=self.use_biases,
                name="conv_filter"
            )
            self.conv_gate = keras.layers.Conv1D(
                filters=self.dilation_channels,
                kernel_size=self.filter_width,
                dilation_rate=self.dilation,
                padding='valid',
                use_bias=self.use_biases,
                name="conv_gate"
            )
            ## transformed : 1x1 conv to out (= gate * filter) to produce residuals (= dense output)
            ## conv_residual (=skip_contribution in original)
            self.conv_residual = keras.layers.Conv1D(
                filters=self.residual_channels,
                kernel_size=1,
                padding="same",
                use_bias=self.use_biases,
                name="dense"
            )
            self.conv_skip = keras.layers.Conv1D(
                filters=self.skip_channels,
                kernel_size=1,
                padding="same",
                use_bias=self.use_biases,
                name="skip"
            )


    @tf.function
    def call(self, inputs, training=False):
        out = tf.tanh(self.conv_filter(inputs)) * tf.sigmoid(self.conv_gate(inputs))
        
        ## skip_output (=skip contribution in original) : Summed up to create output
        skip_cut = tf.shape(out)[1] - self.output_width
        out_skip = tf.slice(out, [0, skip_cut, 0], [-1, -1, self.dilation_channels])
        skip_output = self.conv_skip(out_skip)

        transformed = self.conv_residual(out)
        input_cut = tf.shape(x)[1] - tf.shape(transformed)[1]
        x_cut = tf.slice(x, [0, input_cut, 0], [-1, -1, -1])
        dense_output = x_cut + transformed

        return skip_output, dense_output

In [151]:
class PostProcessing(keras.Model):
    def __init__(self, skip_channels, out_channels, use_biases):
        super().__init__()

        self.skip_channels = skip_channels
        self.out_channels = out_channels
        self.use_biases = use_biases

    def build(self, input_shape):
        with tf.name_scope("postprocessing"):
            self.conv_1 = keras.layers.Conv1D(
                filters=self.skip_channels,
                kernel_size=1,
                padding="same",
                use_bias=self.use_biases
            )
            self.conv_2 = keras.layers.Conv1D(
                filters=self.out_channels,
                kernel_size=1,
                padding="same",
                use_bias=self.use_biases
            )
    
    @tf.function
    def call(self, inputs, training=False):
        x = tf.nn.relu(inputs)
        x = self.conv_1(x)

        x = tf.nn.relu(x)
        x = self.conv_2(x)

        return x

In [172]:
class WaveNet(keras.Model):
    def __init__(self, batch_size, dilations, filter_width, initial_filter_width, dilation_channels, residual_channels, skip_channels, quantization_channels, out_channels, use_biases):
        super().__init__()

        self.batch_size = batch_size
        self.dilations = dilations
        self.filter_width = filter_width
        self.initial_filter_width = initial_filter_width
        self.dilation_channels = dilation_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.quantization_channels = quantization_channels
        self.out_channels = out_channels
        self.use_biases = use_biases

        self.receptive_field = (self.filter_width - 1) * sum(self.dilations) + self.initial_filter_width

    def build(self, input_shape):
        #self.receptive_field = input_shape[1] - sum(self.dilations)      
        self.output_width = input_shape[1] - self.receptive_field + 1       # total output width of model

        with tf.name_scope("preprocessing"):
            self.preprocessing_layer = keras.layers.Conv1D(
                filters=self.residual_channels,
                kernel_size=self.initial_filter_width,
                use_bias=self.use_biases)

        self.residual_blocks = []
        for _ in range(1):
            for i, dilation in enumerate(self.dilations):
                self.residual_blocks.append(
                    ResidualBlock(
                        layer_index=i,
                        dilation=self.dilations[0], 
                        filter_width=self.filter_width, 
                        dilation_channels=self.dilation_channels, 
                        residual_channels=self.residual_channels, 
                        skip_channels=self.skip_channels, 
                        use_biases=self.use_biases, 
                        output_width=self.output_width)
                    )

        self.postprocessing_layer = PostProcessing(self.skip_channels, self.out_channels, self.use_biases)

    @tf.function
    def call(self, inputs, training=False):
        '''
        == predict_proba_incremental

        Assume that x is integer (== scalar_input = True)
        '''

        x = self.preprocessing_layer(inputs)
        skip_outputs = []

        for layer_index in range(len(self.dilations)):
            skip_output, x = self.residual_blocks[layer_index](x)
            skip_outputs.append(skip_output)

        skip_sum = tf.math.add_n(skip_outputs)          
        
        raw_output = self.postprocessing_layer(skip_sum)

        out = tf.reshape(raw_output, [self.batch_size, -1, self.out_channels])
        proba = sample_from_discretized_mix_logistic(out)

        return proba

In [173]:
x = tf.random.uniform([1, 1024, 1], minval=0, maxval=255, dtype=tf.int32)
x = tf.cast(x, tf.float32)
x

<tf.Tensor: shape=(1, 1024, 1), dtype=float32, numpy=
array([[[122.],
        [ 63.],
        [110.],
        ...,
        [222.],
        [200.],
        [154.]]], dtype=float32)>

In [174]:
# HParms follows the Diagram 
batch_size = 1
dilations = [1, 2, 4, 8, 16, 32, 64, 128]
filter_width = 2        # == kernel_size
initial_filter_width = 32       # from tacokr
dilation_channels = 32  # unknown
residual_channels = 24
skip_channels = 128
quantization_channels = 2**8
out_channels = 10*3
use_biases = False

In [175]:
wavenet = WaveNet(batch_size, dilations, filter_width, initial_filter_width, dilation_channels, residual_channels, skip_channels, quantization_channels, out_channels, use_biases)

In [176]:
wavenet(x)

<tf.Tensor: shape=(1, 738), dtype=float32, numpy=
array([[ -3.9543817,  -1.       ,  -1.5890498,  -1.       ,  -4.748875 ,
        -10.696519 ,  -1.       ,  -1.       ,  -6.1428704,  -1.       ,
         -1.       ,  -1.       ,  -1.       ,  -1.       ,  -2.676597 ,
         -1.       ,  -1.       ,  -1.       ,  -1.       ,  -1.       ,
         -8.228392 ,  -2.2503147,  -1.4384542,  -2.3819356,  -1.       ,
         -1.       ,  -2.6098852,  -1.       ,  -2.3426914,  -1.       ,
         -3.7028098,  -3.713524 ,  -2.4401789,  -1.       ,  -1.       ,
         -1.       ,  -1.       ,  -1.       ,  -7.271158 ,  -1.1489387,
         -1.5377283,  -1.       ,  -1.       ,  -1.       ,  -1.       ,
         -1.       ,  -1.       ,  -1.       ,  -1.       ,  -1.       ,
         -1.       ,  -1.       ,  -1.       ,  -3.0452144,  -1.       ,
         -1.       ,  -1.       ,  -1.       ,  -1.       ,  -1.       ,
         -1.       ,  -1.       ,  -1.       ,  -1.2507843,  -1.       ,
 

In [171]:
wavenet.output_width

738