In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
def describe(s,x):
    print(s+':==============================\n')
    print(x)
    print('==============================\n')

In [3]:
class customizedModule(nn.Module):
    def __init(self):
        super(customizedModule,self).__init()
    def customizedLinear(self,in_dim,out_dim,activation=None,dropout=False):
        c1 = nn.Sequential(nn.Linear(in_dim,out_dim))
        nn.init.xavier_uniform_(c1[0].weight)
        nn.init.constant_(c1[0].bias,0)
        
        if activation is not None:
            c1.add_module(str(len(c1)),activation)
        if dropout:
            c1.add_module(str(len(c1)),nn.Dropout(p=self.args.dropout))  
        return c1

# CrossAttention

In [8]:
class CrossAttention(customizedModule):
    def __init__(self,dx,dq,mode):
        super(CrossAttention,self).__init__()
        self.w1 = self.customizedLinear(dx,dx)
        self.w2 = self.customizedLinear(dq,dx)   
        self.w1[0].bias.requires_grad = False
        self.w2[0].bias.requires_grad = False
        self.mode = mode
        
        if mode is 'add':
            # bias for add attention
            self.wt = self.customizedLinear(dx,1,activation= nn.Sigmoid())
            self.wt[0].bias.requires_grad = False
            self.bsa = nn.Parameter(torch.zeros(dx))  
        elif mode is not 'mul':
            raise NotImplementedError('crossattention mode error')
    def forward(self,x,q):
        if self.mode is 'mul':
            # W(1)x W(2)c
            wx = self.w1(x)
            wq = self.w2(q)  
            # <x,q>
            p = wx*wq
            # p = [a0,a1,a2...]
            p = torch.sum(p,dim=1)
            # softmax along row
            p = F.softmax(p,dim=0)
            #p = p.contiguous().view(p.size(0),p.size(1),1)
            p = torch.reshape(p,(p.size(0),-1))
            return p
        elif self.mode is 'add':     
            wx = self.w1(x)
            wq = self.w2(q)  
            p = self.wt(wx+wq+self.bsa)
            p = F.softmax(p,dim = 0)
            p = torch.reshape(p,(p.size(0),-1))
            #p = p.contiguous().view(p.size(0),p.size(1),1)
            return p
        else:
            raise NotImplementedError('CrossAttention error:<mul or add>')

# position wise feedforward network

In [9]:
class PositionwiseFeedForward(customizedModule):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = self.customizedLinear(d_in, d_hid) # position-wise
        self.w_2 = self.customizedLinear(d_hid, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual
        x = self.layer_norm(x)
        return x

# CSA

In [12]:
class CSA(customizedModule):
    def __init__(self,args,dx,dq):
        #self.args = args
        #self.args.gpu = torch.device('cuda:0')
        super(CSA,self).__init__()
        self.args = args
        self.dx = dx
        self.dq = dq 
        if self.args.csa_mode is 'mul':
            self.crossAttention = CrossAttention(dx,dq,'mul')
        elif self.args.csa_mode is 'add':
            self.crossAttention = CrossAttention(dx,dq,'add')
        else:
            raise NotImplementedError('CSA->CrossAttention error')
        self.addCrossAttention = CrossAttention(dx,dx,'add')
        self.debug = True
        self.PFN = PositionwiseFeedForward(dx,dx)
    def forward(self,x,c):
        # x(seq_len,word_dim) c(word_dim)
        #x = x*self.crossAttention(x,q)
        seq_len = x.size(-2)
        h = self.crossAttention(x,c)
        describe('h',h)
        h = x*h
        describe('x*h',h)
        # p = (seq_len*seq_len): the attention of xi to xj
        pMatrix = self.addCrossAttention(h.unsqueeze(0),h.unsqueeze(1))
        describe('pMatrix = addcross',pMatrix)
        # mask 對角
        M = Variable(torch.eye(seq_len)).to(self.args.gpu).detach()
        M[M==1]= float('-inf')
        pMatrix = pMatrix+M
        if self.debug: 
            describe('pMatrix+M',pMatrix)
            
        
        pMatrix = F.softmax(pMatrix,dim=-1)  
        describe('pMatrix after sm',pMatrix)
        pMatrix = pMatrix.contiguous().view(M.size(0),M.size(1),1)
        if self.debug: 
            describe('after pmatrix add one dim',pMatrix)
        
        ui = pMatrix*x 
        if self.debug:
            describe('x',x)
            describe('ui=pMatrix*x',ui)
        ui = torch.sum(ui,1)
        if self.debug:    
            describe('ui after sum dim -1',ui)   
        ui = self.PFN(ui)
        if self.debug:    
            describe('ui after PFN',ui)   
        return  ui

In [13]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--gpu', default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), type=int)
parser.add_argument('--csa-mode',default='mul',type = str)
args = parser.parse_args(args=[])

x = torch.randn(5,5).to('cuda:0')
q = torch.randn(10).to('cuda:0')
describe('x',x)
describe('q',q)
model = CSA(args,x.size(-1),q.size(-1)).to('cuda:0')

res = model(x,q)
res


tensor([[ 0.7294,  0.4527,  0.7317, -0.2087, -0.0841],
        [-1.7427, -0.7813, -1.2078,  0.2955,  1.3490],
        [ 0.3239, -0.8199,  1.8621, -1.4259, -1.7141],
        [-0.0750,  1.5943,  0.1255, -0.2731,  0.0634],
        [-0.5715,  0.9066,  3.0784,  0.2984,  0.1899]], device='cuda:0')


tensor([ 0.1861, -1.5704, -0.3783, -0.6402,  2.2979, -0.1522,  0.0749, -0.7025,
        -0.3740,  0.4002], device='cuda:0')


tensor([[0.0233],
        [0.0277],
        [0.7749],
        [0.0030],
        [0.1711]], device='cuda:0', grad_fn=<AsStridedBackward>)


tensor([[ 1.6967e-02,  1.0531e-02,  1.7021e-02, -4.8554e-03, -1.9553e-03],
        [-4.8315e-02, -2.1660e-02, -3.3485e-02,  8.1930e-03,  3.7399e-02],
        [ 2.5100e-01, -6.3533e-01,  1.4429e+00, -1.1048e+00, -1.3282e+00],
        [-2.2573e-04,  4.8000e-03,  3.7788e-04, -8.2223e-04,  1.9075e-04],
        [-9.7807e-02,  1.5516e-01,  5.2685e-01,  5.1075e-02,  3.2506e-02]],
       device='cuda:0', grad_fn=<MulBackward0>)


tensor([[0.20

tensor([[-1.1455,  0.1088,  1.8027, -0.6409, -0.1251],
        [-0.1937,  0.3532,  1.7410, -0.9311, -0.9694],
        [-1.8874,  0.7674,  0.6772, -0.1609,  0.6037],
        [-0.6130, -0.5235,  1.9949, -0.4602, -0.3982],
        [-0.5669,  0.5699,  1.5677, -1.3609, -0.2098]], device='cuda:0',
       grad_fn=<AddcmulBackward>)