### Transformer Theoretical Model

In [14]:
from collections import OrderedDict

In [29]:
block_size = 1
vocab_size = 33278
n_layer = 4
n_head = 8
n_embd = 368
bias = False
H = 2

In [30]:
def params():
    """ estimates the number of parameters in the model"""
    out = OrderedDict()

    # token and position embeddings
    out['embedding/token'] = n_embd * vocab_size
    out['embedding'] = out['embedding/token']

    # recurrent units
    out['recurrent/ih'] = n_embd * H  # input to hidden weights
    out['recurrent/hh'] = n_embd * H * H  # hidden to hidden weights
    out['recurrent/bias'] = n_embd * 2 * H  # 2 bias terms per RNN unit, each of size H, and n_embd such units
    out['recurrent'] = (out['recurrent/ih'] + out['recurrent/hh'] + out['recurrent/bias'])

    # attention blocks
    out['attention/ln'] = n_embd # note, bias=False in our LN
    out['attention/ln2'] = n_embd # note, bias=False in our LN
    out['attention/kqv'] = n_embd * 3*n_embd
    out['attention/proj'] = n_embd**2
    out['attention'] = out['attention/ln'] + out['attention/ln2'] + out['attention/kqv'] + out['attention/proj']

    # MLP blocks
    ffw_size = 4*n_embd # feed forward size
    out['mlp/ln'] = n_embd
    out['mlp/ffw'] = n_embd * ffw_size
    out['mlp/proj'] = ffw_size * n_embd
    out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']
    
    # the transformer and the rest of it
    out['block'] = out['attention'] + out['mlp'] + out['recurrent']
    out['transformer'] = n_layer * out['block']
    #out['ln_f'] = n_embd # final layernorm
    out['dense'] = 0 # 0 because of parameter sharing. This layer uses the weights from the embedding layer

    # total
    out['total'] = out['embedding'] + out['transformer'] # + out['ln_f'] + out['dense']

    return out

# compare our param count to that reported by PyTorch
p = params()
params_total = p['total']
print(f"we see: {params_total}")
# create a header
print(f"{'name':20s} {'params':10s} {'ratio (%)':10s}")
for k,v in p.items():
    print(f"{k:20s} {v:10d} {v/params_total*100:10.4f}")
    

we see: 18792288
name                 params     ratio (%) 
embedding/token        12246304    65.1666
embedding              12246304    65.1666
recurrent/ih               1472     0.0078
recurrent/hh               5888     0.0313
recurrent/bias             2944     0.0157
recurrent                 10304     0.0548
attention/ln                368     0.0020
attention/ln2               368     0.0020
attention/kqv            406272     2.1619
attention/proj           135424     0.7206
attention                542432     2.8865
mlp/ln                      368     0.0020
mlp/ffw                  541696     2.8825
mlp/proj                 541696     2.8825
mlp                     1083760     5.7670
block                   1636496     8.7083
transformer             6545984    34.8334
dense                         0     0.0000
total                  18792288   100.0000
