# GPT-2 
---
## TF Transpose, Stack, Unstack, Split

Documentation [here](https://www.tensorflow.org/api_docs/python/tf/transpose).

In [4]:
import tensorflow as tf
import numpy as np

In [5]:
tf.enable_eager_execution()

---

Transposition means the dimensions of the matrix will be swapped: for a matrix of `[5,2]`, 5 rows of 2 numbers, the transposition will be 2 rows of 5 numbers. The `perm` parameter allows one to place the specific dimensions in the final matrix. A perm of `[1,0]` means then "put the dimension 1 (the second one) in the place of the first, and dimension 0 in the place of the second (swap them)". This becomes more straightforward when more dimensions are involved.

In [3]:
x = tf.reshape(tf.constant(np.arange(10)), [5,2])
print(x.numpy(), end='\n\n')
print(tf.transpose(x).numpy(), end='\n\n')
# print(tf.transpose(x, perm=[1,0]).numpy(), end='\n\n') # identical to the one above

[[0 1]
 [2 3]
 [4 5]
 [6 7]
 [8 9]]

[[0 2 4 6 8]
 [1 3 5 7 9]]



---

A variable, shape: five (groups) of two (groups) of three (numbers):

In [4]:
x = tf.cast(
            20*tf.get_variable('x',[5, 2, 3], 
                            dtype=tf.float32), 
            tf.int32) 

In [5]:
print(x.numpy())

[[[-8 -1  2]
  [ 0  7 -2]]

 [[ 0  0  4]
  [-1 -4 -5]]

 [[-8  2 -3]
  [ 1  9  9]]

 [[ 8  7  8]
  [-4 -2  4]]

 [[ 8  3  6]
  [ 4  6 -5]]]


Transposing the last two dimensions mean the inner matrices are switched from horizontal to vertical (or vice versa).

In [6]:
print(tf.transpose(x, perm=[0, 2, 1]).numpy())

[[[-8  0]
  [-1  7]
  [ 2 -2]]

 [[ 0 -1]
  [ 0 -4]
  [ 4 -5]]

 [[-8  1]
  [ 2  9]
  [-3  9]]

 [[ 8 -4]
  [ 7 -2]
  [ 8  4]]

 [[ 8  4]
  [ 3  6]
  [ 6 -5]]]


Another one, the original matrix becomes a 2 (dimension 1) of 3 (dimension 2) of 5 (dimension 0). 

In [8]:
print(tf.transpose(x, perm=[1,2,0]).numpy()) 

[[[-8  0 -8  8  8]
  [-1  0  2  7  3]
  [ 2  4 -3  8  6]]

 [[ 0 -1  1 -4  4]
  [ 7 -4  9 -2  6]
  [-2 -5  9  4 -5]]]


Given that the last two dimensions are swapped, the following is the close/inner transpose of the above.

In [10]:
print(tf.transpose(x, perm=[1,0,2]).numpy()) 

[[[-8 -1  2]
  [ 0  0  4]
  [-8  2 -3]
  [ 8  7  8]
  [ 8  3  6]]

 [[ 0  7 -2]
  [-1 -4 -5]
  [ 1  9  9]
  [-4 -2  4]
  [ 4  6 -5]]]


In [11]:
print(tf.transpose(x).numpy())
# print(tf.transpose(x, perm=[2, 1, 0]).numpy()) # same

[[[-8  0 -8  8  8]
  [ 0 -1  1 -4  4]]

 [[-1  0  2  7  3]
  [ 7 -4  9 -2  6]]

 [[ 2  4 -3  8  6]
  [-2 -5  9  4 -5]]]


In [12]:
print(tf.transpose(x, perm=[2, 0, 1]).numpy())

[[[-8  0]
  [ 0 -1]
  [-8  1]
  [ 8 -4]
  [ 8  4]]

 [[-1  7]
  [ 0 -4]
  [ 2  9]
  [ 7 -2]
  [ 3  6]]

 [[ 2 -2]
  [ 4 -5]
  [-3  9]
  [ 8  4]
  [ 6 -5]]]


---

## TF Stack, Unstack

Packing and unpacking tensors.

Documentation [here](https://www.tensorflow.org/api_docs/python/tf/stack) and [here](https://www.tensorflow.org/api_docs/python/tf/unstack).

> If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` and each tensor in `output` will have shape `(B, C, D)`. (Note that the dimension unpacked along is gone, unlike `split`).  
> If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` and each tensor in `output` will have shape `(A, C, D)`. Etc.

Simple case:

In [13]:
y1 = tf.constant(np.arange(0,3), dtype=tf.int32)
y2 = tf.constant(np.arange(3,6), dtype=tf.int32)
y3 = tf.constant(np.arange(6,9), dtype=tf.int32)
print(y1.numpy(), y2.numpy(), y3.numpy(), sep='\n')

[0 1 2]
[3 4 5]
[6 7 8]


In [20]:
y = tf.stack([y1,y2,y3])
print(y.numpy())

[[0 1 2]
 [3 4 5]
 [6 7 8]]


In [48]:
for unst in tf.unstack(y):
    print(unst.numpy())

[0 1 2]
[3 4 5]
[6 7 8]


In [50]:
y = tf.stack([y1,y2,y3], axis=1)
print(y.numpy())

[[0 3 6]
 [1 4 7]
 [2 5 8]]


In [51]:
for unst in tf.unstack(y):
    print(unst.numpy())

[0 3 6]
[1 4 7]
[2 5 8]


Easy build-up of new axes:

In [53]:
yy = tf.stack([y,y])
print(yy.numpy())

[[[0 3 6]
  [1 4 7]
  [2 5 8]]

 [[0 3 6]
  [1 4 7]
  [2 5 8]]]


Another example (practice with axis=0, 1, 2):

In [22]:
axis=2
cnst = tf.constant([[[0,1,2],[3,4,5]],[[0,1,2],[3,4,5]]])
print(cnst.shape)
print()
for unst in tf.unstack(cnst, axis=axis):
    print(unst.shape)
    print(unst.numpy())
    print()

(2, 2, 3)

(2, 2)
[[0 3]
 [0 3]]

(2, 2)
[[1 4]
 [1 4]]

(2, 2)
[[2 5]
 [2 5]]



---

In [41]:
z = tf.cast(tf.get_variable(name='x',shape=[2,3,4,5])*20, dtype=tf.int32)
print(z.numpy())

[[[[ 0  6  1 -4  0]
   [ 0 -4 -5 -1 -5]
   [ 6 -6 -5  0 -2]
   [ 6 -1  2 -6 -3]]

  [[-1  3  2 -2  4]
   [ 4 -1  3  4  5]
   [ 4 -5  4  0  0]
   [-3 -4 -3  3  3]]

  [[ 0  4 -3 -4  6]
   [-5  4 -5 -5  4]
   [ 5 -4  1 -5 -2]
   [-2 -4 -4 -5 -4]]]


 [[[ 4  0  2  4  5]
   [ 4 -5  1  4  1]
   [-5 -1 -1 -4  3]
   [-1  3  0 -6  2]]

  [[ 0  5  5 -4 -1]
   [ 1  2  1  1  3]
   [ 3 -6  1  6 -5]
   [ 6  5 -5 -3  2]]

  [[ 5 -4 -5 -4 -3]
   [ 6  5  3 -6  3]
   [ 2 -4  0  6  3]
   [-1  1 -5  2  1]]]]


In [42]:
for unst in tf.unstack(z): # splits through the outer dimension > two tensors
    print(unst.numpy(), end='\n----------\n')

[[[ 0  6  1 -4  0]
  [ 0 -4 -5 -1 -5]
  [ 6 -6 -5  0 -2]
  [ 6 -1  2 -6 -3]]

 [[-1  3  2 -2  4]
  [ 4 -1  3  4  5]
  [ 4 -5  4  0  0]
  [-3 -4 -3  3  3]]

 [[ 0  4 -3 -4  6]
  [-5  4 -5 -5  4]
  [ 5 -4  1 -5 -2]
  [-2 -4 -4 -5 -4]]]
----------
[[[ 4  0  2  4  5]
  [ 4 -5  1  4  1]
  [-5 -1 -1 -4  3]
  [-1  3  0 -6  2]]

 [[ 0  5  5 -4 -1]
  [ 1  2  1  1  3]
  [ 3 -6  1  6 -5]
  [ 6  5 -5 -3  2]]

 [[ 5 -4 -5 -4 -3]
  [ 6  5  3 -6  3]
  [ 2 -4  0  6  3]
  [-1  1 -5  2  1]]]
----------


In [43]:
for unst in tf.unstack(z, axis=1):
    print(unst.numpy(), end='\n----------\n')

[[[ 0  6  1 -4  0]
  [ 0 -4 -5 -1 -5]
  [ 6 -6 -5  0 -2]
  [ 6 -1  2 -6 -3]]

 [[ 4  0  2  4  5]
  [ 4 -5  1  4  1]
  [-5 -1 -1 -4  3]
  [-1  3  0 -6  2]]]
----------
[[[-1  3  2 -2  4]
  [ 4 -1  3  4  5]
  [ 4 -5  4  0  0]
  [-3 -4 -3  3  3]]

 [[ 0  5  5 -4 -1]
  [ 1  2  1  1  3]
  [ 3 -6  1  6 -5]
  [ 6  5 -5 -3  2]]]
----------
[[[ 0  4 -3 -4  6]
  [-5  4 -5 -5  4]
  [ 5 -4  1 -5 -2]
  [-2 -4 -4 -5 -4]]

 [[ 5 -4 -5 -4 -3]
  [ 6  5  3 -6  3]
  [ 2 -4  0  6  3]
  [-1  1 -5  2  1]]]
----------


In [46]:
print(z.numpy(), end='\n----------\n----------\n')
for unst in tf.unstack(z, axis=2):
    print(unst.numpy(), end='\n----------\n')

[[[[ 0  6  1 -4  0]
   [ 0 -4 -5 -1 -5]
   [ 6 -6 -5  0 -2]
   [ 6 -1  2 -6 -3]]

  [[-1  3  2 -2  4]
   [ 4 -1  3  4  5]
   [ 4 -5  4  0  0]
   [-3 -4 -3  3  3]]

  [[ 0  4 -3 -4  6]
   [-5  4 -5 -5  4]
   [ 5 -4  1 -5 -2]
   [-2 -4 -4 -5 -4]]]


 [[[ 4  0  2  4  5]
   [ 4 -5  1  4  1]
   [-5 -1 -1 -4  3]
   [-1  3  0 -6  2]]

  [[ 0  5  5 -4 -1]
   [ 1  2  1  1  3]
   [ 3 -6  1  6 -5]
   [ 6  5 -5 -3  2]]

  [[ 5 -4 -5 -4 -3]
   [ 6  5  3 -6  3]
   [ 2 -4  0  6  3]
   [-1  1 -5  2  1]]]]
----------
----------
[[[ 0  6  1 -4  0]
  [-1  3  2 -2  4]
  [ 0  4 -3 -4  6]]

 [[ 4  0  2  4  5]
  [ 0  5  5 -4 -1]
  [ 5 -4 -5 -4 -3]]]
----------
[[[ 0 -4 -5 -1 -5]
  [ 4 -1  3  4  5]
  [-5  4 -5 -5  4]]

 [[ 4 -5  1  4  1]
  [ 1  2  1  1  3]
  [ 6  5  3 -6  3]]]
----------
[[[ 6 -6 -5  0 -2]
  [ 4 -5  4  0  0]
  [ 5 -4  1 -5 -2]]

 [[-5 -1 -1 -4  3]
  [ 3 -6  1  6 -5]
  [ 2 -4  0  6  3]]]
----------
[[[ 6 -1  2 -6 -3]
  [-3 -4 -3  3  3]
  [-2 -4 -4 -5 -4]]

 [[-1  3  0 -6  2]
  [ 6  5 -5 -3  2]

In [47]:
print(z.numpy(), end='\n----------\n----------\n')
for unst in tf.unstack(z, axis=3):
    print(unst.numpy(), end='\n----------\n')

[[[[ 0  6  1 -4  0]
   [ 0 -4 -5 -1 -5]
   [ 6 -6 -5  0 -2]
   [ 6 -1  2 -6 -3]]

  [[-1  3  2 -2  4]
   [ 4 -1  3  4  5]
   [ 4 -5  4  0  0]
   [-3 -4 -3  3  3]]

  [[ 0  4 -3 -4  6]
   [-5  4 -5 -5  4]
   [ 5 -4  1 -5 -2]
   [-2 -4 -4 -5 -4]]]


 [[[ 4  0  2  4  5]
   [ 4 -5  1  4  1]
   [-5 -1 -1 -4  3]
   [-1  3  0 -6  2]]

  [[ 0  5  5 -4 -1]
   [ 1  2  1  1  3]
   [ 3 -6  1  6 -5]
   [ 6  5 -5 -3  2]]

  [[ 5 -4 -5 -4 -3]
   [ 6  5  3 -6  3]
   [ 2 -4  0  6  3]
   [-1  1 -5  2  1]]]]
----------
----------
[[[ 0  0  6  6]
  [-1  4  4 -3]
  [ 0 -5  5 -2]]

 [[ 4  4 -5 -1]
  [ 0  1  3  6]
  [ 5  6  2 -1]]]
----------
[[[ 6 -4 -6 -1]
  [ 3 -1 -5 -4]
  [ 4  4 -4 -4]]

 [[ 0 -5 -1  3]
  [ 5  2 -6  5]
  [-4  5 -4  1]]]
----------
[[[ 1 -5 -5  2]
  [ 2  3  4 -3]
  [-3 -5  1 -4]]

 [[ 2  1 -1  0]
  [ 5  1  1 -5]
  [-5  3  0 -5]]]
----------
[[[-4 -1  0 -6]
  [-2  4  0  3]
  [-4 -5 -5 -5]]

 [[ 4  4 -4 -6]
  [-4  1  6 -3]
  [-4 -6  6  2]]]
----------
[[[ 0 -5 -2 -3]
  [ 4  5  0  3]
  [ 6  

---

### TF Concat

Documentation [here](https://www.tensorflow.org/api_docs/python/tf/concat).

The idea here is to mash tensors together within one dimension, instead of packing (stacking) them into a higher one.

In [21]:
y1 = tf.constant(np.arange(0,3), dtype=tf.int32)
y2 = tf.constant(np.arange(3,6), dtype=tf.int32)
y3 = tf.constant(np.arange(6,9), dtype=tf.int32)
print(y1.numpy(), y2.numpy(), y3.numpy(), sep='\n')

[0 1 2]
[3 4 5]
[6 7 8]


In [22]:
y = tf.concat([y1,y2,y3], axis=0)
print(y.numpy())

[0 1 2 3 4 5 6 7 8]


In [29]:
n1 = tf.cast(20*tf.get_variable('n', [2,3]), tf.int32)
n2 = tf.cast(20*tf.get_variable('n', [2,3]), tf.int32)
n3 = tf.cast(20*tf.get_variable('n', [2,3]), tf.int32)
n4 = tf.cast(20*tf.get_variable('n', [2,3]), tf.int32)
print(n1.numpy(), n2.numpy(), n3.numpy(), n4.numpy(), sep='\n\n')

[[-21  -8 -19]
 [ 18   2  10]]

[[  3  10  -2]
 [-17  13  -8]]

[[ 20  10   5]
 [-15  -5  -9]]

[[-21   0  -3]
 [-17  11   2]]


In [31]:
n = tf.concat([n1,n2,n3,n4], axis=0)
print(n.numpy())

[[-21  -8 -19]
 [ 18   2  10]
 [  3  10  -2]
 [-17  13  -8]
 [ 20  10   5]
 [-15  -5  -9]
 [-21   0  -3]
 [-17  11   2]]


In [32]:
n = tf.concat([n1,n2,n3,n4], axis=1)
print(n.numpy())

[[-21  -8 -19   3  10  -2  20  10   5 -21   0  -3]
 [ 18   2  10 -17  13  -8 -15  -5  -9 -17  11   2]]


In [35]:
m1 = tf.cast(20*tf.get_variable('n', [2,3,4]), tf.int32)
m2 = tf.cast(20*tf.get_variable('n', [2,3,4]), tf.int32)
print(m1.numpy(), m2.numpy(), sep='\n\n')

[[[  0   0 -12   7]
  [  1  -1  -9  -7]
  [ -8  11 -11  -6]]

 [[  0 -11   1  12]
  [  8  -7 -10  -7]
  [  9   0  -6  10]]]

[[[ -5  -6   4  -9]
  [ -2 -12  -1  -1]
  [-13  10  -2   3]]

 [[ -9  -9 -12   2]
  [-11   2  -6  -5]
  [ -4  -3   3   7]]]


In [39]:
print(tf.concat([m1,m2], axis=0).numpy()) # the external dimension now contains four tensors of shape 2x[3,4]
print('-'*30)
print(tf.concat([m1,m2], axis=1).numpy()) # the external dimension contains two tensors of shape [2x3,4] == [6,4]
print('-'*30)
print(tf.concat([m1,m2], axis=2).numpy()) # the external dimension contains two tensors of shape  [3,4*2] == [3,8]

[[[  0   0 -12   7]
  [  1  -1  -9  -7]
  [ -8  11 -11  -6]]

 [[  0 -11   1  12]
  [  8  -7 -10  -7]
  [  9   0  -6  10]]

 [[ -5  -6   4  -9]
  [ -2 -12  -1  -1]
  [-13  10  -2   3]]

 [[ -9  -9 -12   2]
  [-11   2  -6  -5]
  [ -4  -3   3   7]]]
------------------------------
[[[  0   0 -12   7]
  [  1  -1  -9  -7]
  [ -8  11 -11  -6]
  [ -5  -6   4  -9]
  [ -2 -12  -1  -1]
  [-13  10  -2   3]]

 [[  0 -11   1  12]
  [  8  -7 -10  -7]
  [  9   0  -6  10]
  [ -9  -9 -12   2]
  [-11   2  -6  -5]
  [ -4  -3   3   7]]]
------------------------------
[[[  0   0 -12   7  -5  -6   4  -9]
  [  1  -1  -9  -7  -2 -12  -1  -1]
  [ -8  11 -11  -6 -13  10  -2   3]]

 [[  0 -11   1  12  -9  -9 -12   2]
  [  8  -7 -10  -7 -11   2  -6  -5]
  [  9   0  -6  10  -4  -3   3   7]]]


---

### TF Split

Documentation [here](https://www.tensorflow.org/api_docs/python/tf/split).

Difference with `unstack`, the dimensions are preserved. According to the [latter's doc](https://www.tensorflow.org/api_docs/python/tf/unstack):    
> (Note that the dimension unpacked along is gone, unlike split).

In [24]:
z = tf.cast(
            20*tf.get_variable(name='x',
                               shape=[2,3,4,5]), 
            dtype=tf.int32)

print(z.shape)

(2, 3, 4, 5)


With `unstack` we get a list of unpacked tensors:

In [28]:
[print(unst.shape) for unst in tf.unstack(z, axis=0)]

(3, 4, 5)
(3, 4, 5)


[None, None]

Automatically splits the outermost dimension (2). (The outer dim is preserved.)

In [49]:
for splt in tf.split(z, 2):
    print(splt.shape) 

(1, 3, 4, 5)
(1, 3, 4, 5)


Splits dimension 1 (3) into one tensor with size one, another with two `[1,2]`.

In [50]:
for splt in tf.split(z, [1,2], 1):
    print(splt.shape) 

(2, 1, 4, 5)
(2, 2, 4, 5)


Split along dimension 2, once into `[1,3]`, the other into two equal parts.

In [51]:
for splt in tf.split(z, [1,3], 2):
    print(splt.shape)
print()
for splt in tf.split(z, num_or_size_splits=2, axis=2):
    print(splt.shape) 

(2, 3, 1, 5)
(2, 3, 3, 5)

(2, 3, 2, 5)
(2, 3, 2, 5)


In [52]:
for splt in tf.split(z, [2,3], 3):
    print(splt.shape)
print()
for splt in tf.split(z, num_or_size_splits=5, axis=3):
    print(splt.shape) 

(2, 3, 4, 2)
(2, 3, 4, 3)

(2, 3, 4, 1)
(2, 3, 4, 1)
(2, 3, 4, 1)
(2, 3, 4, 1)
(2, 3, 4, 1)


In [59]:
print(z.shape, end='\n---------\n')
print(z.numpy(), end='\n--------\n--------\n')
for splt in tf.split(z, [1,2], 1): # split the next to outermost dimension into 1 & 2
    print(splt.numpy(), end='\n--------\n')

(2, 3, 4, 5)
---------
[[[[ 0 -2  3  3  0]
   [-3  6  0 -1 -1]
   [ 0  0  0 -5  3]
   [-5 -1  3 -3 -2]]

  [[-3 -5 -4  0  1]
   [ 3 -5  1  5  0]
   [ 0  0  6 -1  0]
   [ 0 -5  0 -6 -5]]

  [[ 5 -4  0 -6  0]
   [ 3 -3 -1  0 -3]
   [ 1  2  4  4 -1]
   [ 3  5 -3 -4 -4]]]


 [[[-3  5  2  1 -2]
   [ 1  5  4  2  2]
   [-2 -1 -4  0 -4]
   [ 5  3  0  5 -4]]

  [[-2  4  1 -2 -4]
   [-5 -5 -5  3  3]
   [-3  6 -1  5 -5]
   [ 0  1 -6  0  2]]

  [[ 4  4 -2 -5 -4]
   [ 6  0 -2  3  6]
   [-4 -6  2 -3  4]
   [ 0 -4  6  0  1]]]]
--------
--------
[[[[ 0 -2  3  3  0]
   [-3  6  0 -1 -1]
   [ 0  0  0 -5  3]
   [-5 -1  3 -3 -2]]]


 [[[-3  5  2  1 -2]
   [ 1  5  4  2  2]
   [-2 -1 -4  0 -4]
   [ 5  3  0  5 -4]]]]
--------
[[[[-3 -5 -4  0  1]
   [ 3 -5  1  5  0]
   [ 0  0  6 -1  0]
   [ 0 -5  0 -6 -5]]

  [[ 5 -4  0 -6  0]
   [ 3 -3 -1  0 -3]
   [ 1  2  4  4 -1]
   [ 3  5 -3 -4 -4]]]


 [[[-2  4  1 -2 -4]
   [-5 -5 -5  3  3]
   [-3  6 -1  5 -5]
   [ 0  1 -6  0  2]]

  [[ 4  4 -2 -5 -4]
   [ 6  0 -2  3  6]


In [58]:
print(z.shape, end='\n---------\n')
print(z.numpy(), end='\n--------\n--------\n')
for splt in tf.split(z, num_or_size_splits=2, axis=2): # split the next to innermost dimension into shape [2,2]
    print(splt.numpy(), end='\n--------\n')

(2, 3, 4, 5)
---------
[[[[ 0 -2  3  3  0]
   [-3  6  0 -1 -1]
   [ 0  0  0 -5  3]
   [-5 -1  3 -3 -2]]

  [[-3 -5 -4  0  1]
   [ 3 -5  1  5  0]
   [ 0  0  6 -1  0]
   [ 0 -5  0 -6 -5]]

  [[ 5 -4  0 -6  0]
   [ 3 -3 -1  0 -3]
   [ 1  2  4  4 -1]
   [ 3  5 -3 -4 -4]]]


 [[[-3  5  2  1 -2]
   [ 1  5  4  2  2]
   [-2 -1 -4  0 -4]
   [ 5  3  0  5 -4]]

  [[-2  4  1 -2 -4]
   [-5 -5 -5  3  3]
   [-3  6 -1  5 -5]
   [ 0  1 -6  0  2]]

  [[ 4  4 -2 -5 -4]
   [ 6  0 -2  3  6]
   [-4 -6  2 -3  4]
   [ 0 -4  6  0  1]]]]
--------
--------
[[[[ 0 -2  3  3  0]
   [-3  6  0 -1 -1]]

  [[-3 -5 -4  0  1]
   [ 3 -5  1  5  0]]

  [[ 5 -4  0 -6  0]
   [ 3 -3 -1  0 -3]]]


 [[[-3  5  2  1 -2]
   [ 1  5  4  2  2]]

  [[-2  4  1 -2 -4]
   [-5 -5 -5  3  3]]

  [[ 4  4 -2 -5 -4]
   [ 6  0 -2  3  6]]]]
--------
[[[[ 0  0  0 -5  3]
   [-5 -1  3 -3 -2]]

  [[ 0  0  6 -1  0]
   [ 0 -5  0 -6 -5]]

  [[ 1  2  4  4 -1]
   [ 3  5 -3 -4 -4]]]


 [[[-2 -1 -4  0 -4]
   [ 5  3  0  5 -4]]

  [[-3  6 -1  5 -5]
   [ 0  1 

In [56]:
print(z.shape, end='\n---------\n')
print(z.numpy(), end='\n--------\n--------\n')
for splt in tf.split(z, [2,3], 3): # split the innermost dimension into shape [2,3]
    print(splt.numpy(), end='\n--------\n')

(2, 3, 4, 5)
---------
[[[[ 0 -2  3  3  0]
   [-3  6  0 -1 -1]
   [ 0  0  0 -5  3]
   [-5 -1  3 -3 -2]]

  [[-3 -5 -4  0  1]
   [ 3 -5  1  5  0]
   [ 0  0  6 -1  0]
   [ 0 -5  0 -6 -5]]

  [[ 5 -4  0 -6  0]
   [ 3 -3 -1  0 -3]
   [ 1  2  4  4 -1]
   [ 3  5 -3 -4 -4]]]


 [[[-3  5  2  1 -2]
   [ 1  5  4  2  2]
   [-2 -1 -4  0 -4]
   [ 5  3  0  5 -4]]

  [[-2  4  1 -2 -4]
   [-5 -5 -5  3  3]
   [-3  6 -1  5 -5]
   [ 0  1 -6  0  2]]

  [[ 4  4 -2 -5 -4]
   [ 6  0 -2  3  6]
   [-4 -6  2 -3  4]
   [ 0 -4  6  0  1]]]]
--------
--------
[[[[ 0 -2]
   [-3  6]
   [ 0  0]
   [-5 -1]]

  [[-3 -5]
   [ 3 -5]
   [ 0  0]
   [ 0 -5]]

  [[ 5 -4]
   [ 3 -3]
   [ 1  2]
   [ 3  5]]]


 [[[-3  5]
   [ 1  5]
   [-2 -1]
   [ 5  3]]

  [[-2  4]
   [-5 -5]
   [-3  6]
   [ 0  1]]

  [[ 4  4]
   [ 6  0]
   [-4 -6]
   [ 0 -4]]]]
--------
[[[[ 3  3  0]
   [ 0 -1 -1]
   [ 0 -5  3]
   [ 3 -3 -2]]

  [[-4  0  1]
   [ 1  5  0]
   [ 6 -1  0]
   [ 0 -6 -5]]

  [[ 0 -6  0]
   [-1  0 -3]
   [ 4  4 -1]
   [-3 -4 -4]]]




---

### TF Gather, Gather_nd

Documentation [here](https://www.tensorflow.org/api_docs/python/tf/gather) and [here](https://www.tensorflow.org/api_docs/python/tf/gather_nd).

Two blog posts with examples [here](https://riptutorial.com/tensorflow/example/29018/extract-non-contiguous-slices-from-the-first-dimension-of-a-tensor) and [here](https://riptutorial.com/tensorflow/example/29069/how-to-use-tf-gather-nd).

Tf.gather is built to access the innermost dimension of tensors (the indices as the last parameter are referring to that).

In [4]:
gth = tf.constant(np.arange(15), shape=[3,5])
print(gth.numpy())

[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]]


In [5]:
print(tf.gather(gth, [1]).numpy())

[[5 6 7 8 9]]


Also with a scalar, ergo no dimension around it:

In [6]:
print(tf.gather(gth, 0).numpy()) 

[0 1 2 3 4]


In [7]:
print(tf.gather(gth, [1, 2]).numpy())

[[ 5  6  7  8  9]
 [10 11 12 13 14]]


Reduplications:

In [8]:
print(tf.gather(gth, [0, 0, 0]).numpy()) 

[[0 1 2 3 4]
 [0 1 2 3 4]
 [0 1 2 3 4]]


Play with dimensions possible:

In [9]:
print(tf.gather(gth, [[0,0],[0,0]]).numpy()) 

[[[0 1 2 3 4]
  [0 1 2 3 4]]

 [[0 1 2 3 4]
  [0 1 2 3 4]]]


In [21]:
print(tf.gather(gth, [[[0,0],[0,0]],[[1,1],[2,2]]]).numpy()) 

[[[[ 0  1  2  3  4]
   [ 0  1  2  3  4]]

  [[ 0  1  2  3  4]
   [ 0  1  2  3  4]]]


 [[[ 5  6  7  8  9]
   [ 5  6  7  8  9]]

  [[10 11 12 13 14]
   [10 11 12 13 14]]]]


---

Now `gather_nd`:

In [101]:
gth2 = tf.constant(np.arange(24), shape=[2,3,4])
print(gth2.numpy())
print()
print(tf.gather_nd(gth2, [0,0,3]).numpy()) # identical to 
print(gth2[0][0][3].numpy())               # slicing!
print()
print(tf.gather_nd(gth2, [[0,0,3],[0,1,3],[0,2,3]]).numpy()) 
print()
print(tf.gather_nd(gth2, [0,0]).numpy()) # similar to gather
print()
print(tf.gather_nd(gth2, [[0,0],[0,1],[1,2],[1,2]]).numpy()) # retrieve several rows, with repetition
print()

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]

3
3

[ 3  7 11]

[0 1 2 3]

[[ 0  1  2  3]
 [ 4  5  6  7]
 [20 21 22 23]
 [20 21 22 23]]



In [2]:
import os
os.listdir()

['GPT-2-study-sample.ipynb',
 'Normalization.ipynb',
 'Shape_list_reduce_mean_max_splitting_merging_states_expanddims_convert_tile.ipynb',
 'Gelu_plot.ipynb',
 'the-annotated-transformer_14_0.png',
 'Attention_mask.ipynb',
 '.ipynb_checkpoints',
 'matrix-product-is-defined.jpg',
 'GPT-2-study-encoder.ipynb',
 'Conv1d.ipynb',
 'Expand_tile_positions_fill_topk_oneszeroeslike_where_cond_while_multinomial_categorical.ipynb',
 'Softmax.ipynb',
 'Attention.ipynb',
 'Transpose_stack_concat_split_gather.ipynb',
 'ScatterNd1.png',
 'batch_norm.png',
 'Inspect_source.ipynb',
 'the-annotated-transformer_33_0.png',
 'matrix-multiplication.jpg',
 'the-annotated-transformer_38_0.png',
 'ScatterNd2.png']

---

### TF Scatter_nd

Not present in this code base, but the inverse operation of `tf.gather_nd`. Here [the documentation](https://www.tensorflow.org/api_docs/python/tf/scatter_nd).

Nice pics on that page, couldn't resist...:

![Scatter Nn1](ScatterNd1.png)

In [7]:
indices = tf.constant([[4], [3], [1], [7]]) # where the elements will end up
updates = tf.constant([9, 10, 11, 12]) # the elements to be inserted/scattered
shape = tf.constant([8]) # shape of final tensor, will be initialized at zero
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter.numpy())

[ 0 11  0 10  9  0  0 12]


As mentioned in the docs, if the same index is used several times, addition ensues.

In [16]:
indices = tf.constant([[0], [0], [1], [1]])
updates = tf.constant([9, 10, 11, 12]) 
shape = tf.constant([8]) 
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter.numpy())

[19 23  0  0  0  0  0  0]


Another try, add elements into 2d tensor?

In [22]:
indices = tf.constant([[0,0], [0,1], [1,1], [1,3]])
updates = tf.constant([9, 10, 11, 12]) 
shape = tf.constant([2,8]) 
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter.numpy())

[[ 9 10  0  0  0  0  0  0]
 [ 0 11  0 12  0  0  0  0]]


Now adding slices:

![Scatter Nn2](ScatterNd2.png)

In [15]:
indices = tf.constant([[0], [2]]) 
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]],
                       [[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter.numpy())

[[[5 5 5 5]
  [6 6 6 6]
  [7 7 7 7]
  [8 8 8 8]]

 [[0 0 0 0]
  [0 0 0 0]
  [0 0 0 0]
  [0 0 0 0]]

 [[5 5 5 5]
  [6 6 6 6]
  [7 7 7 7]
  [8 8 8 8]]

 [[0 0 0 0]
  [0 0 0 0]
  [0 0 0 0]
  [0 0 0 0]]]


The adding functionality.

In [19]:
indices = tf.constant([[0], [0]]) 
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]],
                       [[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([2, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter.numpy())

[[[10 10 10 10]
  [12 12 12 12]
  [14 14 14 14]
  [16 16 16 16]]

 [[ 0  0  0  0]
  [ 0  0  0  0]
  [ 0  0  0  0]
  [ 0  0  0  0]]]


Same with slices, although complex-ish to think about. This time the first 2d tensor of `updates` gets distributed into the first rows of 2d tensors of `scatter`, while the second ends sprayed over the last rows.
Remember:
- the shape of `indices` must match `updates`: all elements of `updates` must be accounted for, and it's the position within `indices` that tells TF which element of `update` to treat next;
- the content of each element of `indices` works like a slice, specifying a location within `scatter` (a slice or element).

In [27]:
indices = tf.constant([[[0,0],[1,0],
                        [2,0],[3,0]],
                       [[0,3],[1,3],
                        [2,3],[3,3]]]) 
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]],
                       [[1, 1, 1, 1], [2, 2, 2, 2],
                       [3, 3, 3, 3], [4, 4, 4, 4]]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter.numpy())

[[[5 5 5 5]
  [0 0 0 0]
  [0 0 0 0]
  [1 1 1 1]]

 [[6 6 6 6]
  [0 0 0 0]
  [0 0 0 0]
  [2 2 2 2]]

 [[7 7 7 7]
  [0 0 0 0]
  [0 0 0 0]
  [3 3 3 3]]

 [[8 8 8 8]
  [0 0 0 0]
  [0 0 0 0]
  [4 4 4 4]]]
