In [None]:
import torch
from torch.utils.data import Dataset
import pandas as pd
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
special_token_dict = {"bos_token": "<s>"}
tokenizer.add_special_tokens(special_token_dict)

tokenizer.encode("hi my name is neet")
import torch
from torch.utils.data import Dataset
import pandas as pd

class WMTDataset(Dataset):
    
    def __init__(self, data_path, src_tokenizer, tgt_tokenizer, seq_len):
        super().__init__()
        self.data = pd.read_csv(data_path)
        self.src_vocab_size = src_tokenizer.vocab_size
        self.tgt_vocab_size = tgt_tokenizer.vocab_size
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.seq_len = seq_len
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sos_token = self.src_tokenizer.encode(['<s>'])[0]
        eos_token = self.src_tokenizer.encode(['</s>'])[0]
        pad_token = self.src_tokenizer.encode(['<pad>'])[0]
        src_encoding = self.src_tokenizer.encode(self.data.iloc[index]['en'])[:-1] # remove default eos token
        tgt_encoding = self.tgt_tokenizer.encode(self.data.iloc[index]['de'])[:-1] # remove default eos token
        print("len of src sen: ", len(src_encoding))
        print("len of tgt sen: ", len(tgt_encoding))
        assert len(src_encoding) < self.seq_len + 2, "sentence too big"
        assert len(tgt_encoding) < self.seq_len + 2, "sentence too big"
        
        src_padding_len = self.seq_len - (len(src_encoding) + 2)  
        tgt_padding_len = self.seq_len - (len(tgt_encoding) + 2) 
        
        src_encoding = torch.tensor([sos_token] + src_encoding + [eos_token] + [pad_token]*src_padding_len, dtype=torch.uint64)
        tgt_encoding = torch.tensor([sos_token] + tgt_encoding + [eos_token] + [pad_token]*tgt_padding_len, dtype=torch.uint64)
        
        causal_mask = torch.triu(torch.ones(self.seq_len, self.seq_len, dtype=bool), diagonal=1).to(bool)

        encoder_self_attention_mask = (src_encoding == pad_token).unsqueeze(0) | (src_encoding == pad_token).unsqueeze(1)
        decoder_self_attention_mask = (tgt_encoding == pad_token).unsqueeze(0) | (tgt_encoding == pad_token).unsqueeze(1) | causal_mask
        decoder_cross_attention_mask = (tgt_encoding == pad_token).unsqueeze(0) | (src_encoding == pad_token).unsqueeze(1)
        
        return src_encoding, tgt_encoding, encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask
    
ds = WMTDataset("wmt14_translate_de-en_test.csv", tokenizer, tokenizer, 100)

In [None]:
ds[0][3]

In [None]:
x = torch.zeros(4,4)
y = torch.tensor([[1,0,0,1]], dtype=bool)
x.masked_fill_(y, -1e9)

In [None]:
x = torch.rand(2,5,2)
y = torch.randint(0,2,(1,2)).to(bool)
print(x)
print(y)

In [None]:
x.masked_fill_(y, 0)

In [None]:
import torch
import torch.nn as nn
import math

torch.manual_seed(42)
class WordEmbeddings(nn.Module):
    
    def __init__(self, vocab_size, d_model):
        super().__init__()
        
        self.word_embd = nn.Embedding(vocab_size, d_model)
        
    def forward(self, tokens):
        
        x = self.word_embd(tokens)
        return x
    
class PositionalEmbedding(nn.Module):
    
    def __init__(self, seq_len, d_model):
        super().__init__()
 
        pe = torch.zeros(seq_len, d_model, requires_grad=False)
        div_term = torch.pow(10000, torch.arange(0, d_model, 2)/d_model)
        pos = torch.arange(seq_len).unsqueeze(1)
        pe[:, 0::2] = torch.sin(pos/div_term)
        pe[:, 1::2] = torch.cos(pos/div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        
        return x + self.pe[:, :x.shape[1], :]

class MultiheadAttention(nn.Module):
    
    def __init__(self, d_model, num_heads):
        super().__init__()
        
        self.h = num_heads
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        
        self.w_o = nn.Linear(d_model, d_model)
        
        self.d_k = d_model // num_heads
        
        
        assert d_model % num_heads == 0, "d_model not divisible by num_heads"
        
    def forward(self, q, k, v, mask=None):
        
        B, seq_len, embd_dim = q.shape
        # q -> (B, seq_len, embd_dim)
        # k -> (B, seq_len, embd_dim)
        # v -> (B, seq_len, embd_dim)
        queries = self.w_q(q)
        keys = self.w_k(k)
        values = self.w_v(v)
        # q -> (B, seq_len, embd_dim)
        # k -> (B, seq_len, embd_dim)
        # v -> (B, seq_len, embd_dim)
        
        queries = queries.view(B, seq_len, self.h, self.d_k).transpose(1,2)
        keys = keys.view(B, seq_len, self.h, self.d_k).transpose(1,2)
        values = values.view(B, seq_len, self.h, self.d_k).transpose(1,2)
        
        # q -> (B, h, seq_len, dk)
        # k -> (B, h, seq_len, dk)
        # v -> (B, h, seq_len, dk)
        
        attention_weights = (queries @ keys.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            # boolean mask
            attention_weights = attention_weights.masked_fill(mask, float('-inf'))
        
        attention_weights = attention_weights.softmax(dim=-1)
        
        # attention_weights -> (B, h, seq_len, seq_len)
        
        attention_output = attention_weights @ values
        
        # attention_weights -> (B, h, seq_len, dk)
        
        attention_output = attention_output.transpose(1, 2).reshape(B, seq_len, -1)
        
        attention_output = self.w_o(attention_output)
        return attention_output, attention_weights

class FeedForwardNetwork(nn.Module):
    
    def __init__(self, d_model, d_ff, p_d=0.0):
        super().__init__()
        
        # p_d = 0.1 from original paper
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(p_d)
        self.linear_2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x):
        
        x = torch.relu(self.dropout(self.linear_1(x)))
        x = self.linear_2(x)
        return x
    
class EncoderBlock(nn.Module):
    
    def __init__(self, d_model, num_heads, d_ff, p_d=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p_d)
        self.multi_headed_self_attention = MultiheadAttention(d_model=d_model, num_heads=num_heads)
        self.ffn = FeedForwardNetwork(d_model=d_model, d_ff=d_ff, p_d=p_d)
        
    def forward(self, x, self_attn_mask):
        
        x_residual_1 = x
        
        attention_out = self.dropout(self.multi_headed_self_attention(x, x, x, self_attn_mask)[0])
        
        sublayer_1_out = self.layer_norm(x_residual_1 + attention_out)
        
        x_residual_2 = sublayer_1_out
        
        ffn_out = self.dropout(self.ffn(sublayer_1_out))
        
        sublayer_2_out = self.layer_norm(x_residual_2 + ffn_out)
        
        return sublayer_2_out
    
class DecoderBlock(nn.Module):
    
    def __init__(self, d_model, num_heads, d_ff, p_d=0.0, layer_norm_eps=1e-5, bias=True):
        super().__init__()
        self.d_model = d_model
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, bias=bias)
        self.dropout1 = nn.Dropout(p_d)
        self.dropout2 = nn.Dropout(p_d)
        self.dropout3 = nn.Dropout(p_d)
        self.masked_multi_headed_self_attention = MultiheadAttention(d_model=d_model, num_heads=num_heads)
        self.multi_headed_cross_attention = MultiheadAttention(d_model=d_model, num_heads=num_heads)
        self.ffn = FeedForwardNetwork(d_model=d_model, d_ff=d_ff, p_d=p_d)
        
    def forward(self, x, memory, masked_self_attn_mask, cross_attn_mask):
        print(self.masked_multi_headed_self_attention(x, x, x, masked_self_attn_mask)[0])
        x = self.norm1(x + self.dropout1(self.masked_multi_headed_self_attention(x, x, x, masked_self_attn_mask)[0]))
        
        x = self.norm2(x + self.dropout2(self.multi_headed_cross_attention(x, memory, memory, cross_attn_mask)[0]))
        
        x = self.norm3(x + self.dropout3(self.ffn(x)))
        
        return x
        


In [56]:
torch.manual_seed(42)
x = torch.rand(1, 5, 8)  # (batch, seq_len, d_model)

m1 = DecoderBlock(8, 2, 16, p_d=0.0)
m2 = nn.TransformerDecoderLayer(8, 2, 16, dropout=0.0, activation='relu', batch_first=True)

with torch.no_grad():
    # Self-attention weights
    m2.self_attn.in_proj_weight.copy_(
        torch.cat([
            m1.masked_multi_headed_self_attention.w_q.weight,
            m1.masked_multi_headed_self_attention.w_k.weight,
            m1.masked_multi_headed_self_attention.w_v.weight
        ], dim=0)
    )
    m2.self_attn.in_proj_bias.copy_(
        torch.cat([
            m1.masked_multi_headed_self_attention.w_q.bias,
            m1.masked_multi_headed_self_attention.w_k.bias,
            m1.masked_multi_headed_self_attention.w_v.bias
        ], dim=0)
    )
    m2.self_attn.out_proj.weight.copy_(m1.masked_multi_headed_self_attention.w_o.weight)
    m2.self_attn.out_proj.bias.copy_(m1.masked_multi_headed_self_attention.w_o.bias)

    # Cross attention weights
    m2.multihead_attn.in_proj_weight.copy_(
        torch.cat([
            m1.multi_headed_cross_attention.w_q.weight,
            m1.multi_headed_cross_attention.w_k.weight,
            m1.multi_headed_cross_attention.w_v.weight
        ], dim=0)
    )
    m2.multihead_attn.in_proj_bias.copy_(
        torch.cat([
            m1.multi_headed_cross_attention.w_q.bias,
            m1.multi_headed_cross_attention.w_k.bias,
            m1.multi_headed_cross_attention.w_v.bias
        ], dim=0)
    )
    m2.multihead_attn.out_proj.weight.copy_(m1.multi_headed_cross_attention.w_o.weight)
    m2.multihead_attn.out_proj.bias.copy_(m1.multi_headed_cross_attention.w_o.bias)

    # Feed-forward layers
    m2.linear1.weight.copy_(m1.ffn.linear_1.weight)
    m2.linear1.bias.copy_(m1.ffn.linear_1.bias)
    m2.linear2.weight.copy_(m1.ffn.linear_2.weight)
    m2.linear2.bias.copy_(m1.ffn.linear_2.bias)

    # LayerNorms
    m2.norm1.weight.copy_(m1.norm1.weight)
    m2.norm1.bias.copy_(m1.norm1.bias)
    m2.norm2.weight.copy_(m1.norm2.weight)
    m2.norm2.bias.copy_(m1.norm2.bias)
    m2.norm3.weight.copy_(m1.norm3.weight)
    m2.norm3.bias.copy_(m1.norm3.bias)


# Disable dropout randomness
m1.eval()
m2.eval()

out1 = m1(x, x, None, None)

out2 = m2(x, x, tgt_mask=None, memory_mask=None)

print(torch.allclose(out1, out2, atol=1e-5))  # Should be True!


tensor([[[-0.3731,  0.5099,  0.4786,  0.0406,  0.1351, -0.0435, -0.1737,
          -0.7356],
         [-0.3703,  0.5110,  0.4788,  0.0455,  0.1395, -0.0410, -0.1781,
          -0.7381],
         [-0.3719,  0.5099,  0.4792,  0.0426,  0.1367, -0.0425, -0.1757,
          -0.7360],
         [-0.3704,  0.5129,  0.4783,  0.0455,  0.1409, -0.0408, -0.1780,
          -0.7404],
         [-0.3721,  0.5115,  0.4782,  0.0419,  0.1371, -0.0427, -0.1747,
          -0.7377]]], grad_fn=<ViewBackward0>)
tensor([[[-0.3731,  0.5099,  0.4786,  0.0406,  0.1351, -0.0435, -0.1737,
          -0.7356],
         [-0.3703,  0.5110,  0.4788,  0.0455,  0.1395, -0.0410, -0.1781,
          -0.7381],
         [-0.3719,  0.5099,  0.4792,  0.0426,  0.1367, -0.0425, -0.1757,
          -0.7360],
         [-0.3704,  0.5129,  0.4783,  0.0455,  0.1409, -0.0408, -0.1780,
          -0.7404],
         [-0.3721,  0.5115,  0.4782,  0.0419,  0.1371, -0.0427, -0.1747,
          -0.7377]]], grad_fn=<TransposeBackward0>)
True


In [45]:
m1 = MultiheadAttention(8, 2)

m2 = nn.MultiheadAttention(8, 2, 0.0, batch_first=True)

In [46]:
m2.state_dict().keys()

odict_keys(['in_proj_weight', 'in_proj_bias', 'out_proj.weight', 'out_proj.bias'])

In [47]:

m2.in_proj_weight.data.copy_(torch.cat([m1.w_q.weight.data, m1.w_k.weight.data, m1.w_v.weight.data]))
m2.in_proj_bias.data.copy_(torch.cat([m1.w_q.bias.data, m1.w_k.bias.data, m1.w_v.bias.data]))
m2.out_proj.weight.data.copy_(m1.w_o.weight.data)
m2.out_proj.bias.data.copy_(m1.w_o.bias.data)

tensor([-0.0909,  0.1994,  0.1285,  0.2801, -0.1324,  0.1190,  0.1258, -0.2944])

In [53]:
# torch.manual_seed(42)
with torch.no_grad():
    m1.eval()
    m2.eval()

    x = torch.rand(1, 5, 8)

    y = m1(x,x,x,None)[0]
    z = m2(x,x,x)[0]

In [55]:
y-z

tensor([[[ 2.9802e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          -2.9802e-08,  2.9802e-08,  2.9802e-08],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 2.9802e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00, -2.9802e-08,
          -2.9802e-08, -2.9802e-08,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00]]])