In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
%matplotlib inline
%config InlineBackend.figure_format='retina'
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))

PyTorch version:[1.10.0+cu111].
device:[cuda:0].


In [2]:
# Scaled Dot-Product Attention

class ScaledDotProductAttention(nn.Module):
    def forward(self, Q, K, V, mask=None):
        d_K = K.size()[-1]
        scores = Q.matmul(K.transpose(-2,-1)) / np.sqrt(d_K)
        if mask is not None :
            scores = scores.masked_fill(mask==0, -1e9)
        attention = F.softmax(scores, dim=-1)
        out = attention.matmul(V)
        return out, attention
    
SPDA = ScaledDotProductAttention()
n_batch, d_K, d_V = 3, 128, 256 # K != V 달라도 된다.
n_Q, n_K, n_V = 30, 50, 50 # Q != K 달라도 된다.
Q = torch.rand(n_batch, n_Q, d_K)
K = torch.rand(n_batch, n_K, d_K)
V = torch.rand(n_batch, n_V, d_V)
out, attention = SPDA.forward(Q,K,V,mask=None)
def sh(x) : return str(x.shape)[11:-1]
print(f"SPDA | Q : {sh(Q)} | K : {sh(K)} | V : {sh(V)} | out : {sh(out)} | attention : {sh(attention)}")

n_batch, n_head, d_K, d_V = 3, 5, 128, 256
n_Q, n_K, n_V = 30, 50, 50
Q = torch.rand(n_batch, n_head, n_Q, d_K)
K = torch.rand(n_batch, n_head, n_K, d_K)
V = torch.rand(n_batch, n_head, n_V, d_V)
out, attention = SPDA.forward(Q,K,V,mask=None)
print(f"(M)SPDA | Q : {sh(Q)} | K : {sh(K)} | V : {sh(V)} | out : {sh(out)} | attention : {sh(attention)}")


SPDA | Q : [3, 30, 128] | K : [3, 50, 128] | V : [3, 50, 256] | out : [3, 30, 256] | attention : [3, 30, 50]
(M)SPDA | Q : [3, 5, 30, 128] | K : [3, 5, 50, 128] | V : [3, 5, 50, 256] | out : [3, 5, 30, 256] | attention : [3, 5, 30, 50]


In [None]:
#Multi head

class MultiHeadedAttention(nn.Module):
    def __init__(self, d_feat=128, n_head=5, actv=F.relu, USE_BIAS=True, dropout_p=0.1, device=None):
        """
        d_feat : feature dimension
        n_head : number of heads
        actv : activation after each linear layer
        USE_BIAS : whether use bias
        dropout_p : dropout rate
        device : which device to use
        """

        super(MultiHeadedAttention, self).__init__()
        if (d_feat%n_head) != 0 :
            raise ValueError(f"d_feat({d_feat:d}) should be dibsible by b_head({n_head})")
        self.d_feat = d_feat
        self.n_head = n_head
        self.d_head = self.d_feat // self.n_head
        self.actv = actv
        self.USE_BIAS = USE_BIAS
        self.dropout_p = dropout_p

        self.lin_Q = nn.Linear(self.d_feat, self.d_feat, self.USE_BIAS)
        self.lin_K = nn.Linear(self.d_feat, self.d_feat, self.USE_BIAS)
        self.lin_V = nn.Linear(self.d_feat, self.d_feat, self.USE_BIAS)
        self.lin_O = nn.Linear(self.d_feat, self.d_feat, self.USE_BIAS)

        self.dropout = nn.Dropout(p=self.dropout_p)
    
    def forward(self, Q, K, V, mask=None):
        """
        param Q : [n_batch, n_Q, d_feat]
        param X : [n_batch, n_K, d_feat]
        param V : [n_batch, n_V, d_feat]
        param mask
        """
        n_batch = Q.shape[0]
        Q_feat = self.lin_Q(Q)
        K_feat = self.lin_K(K)
        V_feat = self.lin_V(V)

        # Multi-head split of Q,K,V
        Q_split = Q_feat.view(n_batch, -1, self.n_head, self.d_head).permute(0,2,1,3)
        K_split = K_feat.view(n_batch, -1, self.n_head, self.d_head).permute(0,2,1,3)
        V_split = V_feat.view(n_batch, -1, self.n_head, self.d_head).permute(0,2,1,3)

        # Multi-headed Attention
        scores = torch.matmul(Q_split, K_split.permute(0,1,3,2)) / np.sqrt(d_K)
        if mask is not None :
            scores = scores.masked_fill(mask==0, -1e9)
        attention = torch.softmax(scores, dim=-1)
        x_raw = torch.matmul(self.dropout(attention), V_split)

        x_rsh1 = x_raw.permute(0,2,1,3).contiguous()
        x_rsh2 = x_rsh1.view(n_batch, -1, self.d_feat)

        x = self.lin_O(x_rsh2)
        out = {'Q_feat'}

SyntaxError: ignored