In [156]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [160]:
import tensorflow as tf

class AdvancedCNN(tf.keras.Model):
    def __init__(self, cnn_input_reshape, num_classes):
        super(AdvancedCNN, self).__init__()
        
        self.reshape = tf.keras.layers.Reshape(cnn_input_reshape)
        
        self.conv1 = tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')
        self.conv2 = tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')
        self.max_pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))
        
        self.conv3 = tf.keras.layers.Conv2D(128, kernel_size=3, activation='relu', padding='same')
        self.conv4 = tf.keras.layers.Conv2D(128, kernel_size=3, activation='relu', padding='same')
        self.max_pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))
        
        self.conv5 = tf.keras.layers.Conv2D(256, kernel_size=3, activation='relu', padding='same')
        self.conv6 = tf.keras.layers.Conv2D(256, kernel_size=3, activation='relu', padding='same')
        self.max_pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))

        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(512, activation='relu')
        self.dropout1 = tf.keras.layers.Dropout(0.5)
        self.dense2 = tf.keras.layers.Dense(512, activation='relu')
        self.dropout2 = tf.keras.layers.Dropout(0.5)
        self.dense3 = tf.keras.layers.Dense(num_classes, activation='softmax')

    def call(self, inputs, training=None):
        x = self.reshape(inputs)  # Add a channel dimension
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.max_pool1(x)

        x = self.conv3(x)
        x = self.conv4(x)
        x = self.max_pool2(x)

        x = self.conv5(x)
        x = self.conv6(x)
        x = self.max_pool3(x)

        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dropout1(x, training=training)
        x = self.dense2(x)
        x = self.dropout2(x, training=training)
        x = self.dense3(x)
        return x

    @tf.function
    def step(self, batch):
        x_batch, y_batch = batch

        with tf.GradientTape() as tape:
            # Forward pass: Compute predictions
            y_batch_pred = self(x_batch, training=True)

            # Compute the loss value
            loss = self.loss(y_batch, y_batch_pred)

        # Compute gradients
        gradients = tape.gradient(loss, self.trainable_variables)
        
        # Apply gradients to the model's trainable variables (update weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
       
    def train(self, dataset):
        for batch in dataset:
            self.step(batch)
            
    def set_trainable_variables(self, trainable_vars):
        for model_var, var in zip(self.trainable_variables, trainable_vars):
            model_var.assign(var)

    def trainable_vars_as_vector(self):
        return tf.concat([tf.reshape(var, [-1]) for var in self.trainable_variables], axis=0)

    @tf.function
    def trainable_vars_as_vector1(self):
        print("Retrace trainable_vars_as_vector1")
        return tf.concat([tf.reshape(var, [-1]) for var in self.trainable_variables], axis=0)
    

def get_compiled_and_built_advanced_cnn(cnn_batch_input, cnn_input_reshape, num_classes):
    """
    Compile and build an Advanced CNN model.

    Args:
    - cnn_batch_input (tuple): The shape of the input including batch size (e.g., (None, 28, 28)).
    - cnn_input_reshape (tuple): The shape to which the input should be reshaped (e.g., (28, 28, 1)).
    - num_classes (int): Number of output classes.

    Returns:
    - AdvancedCNN: A compiled and built Advanced CNN model.
    """
    advanced_cnn = AdvancedCNN(cnn_input_reshape, num_classes)
    
    advanced_cnn.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),  # we have softmax
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')]
    )
    
    advanced_cnn.build(cnn_batch_input)
    
    return advanced_cnn

In [168]:
adv = get_compiled_and_built_advanced_cnn((None, 28, 28), (28, 28, 1), 10)

<tf.Variable 'conv2d_94/kernel:0' shape=(3, 3, 1, 64) dtype=float32, numpy=
array([[[[ 6.62436336e-03, -8.48598927e-02,  9.64885205e-02,
          -6.27466589e-02,  2.55306214e-02, -7.81978294e-02,
          -5.68179972e-02, -4.94226888e-02,  5.32500595e-02,
          -4.78790589e-02, -5.79350702e-02,  8.70027095e-02,
           1.47773400e-02, -8.15537274e-02, -9.88330096e-02,
           8.80521238e-02, -4.49408554e-02,  3.05297822e-02,
           8.27989727e-02,  8.73940438e-02,  4.85185385e-02,
          -7.53182992e-02, -2.51235589e-02,  3.39087397e-02,
          -5.78612834e-03,  3.15476358e-02, -1.19907707e-02,
          -1.21367723e-03, -6.53995574e-02,  9.87221599e-02,
          -5.86902946e-02, -9.97259393e-02,  1.48980469e-02,
           5.64086139e-02, -8.67383927e-02,  7.35486150e-02,
          -7.35185593e-02,  5.53694367e-02, -1.34367794e-02,
          -9.35118571e-02, -5.23242354e-03,  7.36542344e-02,
          -1.00085177e-01,  5.25222570e-02,  2.76980549e-02,
         

In [153]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(256).repeat().take(1)

In [154]:
train_dataset

<_TakeDataset element_spec=(TensorSpec(shape=(None, 28, 28), dtype=tf.float64, name=None), TensorSpec(shape=(None,), dtype=tf.uint8, name=None))>

In [144]:
%%timeit 

adv.train(train_dataset)

1.16 s ± 22.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [155]:
%%timeit 

adv.train(train_dataset)

Retrace call
1.1 s ± 19.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [78]:
adv.trainable_vars_as_vector1()
adv.set_trainable_variables1(adv.trainable_variables)

In [48]:
%%timeit

adv.trainable_vars_as_vector()

7.64 ms ± 261 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [49]:
%%timeit

adv.trainable_vars_as_vector1()

2.31 ms ± 51.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [79]:
adv.trainable_vars_as_vector()

<tf.Tensor: shape=(2592202,), dtype=float32, numpy=
array([-0.00058674,  0.04495262, -0.04219111, ...,  0.        ,
        0.        ,  0.        ], dtype=float32)>

In [112]:
adv.set_trainable_variables(adv3.trainable_variables)
adv.trainable_vars_as_vector1()

<tf.Tensor: shape=(2592202,), dtype=float32, numpy=
array([-0.00058674,  0.04495262, -0.04219111, ...,  0.        ,
        0.        ,  0.        ], dtype=float32)>

In [113]:
adv.trainable_vars_as_vector()

<tf.Tensor: shape=(2592202,), dtype=float32, numpy=
array([-0.00058674,  0.04495262, -0.04219111, ...,  0.        ,
        0.        ,  0.        ], dtype=float32)>

In [117]:
adv.set_trainable_variables(adv2.trainable_variables)
adv.trainable_vars_as_vector1()

<tf.Tensor: shape=(2592202,), dtype=float32, numpy=
array([-0.07618969, -0.02637688,  0.08039571, ...,  0.        ,
        0.        ,  0.        ], dtype=float32)>

In [118]:
adv.trainable_vars_as_vector()

<tf.Tensor: shape=(2592202,), dtype=float32, numpy=
array([-0.07618969, -0.02637688,  0.08039571, ...,  0.        ,
        0.        ,  0.        ], dtype=float32)>