# GPT-2 
---

In [1]:
import tensorflow as tf

In [2]:
tf.enable_eager_execution()

---
## Attention mask

In [3]:
def attention_mask(nd, ns, *, dtype):
    """
    1's in the lower triangle, counting from the lower right corner.

    Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
    """
    i = tf.range(nd)[:,None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    return tf.cast(m, dtype)

nd is vertical/external  
ns is horizontal/internal

In [4]:
attention_mask(4, 5, dtype=tf.float32)

<tf.Tensor: id=18, shape=(4, 5), dtype=float32, numpy=
array([[1., 1., 0., 0., 0.],
       [1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1.]], dtype=float32)>

In [5]:
attention_mask(14, 5, dtype=tf.float32)

<tf.Tensor: id=37, shape=(14, 5), dtype=float32, numpy=
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0.],
       [1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1.]], dtype=float32)>

In [6]:
attention_mask(4, 9, dtype=tf.float32)

<tf.Tensor: id=56, shape=(4, 9), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>

In [7]:
attention_mask(9, 9, dtype=tf.float32)

<tf.Tensor: id=75, shape=(9, 9), dtype=float32, numpy=
array([[1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>

--- 

### Tech deets

It is possible to use `None` in a slice as an equivalent for [tf.newaxis](https://www.tensorflow.org/api_docs/python/tf/Tensor#__getitem__) (scroll down a little), cf. also [this on the Stack](https://stackoverflow.com/a/44787333).

In [34]:
nd = 5 # vertical/external
print(tf.range(nd).numpy())

[0 1 2 3 4]


In [35]:
i = tf.range(nd)[:, None] # None same as tf.newaxis
print(i.numpy()) # Now has a new dimension added innermostly

[[0]
 [1]
 [2]
 [3]
 [4]]


In [36]:
i2 = tf.range(nd)[None, :]
print(i2.numpy()) # Instead a new external dimension

[[0 1 2 3 4]]


In [32]:
ns = 8 # horizontal/internal
j = tf.range(ns)
print(j.numpy())

[0 1 2 3 4 5 6 7]


'Shift' to the left, so the rightmost element of j is just equal to the highest dimension of i (which means: `j - ns`, everything shifted to negative numbers, `+ nd`, we go back up to have `nd` numbers above zero). 

In [37]:
print(i.numpy())
print((j - ns).numpy()) 
print((j - ns + nd).numpy()) 

[[0]
 [1]
 [2]
 [3]
 [4]]
[-8 -7 -6 -5 -4 -3 -2 -1]
[-3 -2 -1  0  1  2  3  4]


In [27]:
i >= j - ns + nd

<tf.Tensor: id=296, shape=(5, 8), dtype=bool, numpy=
array([[ True,  True,  True,  True, False, False, False, False],
       [ True,  True,  True,  True,  True, False, False, False],
       [ True,  True,  True,  True,  True,  True, False, False],
       [ True,  True,  True,  True,  True,  True,  True, False],
       [ True,  True,  True,  True,  True,  True,  True,  True]])>

Clever use of the comparison operator, which will use numpy broadcasting, so that the 'vertical' vector i will be compared elementwise to the 'horizontal' one j, leading to a `i x j` matrix filled with `True/False` values, that can then be cast as 1s & 0s in the last step.

In [38]:
m = (i >= j - ns + nd)
m

<tf.Tensor: id=366, shape=(5, 8), dtype=bool, numpy=
array([[ True,  True,  True,  True, False, False, False, False],
       [ True,  True,  True,  True,  True, False, False, False],
       [ True,  True,  True,  True,  True,  True, False, False],
       [ True,  True,  True,  True,  True,  True,  True, False],
       [ True,  True,  True,  True,  True,  True,  True,  True]])>

---
## In TF

The [tf function](https://www.tensorflow.org/api_docs/python/tf/linalg/band_part) mentioned, described as "Copy a tensor setting everything outside a central band in each innermost matrix to zero":

In [53]:
nd, ns = 4, 8
ones = tf.ones([nd, ns])
tf.matrix_band_part(ones, -1, ns-nd)

<tf.Tensor: id=461, shape=(4, 8), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>

Not in fact identical to the implemented function if nd > ns!

In [60]:
nd, ns = 8, 4
ones = tf.ones([nd, ns])
tf.matrix_band_part(ones, -1, ns-nd)

<tf.Tensor: id=534, shape=(8, 4), dtype=float32, numpy=
array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32)>

Examples from the tf doc:

In [55]:
ones = tf.ones([5,5])
print(tf.matrix_band_part(ones, 0, -1)) # ==> Upper triangular part.
print(tf.matrix_band_part(ones, -1, 0)) # ==> Lower triangular part
print(tf.matrix_band_part(ones, 0, 0)) # ==> Diagonal.

tf.Tensor(
[[1. 1. 1. 1. 1.]
 [0. 1. 1. 1. 1.]
 [0. 0. 1. 1. 1.]
 [0. 0. 0. 1. 1.]
 [0. 0. 0. 0. 1.]], shape=(5, 5), dtype=float32)
tf.Tensor(
[[1. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0.]
 [1. 1. 1. 0. 0.]
 [1. 1. 1. 1. 0.]
 [1. 1. 1. 1. 1.]], shape=(5, 5), dtype=float32)
tf.Tensor(
[[1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1.]], shape=(5, 5), dtype=float32)


In [46]:
tf.matrix_band_part(ones,3,3)

<tf.Tensor: id=423, shape=(5, 5), dtype=float32, numpy=
array([[1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1.]], dtype=float32)>

In [46]:
tf.matrix_band_part(ones,3,3)

<tf.Tensor: id=423, shape=(5, 5), dtype=float32, numpy=
array([[1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1.]], dtype=float32)>

In [47]:
tf.matrix_band_part(ones,2,2)

<tf.Tensor: id=427, shape=(5, 5), dtype=float32, numpy=
array([[1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1.]], dtype=float32)>

In [48]:
tf.matrix_band_part(ones,1,1)

<tf.Tensor: id=431, shape=(5, 5), dtype=float32, numpy=
array([[1., 1., 0., 0., 0.],
       [1., 1., 1., 0., 0.],
       [0., 1., 1., 1., 0.],
       [0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1.]], dtype=float32)>

Interestingly, not quite the desired result either if nd > ns.

In [56]:
ones = tf.ones([8,5])
print(tf.matrix_band_part(ones, 0, -1)) # ==> Upper triangular part.
print(tf.matrix_band_part(ones, -1, 0)) # ==> Lower triangular part
print(tf.matrix_band_part(ones, 0, 0)) # ==> Diagonal.

tf.Tensor(
[[1. 1. 1. 1. 1.]
 [0. 1. 1. 1. 1.]
 [0. 0. 1. 1. 1.]
 [0. 0. 0. 1. 1.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]], shape=(8, 5), dtype=float32)
tf.Tensor(
[[1. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0.]
 [1. 1. 1. 0. 0.]
 [1. 1. 1. 1. 0.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]], shape=(8, 5), dtype=float32)
tf.Tensor(
[[1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]], shape=(8, 5), dtype=float32)
