In [1]:
import torch
from torch import nn

Attention

In [2]:
class Attention(nn.Module):
  def __init__(self, emb_dim, drop_rate = 0.5):
    super(Attention, self).__init__()

    self.emb_dim = emb_dim
    self.drop_rate = drop_rate
    self.dropout = nn.Dropout(self.drop_rate)

    self.W_query = nn.Linear(self.emb_dim, self.emb_dim)    # W_q
    self.W_key = nn.Linear(self.emb_dim, self.emb_dim)    # W_k
    self.W_value = nn.Linear(self.emb_dim, self.emb_dim)    # W_v
    self.d_k = torch.FloatTensor([1]).double()    # scaling factor

  def forward(self, query, key, value, mask = None, negative_inf = -1e10):
    # Input Shape :  [batch size, sequence length, embedding dimension]

    batch_size = query.shape[0]    
    query = self.W_query(query).view(batch_size, -1, self.emb_dim)
    key = self.W_key(key).view(batch_size, -1, self.emb_dim)
    value = self.W_value(value).view(batch_size, -1, self.emb_dim)
    # Shape : [batch size, sequence length, embedding dimension]

    # Attention energy
    attention_energy = torch.matmul(query, torch.transpose(key, 1, 2))
    # [batch size, seq len, emb dim] x [batch size, emb dim, seq_len]
    # ---> [batch size, seq len, seq len]
    attention_energy /= torch.sqrt(self.d_k)
    if mask is not None:
      attention_energy = attention_energy.masked_fill(mask==0, -negative_inf)
    attention_energy = torch.softmax(attention_energy, dim = -1)
    
    # Attention Score
    attention_score = torch.matmul(self.dropout(attention_energy), value)
    # [batch size, seq len, seq len] x [batch size, seq len, emb dim]
    # ---> [batch size, seq len, emb dim]
    
    return attention_energy, attention_score

Test

In [3]:
a = torch.tensor([[[1,2,3], [4,5,6]], [[1,2,3], [4,5,6]], 
                  [[1,2,3], [4,5,6]], [[1,2,3], [4,5,6]], 
                  [[1,2,3], [4,5,6]], [[1,2,3], [4,5,6]]])
print(f'batch size {a.shape[0]},  sequence length : {a.shape[1]},   embedding dim :{a.shape[2]}')

batch size 6,  sequence length : 2,   embedding dim :3


In [4]:
tr_a = torch.transpose(a,1, 2)
print(f'transpose of a has shape : {tr_a.shape}') 

transpose of a has shape : torch.Size([6, 3, 2])


In [5]:
mat_mul = torch.matmul(a, tr_a)
print(f'Matrix multiplication of a and its transpose has shape {mat_mul.shape}')

Matrix multiplication of a and its transpose has shape torch.Size([6, 2, 2])


In [6]:
at = Attention(emb_dim = 3).double()

In [7]:
a = a.double()
print(f'attention enegery matrix shape : {at(a, a, a)[0].shape}')
# [batch size, sequence length, sequence length]
print(f'attention score shape : {at(a, a, a)[1].shape}')
# [batch size, sequence length, embedding dim]

attention enegery matrix shape : torch.Size([6, 2, 2])
attention score shape : torch.Size([6, 2, 3])
