In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

In [2]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

In [3]:
from mixtures import discretized_mix_logistic_loss, sample_from_discretized_mix_logistic

In [4]:
class Conv1D(keras.layers.Conv1D):
    def __init__(self, filters, kernel_size, strides=1, padding="causal", dilation_rate=1, use_bias=False, *args, **kwargs):
        super().__init__(filters, kernel_size=kernel_size, strides=strides, padding=padding, dilation_rate=dilation_rate)
        
        ## (issue) Set name other than k and d invoke error : TypeError: unsupported operand type(s) for +: 'int' and 'tuple'
        self.k = kernel_size                
        self.d = dilation_rate

        self.use_bias = use_bias

        if kernel_size > 1:
            self.current_receptive_field = kernel_size + (kernel_size - 1) * (dilation_rate - 1)       # == queue_len (tf2)
            self.residual_channels = residual_channels
            self.queue = tf.zeros([1, self.current_receptive_field, filters])

    def build(self, x_shape):
        super().build(x_shape)

        self.linearized_weights = tf.cast(tf.reshape(self.kernel, [-1, self.filters]), dtype=tf.float32)

    def call(self, x, training=False):
        if not training:
            return super().call(x)

        if self.kernel_size > 1:
            self.queue = self.queue[:, 1:, :]
            self.queue = tf.concat([self.queue, tf.expand_dims(x[:, -1, :], axis=1)], axis=1)

            if self.dilation_rate > 1:
                x = self.queue[:, 0::self.d, :]
            else:
                x = self.queue

            outputs = tf.matmul(tf.reshape(x, [1, -1]), self.linearized_weights)
            
            if self.use_bias:
                outputs = tf.nn.bias_add(outputs, self.bias)

            return tf.reshape(outputs, [-1, 1, self.filters])

    #def init_queue(self):
        

In [5]:
class ResidualBlock(keras.Model):
    def __init__(self, layer_index, dilation, filter_width, dilation_channels, residual_channels, skip_channels, use_biases, output_width):
        super().__init__()

        self.layer_index = layer_index
        self.dilation = dilation
        self.filter_width = filter_width
        self.dilation_channels = dilation_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.use_biases = use_biases
        self.output_width = output_width

    def build(self, input_shape):
        self.conv_filter = keras.layers.Conv1D(
            filters=self.dilation_channels,
            kernel_size=self.filter_width,
            dilation_rate=self.dilation,
            padding='valid',
            use_bias=self.use_biases,
            name="residual_block_{}/conv_filter".format(self.layer_index)
        )
        self.conv_gate = keras.layers.Conv1D(
            filters=self.dilation_channels,
            kernel_size=self.filter_width,
            dilation_rate=self.dilation,
            padding='valid',
            use_bias=self.use_biases,
            name="residual_block_{}/conv_gate".format(self.layer_index)
        )
        ## transformed : 1x1 conv to out (= gate * filter) to produce residuals (= dense output)
        ## conv_residual (=skip_contribution in original)
        self.conv_residual = keras.layers.Conv1D(
            filters=self.residual_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="residual_block_{}/dense".format(self.layer_index)
        )
        self.conv_skip = keras.layers.Conv1D(
            filters=self.skip_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="residual_block_{}/skip".format(self.layer_index)
        )


    @tf.function
    def call(self, inputs, training=False):
        out = tf.tanh(self.conv_filter(inputs)) * tf.sigmoid(self.conv_gate(inputs))
        
        if training:
            skip_cut = tf.shape(out)[1] - self.output_width
        else:
            skip_cut = tf.shape(out)[1] - 1
            
        out_skip = tf.slice(out, [0, skip_cut, 0], [-1, -1, self.dilation_channels])
        skip_output = self.conv_skip(out_skip)

        transformed = self.conv_residual(out)
        input_cut = tf.shape(inputs)[1] - tf.shape(transformed)[1]
        x_cut = tf.slice(inputs, [0, input_cut, 0], [-1, -1, -1])
        dense_output = x_cut + transformed

        return skip_output, dense_output

In [6]:
class PostProcessing(keras.Model):
    def __init__(self, skip_channels, out_channels, use_biases):
        super().__init__()

        self.skip_channels = skip_channels
        self.out_channels = out_channels
        self.use_biases = use_biases

    def build(self, input_shape):
        self.conv_1 = keras.layers.Conv1D(
            filters=self.skip_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="postprocessing/conv_1"
        )
        self.conv_2 = keras.layers.Conv1D(
            filters=self.out_channels,
            kernel_size=1,
            padding="same",
            use_bias=self.use_biases,
            name="postprocessing/conv_2"
        )
    
    @tf.function
    def call(self, inputs, training=False):
        x = tf.nn.relu(inputs)
        x = self.conv_1(x)

        x = tf.nn.relu(x)
        x = self.conv_2(x)

        return x

In [7]:
class WaveNet(keras.Model):
    def __init__(self, batch_size, dilations, filter_width, initial_filter_width, dilation_channels, residual_channels, skip_channels, quantization_channels, out_channels, use_biases):
        super().__init__()

        self.batch_size = batch_size
        self.dilations = dilations
        self.filter_width = filter_width
        self.initial_filter_width = initial_filter_width
        self.dilation_channels = dilation_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.quantization_channels = quantization_channels
        self.out_channels = out_channels
        self.use_biases = use_biases

        self.receptive_field = (self.filter_width - 1) * sum(self.dilations) + self.initial_filter_width

    def build(self, input_shape):     
        self.output_width = input_shape[1] - self.receptive_field + 1       # total output width of model

        self.preprocessing_layer = keras.layers.Conv1D(
            filters=self.residual_channels,
            kernel_size=self.initial_filter_width,
            use_bias=self.use_biases,
            name="preprocessing/conv")

        self.residual_blocks = []
        for _ in range(1):
            for i, dilation in enumerate(self.dilations):
                self.residual_blocks.append(
                    ResidualBlock(
                        layer_index=i,
                        dilation=self.dilations[0], 
                        filter_width=self.filter_width, 
                        dilation_channels=self.dilation_channels, 
                        residual_channels=self.residual_channels, 
                        skip_channels=self.skip_channels, 
                        use_biases=self.use_biases, 
                        output_width=self.output_width)
                    )

        self.postprocessing_layer = PostProcessing(self.skip_channels, self.out_channels, self.use_biases)

    @tf.function(experimental_relax_shapes=True)
    def call(self, inputs, training=False):
        '''
        == predict_proba_incremental

        Assume that x is integer (== scalar_input = True)
        '''

        x = self.preprocessing_layer(inputs)
        skip_outputs = []

        for layer_index in range(len(self.dilations)):
            skip_output, x = self.residual_blocks[layer_index](x, training=training)
            skip_outputs.append(skip_output)
      
        skip_sum = tf.math.add_n(skip_outputs)          
        
        output = self.postprocessing_layer(skip_sum)

        if not training:
            out = tf.reshape(output, [self.batch_size, -1, self.out_channels])
            #output = sample_from_discretized_mix_logistic(out)

        return output

    def train_step(self, data): 
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)            
            loss = self.compiled_loss(y, y_pred)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

In [8]:
class DiscretizedMixLogisticLoss(keras.losses.Loss):
    def __init__(self, name="discretized_mix_logistic_loss"):
        super().__init__(name=name)

    def call(self, y_true, y_pred):
        loss =  discretized_mix_logistic_loss(y_pred, y_true)
        return tf.reduce_mean(loss)

inputs = tf.random.uniform([1, 1025, 1], minval=0, maxval=255, dtype=tf.int32)
inputs = tf.cast(inputs, tf.float32)
inputs

In [9]:
rg = tf.range(255)
inputs = tf.concat([rg, rg, rg, rg, rg], axis=0)[:1025]
inputs = tf.cast(tf.reshape(inputs, [1, 1025, 1]), tf.float32)
inputs

<tf.Tensor: shape=(1, 1025, 1), dtype=float32, numpy=
array([[[0.],
        [1.],
        [2.],
        ...,
        [2.],
        [3.],
        [4.]]], dtype=float32)>

In [10]:
#x_train = inputs[:, :1024, :]
x_train = inputs[:, :287, :]
#y_train = inputs[:, 287:, :]
y_train = inputs[:, 287:288, :]

In [11]:
x_train.shape, y_train.shape

(TensorShape([1, 287, 1]), TensorShape([1, 1, 1]))

In [12]:
# HParms follows the Diagram 
batch_size = 1
dilations = [1, 2, 4, 8, 16, 32, 64, 128]
filter_width = 2        # == kernel_size
initial_filter_width = 32       # from (tacokr)
dilation_channels = 32  # unknown
residual_channels = 24
skip_channels = 128
quantization_channels = 2**8
out_channels = 10*3
use_biases = False

#wavenet.receptive_field = 287
#wavenet.output_width = 738

In [13]:
wavenet = WaveNet(batch_size, dilations, filter_width, initial_filter_width, dilation_channels, residual_channels, skip_channels, quantization_channels, out_channels, use_biases)

In [14]:
wavenet.compile(keras.optimizers.Nadam(), loss=DiscretizedMixLogisticLoss())

In [15]:
#wavenet.fit(x=x, y=x)
wavenet.fit(x_train, y_train, batch_size=8, epochs=1000)

0
Epoch 757/1000
Epoch 758/1000
Epoch 759/1000
Epoch 760/1000
Epoch 761/1000
Epoch 762/1000
Epoch 763/1000
Epoch 764/1000
Epoch 765/1000
Epoch 766/1000
Epoch 767/1000
Epoch 768/1000
Epoch 769/1000
Epoch 770/1000
Epoch 771/1000
Epoch 772/1000
Epoch 773/1000
Epoch 774/1000
Epoch 775/1000
Epoch 776/1000
Epoch 777/1000
Epoch 778/1000
Epoch 779/1000
Epoch 780/1000
Epoch 781/1000
Epoch 782/1000
Epoch 783/1000
Epoch 784/1000
Epoch 785/1000
Epoch 786/1000
Epoch 787/1000
Epoch 788/1000
Epoch 789/1000
Epoch 790/1000
Epoch 791/1000
Epoch 792/1000
Epoch 793/1000
Epoch 794/1000
Epoch 795/1000
Epoch 796/1000
Epoch 797/1000
Epoch 798/1000
Epoch 799/1000
Epoch 800/1000
Epoch 801/1000
Epoch 802/1000
Epoch 803/1000
Epoch 804/1000
Epoch 805/1000
Epoch 806/1000
Epoch 807/1000
Epoch 808/1000
Epoch 809/1000
Epoch 810/1000
Epoch 811/1000
Epoch 812/1000
Epoch 813/1000
Epoch 814/1000
Epoch 815/1000
Epoch 816/1000
Epoch 817/1000
Epoch 818/1000
Epoch 819/1000
Epoch 820/1000
Epoch 821/1000
Epoch 822/1000
Epoch 82

<tensorflow.python.keras.callbacks.History at 0x24a89940348>

In [16]:
pred_training = wavenet(x_train, training=True)
pred_training

<tf.Tensor: shape=(1, 1, 30), dtype=float32, numpy=
array([[[ -4.7823987 ,  -2.524256  ,  -8.152562  ,  -8.037164  ,
          -2.1420841 , -10.622566  ,  -3.7060866 ,  15.436218  ,
          -6.561074  ,  -4.87299   ,   3.2193644 ,   5.4519005 ,
           3.441081  ,  -4.6090403 ,  -3.4138145 ,   6.53657   ,
          10.395798  ,  10.736645  ,  -1.3931888 ,   0.20424128,
          -4.651189  ,   7.193293  ,  -1.5946159 ,  -0.74131685,
          10.254593  ,  -1.7528414 ,   7.6960487 ,  21.846106  ,
          -1.8778284 ,   6.5342507 ]]], dtype=float32)>

In [17]:
loss_fn = DiscretizedMixLogisticLoss()
loss_fn(y_train, pred_training)

<tf.Tensor: shape=(), dtype=float32, numpy=0.6931472>

In [18]:
sample_from_discretized_mix_logistic(wavenet(x_train))

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.]], dtype=float32)>

In [19]:
log_scale_min=float(tf.math.log(1e-14))
y = wavenet(inputs[:, :287, :])
y_true = inputs[:, 287:288, :]

In [20]:
y, y_true

(<tf.Tensor: shape=(1, 1, 30), dtype=float32, numpy=
 array([[[ -4.7823987 ,  -2.524256  ,  -8.152562  ,  -8.037164  ,
           -2.1420841 , -10.622566  ,  -3.7060866 ,  15.436218  ,
           -6.561074  ,  -4.87299   ,   3.2193644 ,   5.4519005 ,
            3.441081  ,  -4.6090403 ,  -3.4138145 ,   6.53657   ,
           10.395798  ,  10.736645  ,  -1.3931888 ,   0.20424128,
           -4.651189  ,   7.193293  ,  -1.5946159 ,  -0.74131685,
           10.254593  ,  -1.7528414 ,   7.6960487 ,  21.846106  ,
           -1.8778284 ,   6.5342507 ]]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1, 1), dtype=float32, numpy=array([[[32.]]], dtype=float32)>)

In [21]:
nr_mix = y.shape[2] // 3
nr_mix

10

In [22]:
logit_probs = y[:, :, :nr_mix]
logit_probs

<tf.Tensor: shape=(1, 1, 10), dtype=float32, numpy=
array([[[ -4.7823987,  -2.524256 ,  -8.152562 ,  -8.037164 ,
          -2.1420841, -10.622566 ,  -3.7060866,  15.436218 ,
          -6.561074 ,  -4.87299  ]]], dtype=float32)>

In [23]:
sel = tf.one_hot(tf.argmax(logit_probs - tf.math.log(-tf.math.log(tf.random.uniform(tf.shape(logit_probs), minval=1e-5, maxval=1. - 1e-5))), 2), depth=nr_mix, dtype=tf.float32)
sel

<tf.Tensor: shape=(1, 1, 10), dtype=float32, numpy=array([[[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]], dtype=float32)>

In [24]:
means = tf.math.reduce_sum(y[:, :, nr_mix:nr_mix * 2] * sel, axis=2)
means

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[10.736645]], dtype=float32)>

In [25]:
log_scales = tf.math.maximum(tf.math.reduce_sum(y[:, :, nr_mix * 2:nr_mix * 3] * sel, axis=2), log_scale_min)
log_scales

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[21.846106]], dtype=float32)>

In [26]:
u = tf.random.uniform(tf.shape(means), minval=1e-5, maxval=1. - 1e-5)
u

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.79140556]], dtype=float32)>

In [27]:
x = means + tf.math.exp(log_scales) * (tf.math.log(u) - tf.math.log(1. - u))
x

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[4.0983555e+09]], dtype=float32)>

In [28]:
x = tf.math.minimum(tf.math.maximum(x, -1.), 1.)
x

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)>