In [30]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Input, InputLayer, Conv2D, Flatten, Dense, Reshape, Conv2DTranspose
from tqdm import tqdm
import numpy as np

print(tf.executing_eagerly())

fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Normaliser les données
train_images = train_images / 255.0
test_images = test_images / 255.0

True


In [34]:
class VAE_Huber_energy:

    def __init__(self, input_dim, latent_dim = 8, encoder = None, decoder = None, optimizer = None):
        self.latent_dim = latent_dim

        if encoder is None:
            self.encoder = tf.keras.Sequential([
                InputLayer(input_shape=input_dim),
                Conv2D(32, (3, 3), activation='relu'),
                Conv2D(64, (3, 3), activation='relu'),
                Flatten(),
                Dense(128, activation='relu'),
                Dense(latent_dim, activation=None)
            ])
        else:
            self.encoder = encoder

        if decoder is None:
            self.decoder = tf.keras.Sequential([
                InputLayer(input_shape=latent_dim),
                Dense(256, activation='relu'),
                Dense(1024, activation='relu'),
                Dense(8*8*8, activation='relu'),
                Reshape((8,8,8)),
                Conv2DTranspose(16, (8, 8), activation='relu'),
                Conv2DTranspose(4, (8, 8), activation='relu'),
                Conv2DTranspose(1, (7, 7)),
            ])
            print(self.decoder.output_shape)
        else:
            self.decoder = decoder

        if optimizer is None:
            self.optimizer = tf.compat.v1.train.AdamOptimizer(
                learning_rate=0.001,
                beta1=0.9,
                beta2=0.999,
                epsilon=1e-08,
                use_locking=False,
                name='Adam'
            )
        else:
            self.optimizer = optimizer

    def forward(self, input):
        return self.decoder(self.encoder(input))

    def RS_distance_to_N01(self, Z):
        Nf = tf.stop_gradient(tf.cast(tf.shape(Z)[0], tf.float32))#batch size as float
        Df = tf.stop_gradient(tf.cast(tf.shape(Z)[1], tf.float32))#dimension as float

        # Term 1: c0 and c1
        f0 = tf.stop_gradient((tf.sqrt(2.)-1.)*tf.exp(tf.math.lgamma(Df/2.0+0.5)-tf.math.lgamma(Df/2.0)))
        ddf0 = tf.stop_gradient(tf.exp(tf.math.lgamma(.5+Df/2.)-tf.math.lgamma(1.+Df/2.))/tf.sqrt(2.))

        c0 = f0 - 1./ddf0
        c1 = 1/(ddf0**2)

        # Term 2
        term2 = tf.reduce_mean(tf.sqrt(tf.square(tf.math.reduce_euclidean_norm(Z, 1)) + c1))

        # Term 3
        _1 = tf.tile(tf.expand_dims(Z, 2), [1,1,Df])
        _2 = tf.math.reduce_euclidean_norm(_1 - tf.transpose(_1, perm = [0,2,1]), axis=0)
        term3 = tf.reduce_mean(_2)/2

        RS_loss = c0 + term2 - term3
        return RS_loss
        #set diag = 1 because of the singularity of sqrt'(x) in zero. This is a numerical
        #issue due to low precision of the GPU computations; it is compensated
        #exactly in the return value
        #distZZ = tf.matrix_set_diag(distZZ,tf.ones([ tf.stop_gradient(tf.shape(Z)[0]) ]),name=None)

        #smalleps=1e-4
        #return c0+tf.reduce_mean(tf.sqrt(dist_real_Z+c1)) -0.5*tf.reduce_mean(tf.sqrt(distZZ+1e-6))+0.5/Nf+2.*tf.sqrt(smalleps)
        #Note: 0.5/Nf has been added in order to compensate the "1" set on the diagonal
        #if necessary to apply the log in the cost, add a small constant for smoothness near zero

    def train(self, tensor_input_x, epochs = 100, batch_size = 10):
        rng = np.random.default_rng()
        epoch = 0

        for epoch in tqdm(range(epochs)):

            batch = rng.integers(low = 0, high = len(tensor_input_x), size = batch_size)
            batch_input = tensor_input_x[batch]

            with tf.GradientTape() as tape:
                encoded_batch = self.encoder(batch_input)
                decoded_batch = self.decoder(encoded_batch)

                flatten_input = tf.cast(tf.reshape(batch_input, [batch_size, -1]), 'float32')
                flatten_output = tf.reshape(decoded_batch, [batch_size, -1])

                _1 = flatten_input - flatten_output
                _2 = tf.square(_1)
                _3 = tf.reduce_sum(_2, axis=1)
                _4 = tf.math.sqrt(_3)

                rec_term = tf.cast(tf.reduce_mean(_4), dtype = 'float32')

                dist_to_N01_term = self.RS_distance_to_N01(encoded_batch)
                lambda_factor = 100.0

                cost_function = rec_term + lambda_factor * dist_to_N01_term

                gradients = tape.gradient(cost_function, self.encoder.trainable_weights + self.decoder.trainable_weights)
                print(gradients)
                self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_weights + self.decoder.trainable_weights))

            epoch += 1
            if epoch % 10 == 0:
                print(f"\nEpoch {epoch}, Reconstruction Error: {rec_term.numpy()}, RS Distance: {dist_to_N01_term.numpy()}")

input_dim = train_images.shape
print(input_dim)
model = VAE_Huber_energy((28, 28, 1))
model.train(train_images)


(60000, 28, 28)
(None, 28, 28, 1)


  1%|          | 1/100 [00:00<00:19,  5.05it/s]

[<tf.Tensor: shape=(3, 3, 1, 32), dtype=float32, numpy=
array([[[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]]],


       [[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, 

  4%|▍         | 4/100 [00:00<00:15,  6.39it/s]

[<tf.Tensor: shape=(3, 3, 1, 32), dtype=float32, numpy=
array([[[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]]],


       [[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, 

  6%|▌         | 6/100 [00:00<00:13,  6.77it/s]

[<tf.Tensor: shape=(3, 3, 1, 32), dtype=float32, numpy=
array([[[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]]],


       [[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, 

  8%|▊         | 8/100 [00:01<00:13,  6.79it/s]

[<tf.Tensor: shape=(3, 3, 1, 32), dtype=float32, numpy=
array([[[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]]],


       [[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, 

 10%|█         | 10/100 [00:01<00:12,  7.37it/s]

[<tf.Tensor: shape=(3, 3, 1, 32), dtype=float32, numpy=
array([[[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]]],


       [[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, 

 12%|█▏        | 12/100 [00:01<00:11,  7.47it/s]

[<tf.Tensor: shape=(3, 3, 1, 32), dtype=float32, numpy=
array([[[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]]],


       [[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, 

 14%|█▍        | 14/100 [00:01<00:11,  7.70it/s]

[<tf.Tensor: shape=(3, 3, 1, 32), dtype=float32, numpy=
array([[[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]]],


       [[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, 

 15%|█▌        | 15/100 [00:02<00:12,  6.74it/s]

[<tf.Tensor: shape=(3, 3, 1, 32), dtype=float32, numpy=
array([[[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]]],


       [[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, 




KeyboardInterrupt: 