<a href="https://colab.research.google.com/github/jojivk/The-Ramp/blob/master/on_device_embedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf

@tf.keras.utils.register_keras_serializable(package='keras_nlp')
class OnDeviceEmbedding(tf.keras.layers.Layer):

  def __init__(self,
               vocab_size,
               embedding_width,
               initalizer="glorot_uniform",
               use_one_hot=False
               scale_factor=None
               **kwargs):
    super(OnDeviceEmbedding, self).__init__(**kwargs)
    self._vocab_size = vocab_size
    self._embedding_width = embedding_width
    self._initalizer = intializer
    self._use_one_hot = use_one_hot
    self._scale_factor = scale_factor

  def get_config(self):
     config = {
        "vocab_size": self._vocab_size,
        "embedding_width": self._embedding_width,
        "initializer": self._initializer,
        "use_one_hot": self._use_one_hot,
        "scale_factor": self._scale_factor,
    }
    base_config = super(OnDeviceEmbedding, self).get_config()
    return dict(list(base_config.items())+ list(config.items()))

  def build(self, input_shape):
    self.embeddings = self.add_weight(
        "embeddings",
        shape=[self._vocab_size, self._embedding_width],
        initializer=self._initializer,
        dtype=tf.float32
    )
    super(OnDeviceEmbedding, self).build(input_shape)

    def call(self, inputs):
    flat_inputs = tf.reshape(inputs, [-1])
    if self._use_one_hot:
      dtype = self._compute_dtype
      if not tf.dtypes.as_dtype(dtype).is_floating:
        # TensorFlow 1 compatibility. In TF1, self._compute_dtype is int32
        # instead of a floating-point dtype, as the dtype is inferred from the
        # dtype of the inputs
        dtype = tf.float32
      one_hot_data = tf.one_hot(
          flat_inputs, depth=self._vocab_size, dtype=dtype)
      embeddings = tf.matmul(one_hot_data, self.embeddings)
    else:
      embeddings = tf.gather(self.embeddings, flat_inputs)
    embeddings = tf.reshape(
        embeddings,
        # Work around b/142213824: prefer concat to shape over a Python list.
        tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
    embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
    if self._scale_factor:
      embeddings *= self._scale_factor
    return embeddings

  @property
  def vocab_size(self):
    return self._vocab_size

  @property
  def embedding_width(self):
    return self._embedding_width