In [7]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class ScaledDotProductAttention(torch.nn.Module):
    def __init__(self, dropout_rate=0.0):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, q, k, v, attn_mask=None):
        attention = torch.bmm(q, k.transpose(1, 2))
        if attn_mask:
            attention = attention.masked_fill_(attn_mask, -np.inf)
        attention = F.softmax(attention, dim=2)
        attention = self.dropout(attention)
        context = torch.bmm(attention, v)
        return context, attention

q = torch.randn(2, 10, 512)
k = torch.randn(2, 10, 512)
v = torch.randn(2, 10, 512)
dot_product = ScaledDotProductAttention(0.5)
dot_product(q, k, v)

(tensor([[[ 2.3727e-01,  1.7001e-01,  1.6563e+00,  ..., -1.0552e+00,
           -1.3600e+00,  2.7937e-01],
          [ 2.6515e+00, -3.7611e+00, -3.5257e+00,  ..., -5.3747e-01,
           -8.4978e-01, -8.3924e-01],
          [ 2.6520e+00, -3.7618e+00, -3.5263e+00,  ..., -5.3754e-01,
           -8.4986e-01, -8.3941e-01],
          ...,
          [ 2.0201e+00, -6.8209e-01,  1.1847e+00,  ..., -4.0825e-01,
            1.9488e+00, -1.2257e+00],
          [ 2.0196e+00, -6.8189e-01,  1.1844e+00,  ..., -4.0816e-01,
            1.9483e+00, -1.2254e+00],
          [ 2.0202e+00, -6.8213e-01,  1.1848e+00,  ..., -4.0828e-01,
            1.9489e+00, -1.2258e+00]],
 
         [[-8.1388e-01,  1.4988e+00, -1.9217e+00,  ..., -1.7247e+00,
            1.2137e+00, -2.8184e+00],
          [ 1.8304e+00, -3.6473e-01,  2.7888e-01,  ..., -6.2142e-01,
           -1.2770e+00,  1.4380e-01],
          [ 6.6593e-01, -6.2210e-01,  7.4823e-01,  ..., -8.1016e-01,
           -1.5139e+00, -2.9042e-01],
          ...,
    

In [41]:


class MultiHeadAttention(torch.nn.Module):
    
    def __init__(self, d_model, n_heads, dropout_rate=0.0):
        super(MultiHeadAttention, self).__init__()
        self.dk = d_model // n_heads
        self.dropout = nn.Dropout(dropout_rate)
        self.linear_q = nn.Linear(d_model, self.dk)
        self.linear_k = nn.Linear(d_model, self.dk)
        self.linear_v = nn.Linear(d_model, self.dk)
        self.scaled_dot_product_attention = ScaledDotProductAttention(dropout_rate)     
        self.linear_final = nn.Linear(self.dk * n_heads, self.dk * n_heads)
    
    def forward(self, Q, K, V):
        Qs = Q.split(self.dk, dim=2)
        Ks = K.split(self.dk, dim=2)
        Vs = V.split(self.dk, dim=2)
        ipdb.set_trace()
        heads = []
        for q, k, v in zip(Qs, Ks, Vs):
            heads.append(self.scaled_dot_product_attention(q, k, v)[1])
        z = torch.cat(heads, dim=-1)
        
        output = self.linear_final(z)
        return heads
        
multihead = MultiHeadAttention(512, 8)

q = torch.randn(2, 10, 512)
k = torch.randn(2, 10, 512)
v = torch.randn(2, 10, 512)
x = multihead(q, k, v)

ModuleNotFoundError: No module named 'ipbd'

In [33]:
torch.cat(x, dim=-1)

tensor([[[1.1064e-10, 9.4015e-01, 9.0640e-12,  ..., 2.0927e-11,
          1.5080e-06, 4.7086e-07],
         [2.8416e-06, 5.6764e-10, 1.2079e-04,  ..., 8.9626e-03,
          1.5545e-04, 4.6506e-04],
         [6.9571e-01, 8.1735e-08, 1.8181e-03,  ..., 4.4118e-01,
          6.5938e-13, 1.4021e-16],
         ...,
         [1.5407e-03, 3.0334e-03, 5.5015e-01,  ..., 1.0975e-07,
          9.8197e-03, 4.2089e-04],
         [2.0353e-04, 1.6767e-04, 8.1614e-04,  ..., 5.8121e-09,
          2.5852e-04, 2.3458e-09],
         [4.7344e-06, 9.4880e-01, 1.7122e-09,  ..., 7.0488e-06,
          4.8286e-01, 1.1340e-05]],

        [[1.3956e-05, 6.2270e-03, 1.9188e-02,  ..., 4.9245e-01,
          7.1388e-07, 2.7141e-01],
         [2.5753e-13, 1.7032e-12, 3.5965e-16,  ..., 1.0277e-06,
          4.8744e-06, 4.2767e-10],
         [1.6295e-07, 5.0871e-14, 9.9985e-01,  ..., 1.7914e-08,
          3.0615e-11, 1.8504e-07],
         ...,
         [9.6641e-06, 2.7135e-05, 1.4340e-03,  ..., 2.8831e-13,
          5.614

In [26]:
x[0].shape

torch.Size([2, 10, 10])