https://medium.com/@hhpatil001/transformers-from-scratch-in-simple-python-part-i-b290760c1040

In [11]:
from transformers import AutoTokenizer
from transformers import AutoConfig

In [14]:
from torch import nn

In [19]:
import torch
import torch.nn.functional as F
from math import sqrt

In [6]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', clean_up_tokenization_spaces=True)

In [7]:
text = 'I love data science.'

In [9]:
inputs = tokenizer(text, add_special_tokens=False, return_tensors='pt')

In [10]:
inputs

{'input_ids': tensor([[1045, 2293, 2951, 2671, 1012]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [12]:
config = AutoConfig.from_pretrained('bert-base-uncased')

In [13]:
config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.44.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [15]:
token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)

In [16]:
token_embeddings

Embedding(30522, 768)

In [17]:
inputs_embeds = token_embeddings(inputs.input_ids)

In [18]:
inputs_embeds.size()

torch.Size([1, 5, 768])

In [20]:
query = key = value = inputs_embeds

def scaled_dot_product_attention(query, key, value):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k) # torch.bmm is batch matrix - matrix multiplication. 
                                                                 # Basically a dot product.
    weights = F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)

In [21]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)
    
    def forward(self, hidden_state):
        attn_outputs = scaled_dot_product_attention(self.q(hidden_state), self.k(hidden_state), self.v(hidden_state))
        return attn_outputs

In [22]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList([AttentionHead(embed_dim, head_dim) for _ in range(num_heads)])
        self.output_linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, hidden_state):
        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
        x = self.output_linear(x)
        return x

In [23]:
multihead_attn = MultiHeadAttention(config)

In [24]:
attn_output = multihead_attn(inputs_embeds)

In [25]:
attn_output.size()

torch.Size([1, 5, 768])

In [26]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return x

In [27]:
feed_forward = FeedForward(config)

In [28]:
ff_outputs = feed_forward(attn_output)

In [29]:
ff_outputs.size()

torch.Size([1, 5, 768])