In [1]:
import torch
import torch.nn as nn
import copy

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, embed_dim, query_size):
        super().__init__()
        self.heads = heads
        self.query_size = embed_dim / heads
        self.lin = nn.Linear(embed_dim, embed_dim)
        self.qkv = nn.ModuleList([copy.deepcopy(self.lin) for _ in range(4)])

    def forward(self, query, key, value): # each is of shape [BATCH_SIZE x SEQ_LEN x EMB_DIM]
        #x = torch.cat([x, x, x], dim=-1) # x reshaped to [BATCH_SIZE x SEQ_LEN x 3 * EMB_DIM]
        n_batches = query.size(0)

        query, key, value = [
            qkv(x).view(n_batches, -2, self.heads, self.query_size).permute(0, 2, 1, 3) 
            for qkv, x in zip(self.qkv, (query, key, value))
            ]
        # x projected to [BATCH_SIZE x SEQ_LEN x HEADS * QUERY_SIZE], then
        # reshaped to [BATCH_SIZE x SEQ_LEN x HEADS x QUERY_SIZE], and finally
        # permuted to [BATCH_SIZE x HEADS x SEQ_LEN x QUERY_SIZE]





        #print(x.shape)
        q = x[:,:,:self.hidden_dim]
        #print(q.shape)
        k = x[:,:,self.hidden_dim:(self.hidden_dim * 2)]
        #print(k.shape)
        v = x[:,:,(self.hidden_dim * 2):]
        #print(v.shape)
        score = q @ torch.transpose(k,-2,-1)
        #print(score)
        score /= 8
        score = torch.softmax(score, dim=-1)
        #print(score)
        z = torch.matmul(score, v)
        #print(z.shape)
        return z, k, v

In [2]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.qkv = nn.Linear(embed_dim * 3, hidden_dim * 3)

    def forward(self, x:torch.tensor) -> torch.tensor:
        x = torch.cat([x, x, x], dim=-1)
        #print(x.shape)
        x = self.qkv(x)
        #print(x.shape)
        q = x[:,:,:self.hidden_dim]
        #print(q.shape)
        k = x[:,:,self.hidden_dim:(self.hidden_dim * 2)]
        #print(k.shape)
        v = x[:,:,(self.hidden_dim * 2):]
        #print(v.shape)
        score = q @ torch.transpose(k,-2,-1)
        #print(score)
        score /= 8
        score = torch.softmax(score, dim=-1)
        #print(score)
        z = torch.matmul(score, v)
        #print(z.shape)
        return z, k, v

In [102]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, embed_dim, hidden_dim):
        super().__init__()
        self.heads = [AttentionHead(embed_dim, hidden_dim) for h in range(heads)]
        self.linear_combine = nn.Linear(heads * hidden_dim, embed_dim)

    def forward(self, x):
        z_cat = torch.cat([h(x) for h in self.heads], dim=-1)
        z = self.linear_combine(z_cat)
        return z

In [121]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.self_attn = MultiHeadAttention(8, 512, 64)
        self.fc1 = nn.Linear(512, 512)
        self.lnorm1 = nn.LayerNorm(512)
        self.lnorm2 = nn.LayerNorm(512)

    def forward(self, x):
        out = self.self_attn(x)
        norm_out = self.lnorm1(out + x)
        out = self.fc1(norm_out)
        out = self.lnorm2(out + norm_out)
        return out

In [None]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.self_attn = MultiHeadAttention(8, 512, 64)
        self.enc_dec_attn = MultiHeadAttention(8, 512, 64)
        self.fc1 = nn.Linear(512, 512)
        self.lnorm1 = nn.LayerNorm(512)
        self.lnorm2 = nn.LayerNorm(512)
        self.lnorm3 = nn.LayerNorm(512)

    def forward(self, x, k, v):
        out = self.self_attn(x)
        norm_out = self.lnorm1(out + x)

        out = self.enc_dec_attn(norm_out, k, v)
        norm_out = self.lnorm2(out + norm_out)
        
        out = self.fc1(norm_out)
        out = self.lnorm3(out + norm_out)
        return out

In [122]:
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        # produces matrix of [BATCH_SIZE x SEQ_LEN x EMB_SIZE]
        self.emb = nn.Embedding(10, 512)
        self.enc = TransformerEncoderLayer()
        self.dec = TransformerDecoderLayer()
    
    def forward(self, x):
        out = self.emb(x)
        out, k, v = self.enc(out)
        out = self.dec(out, k, v)
        return out

In [123]:
inp = torch.LongTensor([[1,2,3,4],[3,2,5,1]])
trnfrmr = Transformer()
out = trnfrmr(inp)
out.shape

torch.Size([2, 4, 512])