# get_causal_attention_mask()

In [1]:
import tensorflow as tf

In [2]:
seq_length = 10
batch_size = 4

In [3]:
i = tf.range(seq_length)[:, tf.newaxis]
j = tf.range(seq_length)
print(i.shape)
print(j.shape)

mask = tf.cast(i >= j, dtype="int32") #shape=(10, 10) # broadcasting 을 사용하기 위한 !!
print(mask) #

mask = tf.reshape(mask, (1, seq_length, seq_length))
print(mask.shape) #


(10, 1)
(10,)
tf.Tensor(
[[1 0 0 0 0 0 0 0 0 0]
 [1 1 0 0 0 0 0 0 0 0]
 [1 1 1 0 0 0 0 0 0 0]
 [1 1 1 1 0 0 0 0 0 0]
 [1 1 1 1 1 0 0 0 0 0]
 [1 1 1 1 1 1 0 0 0 0]
 [1 1 1 1 1 1 1 0 0 0]
 [1 1 1 1 1 1 1 1 0 0]
 [1 1 1 1 1 1 1 1 1 0]
 [1 1 1 1 1 1 1 1 1 1]], shape=(10, 10), dtype=int32)
(1, 10, 10)


In [4]:
tf.expand_dims(batch_size, -1)

<tf.Tensor: shape=(1,), dtype=int32, numpy=array([4], dtype=int32)>

In [5]:
mult = tf.concat(
            [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
            axis=0,
        )

In [6]:
mult.numpy()

array([4, 1, 1], dtype=int32)

tf.tile(): Constructs a tensor by tiling a given tensor.

This operation creates a new tensor by replicating input multiples times. The output tensor's i'th dimension has input.dims(i) * multiples[i] elements, and the values of input are replicated multiples[i] times along the 'i'th dimension. For example, tiling [a b c d] by [2] produces [a b c d a b c d].

In [7]:
tf.tile(mask, mult)

<tf.Tensor: shape=(4, 10, 10), dtype=int32, numpy=
array([[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],

       [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],

       [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0

### 원본 Code

In [8]:
def get_causal_attention_mask(inputs):
    input_shape = tf.shape(inputs)
    batch_size, sequence_length = input_shape[0], input_shape[1]
    i = tf.range(sequence_length)[:, tf.newaxis]
    j = tf.range(sequence_length)
    mask = tf.cast(i >= j, dtype="int32")
    mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
    mult = tf.concat(
        [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
        axis=0,
    )
    return tf.tile(mask, mult)

Even if you used input_shape=(50,50,3), when keras sends you messages, or when you print the model summary, it will show (None,50,50,3)

See https://stackoverflow.com/a/44748370

In [9]:
seq_len = 10
emb_dim = 64

inputs = tf.keras.layers.Input(shape=(seq_len, emb_dim)) 
print(inputs.shape)

(None, 10, 64)


In [10]:
causal_mask = get_causal_attention_mask(inputs)
print(causal_mask)

KerasTensor(type_spec=TensorSpec(shape=(None, 10, 10), dtype=tf.int32, name=None), name='tf.tile/Tile:0', description="created by layer 'tf.tile'")


### References

* https://keras.io/examples/nlp/neural_machine_translation_with_transformer/