In [529]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

In [2]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

In [579]:
from mixtures import discretized_mix_logistic_loss, sample_from_discretized_mix_logistic

In [4]:
class Conv1D(keras.layers.Conv1D):
    def __init__(self, filters, kernel_size, strides=1, padding="causal", dilation_rate=1, use_bias=False, *args, **kwargs):
        super().__init__(filters, kernel_size=kernel_size, strides=strides, padding=padding, dilation_rate=dilation_rate)
        
        ## (issue) Set name other than k and d invoke error : TypeError: unsupported operand type(s) for +: 'int' and 'tuple'
        self.k = kernel_size                
        self.d = dilation_rate

        self.use_bias = use_bias

        if kernel_size > 1:
            self.current_receptive_field = kernel_size + (kernel_size - 1) * (dilation_rate - 1)       # == queue_len (tf2)
            self.residual_channels = residual_channels
            self.queue = tf.zeros([1, self.current_receptive_field, filters])

    def build(self, x_shape):
        super().build(x_shape)

        self.linearized_weights = tf.cast(tf.reshape(self.kernel, [-1, self.filters]), dtype=tf.float32)

    def call(self, x, training=False):
        if not training:
            return super().call(x)

        if self.kernel_size > 1:
            self.queue = self.queue[:, 1:, :]
            self.queue = tf.concat([self.queue, tf.expand_dims(x[:, -1, :], axis=1)], axis=1)

            if self.dilation_rate > 1:
                x = self.queue[:, 0::self.d, :]
            else:
                x = self.queue

            outputs = tf.matmul(tf.reshape(x, [1, -1]), self.linearized_weights)
            
            if self.use_bias:
                outputs = tf.nn.bias_add(outputs, self.bias)

            return tf.reshape(outputs, [-1, 1, self.filters])

    #def init_queue(self):
        

In [510]:
class ResidualBlock(keras.Model):
    def __init__(self, layer_index, dilation, filter_width, dilation_channels, residual_channels, skip_channels, use_biases, output_width):
        super().__init__()

        self.layer_index = layer_index
        self.dilation = dilation
        self.filter_width = filter_width
        self.dilation_channels = dilation_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.use_biases = use_biases
        self.output_width = output_width

    def build(self, input_shape):
        self.conv_filter = keras.layers.Conv1D(
            filters=self.dilation_channels,
            kernel_size=self.filter_width,
            dilation_rate=self.dilation,
            padding='valid',
            use_bias=self.use_biases,
            name="residual_block_{}/conv_filter".format(self.layer_index)
        )
        self.conv_gate = keras.layers.Conv1D(
            filters=self.dilation_channels,
            kernel_size=self.filter_width,
            dilation_rate=self.dilation,
            padding='valid',
            use_bias=self.use_biases,
            name="residual_block_{}/conv_gate".format(self.layer_index)
        )
        ## transformed : 1x1 conv to out (= gate * filter) to produce residuals (= dense output)
        ## conv_residual (=skip_contribution in original)
        self.conv_residual = keras.layers.Conv1D(
            filters=self.residual_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="residual_block_{}/dense".format(self.layer_index)
        )
        self.conv_skip = keras.layers.Conv1D(
            filters=self.skip_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="residual_block_{}/skip".format(self.layer_index)
        )


    @tf.function
    def call(self, inputs, training=False):
        out = tf.tanh(self.conv_filter(inputs)) * tf.sigmoid(self.conv_gate(inputs))
        
        if training:
            skip_cut = tf.shape(out)[1] - self.output_width
        else:
            skip_cut = tf.shape(out)[1] - 1
            
        out_skip = tf.slice(out, [0, skip_cut, 0], [-1, -1, self.dilation_channels])
        skip_output = self.conv_skip(out_skip)

        transformed = self.conv_residual(out)
        input_cut = tf.shape(x)[1] - tf.shape(transformed)[1]
        x_cut = tf.slice(x, [0, input_cut, 0], [-1, -1, -1])
        dense_output = x_cut + transformed

        return skip_output, dense_output

In [511]:
class PostProcessing(keras.Model):
    def __init__(self, skip_channels, out_channels, use_biases):
        super().__init__()

        self.skip_channels = skip_channels
        self.out_channels = out_channels
        self.use_biases = use_biases

    def build(self, input_shape):
        self.conv_1 = keras.layers.Conv1D(
            filters=self.skip_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="postprocessing/conv_1"
        )
        self.conv_2 = keras.layers.Conv1D(
            filters=self.out_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="postprocessing/conv_2"
        )
    
    @tf.function
    def call(self, inputs, training=False):
        x = tf.nn.relu(inputs)
        x = self.conv_1(x)

        x = tf.nn.relu(x)
        x = self.conv_2(x)

        return x

In [531]:
class WaveNet(keras.Model):
    def __init__(self, batch_size, dilations, filter_width, initial_filter_width, dilation_channels, residual_channels, skip_channels, quantization_channels, out_channels, use_biases):
        super().__init__()

        self.batch_size = batch_size
        self.dilations = dilations
        self.filter_width = filter_width
        self.initial_filter_width = initial_filter_width
        self.dilation_channels = dilation_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.quantization_channels = quantization_channels
        self.out_channels = out_channels
        self.use_biases = use_biases

        self.receptive_field = (self.filter_width - 1) * sum(self.dilations) + self.initial_filter_width

    def build(self, input_shape):
        #self.receptive_field = input_shape[1] - sum(self.dilations)      
        self.output_width = input_shape[1] - self.receptive_field + 1       # total output width of model

        self.preprocessing_layer = keras.layers.Conv1D(
            filters=self.residual_channels,
            kernel_size=self.initial_filter_width,
            use_bias=self.use_biases,
            name="preprocessing/conv")

        self.residual_blocks = []
        for _ in range(1):
            for i, dilation in enumerate(self.dilations):
                self.residual_blocks.append(
                    ResidualBlock(
                        layer_index=i,
                        dilation=self.dilations[0], 
                        filter_width=self.filter_width, 
                        dilation_channels=self.dilation_channels, 
                        residual_channels=self.residual_channels, 
                        skip_channels=self.skip_channels, 
                        use_biases=self.use_biases, 
                        output_width=self.output_width)
                    )

        self.postprocessing_layer = PostProcessing(self.skip_channels, self.out_channels, self.use_biases)

    @tf.function
    def call(self, inputs, training=False):
        '''
        == predict_proba_incremental

        Assume that x is integer (== scalar_input = True)
        '''

        x = self.preprocessing_layer(inputs)
        skip_outputs = []

        for layer_index in range(len(self.dilations)):
            skip_output, x = self.residual_blocks[layer_index](x, training=training)
            skip_outputs.append(skip_output)

        skip_sum = tf.math.add_n(skip_outputs)          
        
        output = self.postprocessing_layer(skip_sum)

        if not training:
            out = tf.reshape(output, [self.batch_size, -1, self.out_channels])
            #output = sample_from_discretized_mix_logistic(out)

        return output

    @tf.function
    def train_step(self, data): 
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)            
            loss = self.compiled_loss(y, y_pred)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

In [532]:
class DiscretizedMixLogisticLoss(keras.losses.Loss):
    def __init__(self, name="discretized_mix_logistic_loss"):
        super().__init__(name=name)

    def call(self, y_true, y_pred):
        loss =  discretized_mix_logistic_loss(y_pred, y_true)
        return tf.reduce_mean(loss)

In [559]:
x = tf.random.uniform([1, 1024, 1], minval=0, maxval=255, dtype=tf.int32)
x = tf.cast(x, tf.float32)
x

<tf.Tensor: shape=(1, 1024, 1), dtype=float32, numpy=
array([[[109.],
        [184.],
        [ 55.],
        ...,
        [186.],
        [162.],
        [243.]]], dtype=float32)>

In [560]:
# HParms follows the Diagram 
batch_size = 1
dilations = [1, 2, 4, 8, 16, 32, 64, 128]
filter_width = 2        # == kernel_size
initial_filter_width = 32       # from (tacokr)
dilation_channels = 32  # unknown
residual_channels = 24
skip_channels = 128
quantization_channels = 2**8
out_channels = 10*3
use_biases = False

#wavenet.receptive_field = 287
#wavenet.output_width = 738

In [561]:
wavenet = WaveNet(batch_size, dilations, filter_width, initial_filter_width, dilation_channels, residual_channels, skip_channels, quantization_channels, out_channels, use_biases)

In [572]:
pred_training = wavenet(x, training=True)
pred_training

<tf.Tensor: shape=(1, 738, 30), dtype=float32, numpy=
array([[[ 31.044838  , -18.3206    , -57.86234   , ...,  -1.4861172 ,
          -0.27262142,  18.613617  ],
        [ 37.923584  , -21.745275  , -71.82771   , ...,  -1.5249838 ,
           0.10239481,  22.977215  ],
        [ 37.41273   , -19.714155  , -66.07158   , ...,  -2.2198381 ,
           1.0691111 ,  21.501781  ],
        ...,
        [ 38.902298  , -21.019793  , -70.890495  , ...,  -2.6501193 ,
           0.61166775,  22.990387  ],
        [ 37.43091   , -21.598278  , -71.11528   , ...,  -0.99064434,
           0.27957615,  22.744024  ],
        [ 38.454304  , -21.674555  , -71.42498   , ...,  -1.9789606 ,
           0.2171533 ,  23.155088  ]]], dtype=float32)>

In [538]:
#[tv.name for tv in wavenet.trainable_variables]

In [563]:
wavenet.compile(keras.optimizers.Nadam(), loss=DiscretizedMixLogisticLoss(), metrics=["accuracy"])

In [573]:
loss_fn = DiscretizedMixLogisticLoss()
loss_fn(x[:, (1024-738):, :], pred_training)

<tf.Tensor: shape=(), dtype=float32, numpy=6.358543>

In [569]:
#wavenet.fit(x=x, y=x)
wavenet.fit(x, x[:, (1024-738):, :], batch_size=8, epochs=5000)

racy: 0.0054
Epoch 4804/5000
Epoch 4805/5000
Epoch 4806/5000
Epoch 4807/5000
Epoch 4808/5000
Epoch 4809/5000
Epoch 4810/5000
Epoch 4811/5000
Epoch 4812/5000
Epoch 4813/5000
Epoch 4814/5000
Epoch 4815/5000
Epoch 4816/5000
Epoch 4817/5000
Epoch 4818/5000
Epoch 4819/5000
Epoch 4820/5000
Epoch 4821/5000
Epoch 4822/5000
Epoch 4823/5000
Epoch 4824/5000
Epoch 4825/5000
Epoch 4826/5000
Epoch 4827/5000
Epoch 4828/5000
Epoch 4829/5000
Epoch 4830/5000
Epoch 4831/5000
Epoch 4832/5000
Epoch 4833/5000
Epoch 4834/5000
Epoch 4835/5000
Epoch 4836/5000
Epoch 4837/5000
Epoch 4838/5000
Epoch 4839/5000
Epoch 4840/5000
Epoch 4841/5000
Epoch 4842/5000
Epoch 4843/5000
Epoch 4844/5000
Epoch 4845/5000
Epoch 4846/5000
Epoch 4847/5000
Epoch 4848/5000
Epoch 4849/5000
Epoch 4850/5000
Epoch 4851/5000
Epoch 4852/5000
Epoch 4853/5000
Epoch 4854/5000
Epoch 4855/5000
Epoch 4856/5000
Epoch 4857/5000
Epoch 4858/5000
Epoch 4859/5000
Epoch 4860/5000
Epoch 4861/5000
Epoch 4862/5000
Epoch 4863/5000
Epoch 4864/5000
Epoch 4865/

<tensorflow.python.keras.callbacks.History at 0x25e748f6708>

In [450]:
#wavenet(x)

In [587]:
#tf.argmin(wavenet(x), axis=1)
#wavenet(x[:, :287, :])
wavenet(x)

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.]], dtype=float32)>