 The high-level idea is that the encoder could produce a representation of length equal to the original input sequence. Then, at decoding time, the decoder can (via some control mechanism) receive as input a context vector consisting of a weighted sum of the representations on the input at each time step. Intuitively, the weights determine the extent to which each step’s context “focuses” on each input token, and the key is to make this process for assigning the weights differentiable so that it can be learned along with all of the other neural network parameters.

In [1]:
import tensorflow as tf
from utils import util_functions as utils

2024-12-26 13:54:10.448519: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Max
2024-12-26 13:54:10.448543: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 32.00 GB
2024-12-26 13:54:10.448547: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 10.67 GB
2024-12-26 13:54:10.448579: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-12-26 13:54:10.448594: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


scores: (batch_size, num_queries_per_sample, num_keys)
valid_lens: (batch_size, num_queries_per_sample)

In [2]:
def masked_softmax(X, valid_lens, value=0):
  """
  X: (batch_size, num_queries_per_sample, num_keys)
  valid_lens: (batch_size, ) OR (batch_size, num_queries_per_sample)
  """
  if valid_lens is None:
    return tf.nn.softmax(logits=X, axis=-1)
  
  shape = X.shape
  batch_size = shape[0]
  num_queries_per_sample = shape[1]
  num_keys = shape[2]

  X = tf.reshape(X, shape=(-1, num_keys)) 
  ## Shape: (batch_size * num_queries_per_sample, num_keys)
  if len(valid_lens.shape) == 1:
    valid_lens = tf.repeat(valid_lens, repeats=shape[1]) 
    ## Shape (batch_size * num_queries_per_sample, )
  else:
    valid_lens = tf.reshape(valid_lens, shape=-1)       
    ## Shape (batch_size * num_queries_per_sample, )

  mask = tf.range(0, num_keys)      ## Shape (num_keys, )
  mask = tf.expand_dims(mask, axis=0)     ## Shape (1, num_keys)
  mask = tf.tile(mask, multiples=[batch_size*num_queries_per_sample, 1])     
  ## Shape (batch_size * num_queries_per_sample, num_keys)

  valid_lens = tf.expand_dims(valid_lens, axis=1) 
  ## Shape (batch_size * num_queries_per_sample, 1)

  valid_lens = tf.tile(valid_lens, multiples=[1, num_keys])     
  ## Shape (batch_size * num_queries_per_sample, num_keys)

  masked_X = tf.where(mask < valid_lens, x=X, y=value)
  masked_X = tf.reshape(masked_X, shape=shape)
  return tf.nn.softmax(masked_X, axis=-1)

In [None]:
masked_softmax(tf.random.uniform(shape=(2, 2, 4)), tf.constant([2, 3]), value=-1e6)

In [8]:
class DotProductAttention(tf.keras.layers.Layer):
  def __init__(self, dropout):
    super().__init__()
    self.dropout = tf.keras.layers.Dropout(dropout)

  def call(self, queries, keys, values, valid_lens=None, **kwargs):
    """
    Queries: Decoder Input: (batch_size, num_queries_per_sample, dims)
    Keys: Encoder Hidden States: (batch_size, num_keys, dims)           ## (num_keys = num_steps)
    Values: Encoder Hidden States: (batch_size, num_keys, dims)         ## (num_keys = num_steps)
    Valid Lens: (batch_size, ) OR (batch_size, num_queries_per_sample)  
    """
    dims = queries.shape[-1]
    scores = tf.matmul(a=queries, b=keys, transpose_b=True)/tf.math.sqrt(x=tf.cast(dims, dtype=tf.float32))
    ## (batch_size, num_queries, num_keys)
    self.attention_weigts = masked_softmax(X=scores, valid_lens=valid_lens)
    ## (batch_size, num_queries, num_keys)

    weights = self.dropout(self.attention_weigts, **kwargs) 
    ## (batch_size, num_queries, num_keys)

    ## (batch_size, num_queries, dims)
    return tf.matmul(weights, values)

In [None]:
queries = tf.random.normal(shape=(2, 1, 2))
keys = tf.random.normal(shape=(2, 10, 2))
values = tf.random.normal(shape=(2, 10, 4))
valid_lens = tf.constant([2, 6])

attention = DotProductAttention(dropout=0.5)
attention(queries, keys, values, valid_lens, training=False).shape

In [14]:
class AdditiveAttention(tf.keras.layers.Layer):
  def __init__(self, key_dims, query_dims, num_hiddens, dropout, **kwargs):
    super().__init__(**kwargs)
    self.dropout = tf.keras.layers.Dropout(dropout)
    self.W_k = tf.keras.layers.Dense(units=num_hiddens, use_bias=False)
    self.W_q = tf.keras.layers.Dense(units=num_hiddens, use_bias=False)
    self.w_v = tf.keras.layers.Dense(units=1, use_bias=False)

  def call(self, queries, keys, values, valid_lens=None, **kwargs):
    """
    Queries: Decoder Input: (batch_size, num_queries_per_sample, dims_1)
    Keys: Encoder Hidden States: (batch_size, num_keys, dims_1)           ## (num_keys = num_steps)
    Values: Encoder Hidden States: (batch_size, num_keys, dims)           ## (num_keys = num_steps)
    Valid Lens: (batch_size, ) OR (batch_size, num_queries_per_sample)  
    """

    queries = self.W_q(queries) # (batch_size, num_queries_per_sample, num_hiddens)
    keys = self.W_k(keys)       # (batch_size, num_keys, num_hiddens)

    features = tf.expand_dims(input=queries, axis=2) + tf.expand_dims(input=keys, axis=1)
               # (batch_size, num_queries_per_sample, 1, num_hiddens) 
               # (batch_size, 1, num_keys, num_hiddens)
               # Output: (batch_size, num_queries_per_sample, num_keys, num_hiddens)
    features = tf.nn.tanh(features)
    scores = tf.squeeze(self.w_v(features), axis=-1)
    ## (batch_size, num_queries, num_keys)
    
    self.attention_weigts = masked_softmax(X=scores, valid_lens=valid_lens)
    ## (batch_size, num_queries, num_keys)

    weights = self.dropout(self.attention_weigts, **kwargs) 
    ## (batch_size, num_queries, num_keys)

    ## (batch_size, num_queries, dims)
    return tf.matmul(weights, values)

In [None]:
queries = tf.random.normal(shape=(2, 1, 20))

attention = AdditiveAttention(key_dims=2, query_dims=20, num_hiddens=8,
                              dropout=0.1)
attention(queries, keys, values, valid_lens, training=False).shape

In [3]:
class AttentionDecoder(utils.Decoder):
  """Base Attention Based Decoder Interface"""
  def __init__(self):
    super().__init__()

  @property
  def attention_weights(self):
    raise NotImplementedError

In [None]:
class Seq2SeqAttentionDecoder(AttentionDecoder):
  def __init__(self, vocab_size, embed_size, num_hiddens, num_layers=2, dropout=0):
    super().__init__()
    self.attention = utils.AdditiveAttention(key_dims=num_hiddens, query_dims=num_hiddens, num_hiddens=num_hiddens, dropout=0)

    ## accepts (num_steps, batch_size) as input
    ## gives out (num_steps, batch_size, embed_size)
    self.embedding = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embed_size)

    ## Accepts (num_steps, batch_size, embed_size)
    ## gives out (num_steps, batch_size, num_hiddens), [(batch_size, num_hiddens)]
    self.rnn = tf.keras.layers.RNN(tf.keras.layers.StackedRNNCells([
      tf.keras.layers.GRUCell(num_hiddens, dropout=dropout) for _ in range(num_layers)
    ]), return_sequences=True, return_state=True)
    self.dense = tf.keras.layers.Dense(units=vocab_size)
    
  def init_state(self, enc_all_outputs, *args):
    """returns encoder outs"""
    return enc_all_outputs, *args
  
  def call(self, X, state):
    """
    X: shape (batch_size, num_steps)
    state: from encoder encoder outputs
    Output: (num_steps, batch_size, vocab_size)
    """
    num_steps = X.shape[1]
    decoder_inputs = self.embedding(tf.transpose(X))
    encoder_outputs, hidden_state, valid_lens = state
    
    context_vector = encoder_outputs[-1] 
    #(batch_size, num_hiddens)

    context_vector = tf.expand_dims(input=context_vector, axis=0) 
    # (1, batch_size, num_hiddens)
    
    context_vector = tf.tile(input=context_vector, multiples=[num_steps, 1, 1]) 
    # (num_steps, batch_size, num_hiddens)

    decoder_inputs = tf.concat([decoder_inputs, context_vector], axis=-1)

    decoder_outputs, hidden_state = self.rnn(X=decoder_inputs, state=hidden_state)  
    # (num_steps, batch_size, num_hiddens)
    
    outputs = self.dense(decoder_outputs) 
    # (num_steps, batch_size, vocab_size)

    outputs = tf.transpose(outputs, perm=[1, 0, 2]) 
    # (batch_size, num_steps, vocab_size)

    return outputs, [encoder_outputs, hidden_state]