In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_percentage_error
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN, Input, Activation, Dropout, Add, LSTM, GRU, RNN, LayerNormalization, BatchNormalization, Conv1D, MaxPooling1D, Flatten, Layer
from keras import backend as K
from keras.optimizers import Adam,SGD
import tensorflow as tf
from keras import Model, regularizers, activations
from keras.constraints import Constraint
import pickle
from keras.utils import to_categorical, plot_model
from keras.datasets import mnist

In [2]:
DEFAULT_BETA_BJORCK = 0.5
DEFAULT_EPS_SPECTRAL = 1e-3
DEFAULT_EPS_BJORCK = 1e-3
DEFAULT_MAXITER_BJORCK = 15
DEFAULT_MAXITER_SPECTRAL = 10
SWAP_MEMORY = True
STOP_GRAD_SPECTRAL = True

def reshaped_kernel_orthogonalization(
    kernel,
    u,
    adjustment_coef,
    eps_spectral=DEFAULT_EPS_SPECTRAL,
    eps_bjorck=DEFAULT_EPS_BJORCK,
    beta=DEFAULT_BETA_BJORCK,
    maxiter_spectral=DEFAULT_MAXITER_SPECTRAL,
    maxiter_bjorck=DEFAULT_MAXITER_BJORCK,
):
    
    """
    Perform reshaped kernel orthogonalization (RKO) to the kernel given as input. It
    apply the power method to find the largest singular value and apply the Bjorck
    algorithm to the rescaled kernel. This greatly improve the stability and and
    speed convergence of the bjorck algorithm.

    Args:
        kernel (tf.Tensor): the kernel to orthogonalize
        u (tf.Tensor): the vector used to do the power iteration method
        adjustment_coef (float): the adjustment coefficient as used in convolution
        eps_spectral (float): stopping criterion in spectral algorithm
        eps_bjorck (float): stopping criterion in bjorck algorithm
        beta (float): the beta used in the bjorck algorithm
        maxiter_spectral (int): maximum number of iterations for the power iteration
        maxiter_bjorck (int): maximum number of iterations for bjorck algorithm

    Returns:
        tf.Tensor: the orthogonalized kernel, the new u, and sigma which is the largest
            singular value
    
    Reference:
        Serrurier, M., Mamalet, F., González-Sanz, A., Boissin, T., Loubes, J. M., & Del Barrio, E. (2021). 
        Achieving robustness in classification using optimal transport with hinge regularization. 
        In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 505-514).

    """
    W_shape = kernel.shape
    # Flatten the Tensor
    W_reshaped = tf.reshape(kernel, [-1, W_shape[-1]])
    W_bar, u, sigma = spectral_normalization(
        W_reshaped, u, eps=eps_spectral, maxiter=maxiter_spectral
    )
    if (eps_bjorck is not None) and (beta is not None):
        W_bar = bjorck_normalization(
            W_bar, eps=eps_bjorck, beta=beta, maxiter=maxiter_bjorck
        )
    W_bar = W_bar * adjustment_coef
    W_bar = K.reshape(W_bar, kernel.shape)
    return W_bar, u, sigma


def _wwtw(w):
    if w.shape[0] > w.shape[1]:
        return w @ (tf.transpose(w) @ w)
    else:
        return (w @ tf.transpose(w)) @ w


def bjorck_normalization(
    w, eps=DEFAULT_EPS_BJORCK, beta=DEFAULT_BETA_BJORCK, maxiter=DEFAULT_MAXITER_BJORCK
):
    """
    apply Bjorck normalization on w.

    Args:
        w (tf.Tensor): weight to normalize, in order to work properly, we must have
            max_eigenval(w) ~= 1
        eps (float): epsilon stopping criterion: norm(wt - wt-1) must be less than eps
        beta (float): beta used in each iteration, must be in the interval ]0, 0.5]
        maxiter (int): maximum number of iterations for the algorithm

    Returns:
        tf.Tensor: the orthonormal weights
    
    Reference:
        Serrurier, M., Mamalet, F., González-Sanz, A., Boissin, T., Loubes, J. M., & Del Barrio, E. (2021). 
        Achieving robustness in classification using optimal transport with hinge regularization. 
        In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 505-514).

    """
    # create a fake old_w that does'nt pass the loop condition
    # it won't affect computation as the first action done in the loop overwrite it.
    old_w = 10 * w
    # define the loop condition

    def cond(w, old_w):
        return tf.linalg.norm(w - old_w) >= eps

    # define the loop body
    def body(w, old_w):
        old_w = w
        w = (1 + beta) * w - beta * _wwtw(w)
        return w, old_w

    # apply the loop
    w, old_w = tf.while_loop(
        cond,
        body,
        (w, old_w),
        parallel_iterations=30,
        maximum_iterations=maxiter,
        swap_memory=SWAP_MEMORY,
    )
    return w


def _power_iteration(
    linear_operator,
    adjoint_operator,
    u,
    eps=DEFAULT_EPS_SPECTRAL,
    maxiter=DEFAULT_MAXITER_SPECTRAL,
    axis=None,
):
    """Internal function that performs the power iteration algorithm to estimate the
    largest singular vector of a linear operator.

    Args:
        linear_operator (Callable): a callable object that maps a linear operation.
        adjoint_operator (Callable): a callable object that maps the adjoint of the
            linear operator.
        u (tf.Tensor): initialization of the singular vector.
        eps (float, optional): stopping criterion of the algorithm, when
            norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL.
        maxiter (int, optional): maximum number of iterations for the algorithm.
            Defaults to DEFAULT_MAXITER_SPECTRAL.
        axis (int/list, optional): dimension along which to normalize. Can be set for
            depthwise convolution for example. Defaults to None.

    Returns:
        tf.Tensor: the maximum singular vector.
        
    Reference:
        Serrurier, M., Mamalet, F., González-Sanz, A., Boissin, T., Loubes, J. M., & Del Barrio, E. (2021). 
        Achieving robustness in classification using optimal transport with hinge regularization. 
        In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 505-514).
        
    """

    # Prepare while loop variables
    u = tf.math.l2_normalize(u, axis=axis)
    # create a fake old_w that doesn't pass the loop condition, it will be overwritten
    old_u = u + 2 * eps

    # Loop body
    def body(u, old_u):
        old_u = u
        v = linear_operator(u)
        u = adjoint_operator(v)

        u = tf.math.l2_normalize(u, axis=axis)

        return u, old_u

    # Loop stopping condition
    def cond(u, old_u):
        return tf.linalg.norm(u - old_u) >= eps

    # Run the while loop
    u, _ = tf.while_loop(
        cond,
        body,
        (u, old_u),
        maximum_iterations=maxiter,
        swap_memory=SWAP_MEMORY,
    )

    # Prevent gradient to back-propagate into the while loop
    if STOP_GRAD_SPECTRAL:
        u = tf.stop_gradient(u)

    return u


def spectral_normalization(
    kernel, u, eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL
):
    """
    Normalize the kernel to have its maximum singular value equal to 1.

    Args:
        kernel (tf.Tensor): the kernel to normalize, assuming a 2D kernel.
        u (tf.Tensor): initialization of the maximum singular vector.
        eps (float, optional): stopping criterion of the algorithm, when
            norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL.
        maxiter (int, optional): maximum number of iterations for the algorithm.
            Defaults to DEFAULT_MAXITER_SPECTRAL.

    Returns:
        the normalized kernel, the maximum singular vector, and the maximum singular
            value.

    Reference:
        Serrurier, M., Mamalet, F., González-Sanz, A., Boissin, T., Loubes, J. M., & Del Barrio, E. (2021). 
        Achieving robustness in classification using optimal transport with hinge regularization. 
        In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 505-514).
        
    """

    if u is None:
        u = tf.random.uniform(
            shape=(1, kernel.shape[-1]), minval=0.0, maxval=1.0, dtype=kernel.dtype
        )

    def linear_op(u):
        return u @ tf.transpose(kernel)

    def adjoint_op(v):
        return v @ kernel

    u = _power_iteration(linear_op, adjoint_op, u, eps, maxiter)

    # Compute the largest singular value and the normalized kernel.
    # We assume that in the worst case we converged to sigma + eps (as u and v are
    # normalized after each iteration)
    # In order to be sure that operator norm of normalized kernel is strictly less than
    # one we use sigma + eps, which ensures stability of Björck algorithm even when
    # beta=0.5
    sigma = tf.reshape(tf.norm(linear_op(u)), (1, 1))
    normalized_kernel = kernel / (sigma + eps)
    return normalized_kernel, u, sigma

In [3]:
class SpectralConstraint(Constraint):
    def __init__(
        self,
        k_coef_lip=1.0,
        eps_spectral=DEFAULT_EPS_SPECTRAL,
        eps_bjorck=DEFAULT_EPS_BJORCK,
        beta_bjorck=DEFAULT_BETA_BJORCK,
        u=None,
    ) -> None:
        
        """
        Ensure that *all* singular values of the weight matrix equals to 1. Computation
        based on Bjorck algorithm. The computation is done in two steps:

        1. reduce the larget singular value to k_coef_lip, using iterate power method.
        2. increase other singular values to k_coef_lip, using bjorck algorithm.

        Args:
            k_coef_lip (float): lipschitz coefficient of the weight matrix
            eps_spectral (float): stopping criterion for the iterative power algorithm.
            eps_bjorck (float): stopping criterion Bjorck algorithm.
            beta_bjorck (float): beta parameter in bjorck algorithm.
            u (tf.Tensor): vector used for iterated power method, can be set to None
                (used for serialization/deserialization purposes).
                
        Reference:
            Serrurier, M., Mamalet, F., González-Sanz, A., Boissin, T., Loubes, J. M., & Del Barrio, E. (2021). 
            Achieving robustness in classification using optimal transport with hinge regularization. 
            In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 505-514).
            
        """
        self.eps_spectral = eps_spectral
        self.eps_bjorck = eps_bjorck
        self.beta_bjorck = beta_bjorck
        self.k_coef_lip = k_coef_lip
        if not (isinstance(u, tf.Tensor) or (u is None)):
            u = tf.convert_to_tensor(u)
        self.u = u
        super(SpectralConstraint, self).__init__()

    def __call__(self, w):
        # make the largest singular value of W to be 1
        wbar, _, _ = reshaped_kernel_orthogonalization(
            w,
            self.u,
            self.k_coef_lip,
            self.eps_spectral,
            self.eps_bjorck,
            self.beta_bjorck,
        )

        # clip to ensure non-negative weight
        wbar = K.clip(wbar, 0, wbar)
        return wbar

    def get_config(self):
        config = {
            "k_coef_lip": self.k_coef_lip,
            "eps_spectral": self.eps_spectral,
            "eps_bjorck": self.eps_bjorck,
            "beta_bjorck": self.beta_bjorck,
            "u": None if self.u is None else self.u.numpy(),
        }
        base_config = super(SpectralConstraint, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [4]:
def add_salt_and_pepper_noise(images, amount):
    noisy_images = images.copy()
    num_images = images.shape[0]
    num_pixels = images.shape[1] * images.shape[2]
    num_salt = np.ceil(amount * num_pixels)
    for i in range(num_images):
        # Add salt noise
        salt_indices = np.random.choice(num_pixels, size=int(num_salt), replace=False)
        noisy_images[i].flat[salt_indices] = 1
        
        # Add pepper noise
        pepper_indices = np.random.choice(num_pixels, size=int(num_salt), replace=False)
        noisy_images[i].flat[pepper_indices] = 0
    return noisy_images

In [5]:
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation, SimpleRNN
from keras.utils import to_categorical, plot_model
from keras.datasets import mnist, cifar10

tf.random.set_seed(42)

training_acc = [0]*10
test_acc = [0]*10
noise = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
for i in range(10):
    # load mnist dataset
    (X_train, y_train), (X_test, y_test) = mnist.load_data()

    # compute the number of labels
    num_labels = len(np.unique(y_train))

    # convert to one-hot vector
    y_train = to_categorical(y_train)
    y_test = to_categorical(y_test)

    # resize and normalize
    image_size = X_train.shape[1]
    X_train = np.reshape(X_train,[-1, image_size, image_size])
    X_test = np.reshape(X_test,[-1, image_size, image_size])
    X_train = X_train.astype('float32') / 255
    X_test = X_test.astype('float32') / 255
    
    X_train = add_salt_and_pepper_noise(X_train, noise[i])

    input = Input(shape=(X_train.shape[1],X_train.shape[2]))
    x = SimpleRNN(512, activation='relu', return_sequences=True, kernel_constraint=SpectralConstraint())(input)
    x = SimpleRNN(512, activation='relu', return_sequences=False, kernel_constraint=SpectralConstraint())(x)
    x = Dense(num_labels, activation='softmax', kernel_constraint=tf.keras.constraints.NonNeg())(x)
    model = Model(input, x)

    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    history = model.fit(X_train, y_train, epochs=100, batch_size=128, validation_split=0.25, verbose=2)

    training_acc[i] = history.history['val_accuracy'][-1]
    print(training_acc)
    
    loss, acc = model.evaluate(X_test, y_test, batch_size=128)

    test_acc[i] = acc
    print(test_acc)

    name = 'iclrnn_512_'
    name = name + str(noise[i]) + '.h5'
    model.save(name)

model.summary()

Epoch 1/100
352/352 - 163s - loss: 0.8552 - accuracy: 0.7124 - val_loss: 0.3708 - val_accuracy: 0.8922 - 163s/epoch - 462ms/step
Epoch 2/100
352/352 - 119s - loss: 0.2688 - accuracy: 0.9191 - val_loss: 0.1892 - val_accuracy: 0.9419 - 119s/epoch - 339ms/step
Epoch 3/100
352/352 - 121s - loss: 0.1825 - accuracy: 0.9469 - val_loss: 0.2189 - val_accuracy: 0.9353 - 121s/epoch - 343ms/step
Epoch 4/100
352/352 - 121s - loss: 0.1420 - accuracy: 0.9576 - val_loss: 0.1356 - val_accuracy: 0.9622 - 121s/epoch - 343ms/step
Epoch 5/100
352/352 - 121s - loss: 0.1116 - accuracy: 0.9680 - val_loss: 0.1709 - val_accuracy: 0.9513 - 121s/epoch - 344ms/step
Epoch 6/100
352/352 - 120s - loss: 0.1094 - accuracy: 0.9684 - val_loss: 0.1315 - val_accuracy: 0.9624 - 120s/epoch - 340ms/step
Epoch 7/100
352/352 - 121s - loss: 0.1031 - accuracy: 0.9704 - val_loss: 0.1212 - val_accuracy: 0.9630 - 121s/epoch - 344ms/step
Epoch 8/100
352/352 - 122s - loss: 0.0843 - accuracy: 0.9758 - val_loss: 0.1286 - val_accuracy: 0

352/352 - 112s - loss: 0.0292 - accuracy: 0.9922 - val_loss: 0.0891 - val_accuracy: 0.9788 - 112s/epoch - 317ms/step
Epoch 65/100
352/352 - 112s - loss: 0.0247 - accuracy: 0.9928 - val_loss: 0.1004 - val_accuracy: 0.9793 - 112s/epoch - 317ms/step
Epoch 66/100
352/352 - 103s - loss: 0.0245 - accuracy: 0.9934 - val_loss: 0.0896 - val_accuracy: 0.9785 - 103s/epoch - 293ms/step
Epoch 67/100
352/352 - 99s - loss: 0.0207 - accuracy: 0.9937 - val_loss: 0.0687 - val_accuracy: 0.9841 - 99s/epoch - 281ms/step
Epoch 68/100
352/352 - 113s - loss: 0.0180 - accuracy: 0.9952 - val_loss: 0.0734 - val_accuracy: 0.9818 - 113s/epoch - 320ms/step
Epoch 69/100
352/352 - 116s - loss: 0.0250 - accuracy: 0.9936 - val_loss: 0.0746 - val_accuracy: 0.9842 - 116s/epoch - 329ms/step
Epoch 70/100
352/352 - 117s - loss: 0.0190 - accuracy: 0.9947 - val_loss: 0.0908 - val_accuracy: 0.9811 - 117s/epoch - 332ms/step
Epoch 71/100
352/352 - 114s - loss: 0.0207 - accuracy: 0.9941 - val_loss: 0.1467 - val_accuracy: 0.9684 -

  saving_api.save_model(


Epoch 1/100
352/352 - 135s - loss: 1.0357 - accuracy: 0.6516 - val_loss: 0.5469 - val_accuracy: 0.8237 - 135s/epoch - 384ms/step
Epoch 2/100
352/352 - 105s - loss: 0.4278 - accuracy: 0.8653 - val_loss: 0.3762 - val_accuracy: 0.8839 - 105s/epoch - 297ms/step
Epoch 3/100
352/352 - 105s - loss: 0.2982 - accuracy: 0.9080 - val_loss: 0.2785 - val_accuracy: 0.9238 - 105s/epoch - 298ms/step
Epoch 4/100
352/352 - 107s - loss: 0.2171 - accuracy: 0.9341 - val_loss: 0.2098 - val_accuracy: 0.9393 - 107s/epoch - 303ms/step
Epoch 5/100
352/352 - 104s - loss: 0.1858 - accuracy: 0.9454 - val_loss: 0.1745 - val_accuracy: 0.9453 - 104s/epoch - 296ms/step
Epoch 6/100
352/352 - 104s - loss: 0.1655 - accuracy: 0.9506 - val_loss: 0.2026 - val_accuracy: 0.9397 - 104s/epoch - 295ms/step
Epoch 7/100
352/352 - 106s - loss: 0.1426 - accuracy: 0.9575 - val_loss: 0.1954 - val_accuracy: 0.9428 - 106s/epoch - 300ms/step
Epoch 8/100
352/352 - 105s - loss: 0.1344 - accuracy: 0.9599 - val_loss: 0.2391 - val_accuracy: 0

352/352 - 103s - loss: 0.0310 - accuracy: 0.9906 - val_loss: 0.1576 - val_accuracy: 0.9638 - 103s/epoch - 293ms/step
Epoch 65/100
352/352 - 104s - loss: 0.0309 - accuracy: 0.9915 - val_loss: 0.1564 - val_accuracy: 0.9643 - 104s/epoch - 296ms/step
Epoch 66/100
352/352 - 103s - loss: 0.0337 - accuracy: 0.9902 - val_loss: 0.1643 - val_accuracy: 0.9569 - 103s/epoch - 291ms/step
Epoch 67/100
352/352 - 101s - loss: 0.0444 - accuracy: 0.9875 - val_loss: 0.1532 - val_accuracy: 0.9607 - 101s/epoch - 288ms/step
Epoch 68/100
352/352 - 103s - loss: 0.0385 - accuracy: 0.9887 - val_loss: 0.1447 - val_accuracy: 0.9664 - 103s/epoch - 293ms/step
Epoch 69/100
352/352 - 103s - loss: 0.0317 - accuracy: 0.9910 - val_loss: 0.1357 - val_accuracy: 0.9673 - 103s/epoch - 293ms/step
Epoch 70/100
352/352 - 103s - loss: 0.0310 - accuracy: 0.9911 - val_loss: 0.1907 - val_accuracy: 0.9551 - 103s/epoch - 293ms/step
Epoch 71/100
352/352 - 104s - loss: 0.0405 - accuracy: 0.9886 - val_loss: 0.1743 - val_accuracy: 0.9579

Epoch 26/100
352/352 - 104s - loss: 0.0931 - accuracy: 0.9713 - val_loss: 0.2809 - val_accuracy: 0.9209 - 104s/epoch - 295ms/step
Epoch 27/100
352/352 - 105s - loss: 0.0878 - accuracy: 0.9726 - val_loss: 0.2810 - val_accuracy: 0.9176 - 105s/epoch - 298ms/step
Epoch 28/100
352/352 - 104s - loss: 0.0924 - accuracy: 0.9717 - val_loss: 0.3005 - val_accuracy: 0.9143 - 104s/epoch - 295ms/step
Epoch 29/100
352/352 - 103s - loss: 0.0908 - accuracy: 0.9715 - val_loss: 0.2964 - val_accuracy: 0.9210 - 103s/epoch - 293ms/step
Epoch 30/100
352/352 - 102s - loss: 0.0870 - accuracy: 0.9737 - val_loss: 0.2997 - val_accuracy: 0.9180 - 102s/epoch - 291ms/step
Epoch 31/100
352/352 - 102s - loss: 0.0925 - accuracy: 0.9710 - val_loss: 0.3180 - val_accuracy: 0.9155 - 102s/epoch - 290ms/step
Epoch 32/100
352/352 - 103s - loss: 0.0898 - accuracy: 0.9724 - val_loss: 0.3597 - val_accuracy: 0.9245 - 103s/epoch - 294ms/step
Epoch 33/100
352/352 - 103s - loss: 0.0783 - accuracy: 0.9762 - val_loss: 0.3007 - val_acc

Epoch 90/100
352/352 - 98s - loss: 0.0713 - accuracy: 0.9775 - val_loss: 0.3233 - val_accuracy: 0.9232 - 98s/epoch - 279ms/step
Epoch 91/100
352/352 - 97s - loss: 0.0585 - accuracy: 0.9821 - val_loss: 0.3256 - val_accuracy: 0.9197 - 97s/epoch - 276ms/step
Epoch 92/100
352/352 - 100s - loss: 0.0622 - accuracy: 0.9799 - val_loss: 0.3385 - val_accuracy: 0.9229 - 100s/epoch - 283ms/step
Epoch 93/100
352/352 - 100s - loss: 0.0596 - accuracy: 0.9813 - val_loss: 0.3048 - val_accuracy: 0.9263 - 100s/epoch - 283ms/step
Epoch 94/100
352/352 - 98s - loss: 0.0614 - accuracy: 0.9807 - val_loss: 0.3710 - val_accuracy: 0.9196 - 98s/epoch - 280ms/step
Epoch 95/100
352/352 - 99s - loss: 0.0633 - accuracy: 0.9809 - val_loss: 0.3121 - val_accuracy: 0.9253 - 99s/epoch - 280ms/step
Epoch 96/100
352/352 - 99s - loss: 0.0712 - accuracy: 0.9785 - val_loss: 0.3113 - val_accuracy: 0.9225 - 99s/epoch - 282ms/step
Epoch 97/100
352/352 - 98s - loss: 0.0597 - accuracy: 0.9814 - val_loss: 0.3324 - val_accuracy: 0.92

352/352 - 77s - loss: 0.1439 - accuracy: 0.9540 - val_loss: 0.6990 - val_accuracy: 0.8283 - 77s/epoch - 218ms/step
Epoch 53/100
352/352 - 78s - loss: 0.1456 - accuracy: 0.9534 - val_loss: 0.7255 - val_accuracy: 0.8216 - 78s/epoch - 220ms/step
Epoch 54/100
352/352 - 78s - loss: 0.1356 - accuracy: 0.9575 - val_loss: 0.6934 - val_accuracy: 0.8271 - 78s/epoch - 222ms/step
Epoch 55/100
352/352 - 78s - loss: 0.1362 - accuracy: 0.9556 - val_loss: 0.7016 - val_accuracy: 0.8266 - 78s/epoch - 222ms/step
Epoch 56/100
352/352 - 78s - loss: 0.1414 - accuracy: 0.9547 - val_loss: 0.6490 - val_accuracy: 0.8293 - 78s/epoch - 222ms/step
Epoch 57/100
352/352 - 78s - loss: 0.1403 - accuracy: 0.9545 - val_loss: 0.6514 - val_accuracy: 0.8297 - 78s/epoch - 223ms/step
Epoch 58/100
352/352 - 78s - loss: 0.1317 - accuracy: 0.9574 - val_loss: 0.6631 - val_accuracy: 0.8297 - 78s/epoch - 222ms/step
Epoch 59/100
352/352 - 78s - loss: 0.1405 - accuracy: 0.9551 - val_loss: 0.6782 - val_accuracy: 0.8291 - 78s/epoch - 

Epoch 14/100
352/352 - 76s - loss: 0.5848 - accuracy: 0.8033 - val_loss: 1.0174 - val_accuracy: 0.6825 - 76s/epoch - 215ms/step
Epoch 15/100
352/352 - 75s - loss: 0.5484 - accuracy: 0.8186 - val_loss: 0.9748 - val_accuracy: 0.6941 - 75s/epoch - 214ms/step
Epoch 16/100
352/352 - 75s - loss: 0.5338 - accuracy: 0.8230 - val_loss: 0.9550 - val_accuracy: 0.6911 - 75s/epoch - 213ms/step
Epoch 17/100
352/352 - 76s - loss: 0.4984 - accuracy: 0.8344 - val_loss: 0.9757 - val_accuracy: 0.6943 - 76s/epoch - 215ms/step
Epoch 18/100
352/352 - 74s - loss: 0.4764 - accuracy: 0.8416 - val_loss: 1.0557 - val_accuracy: 0.6931 - 74s/epoch - 210ms/step
Epoch 19/100
352/352 - 75s - loss: 0.4532 - accuracy: 0.8498 - val_loss: 0.9967 - val_accuracy: 0.6937 - 75s/epoch - 213ms/step
Epoch 20/100
352/352 - 76s - loss: 0.4261 - accuracy: 0.8588 - val_loss: 1.0154 - val_accuracy: 0.6841 - 76s/epoch - 215ms/step
Epoch 21/100
352/352 - 76s - loss: 0.4092 - accuracy: 0.8658 - val_loss: 1.0430 - val_accuracy: 0.6936 -

352/352 - 75s - loss: 0.2680 - accuracy: 0.9100 - val_loss: 1.2770 - val_accuracy: 0.6910 - 75s/epoch - 214ms/step
Epoch 79/100
352/352 - 75s - loss: 0.2775 - accuracy: 0.9080 - val_loss: 1.2036 - val_accuracy: 0.6913 - 75s/epoch - 214ms/step
Epoch 80/100
352/352 - 72s - loss: 0.2656 - accuracy: 0.9114 - val_loss: 1.1875 - val_accuracy: 0.6907 - 72s/epoch - 206ms/step
Epoch 81/100
352/352 - 76s - loss: 0.2565 - accuracy: 0.9155 - val_loss: 1.2674 - val_accuracy: 0.6929 - 76s/epoch - 215ms/step
Epoch 82/100
352/352 - 77s - loss: 0.2476 - accuracy: 0.9180 - val_loss: 1.2013 - val_accuracy: 0.6847 - 77s/epoch - 219ms/step
Epoch 83/100
352/352 - 76s - loss: 0.2651 - accuracy: 0.9117 - val_loss: 1.3418 - val_accuracy: 0.6962 - 76s/epoch - 216ms/step
Epoch 84/100
352/352 - 76s - loss: 0.2516 - accuracy: 0.9172 - val_loss: 1.2739 - val_accuracy: 0.6912 - 76s/epoch - 216ms/step
Epoch 85/100
352/352 - 76s - loss: 0.2712 - accuracy: 0.9107 - val_loss: 1.1860 - val_accuracy: 0.6872 - 76s/epoch - 

Epoch 40/100
352/352 - 73s - loss: 0.5055 - accuracy: 0.8275 - val_loss: 2.1743 - val_accuracy: 0.4599 - 73s/epoch - 207ms/step
Epoch 41/100
352/352 - 72s - loss: 0.4787 - accuracy: 0.8371 - val_loss: 2.1566 - val_accuracy: 0.4622 - 72s/epoch - 206ms/step
Epoch 42/100
352/352 - 74s - loss: 0.4763 - accuracy: 0.8426 - val_loss: 2.3341 - val_accuracy: 0.4585 - 74s/epoch - 209ms/step
Epoch 43/100
352/352 - 74s - loss: 0.4673 - accuracy: 0.8423 - val_loss: 2.0354 - val_accuracy: 0.4697 - 74s/epoch - 212ms/step
Epoch 44/100
352/352 - 75s - loss: 0.4742 - accuracy: 0.8407 - val_loss: 2.1413 - val_accuracy: 0.4648 - 75s/epoch - 213ms/step
Epoch 45/100
352/352 - 75s - loss: 0.4623 - accuracy: 0.8437 - val_loss: 2.0427 - val_accuracy: 0.4613 - 75s/epoch - 214ms/step
Epoch 46/100
352/352 - 75s - loss: 0.4633 - accuracy: 0.8446 - val_loss: 2.2301 - val_accuracy: 0.4656 - 75s/epoch - 213ms/step
Epoch 47/100
352/352 - 75s - loss: 0.4591 - accuracy: 0.8441 - val_loss: 2.1553 - val_accuracy: 0.4575 -

Epoch 2/100
352/352 - 75s - loss: 2.1347 - accuracy: 0.2200 - val_loss: 2.0284 - val_accuracy: 0.2763 - 75s/epoch - 214ms/step
Epoch 3/100
352/352 - 100s - loss: 2.0206 - accuracy: 0.2775 - val_loss: 1.9679 - val_accuracy: 0.3011 - 100s/epoch - 283ms/step
Epoch 4/100
352/352 - 112s - loss: 1.9496 - accuracy: 0.3098 - val_loss: 1.8953 - val_accuracy: 0.3283 - 112s/epoch - 317ms/step
Epoch 5/100
352/352 - 112s - loss: 1.9170 - accuracy: 0.3230 - val_loss: 1.9199 - val_accuracy: 0.3253 - 112s/epoch - 319ms/step
Epoch 6/100
352/352 - 112s - loss: 1.8849 - accuracy: 0.3364 - val_loss: 1.8870 - val_accuracy: 0.3363 - 112s/epoch - 319ms/step
Epoch 7/100
352/352 - 114s - loss: 1.8497 - accuracy: 0.3459 - val_loss: 1.9322 - val_accuracy: 0.3161 - 114s/epoch - 325ms/step
Epoch 8/100
352/352 - 114s - loss: 1.8175 - accuracy: 0.3599 - val_loss: 1.8916 - val_accuracy: 0.3357 - 114s/epoch - 324ms/step
Epoch 9/100
352/352 - 116s - loss: 1.7675 - accuracy: 0.3795 - val_loss: 1.8985 - val_accuracy: 0.3

352/352 - 184s - loss: 0.6518 - accuracy: 0.7774 - val_loss: 3.3946 - val_accuracy: 0.2640 - 184s/epoch - 524ms/step
Epoch 66/100
352/352 - 223s - loss: 0.6375 - accuracy: 0.7815 - val_loss: 3.2752 - val_accuracy: 0.2718 - 223s/epoch - 633ms/step
Epoch 67/100
352/352 - 211s - loss: 0.6376 - accuracy: 0.7824 - val_loss: 3.2207 - val_accuracy: 0.2690 - 211s/epoch - 600ms/step
Epoch 68/100
352/352 - 208s - loss: 0.6319 - accuracy: 0.7827 - val_loss: 3.4154 - val_accuracy: 0.2701 - 208s/epoch - 590ms/step
Epoch 69/100
352/352 - 207s - loss: 0.6248 - accuracy: 0.7851 - val_loss: 3.4964 - val_accuracy: 0.2666 - 207s/epoch - 589ms/step
Epoch 70/100
352/352 - 208s - loss: 0.6243 - accuracy: 0.7841 - val_loss: 3.3628 - val_accuracy: 0.2720 - 208s/epoch - 589ms/step
Epoch 71/100
352/352 - 209s - loss: 0.6163 - accuracy: 0.7881 - val_loss: 3.5848 - val_accuracy: 0.2685 - 209s/epoch - 593ms/step
Epoch 72/100
352/352 - 208s - loss: 0.6282 - accuracy: 0.7823 - val_loss: 3.4860 - val_accuracy: 0.2683

Epoch 26/100
352/352 - 109s - loss: 1.4582 - accuracy: 0.4858 - val_loss: 2.8046 - val_accuracy: 0.1687 - 109s/epoch - 311ms/step
Epoch 27/100
352/352 - 110s - loss: 1.4481 - accuracy: 0.4904 - val_loss: 2.6185 - val_accuracy: 0.1711 - 110s/epoch - 312ms/step
Epoch 28/100
352/352 - 109s - loss: 1.4331 - accuracy: 0.4958 - val_loss: 2.7609 - val_accuracy: 0.1723 - 109s/epoch - 311ms/step
Epoch 29/100
352/352 - 110s - loss: 1.4159 - accuracy: 0.4995 - val_loss: 2.7192 - val_accuracy: 0.1716 - 110s/epoch - 312ms/step
Epoch 30/100
352/352 - 111s - loss: 1.4031 - accuracy: 0.5050 - val_loss: 2.7751 - val_accuracy: 0.1705 - 111s/epoch - 316ms/step
Epoch 31/100
352/352 - 111s - loss: 1.3828 - accuracy: 0.5141 - val_loss: 2.8312 - val_accuracy: 0.1697 - 111s/epoch - 316ms/step
Epoch 32/100
352/352 - 111s - loss: 1.3764 - accuracy: 0.5168 - val_loss: 2.8035 - val_accuracy: 0.1739 - 111s/epoch - 316ms/step
Epoch 33/100
352/352 - 110s - loss: 1.3629 - accuracy: 0.5217 - val_loss: 2.8205 - val_acc

Epoch 90/100
352/352 - 94s - loss: 1.3793 - accuracy: 0.5246 - val_loss: 2.9803 - val_accuracy: 0.1744 - 94s/epoch - 266ms/step
Epoch 91/100
352/352 - 95s - loss: 1.4042 - accuracy: 0.5143 - val_loss: 2.8130 - val_accuracy: 0.1717 - 95s/epoch - 270ms/step
Epoch 92/100
352/352 - 96s - loss: 1.4079 - accuracy: 0.5140 - val_loss: 2.8063 - val_accuracy: 0.1666 - 96s/epoch - 272ms/step
Epoch 93/100
352/352 - 98s - loss: 1.3980 - accuracy: 0.5146 - val_loss: 2.8841 - val_accuracy: 0.1696 - 98s/epoch - 278ms/step
Epoch 94/100
352/352 - 96s - loss: 1.4238 - accuracy: 0.5035 - val_loss: 2.8953 - val_accuracy: 0.1725 - 96s/epoch - 274ms/step
Epoch 95/100
352/352 - 96s - loss: 1.4207 - accuracy: 0.5072 - val_loss: 2.8021 - val_accuracy: 0.1728 - 96s/epoch - 271ms/step
Epoch 96/100
352/352 - 96s - loss: 1.4122 - accuracy: 0.5095 - val_loss: 2.8969 - val_accuracy: 0.1675 - 96s/epoch - 272ms/step
Epoch 97/100
352/352 - 96s - loss: 1.4262 - accuracy: 0.5050 - val_loss: 2.7070 - val_accuracy: 0.1647 -

Epoch 51/100
352/352 - 85s - loss: 2.1033 - accuracy: 0.2441 - val_loss: 2.4198 - val_accuracy: 0.1083 - 85s/epoch - 242ms/step
Epoch 52/100
352/352 - 85s - loss: 2.1145 - accuracy: 0.2395 - val_loss: 2.4264 - val_accuracy: 0.1070 - 85s/epoch - 240ms/step
Epoch 53/100
352/352 - 85s - loss: 2.1080 - accuracy: 0.2422 - val_loss: 2.4348 - val_accuracy: 0.1093 - 85s/epoch - 241ms/step
Epoch 54/100
352/352 - 86s - loss: 2.0997 - accuracy: 0.2474 - val_loss: 2.4395 - val_accuracy: 0.1092 - 86s/epoch - 244ms/step
Epoch 55/100
352/352 - 90s - loss: 2.1542 - accuracy: 0.2174 - val_loss: 2.3975 - val_accuracy: 0.1063 - 90s/epoch - 256ms/step
Epoch 56/100
352/352 - 94s - loss: 2.1274 - accuracy: 0.2301 - val_loss: 2.4035 - val_accuracy: 0.1072 - 94s/epoch - 266ms/step
Epoch 57/100
352/352 - 88s - loss: 2.1251 - accuracy: 0.2321 - val_loss: 2.4084 - val_accuracy: 0.1097 - 88s/epoch - 249ms/step
Epoch 58/100
352/352 - 114s - loss: 2.1406 - accuracy: 0.2252 - val_loss: 2.4116 - val_accuracy: 0.1098 

Epoch 11/100
352/352 - 168s - loss: 2.2295 - accuracy: 0.1781 - val_loss: 2.3360 - val_accuracy: 0.1034 - 168s/epoch - 478ms/step
Epoch 12/100
352/352 - 162s - loss: 2.2172 - accuracy: 0.1857 - val_loss: 2.3503 - val_accuracy: 0.1024 - 162s/epoch - 461ms/step
Epoch 13/100
352/352 - 159s - loss: 2.2098 - accuracy: 0.1900 - val_loss: 2.3475 - val_accuracy: 0.1027 - 159s/epoch - 452ms/step
Epoch 14/100
352/352 - 159s - loss: 2.1993 - accuracy: 0.1954 - val_loss: 2.3580 - val_accuracy: 0.1040 - 159s/epoch - 451ms/step
Epoch 15/100
352/352 - 157s - loss: 2.1970 - accuracy: 0.1971 - val_loss: 2.3522 - val_accuracy: 0.1017 - 157s/epoch - 446ms/step
Epoch 16/100
352/352 - 162s - loss: 2.1902 - accuracy: 0.2010 - val_loss: 2.3600 - val_accuracy: 0.1001 - 162s/epoch - 460ms/step
Epoch 17/100
352/352 - 159s - loss: 2.1840 - accuracy: 0.2040 - val_loss: 2.3711 - val_accuracy: 0.1017 - 159s/epoch - 452ms/step
Epoch 18/100
352/352 - 166s - loss: 2.1844 - accuracy: 0.2036 - val_loss: 2.3666 - val_acc

352/352 - 154s - loss: 2.2474 - accuracy: 0.1653 - val_loss: 2.3333 - val_accuracy: 0.1027 - 154s/epoch - 436ms/step
Epoch 75/100
352/352 - 150s - loss: 2.2687 - accuracy: 0.1567 - val_loss: 2.3588 - val_accuracy: 0.1013 - 150s/epoch - 426ms/step
Epoch 76/100
352/352 - 155s - loss: 2.2926 - accuracy: 0.1319 - val_loss: 2.3253 - val_accuracy: 0.1023 - 155s/epoch - 441ms/step
Epoch 77/100
352/352 - 156s - loss: 2.2750 - accuracy: 0.1432 - val_loss: 2.3320 - val_accuracy: 0.0994 - 156s/epoch - 444ms/step
Epoch 78/100
352/352 - 152s - loss: 2.2672 - accuracy: 0.1507 - val_loss: 2.3299 - val_accuracy: 0.1010 - 152s/epoch - 430ms/step
Epoch 79/100
352/352 - 152s - loss: 2.2601 - accuracy: 0.1570 - val_loss: 2.3354 - val_accuracy: 0.0979 - 152s/epoch - 432ms/step
Epoch 80/100
352/352 - 152s - loss: 2.2973 - accuracy: 0.1350 - val_loss: 2.3246 - val_accuracy: 0.0995 - 152s/epoch - 432ms/step
Epoch 81/100
352/352 - 151s - loss: 2.2952 - accuracy: 0.1279 - val_loss: 2.3187 - val_accuracy: 0.1025

In [6]:
from tensorflow.python.framework.convert_to_constants import  convert_variables_to_constants_v2_as_graph

def get_flops(model):
    concrete = tf.function(lambda inputs: model(inputs))
    concrete_func = concrete.get_concrete_function(
        [tf.TensorSpec([1, *inputs.shape[1:]]) for inputs in model.inputs])
    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func, lower_control_flow=False)
    with tf.Graph().as_default() as graph:
        tf.graph_util.import_graph_def(graph_def, name='')
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd="op", options=opts)
        return flops.total_float_ops

print("The FLOPs is:{}".format(get_flops(model)) ,flush=True )

Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
The FLOPs is:1613884
