
<h1 align="center">5. Transformer</h1>

### [Pytorch Reference](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)


<img width="500" src="img/bert.svg">

In [1]:
import pandas as pd
import torch.nn as nn
import torchviz
import torchsummary

In [2]:
def numParams(vocab, d_model, d_ff, n_layers):

	params_emb_mat   = vocab * d_model
    
	params_attention = d_model*d_model*4 + d_model*4 # Q, K, V, O weights & biases
	params_dense_1   = d_model*d_ff + d_ff
	params_dense_2   = d_ff*d_model + d_model
	params_norm_1    = d_model + d_model
	params_norm_2    = d_model + d_model
	params_per_layer = params_attention + params_dense_1 + params_dense_2 + params_norm_1 + params_norm_2
	params_encoder   = params_per_layer * n_layers
    
	params_total     = params_emb_mat + params_encoder
    
	return params_emb_mat, params_encoder, params_total

In [3]:
configurations = {
    "bertSmall": {"vocab":30522, "d_model":768,  "d_ff":768*4,  "n_layers":6},
    "bertBase":  {"vocab":30522, "d_model":768,  "d_ff":768*4,  "n_layers":12},
    "bertLarge": {"vocab":30522, "d_model":1024, "d_ff":1024*4, "n_layers":24},
    "aminSmall": {"vocab":20,    "d_model":8,    "d_ff":8*4,    "n_layers":6},
    "aminBase":  {"vocab":20,    "d_model":8,    "d_ff":8*4,    "n_layers":12},
    "aminLarge": {"vocab":20,    "d_model":16,   "d_ff":16*4,   "n_layers":12},
    "aminxLarge": {"vocab":20,   "d_model":16,   "d_ff":16*4,   "n_layers":16},
}

In [4]:
df = pd.DataFrame.from_dict(configurations).T
df

Unnamed: 0,vocab,d_model,d_ff,n_layers
bertSmall,30522,768,3072,6
bertBase,30522,768,3072,12
bertLarge,30522,1024,4096,24
aminSmall,20,8,32,6
aminBase,20,8,32,12
aminLarge,20,16,64,12
aminxLarge,20,16,64,16


In [5]:
numParams(**configurations["bertBase"])

(23440896, 85054464, 108495360)

In [6]:
df["Params Emb"], df["Params Enc"], df["Total params"] = numParams(**df[["vocab", "d_model", "d_ff", "n_layers"]])
df

Unnamed: 0,vocab,d_model,d_ff,n_layers,Params Emb,Params Enc,Total params
bertSmall,30522,768,3072,6,23440896,42527232,65968128
bertBase,30522,768,3072,12,23440896,85054464,108495360
bertLarge,30522,1024,4096,24,31254528,302309376,333563904
aminSmall,20,8,32,6,160,5232,5392
aminBase,20,8,32,12,160,10464,10624
aminLarge,20,16,64,12,320,39360,39680
aminxLarge,20,16,64,16,320,52480,52800


### My pytroch code

In [8]:
class Transformer(nn.Module):

    def __init__(self, vocab, d_model, nhead, d_ff, n_layers):
        super(Transformer, self).__init__()
        
        self.embeding  = nn.Embedding(vocab, d_model)
        layer          = nn.TransformerEncoderLayer(d_model, nhead, d_ff, dropout=.1, activation='relu')
        self.trans_enc = nn.TransformerEncoder(layer, n_layers, norm=None)

    def forward(self, seq):

        seq = self.embeding(seq)
        seq = self.trans_enc(seq)
        return seq

In [9]:
numParams(**configurations["bertDebug"])

KeyError: 'bertDebug'

In [11]:
transformer = Transformer(vocab=30522, d_model=768, d_ff=3072, n_layers=12, nhead=12)
sum(p.numel() for p in transformer.parameters())

108495360

In [10]:
for name, param in transformer.named_parameters():
    if param.requires_grad:
        print(name, param.shape)

embeding.weight torch.Size([30522, 768])
trans_enc.layers.0.self_attn.in_proj_weight torch.Size([2304, 768])
trans_enc.layers.0.self_attn.in_proj_bias torch.Size([2304])
trans_enc.layers.0.self_attn.out_proj.weight torch.Size([768, 768])
trans_enc.layers.0.self_attn.out_proj.bias torch.Size([768])
trans_enc.layers.0.linear1.weight torch.Size([3072, 768])
trans_enc.layers.0.linear1.bias torch.Size([3072])
trans_enc.layers.0.linear2.weight torch.Size([768, 3072])
trans_enc.layers.0.linear2.bias torch.Size([768])
trans_enc.layers.0.norm1.weight torch.Size([768])
trans_enc.layers.0.norm1.bias torch.Size([768])
trans_enc.layers.0.norm2.weight torch.Size([768])
trans_enc.layers.0.norm2.bias torch.Size([768])


In [11]:
30522 * 768 + \
2304 * 768  + \
2304        + \
768 * 768   + \
768         + \
3072 * 768  + \
3072        + \
768 * 3072  + \
768         + \
768         + \
768         + \
768         + \
768

30528768

In [12]:
embadding_matrix = 30522 * 768

params_attention = 768*768*4 + 768*4 # Q, K, V, O weights & biases
params_dense_1   = 768*3072 + 3072
params_dense_2   = 3072*768 + 768
params_norm_1    = 768 + 768
params_norm_2    = 768 + 768
param_per_layer  = params_attention + params_dense_1 + params_dense_2 + params_norm_1 + params_norm_2

embadding_matrix + (param_per_layer * 1)

30528768

In [13]:
numParams(30522, 768, 3072, 1)

30528768

In [22]:
768*4 * 768  + \
768*4 

2362368