In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

In [176]:
seq_embs = torch.tensor(np.arange(2*3*4).reshape(2,3,4)).float()
seq_embs

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

In [177]:
padding_mask = torch.tensor([[1, 0, 1], [1, 1, 0]])
padding_mask

tensor([[1, 0, 1],
        [1, 1, 0]])

In [162]:
class MAB_MaskedAttention(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB_MaskedAttention, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)
        
        self.mask_val=1.0e10 

    def forward(self, Q, K, padding_mask=None):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        if padding_mask is not None:
            M = padding_mask.repeat(self.num_heads, 1).unsqueeze(1)
            A = torch.softmax((Q_.bmm(K_.transpose(1,2)) - self.mask_val * (1-M))/math.sqrt(self.dim_V), 2)
        else:
            A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O
    

class PMA_MaskedAttention(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA_MaskedAttention, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB_MaskedAttention(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X, padding_mask=None):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X, padding_mask=padding_mask).squeeze(1)


In [163]:
mpma = PMA_MaskedAttention(4, 2, 1)

In [165]:
mpma(seq_embs, padding_mask=padding_mask)

tensor([[ 7.6991,  2.4427,  3.0977,  0.8994],
        [16.5145,  5.2091,  6.7333, -1.4213]], grad_fn=<SqueezeBackward1>)

In [166]:
mpma(seq_embs)

tensor([[ 6.5866,  2.0569,  2.5016,  0.9922],
        [19.2559,  6.1450,  7.7827, -1.8514]], grad_fn=<SqueezeBackward1>)

In [81]:
padding_mask

tensor([[1, 0, 1],
        [1, 1, 0]])

In [82]:
seq_embs = torch.tensor(np.arange(2*3*4).reshape(2,3,4)).float()
seq_embs

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

In [83]:
mpma(seq_embs, padding_mask=padding_mask).mean(1)

tensor([[ 5.2991,  3.4109,  1.0523,  7.3991],
        [18.8012, 11.3980,  1.0347, 28.8672]], grad_fn=<MeanBackward1>)

In [84]:
mpma(seq_embs).mean(1)

tensor([[ 4.0328,  2.0207,  1.0524,  7.3961],
        [14.1593,  8.0110,  1.0391, 23.4972]], grad_fn=<MeanBackward1>)

In [178]:
seq_embs2 = seq_embs.clone()
seq_embs2[0,1] = 0
seq_embs2

tensor([[[ 0.,  1.,  2.,  3.],
         [ 0.,  0.,  0.,  0.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

In [179]:
mpma(seq_embs, padding_mask=padding_mask)

tensor([[ 7.6991,  2.4427,  3.0977,  0.8994],
        [16.5145,  5.2091,  6.7333, -1.4213]], grad_fn=<SqueezeBackward1>)

In [180]:
mpma(seq_embs2, padding_mask=padding_mask)

tensor([[ 7.6991,  2.4427,  3.0977,  0.8994],
        [16.5145,  5.2091,  6.7333, -1.4213]], grad_fn=<SqueezeBackward1>)

In [181]:
mpma(seq_embs2)

tensor([[ 6.5866,  2.0569,  2.5016,  0.9922],
        [19.2559,  6.1450,  7.7827, -1.8514]], grad_fn=<SqueezeBackward1>)

In [182]:
mpma(seq_embs)

tensor([[ 7.1887,  2.2432,  2.9295,  0.8819],
        [19.2559,  6.1450,  7.7827, -1.8514]], grad_fn=<SqueezeBackward1>)

In [183]:
seq_embs

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])