In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 
import tensorflow as tf



In [2]:
import gpt2
from utils import load_encoder_hparams_and_params

### load model: 

In [3]:
model_size = '124M'# 355M, 774M, 1558M
n_tokens_to_generate = 12
encoder, hparams, params = load_encoder_hparams_and_params(model_size, 'models')

### hyperparameters

In [4]:
hparams

{'n_vocab': 50257, 'n_ctx': 1024, 'n_embd': 768, 'n_head': 12, 'n_layer': 12}

### parameters (nested JSON dictionary with trained weights)

In [5]:
[k for k in params.keys()]

['blocks', 'ln_f', 'wpe', 'wte']

In [6]:
_ = '''
{
    "wpe": [n_ctx, n_embd],
    "wte": [n_vocab, n_embd],
    "ln_f": {"b": [n_embd], "g": [n_embd]},
    "blocks": [
        {
            "attn": {
                "c_attn": {"b": [3*n_embd], "w": [n_embd, 3*n_embd]},
                "c_proj": {"b": [n_embd], "w": [n_embd, n_embd]},
            },
            "ln_1": {"b": [n_embd], "g": [n_embd]},
            "ln_2": {"b": [n_embd], "g": [n_embd]},
            "mlp": {
                "c_fc": {"b": [4*n_embd], "w": [n_embd, 4*n_embd]},
                "c_proj": {"b": [n_embd], "w": [4*n_embd, n_embd]},
            },
        },
        ... # repeat for n_layers
    ]
}
'''

### GPT2 loop: Autoregressive output generation (generates next word based on current input)

In [7]:
def __generate(inputs, params, n_head, n_tokens_to_generate):

    print(inputs)
    print(encoder.decode(inputs))
    print()
    
    outputs = []

    for _ in range(n_tokens_to_generate):                    # auto-regressive decode loop
        
        # project input tokens to vector embedding space
        x = wte[inputs] + wpe[range(len(inputs))]  # [n_seq] -> [n_seq, n_embd]

        # forward pass through n_layer transformer blocks
        for block in blocks:
            x = gpt2.transformer_block(x, **block, n_head=n_head)  # [n_seq, n_embd] -> [n_seq, n_embd]

        # reproject output vector to token space
        x = gpt2.layer_norm(x, **ln_f)
        logits = x @ wte.T 
        
        next_ids = gpt2.np.argmax(logits[-1:])               # get two most likely next tokens
        print(next_ids)
        print(encoder.decode([next_ids]))

        next_id = gpt2.np.argmax(logits[-1])                 # greedy sampling
        inputs.append(int(next_id))                          # append prediction to input

    return inputs[len(inputs) - n_tokens_to_generate :]      # only return generated ids


### 1. input tokenization: words to tokens

In [8]:
prompt = 'what is your favorite musical band?'

In [9]:
inputs = encoder.encode(prompt)
len(inputs)

7

In [10]:
inputs[0], encoder.decode(inputs[0:1])

(10919, 'what')

### 2. input embedding: tokens to vectors
- 50257: llargada del diccionari que tenim.
- 768: dimensions del embedding.
- estructura en la que si li pasem el input 0 "what" el busca al diccionari, aquest token li correspon a aquest vector.
- si es una paraula que no tenim, intenta fer cosetes com mirar per lletres.

In [11]:
params['wte'].shape, params['wpe'].shape

((50257, 768), (1024, 768))

#### token embedding space is learned on training

In [12]:
wte = params['wte'] # token embedding space

In [13]:
# each token corresponds to a vector in the token embedding space
wte[inputs[0]][:5] # show only first 5 components (out of n_embd dimensions = 768)

array([ 0.00742554, -0.0905292 ,  0.09560295, -0.07777912, -0.04478046],
      dtype=float32)

#### position embedding space

In [14]:
wpe = params['wpe'] # position embedding space

In [15]:
# each position, from 0 to n_ctx (max input length), corresponds to a vector in the position embedding space
wpe[0][:5] # show only first 5 components (out of n_embd dimensions = 768)

array([-0.01882072, -0.1974186 ,  0.00402672,  0.01134686,  0.06382412],
      dtype=float32)

- cada posició ja te un vector en concret i una paraula també té un vector. Llavors, si la paraula x està en x posició, es farà la combinació dels dos vectors. 
- els vectors de les posicions, venen donats per mirar quina paraula hi havia en cada poscició quan l'hem entrenat. 

#### each input token is converted into a vector combining the token vector and the position vector

In [16]:
x = wte[inputs] + wpe[range(len(inputs))]
x.shape # embedding vector components: len(inputs) *embedding_dims

(7, 768)

In [17]:
x[0, :5]

array([-0.01139518, -0.2879478 ,  0.09962968, -0.06643226,  0.01904366],
      dtype=float32)

In [18]:
(wte[inputs[0]] +wpe[0])[:5]

array([-0.01139518, -0.2879478 ,  0.09962968, -0.06643226,  0.01904366],
      dtype=float32)

### 3. forward pass through transformer blocks (next notebook)

In [19]:
blocks = params['blocks']

In [20]:
for block in blocks:
    x = gpt2.transformer_block(x, **block, n_head = hparams['n_head'])
    print(x[0, :5])

[-0.8682625   0.40907443 -0.7032808  -0.5987273   0.14790252]
[-0.75784254  0.482823   -0.13060603 -0.2809174  -0.10698892]
[-0.62990224  0.17336455 -0.1549139  -0.09451213 -0.02711199]
[-0.6165827   0.16350532 -0.06441966 -0.09630389  0.12699506]
[-0.7666638   0.04476391 -0.122725   -0.1490272   0.13410892]
[-0.7719635  -0.0340901  -0.08117952 -0.17332554  0.34843472]
[-0.86027086 -0.03359476 -0.03841237 -0.25495803  0.5139295 ]
[-0.90539     0.07989773 -0.12868711 -0.4138459   0.61734265]
[-0.87388366  0.15727827 -0.14486846 -0.4829519   0.57969224]
[-0.9719106   0.2692914  -0.36979988 -0.5350647   0.6047204 ]
[-1.116415    0.4122141  -0.720613   -0.50802207  0.5594149 ]
[-1.5974737  1.0517974 -1.7663152  0.9150096 -0.4267251]


### 4. reproject to token space: vectors to tokens

#### vector normalization

In [21]:
ln_f = params['ln_f']

In [22]:
x = gpt2.layer_norm(x, **ln_f)
x.shape

(7, 768)

#### find the next token: reproject the output vector to the token space

In [23]:
logits = x @ wte.T
logits.shape, logits[-1].shape, gpt2.np.argmax(logits[-1])

((7, 50257), (50257,), 198)

In [24]:
logits_last = x[-1] @ wte.T
gpt2.np.argmax(logits_last)

198

### check 

In [25]:
inputs = encoder.encode(prompt)
len(inputs)

7

In [26]:
n_tokens_to_generate = 1
__generate(inputs, params, hparams['n_head'], n_tokens_to_generate)

[10919, 318, 534, 4004, 10530, 4097, 30]
what is your favorite musical band?

198




[198]