In [5]:
import numpy as np
import torch

import torch.nn as nn

In [3]:
class Tokenizer:
    def __init__(self):
        #self.vocab_size = vocab_size
        self.vocab = {'a': 0, 'b': 1, 'c': 2}
    
    def __call__(self, x):
        return [self.vocab[i] for i in x]

In [4]:
tok = Tokenizer()
tok('abc')

[0, 1, 3]

In [35]:
class AttentionLayer(nn.Module):
    def __init__(self, hidden_size=64, num_heads=4):
        super(AttentionLayer, self).__init__() 
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        
        self.fc_q = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.fc_k = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.fc_v = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
    
    def forward(self, x):
        bs, seq_len = x.shape[:2]
        head_size = self.hidden_size // self.num_heads
        
        q = self.fc_q(x).transpose(1,2).contiguous().view(bs, self.num_heads, seq_len,  head_size)
        k = self.fc_k(x).transpose(1,2).contiguous().view(bs, self.num_heads, seq_len,  head_size)
        v = self.fc_v(x).transpose(1,2).contiguous().view(bs, self.num_heads, seq_len,  head_size)
        
        attn = torch.softmax(torch.matmul(q, k.transpose(2,3) / 8), dim=-1)
        
        return torch.matmul(attn, v).transpose(1,2).contiguous().view(bs, seq_len, self.hidden_size)
    
    
class TransformerLayer(nn.Module):
    def __init__(self, hidden_size=64, num_heads=4):
        super(TransformerLayer, self).__init__() 
        self.attn_layer = AttentionLayer(hidden_size, num_heads)
        self.layer_norm = nn.LayerNorm(hidden_size)
        
    def forward(self, x):
        residual = x
        x = self.attn_layer(x)
        x = self.layer_norm(x)
        
        return x + residual
        
class Transformer(nn.Module):
    def __init__(self, hidden_size=64, num_heads=4, num_layers=10):
        super(Transformer, self).__init__()
        
        self.layers = nn.ModuleList([TransformerLayer(hidden_size, num_heads) for i in range(num_layers)])
    
    def forward(self, x):
        
        for layer in self.layers:
            x = layer(x)
        
        return x

In [42]:
model = Transformer()

x = torch.randn(2, 15, 64)

model(x).shape

torch.Size([2, 15, 64])

In [38]:
layer = TransformerLayer()

x = torch.randn(2, 15, 64)

layer(x).shape

torch.Size([2, 15, 64])

In [32]:
att_layer = AttentionLayer()

x = torch.randn(2, 15, 64)

att_layer(x).shape

torch.Size([2, 15, 64])

In [24]:
q = torch.randn(2, 5, 4, 16).transpose(2, 1)
k = torch.randn(2, 5, 4, 16).transpose(2, 1)
v = torch.randn(2, 5, 4, 16).transpose(2, 1)

#k.transpose(2, 1).shape
attn = torch.softmax(torch.matmul(q, k.transpose(2,3) / 8), dim=-1)


#torch.matmul(attn, v).shape
#attn.shape
torch.matmul(attn, v).transpose(1,2).contiguous().view(2, 5, 64)

tensor([[[ 0.0780, -0.2307, -0.2466,  0.3914,  0.5816,  0.1741, -0.2141,
           0.1885, -0.2946, -0.5029, -0.9384,  0.2282,  0.3352,  0.6129,
           0.2073, -0.5081,  0.1921, -0.3360,  0.4406,  0.0495, -0.0905,
          -0.1582,  0.7354, -0.5887, -0.5715,  0.0904,  0.4597,  0.3446,
          -0.0194, -0.2957,  0.1726, -0.2321, -0.0954,  0.0038, -0.4618,
          -0.1550, -0.5217,  0.4981,  0.4039,  0.2940,  0.4016,  0.5784,
           0.8469,  0.2532, -0.6792, -0.7106,  0.1902,  0.1428, -0.5776,
           0.5145,  0.2267,  0.4681,  0.0272,  0.7518, -0.7133,  0.0617,
          -0.1920, -0.1159,  0.1193, -0.1058,  0.0744, -0.0966,  0.4280,
           0.5555],
         [ 0.0214, -0.2793, -0.2312,  0.3210,  0.5137,  0.0801, -0.0627,
           0.1494, -0.4391, -0.5463, -0.9114,  0.1929,  0.4407,  0.6383,
           0.2059, -0.4331,  0.4221, -0.0648,  0.1145,  0.0468, -0.1162,
          -0.2959,  0.4275, -0.2808, -0.2020, -0.1531,  0.3467,  0.1109,
           0.1768, -0.4886, -0.

In [None]:
class Transformer(nn.Modoule):
    def __init__(self, num_layers)