In [1]:
import tensorflow as tf 
import keras as k
from keras import ops

In [14]:
@k.saving.register_keras_serializable(name="RMSNormalization")
class RMSNormalization(k.layers.Layer):
    # from keras-nlp website: 
    # https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/src/models/gemma/rms_normalization.py
    # guide: https://keras.io/guides/serialization_and_saving/#config_methods
    
    def __init__(self, epsilon=1e-6, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        
    def get_config(self):
        config = super().get_config()
        config.update({"epsilon": self.epsilon})
        return config
    
    # @classmethod
    # def from_config(cls, config):
    #     # Note that you can also use [`keras.saving.deserialize_keras_object`](/api/models/model_saving_apis/serialization_utils#deserializekerasobject-function) here
    #     return cls(**config)

    def build(self, input_shape):
        self.scale = self.add_weight(
            name="scale",
            trainable=True,
            shape=(input_shape[-1],),
            initializer="zeros",
        )
        self.built = True

    def call(self, x):
        # Always compute normalization in float32.
        x = ops.cast(x, "float32")
        scale = ops.cast(self.scale, "float32")
        var = ops.mean(ops.square(x), axis=-1, keepdims=True)
        normed_inputs = x * ops.reciprocal(ops.sqrt(var + self.epsilon))
        normed_inputs = normed_inputs * (1 + scale)
        return ops.cast(normed_inputs, self.compute_dtype)
    
def agent():
    x = k.Input([210, 160, 3])
    conv1 = k.layers.Conv2D(32, [7, 7], strides=[
                            2, 2], padding="SAME", kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(x)
    conv1 = k.layers.LayerNormalization()(conv1)
    # conv1 = RMSNormalization()(conv1)
    conv2 = k.layers.Conv2D(64, [5, 5], strides=[
                            2, 2], padding="SAME", kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(conv1)
    conv2 = k.layers.LayerNormalization()(conv2)
    # conv2 = RMSNormalization()(conv2)
    conv3 = k.layers.Conv2D(128, [5, 5], strides=[
                            2, 2], padding="SAME", kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(conv2)
    conv3 = k.layers.LayerNormalization()(conv3)
    # conv3 = RMSNormalization()(conv3)
    conv4 = k.layers.Conv2D(256, [3, 3], strides=[
                            2, 2], padding="SAME", kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(conv3)
    conv4 = k.layers.LayerNormalization()(conv4)
    # conv4 = RMSNormalization()(conv4)
    conv5 = k.layers.Conv2D(512, [3, 3], strides=[
                            1, 1], padding="SAME", kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(conv4)
    f0 = k.layers.Flatten()(conv5)
    # f0 =  k.layers.LayerNormalization(rms_scaling=True)(f0)
    f0 = RMSNormalization()(f0)
    f1 = k.layers.Dense(1024, kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(f0)
    # f1 = k.layers.LayerNormalization(rms_scaling=True)(f1)
    f1 = RMSNormalization()(f1)
    f2 = k.layers.Dense(1024, kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(f1)
    # f2 = k.layers.LayerNormalization(rms_scaling=True)(f2)
    f2 = RMSNormalization()(f2)
    f3 = k.layers.Dense(1024, kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(f2)
    # f3 = k.layers.LayerNormalization(rms_scaling=True)(f3)
    f3 = RMSNormalization()(f3)
    f4 = k.layers.Dense(1024, kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(f3)
    # f4 = k.layers.LayerNormalization(rms_scaling=True)(f4)
    f4 = RMSNormalization()(f4)
    f5 = k.layers.Dense(1024, kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(f4)
    # f5 = k.layers.LayerNormalization(rms_scaling=True)(f5)
    f5 = RMSNormalization()(f5)
    f6 = k.layers.Dense(1024, kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(f5)
    # f6 = k.layers.LayerNormalization(rms_scaling=True)(f6)
    f6 = RMSNormalization()(f6)
    f7 = k.layers.Dense(1024, kernel_regularizer=k.regularizers.L2(1e-4), activation=k.activations.mish)(f6)
    # f7 = k.layers.LayerNormalization(rms_scaling=True)(f7)
    f7 = RMSNormalization()(f7)
    out = k.layers.Dense(6, k.activations.softmax)(f7)

    return k.Model(x, out)

In [15]:
ag = agent()
ag.save("ag.keras")

In [16]:
reload_ag = k.saving.load_model("ag.keras")