In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

In [2]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

In [3]:
from mixtures import discretized_mix_logistic_loss, sample_from_discretized_mix_logistic

In [4]:
class Conv1D(keras.layers.Conv1D):
    def __init__(self, filters, kernel_size, strides=1, padding="causal", dilation_rate=1, use_bias=False, *args, **kwargs):
        super().__init__(filters, kernel_size=kernel_size, strides=strides, padding=padding, dilation_rate=dilation_rate)
        
        ## (issue) Set name other than k and d invoke error : TypeError: unsupported operand type(s) for +: 'int' and 'tuple'
        self.k = kernel_size                
        self.d = dilation_rate

        self.use_bias = use_bias

        if kernel_size > 1:
            self.current_receptive_field = kernel_size + (kernel_size - 1) * (dilation_rate - 1)       # == queue_len (tf2)
            self.residual_channels = residual_channels
            self.queue = tf.zeros([1, self.current_receptive_field, filters])

    def build(self, x_shape):
        super().build(x_shape)

        self.linearized_weights = tf.cast(tf.reshape(self.kernel, [-1, self.filters]), dtype=tf.float32)

    def call(self, x, training=False):
        if not training:
            return super().call(x)

        if self.kernel_size > 1:
            self.queue = self.queue[:, 1:, :]
            self.queue = tf.concat([self.queue, tf.expand_dims(x[:, -1, :], axis=1)], axis=1)

            if self.dilation_rate > 1:
                x = self.queue[:, 0::self.d, :]
            else:
                x = self.queue

            outputs = tf.matmul(tf.reshape(x, [1, -1]), self.linearized_weights)
            
            if self.use_bias:
                outputs = tf.nn.bias_add(outputs, self.bias)

            return tf.reshape(outputs, [-1, 1, self.filters])

    #def init_queue(self):
        

In [5]:
class ResidualBlock(keras.Model):
    def __init__(self, layer_index, dilation, filter_width, dilation_channels, residual_channels, skip_channels, use_biases, output_width):
        super().__init__()

        self.layer_index = layer_index
        self.dilation = dilation
        self.filter_width = filter_width
        self.dilation_channels = dilation_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.use_biases = use_biases
        self.output_width = output_width

    def build(self, input_shape):
        self.conv_filter = keras.layers.Conv1D(
            filters=self.dilation_channels,
            kernel_size=self.filter_width,
            dilation_rate=self.dilation,
            padding='valid',
            use_bias=self.use_biases,
            name="residual_block_{}/conv_filter".format(self.layer_index)
        )
        self.conv_gate = keras.layers.Conv1D(
            filters=self.dilation_channels,
            kernel_size=self.filter_width,
            dilation_rate=self.dilation,
            padding='valid',
            use_bias=self.use_biases,
            name="residual_block_{}/conv_gate".format(self.layer_index)
        )
        ## transformed : 1x1 conv to out (= gate * filter) to produce residuals (= dense output)
        ## conv_residual (=skip_contribution in original)
        self.conv_residual = keras.layers.Conv1D(
            filters=self.residual_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="residual_block_{}/dense".format(self.layer_index)
        )
        self.conv_skip = keras.layers.Conv1D(
            filters=self.skip_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="residual_block_{}/skip".format(self.layer_index)
        )


    @tf.function
    def call(self, inputs, training=False):
        out = tf.tanh(self.conv_filter(inputs)) * tf.sigmoid(self.conv_gate(inputs))
        
        if training:
            skip_cut = tf.shape(out)[1] - self.output_width
        else:
            skip_cut = tf.shape(out)[1] - 1
            
        out_skip = tf.slice(out, [0, skip_cut, 0], [-1, -1, self.dilation_channels])
        skip_output = self.conv_skip(out_skip)

        transformed = self.conv_residual(out)
        input_cut = tf.shape(inputs)[1] - tf.shape(transformed)[1]
        x_cut = tf.slice(inputs, [0, input_cut, 0], [-1, -1, -1])
        dense_output = x_cut + transformed

        return skip_output, dense_output

In [6]:
class PostProcessing(keras.Model):
    def __init__(self, skip_channels, quantization_channels, use_biases):
        super().__init__()

        self.skip_channels = skip_channels
        self.quantization_channels = quantization_channels
        self.use_biases = use_biases

    def build(self, input_shape):
        self.conv_1 = keras.layers.Conv1D(
            filters=self.skip_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="postprocessing/conv_1"
        )
        '''
        # For Scalar output
        self.conv_2 = keras.layers.Conv1D(
            filters=self.out_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="postprocessing/conv_2"
        )
        '''
        self.conv_2 = keras.layers.Conv1D(
            filters=self.quantization_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="postprocessing/conv_2"
        )
    
    @tf.function
    def call(self, inputs, training=False):
        x = tf.nn.relu(inputs)
        x = self.conv_1(x)

        x = tf.nn.relu(x)
        x = self.conv_2(x)

        return x

In [27]:
class WaveNet(keras.Model):
    def __init__(self, batch_size, dilations, filter_width, initial_filter_width, dilation_channels, residual_channels, skip_channels, quantization_channels=None, out_channels=None, use_biases=False):
        super().__init__()

        self.batch_size = batch_size
        self.dilations = dilations
        self.filter_width = filter_width
        self.initial_filter_width = initial_filter_width
        self.dilation_channels = dilation_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.quantization_channels = quantization_channels
        self.out_channels = out_channels
        self.use_biases = use_biases

        # Scalar Input receptive field
        self.receptive_field = (self.filter_width - 1) * sum(self.dilations) + self.initial_filter_width

    def build(self, input_shape):     
        self.output_width = input_shape[1] - self.receptive_field + 1       # total output width of model

        self.preprocessing_layer = keras.layers.Conv1D(
            filters=self.residual_channels,
            kernel_size=self.initial_filter_width,
            use_bias=self.use_biases,
            name="preprocessing/conv")

        self.residual_blocks = []
        for _ in range(1):
            for i, dilation in enumerate(self.dilations):
                self.residual_blocks.append(
                    ResidualBlock(
                        layer_index=i,
                        dilation=self.dilations[0], 
                        filter_width=self.filter_width, 
                        dilation_channels=self.dilation_channels, 
                        residual_channels=self.residual_channels, 
                        skip_channels=self.skip_channels, 
                        use_biases=self.use_biases, 
                        output_width=self.output_width)
                    )

        self.postprocessing_layer = PostProcessing(self.skip_channels, self.quantization_channels, self.use_biases)

    @tf.function(experimental_relax_shapes=True)
    def call(self, inputs, training=False):
        '''
        == predict_proba_incremental

        Assume that x is integer (== scalar_input = True)
        '''

        x = self.preprocessing_layer(inputs)
        skip_outputs = []

        for layer_index in range(len(self.dilations)):
            skip_output, x = self.residual_blocks[layer_index](x, training=training)
            skip_outputs.append(skip_output)
      
        skip_sum = tf.math.add_n(skip_outputs)          
        
        output = self.postprocessing_layer(skip_sum)

        if not training:
            out = tf.reshape(output, [self.batch_size, -1, self.quantization_channels])
            #output = sample_from_discretized_mix_logistic(out)
            output = tf.cast(tf.nn.softmax(tf.cast(out, tf.float64)), tf.float32)

        return output

    def train_step(self, data): 
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred)
            #loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=y)
            reduced_loss = tf.math.reduce_mean(loss)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

In [18]:
class DiscretizedMixLogisticLoss(keras.losses.Loss):
    def __init__(self, name="discretized_mix_logistic_loss"):
        super().__init__(name=name)

    def call(self, y_true, y_pred):
        loss =  discretized_mix_logistic_loss(y_pred, y_true)
        return tf.reduce_mean(loss)

inputs = tf.random.uniform([1, 1025, 1], minval=0, maxval=255, dtype=tf.int32)
inputs = tf.cast(inputs, tf.float32)
inputs

In [19]:
rg = tf.range(255)
inputs = tf.concat([rg, rg, rg, rg, rg], axis=0)[:1025]
inputs = tf.cast(tf.reshape(inputs, [1, 1025, 1]), tf.float32)
#inputs = tf.reshape(inputs, [1, 1025, 1])
inputs

<tf.Tensor: shape=(1, 1025, 1), dtype=float32, numpy=
array([[[0.],
        [1.],
        [2.],
        ...,
        [2.],
        [3.],
        [4.]]], dtype=float32)>

In [20]:
#x_train = inputs[:, :1024, :]
x_train = inputs[:, :287, :]
#y_train = inputs[:, 287:, :]
y_train = inputs[:, 287:288, :]

In [21]:
x_train.shape, y_train.shape

(TensorShape([1, 287, 1]), TensorShape([1, 1, 1]))

In [37]:
y_train

<tf.Tensor: shape=(1, 1, 1), dtype=int32, numpy=array([[[32]]])>

In [22]:
y_train = tf.cast(y_train, tf.int32)
y_train_onehot = tf.one_hot(y_train, depth=2**8)
y_train_onehot = tf.squeeze(y_train_onehot, axis=2)
y_train_onehot

<tf.Tensor: shape=(1, 1, 256), dtype=float32, numpy=
array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [23]:
# HParms follows the Diagram 
batch_size = 1
dilations = [1, 2, 4, 8, 16, 32, 64, 128]
filter_width = 2        # == kernel_size
initial_filter_width = 32       # from (tacokr)
dilation_channels = 32  # unknown
residual_channels = 24
skip_channels = 128
quantization_channels = 2**8
out_channels = 10*3
use_biases = False

#wavenet.receptive_field = 287
#wavenet.output_width = 738

In [28]:
wavenet = WaveNet(batch_size, dilations, filter_width, initial_filter_width, dilation_channels, residual_channels, skip_channels, quantization_channels, out_channels, use_biases)

In [31]:
wavenet.compile(keras.optimizers.Nadam(), loss=keras.losses.CategoricalCrossentropy(from_logits=True))

In [32]:
#wavenet.fit(x=x, y=x)
wavenet.fit(x_train, y_train_onehot, batch_size=8, epochs=1000)

0
Epoch 768/1000
Epoch 769/1000
Epoch 770/1000
Epoch 771/1000
Epoch 772/1000
Epoch 773/1000
Epoch 774/1000
Epoch 775/1000
Epoch 776/1000
Epoch 777/1000
Epoch 778/1000
Epoch 779/1000
Epoch 780/1000
Epoch 781/1000
Epoch 782/1000
Epoch 783/1000
Epoch 784/1000
Epoch 785/1000
Epoch 786/1000
Epoch 787/1000
Epoch 788/1000
Epoch 789/1000
Epoch 790/1000
Epoch 791/1000
Epoch 792/1000
Epoch 793/1000
Epoch 794/1000
Epoch 795/1000
Epoch 796/1000
Epoch 797/1000
Epoch 798/1000
Epoch 799/1000
Epoch 800/1000
Epoch 801/1000
Epoch 802/1000
Epoch 803/1000
Epoch 804/1000
Epoch 805/1000
Epoch 806/1000
Epoch 807/1000
Epoch 808/1000
Epoch 809/1000
Epoch 810/1000
Epoch 811/1000
Epoch 812/1000
Epoch 813/1000
Epoch 814/1000
Epoch 815/1000
Epoch 816/1000
Epoch 817/1000
Epoch 818/1000
Epoch 819/1000
Epoch 820/1000
Epoch 821/1000
Epoch 822/1000
Epoch 823/1000
Epoch 824/1000
Epoch 825/1000
Epoch 826/1000
Epoch 827/1000
Epoch 828/1000
Epoch 829/1000
Epoch 830/1000
Epoch 831/1000
Epoch 832/1000
Epoch 833/1000
Epoch 83

<tensorflow.python.keras.callbacks.History at 0x2b222420248>

In [33]:
pred_training = wavenet(x_train, training=True)
pred_training

<tf.Tensor: shape=(1, 1, 256), dtype=float32, numpy=
array([[[ -5.786551 ,  -2.552747 ,  -3.067538 ,  -5.2054114,
          -1.7160194,  -3.3329356,  -3.630709 ,  -8.918618 ,
          -5.8658834,  -4.398642 ,  -3.149269 ,  -3.085446 ,
          -4.090922 ,  -3.0003588,  -2.8042355,  -4.095027 ,
          -5.67526  ,  -6.582371 ,  -2.8807018,  -2.9774275,
          -2.5274134,  -2.8233483,  -2.9248116,  -1.6796691,
          -5.1662602,  -3.086809 ,  -5.8180456,  -2.127972 ,
          -4.9839396,  -2.3052182,  -3.6668222,  -2.8676462,
          19.67266  ,  -2.882491 ,  -5.4081235,  -2.665713 ,
          -4.8926096,  -3.022813 ,  -2.7793763,  -2.4388816,
          -5.096632 ,  -8.119506 ,  -2.2289367,  -2.7809553,
          -4.3246856,  -5.8219786,  -4.32346  , -12.484178 ,
          -3.6884115,  -4.4123983,  -5.8725643,  -5.1686306,
          -2.508771 ,  -5.0169435,  -3.159831 ,  -3.145667 ,
          -3.8122768,  -5.2268686,  -3.1002738,  -3.508804 ,
          -2.8082619,  -2.488234

In [36]:
tf.argmax(wavenet(x_train), axis=-1)

<tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[32]], dtype=int64)>