# GPT-2 
---
## Study notebook

Mostly TF utils & standard functions

In [7]:
import tensorflow as tf
import numpy as np

In [2]:
tf.enable_eager_execution()

---

## Shape list

Why the dynamic thing in the first place? See [this comment](https://stackoverflow.com/a/34082273) It is to deal with the difference with dynamic and static shapes: when data flows through the network on a batch per batch basis, the shapes will be, for instance, [None, x, y, z], and therefore the shape is not defined statically.

In [79]:
def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]

In [77]:
def shape_list_comm(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    print('static:', static)
    dynamic = tf.shape(x)
    print('dynamic:', dynamic)
    for i, s in enumerate(static):
        print(s)
        print(dynamic[i])
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]

In [119]:
t1 = tf.random.normal([4, 3,2], 
                      mean=0.0, 
                      stddev=20,
                      dtype=tf.float32)
print(tf.cast(t1, tf.int16)) # casting to int for readability

tf.Tensor(
[[[-10 -21]
  [-14   9]
  [  1  -8]]

 [[-26   8]
  [  7 -14]
  [-33 -21]]

 [[ 25  14]
  [ -5  -8]
  [ 48  15]]

 [[ 13  18]
  [-36   8]
  [-45   5]]], shape=(4, 3, 2), dtype=int16)


In [78]:
shape_list_comm(t1)

static: [4, 3, 2]
dynamic: tf.Tensor([4 3 2], shape=(3,), dtype=int32)
4
tf.Tensor(4, shape=(), dtype=int32)
3
tf.Tensor(3, shape=(), dtype=int32)
2
tf.Tensor(2, shape=(), dtype=int32)


[4, 3, 2]

---

In [72]:
xmean.shape.as_list()

[3, 2]

In [92]:
xmean.ndim

2

In [73]:
xmean.shape

TensorShape([Dimension(3), Dimension(2)])

In [138]:
xmean.shape[1]

Dimension(2)

In [91]:
xmean.shape[1].value

2

In [74]:
tf.shape(xmean)

<tf.Tensor: id=616, shape=(2,), dtype=int32, numpy=array([3, 2], dtype=int32)>

In [103]:
print(tf.reshape(xmean, [2,3]))
print(tf.reshape(xmean, [1,6]))

tf.Tensor(
[[1. 1. 3.]
 [2. 5. 9.]], shape=(2, 3), dtype=float32)
tf.Tensor([[1. 1. 3. 2. 5. 9.]], shape=(1, 6), dtype=float32)


---

## Splitting & merging states

In [120]:
def split_states(x, n):
    """Reshape the last dimension of x into [n, x.shape[-1]/n]."""
    *start, m = shape_list(x)
    return tf.reshape(x, start + [n, m//n])

In [129]:
splitx = tf.get_variable("splitx",
                         [2,3,4],
                         tf.float32,
                         initializer=tf.glorot_uniform_initializer)

In [130]:
splitx

<tf.Variable 'splitx:0' shape=(2, 3, 4) dtype=float32, numpy=
array([[[ 0.09519899, -0.5179945 ,  0.60571945, -0.61901444],
        [ 0.5326501 ,  0.6256759 , -0.21402407, -0.07976532],
        [-0.14802265,  0.11726612,  0.35453045,  0.42391026]],

       [[ 0.45691597,  0.24667728,  0.04992068, -0.42152926],
        [ 0.07240939, -0.40548953,  0.5585022 ,  0.50473773],
        [-0.22778407,  0.28633028, -0.44296736, -0.0016287 ]]],
      dtype=float32)>

In [137]:
*start, m = shape_list(splitx)
n = 2
print(start)
print(m)
print(start + [n, m//n])

[2, 3]
4
[2, 3, 2, 2]


In [131]:
splitx_states = split_states(splitx, 2)
splitx_states

<tf.Tensor: id=799, shape=(2, 3, 2, 2), dtype=float32, numpy=
array([[[[ 0.09519899, -0.5179945 ],
         [ 0.60571945, -0.61901444]],

        [[ 0.5326501 ,  0.6256759 ],
         [-0.21402407, -0.07976532]],

        [[-0.14802265,  0.11726612],
         [ 0.35453045,  0.42391026]]],


       [[[ 0.45691597,  0.24667728],
         [ 0.04992068, -0.42152926]],

        [[ 0.07240939, -0.40548953],
         [ 0.5585022 ,  0.50473773]],

        [[-0.22778407,  0.28633028],
         [-0.44296736, -0.0016287 ]]]], dtype=float32)>

In [132]:
def merge_states(x):
    """Smash the last two dimensions of x into a single dimension."""
    *start, a, b = shape_list(x)
    return tf.reshape(x, start + [a*b])

In [133]:
re_splitx = merge_states(splitx_states)
re_splitx

<tf.Tensor: id=803, shape=(2, 3, 4), dtype=float32, numpy=
array([[[ 0.09519899, -0.5179945 ,  0.60571945, -0.61901444],
        [ 0.5326501 ,  0.6256759 , -0.21402407, -0.07976532],
        [-0.14802265,  0.11726612,  0.35453045,  0.42391026]],

       [[ 0.45691597,  0.24667728,  0.04992068, -0.42152926],
        [ 0.07240939, -0.40548953,  0.5585022 ,  0.50473773],
        [-0.22778407,  0.28633028, -0.44296736, -0.0016287 ]]],
      dtype=float32)>