In [3]:
# from torchtext.transforms import BERTTokenizer

# VOCAB_FILE = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
# tokenizer = BERTTokenizer(vocab_path=VOCAB_FILE, do_lower_case=True, never_split=["[CLS]", "[SEP]"])
# tokenizer("Hello World, How are you!") # single sentence input
# tokenizer(["Hello World","How are you!"])

In [84]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import BertTokenizer
from dataclasses import dataclass
from typing import Tuple, Union, List

@dataclass
class BERTConfig:
    hidden_layers: int = 768
    num_heads: int = 12
    attention_blocks: int = 12
    dropout: float = 0.2
    vocabulary_size: int = None
    sequence_len: int = 32

@dataclass
class MLMData:
    x: torch.Tensor
    y: torch.Tensor
    att_mask: torch.Tensor

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [5]:
tokenizer("Hi how are you?", return_tensors='pt')

{'input_ids': tensor([[ 101, 7632, 2129, 2024, 2017, 1029,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [6]:
import math

class Head(nn.Module):
    def __init__(self, config: BERTConfig, bias: bool=False):
        super().__init__()
        self.W_q = nn.Linear(config.hidden_layers, config.hidden_layers // config.num_heads, bias=bias)
        self.W_k = nn.Linear(config.hidden_layers, config.hidden_layers // config.num_heads, bias=bias)
        self.W_v = nn.Linear(config.hidden_layers, config.hidden_layers // config.num_heads, bias=bias)
        self.dropout = nn.Dropout(p=config.dropout)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor=None) -> torch.Tensor:
        batch_size, t, embeddings = x.shape
        
        #(batch, t, emb)
        q = self.W_q(x)
        k = self.W_k(x).transpose(1, 2)
        v = self.W_v(x)
        attention = (q @ k) * (1 / math.sqrt(embeddings))        
        if mask is not None:
            attention = attention.masked_fill(mask == 0, float('-inf'))
        
        return self.dropout(self.softmax(attention)) @ v
        
        
h = Head(BERTConfig)
test = torch.randn((1, 20, 768))
h.forward(test).shape

torch.Size([1, 20, 64])

In [7]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, config: BERTConfig):
        super().__init__()
        assert config.hidden_layers % config.num_heads == 0, f"Cannot equally distribute {config.hidden_layers} to {config.num_heads} heads!"
        self.attention = nn.ModuleList([Head(config) for _ in range(config.num_heads)])
        self.W_o = nn.Linear(config.hidden_layers, config.hidden_layers, bias=False)
        self.dropout = nn.Dropout(p=config.dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        result = torch.cat([head(x) for head in self.attention], dim=-1)
        
        return self.dropout(self.W_o(result))
        
mh = MultiHeadAttention(BERTConfig)
mh.forward(torch.randn((32, 40, 768))).shape

torch.Size([32, 40, 768])

In [8]:
class Block(nn.Module):
    
    def __init__(self, config: BERTConfig):
        super().__init__()
        self.norm_1 = nn.LayerNorm(config.hidden_layers)
        self.norm_2 = nn.LayerNorm(config.hidden_layers)
        self.fully_connected = nn.Sequential(
            nn.Linear(config.hidden_layers, config.hidden_layers*4),
            nn.GELU(),
            nn.Linear(config.hidden_layers*4, config.hidden_layers),
            nn.Dropout(p=config.dropout)
        )
        self.multi_head_att = MultiHeadAttention(config)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.multi_head_att(self.norm_1(x))
        x = x + self.fully_connected(self.norm_2(x))
        return x

class BERTEmbeddings(nn.Module):
    
    def __init__(self, config: BERTConfig):
        super().__init__()
        self.embeddings = nn.Embedding(config.vocabulary_size, config.hidden_layers)
        self.positional_embeddings = nn.Embedding(config.sequence_len, config.hidden_layers)
        self.segment_embeddings = nn.Embedding(2, config.hidden_layers)
        self.emb_norm = nn.LayerNorm(config.hidden_layers)        
    
    def forward(self, tokens: torch.Tensor, token_type_ids: torch.Tensor) -> torch.Tensor:
        x = self.embeddings(tokens) + self.positional_embeddings(tokens) + self.segment_embeddings(input_ids)
        return self.emb_norm(x)
    
class BERT(nn.Module):
    
    def __init__(self, config: BERTConfig):
        super().__init__()
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.attention_blocks)])
        self.embeddings = BERTEmbeddings(config)
    
    def forward(self,
                x: torch.Tensor,
                attention_mask: torch.Tensor=None,
                input_token_ids: torch.Tensor=None,
                targets: torch.Tensor=None) -> Tuple[torch.Tensor, torch.Tensor]:
        ...

In [82]:
class MLM(Dataset):
    
    def __init__(self, tokenizer, data: list, mask_idx: int, sequence_length: int=64):
        super().__init__()
        self.tokenizer = tokenizer
        vocab_idxs = torch.ones((tokenizer.vocab_size,))
        vocab_idxs[0:1996] = 0
        v_len = len(vocab_idxs[1996:])
        #equal probability distribution over tokens that are litterals.
        self.sample_from_vocab = (vocab_idxs / v_len)
        #0.8 - probability of selecting [MASK] token,  0.1 respectively probability of random token from vocabulary or not replacing a word at all
        self.mask_idx = torch.tensor([0.8, 0.1, 0.1])
        self.mask_mapping = {0: lambda: 103, 1: lambda: torch.multinomial(self.sample_from_vocab, num_samples=1).item()}
        self.data = data
        self.seq_len = sequence_length
    
    def __len__(self) -> int:
        return len(self.data)
    
    def encode(self, data: Union[str, List[str]]) -> MLMData:
        
        tokenizer_data = self.tokenizer(data, return_tensors='pt', 
                                max_length=self.seq_len,
                                truncation=True,
                                padding='max_length')
        #encoded tokens
        tokens = tokenizer_data['input_ids'].reshape(-1)
        #replace padding tokens with index -1 so it willl be skipped in loss_fn
        y = tokens.masked_fill(tokens == 0, -1)
        att_mask = tokenizer_data['attention_mask'].reshape(-1)
        #filter [CLS] [SEP] tokens to not replace them with [MASK]
        special_tokens_filter = tokens.eq(102) | tokens.eq(101)
        #equal probability distribution over tokens excluding [SEP], [CLS]
        samples = att_mask.masked_fill(special_tokens_filter, 0) * 0.15
        
        #draw one index that will decide which masking will be used
        mask_type = torch.multinomial(self.mask_idx, num_samples=1).item()
        selected = torch.bernoulli(samples)
        #fill drawed indexes with masking method
        x = tokens.masked_fill(selected == 1, self.mask_mapping[mask_type]()) if mask_type != 2 else tokens
        return MLMData(x, y, att_mask)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        ...
        
mlm = MLM(tokenizer, [], 103)
obj = mlm.encode("how are you? [SEP] something")
print(obj)
tokenizer.decode(obj.x)

MLMData(x=tensor([ 101,  103, 2024, 2017, 1029,  102, 2242,  102,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0]), y=tensor([ 101, 2129, 2024, 2017, 1029,  102, 2242,  102,   -1,   -1,   -1,   -1,
          -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
          -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
          -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
          -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
          -1,   -1,   -1,   -1]), att_mask=tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0,

'[CLS] [MASK] are you? [SEP] something [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [83]:
tokenizer("how are you?","something", truncation=True, max_length=10, padding='max_length', return_tensors='pt')

{'input_ids': tensor([[ 101, 2129, 2024, 2017, 1029,  102, 2242,  102,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 1, 1, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])}

In [11]:
torch.multinomial(torch.tensor([0.8, 0.1, 0.1]), 1)

tensor([0])

In [12]:
tokenizer.vocab_size

30522

In [55]:
torch.ones_like(torch.tensor([True, False, False, True]) == True, dtype=torch.int8)

tensor([1, 1, 1, 1], dtype=torch.int8)

In [107]:
t = torch.tensor([[0.1, 2.3], [0.06, 0.03], [0.23, 0.43], [1.2, 1.03]])
print(F.softmax((t @ t.T), dim=-1))

tensor([[0.9266, 0.0050, 0.0127, 0.0557],
        [0.2556, 0.2382, 0.2435, 0.2628],
        [0.3875, 0.1447, 0.1787, 0.2891],
        [0.4397, 0.0404, 0.0749, 0.4450]])


In [165]:
t = torch.tensor([[0.1, 2.3], [0.06, 0.03], [0.23, 0.43], [1.2, 1.03]])
v = torch.tensor([[0.1, 2.3], [0.06, 0.03], [0.23, 0.43], [1.2, 1.03]])
mak = torch.tensor([1, 1, 0, 1]).expand(4, 4)
mask = mak.transpose(-2, -1) @ mak
# mask
t = t @ t.T
t = t.masked_fill(mask == 0, 1e-10)
t
# F.softmax(t, dim=-1) @ v

# (mak.T @ mak).unsqueeze(0).unsqueeze(0).shape
# mak.eq(0)
# mm

# F.softmax(test @ test.T, dim=-1)

tensor([[5.3000e+00, 7.5000e-02, 1.0000e-10, 2.4890e+00],
        [7.5000e-02, 4.5000e-03, 1.0000e-10, 1.0290e-01],
        [1.0000e-10, 1.0000e-10, 1.0000e-10, 1.0000e-10],
        [2.4890e+00, 1.0290e-01, 1.0000e-10, 2.5009e+00]])

In [168]:
F.softmax(torch.tensor([1.0000e-10, 1.0000e-10, 1.0000e-10, 1.0000e-10]), dim=-1)

tensor([0.2500, 0.2500, 0.2500, 0.2500])

In [184]:
F.softmax(torch.tensor([0, 0, -10000.0, 0]).unsqueeze(1) * (t @t.T),  dim=1)
# torch.tensor([0, 0, -10000.0, 0]).unsqueeze(0).shape

tensor([[0.2500, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.2500]])

In [160]:
F.softmax(torch.randn((4,4)).masked_fill(torch.tril(torch.ones(4,4)) == 0, float('-inf')), dim=-1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.6014, 0.3986, 0.0000, 0.0000],
        [0.6092, 0.3370, 0.0539, 0.0000],
        [0.6851, 0.1298, 0.0614, 0.1238]])