In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
def describe(s,x,t):
    if t:
        print(s+':==============================\n')
        print(x,x.shape)
        print('==============================\n')

In [3]:
class customizedModule(nn.Module):
    def __init(self):
        super(customizedModule,self).__init()
    def customizedLinear(self,in_dim,out_dim,activation=None,dropout=False):
        c1 = nn.Sequential(nn.Linear(in_dim,out_dim))
        nn.init.xavier_uniform_(c1[0].weight)
        nn.init.constant_(c1[0].bias,0)
        
        if activation is not None:
            c1.add_module(str(len(c1)),activation)
        if dropout:
            c1.add_module(str(len(c1)),nn.Dropout(p=self.args.dropout))  
        return c1

# CrossAttention

In [4]:
class CrossAttention(customizedModule):
    def __init__(self,dx,dq,mode):
        super(CrossAttention,self).__init__()
        self.w1 = self.customizedLinear(dx,dx)
        self.w2 = self.customizedLinear(dq,dx)   
        self.w1[0].bias.requires_grad = False
        self.w2[0].bias.requires_grad = False
        
        # bias for add attention
        self.wt = self.customizedLinear(dx,1)
        self.wt[0].bias.requires_grad = False
        self.bsa = nn.Parameter(torch.zeros(dx))  
        # 'mul' or 'add'
        self.mode = mode  
        self.debug = False
    def forward(self,x,q):
        if self.mode is 'mul':     
            # W(1)x W(2)c
            wx = self.w1(x)
            wq = self.w2(q)
            wq = wq.unsqueeze(-2)    
            describe('wx',wx,self.debug)
            describe('wq',wq,self.debug)         
            # <x,q>
            p = wx*wq
            describe('wx * wq',p,self.debug)               
            # p = [a0,a1,a2...]
            p = torch.sum(p,dim=-1,keepdim=True)
            describe('p after sum dim = -1',p,self.debug)        
            # softmax along row       
            p = F.softmax(p,dim=-2)
            describe('p sm(row)',p,self.debug)        
            #p = torch.reshape(p,(p.size(0),-1))
            return p
        
        elif self.mode is 'add':   
            describe('x is',x,self.debug)
            describe('q is',q,self.debug)
            wx = self.w1(x)
            wq = self.w2(q) 
            wq = wq.unsqueeze(-2)
            describe('wx',wx,self.debug)
            describe('wq',wq,self.debug)
            describe('wx+wq',wx+wq,self.debug)
            describe('bsa',self.bsa,self.debug)
            describe('wx+wq+bsa',wx+wq+self.bsa,self.debug)
            p = self.wt(wx+wq+self.bsa)
            describe('wt',p,self.debug)  
            p = F.softmax(p,dim = -2)
            describe('sm',p,self.debug)
            return p
        else:
            raise NotImplementedError('CrossAttention error:<mul or add>')

# position wise feedforward network

In [5]:
class PositionwiseFeedForward(customizedModule):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = self.customizedLinear(d_in, d_hid) # position-wise
        self.w_2 = self.customizedLinear(d_hid, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual
        x = self.layer_norm(x)
        return x

# CSA

In [8]:
class CSA(customizedModule):
    def __init__(self,args,dx,dq):
        super(CSA,self).__init__()
        self.args = args
        self.dx = dx
        self.dq = dq  
        if self.args.csa_mode is 'mul':
            self.crossAttention = CrossAttention(dx,dq,'mul')
        elif self.args.csa_mode is 'add':
            self.crossAttention = CrossAttention(dx,dq,'add')
        else:
            raise NotImplementedError('CSA->CrossAttention error')
        
        self.Wsa1 = self.customizedLinear(dx,dx)
        self.Wsa2 = self.customizedLinear(dx,dx)
        self.Wsa1[0].bias.requires_grad = False
        self.Wsa2[0].bias.requires_grad = False
        self.wsat = self.customizedLinear(dx,1)
        self.bsa1 = nn.Parameter(torch.zeros(dx))  
        self.bsa2 = nn.Parameter(torch.zeros(dx)) 
        
        self.debug = False
        self.PFN = PositionwiseFeedForward(dx,dx)
    def forward(self,x,c):
        # x(batch,seq_len,word_dim) c(batch,word_dim)
        seq_len = x.size(-2)
        p = self.crossAttention(x,c)
        describe('p',p,self.debug)
        h = x*p
        describe('h',h,self.debug)
        # p = (seq_len*seq_len): the attention of xi to xj
        hi = self.Wsa1(h)
        hj = self.Wsa2(h)
        hi = hi.unsqueeze(-2)
        hj = hj.unsqueeze(-3)
        
        #fcsa(xi,xj|c)
        fcsa = hi+hj+self.bsa1
        describe('fcsa',fcsa,self.debug)
        fcsa = self.wsat(fcsa)
        describe('w(fcsa)',fcsa,self.debug)
        fcsa = torch.sigmoid(fcsa)
        describe('sigmoid fcsa',fcsa,self.debug)
        fcsa = fcsa.squeeze()
        describe('squeeze(fcsa)',fcsa,self.debug)     
        
        # mask 對角
        M = Variable(torch.eye(seq_len)).to(self.args.gpu).detach()
        M[M==1]= float('-inf')
        fcsa = fcsa+M
        describe('fcsa+M',fcsa,self.debug)
          
            
        fcsa = F.softmax(fcsa,dim=-1)  
        describe('fcsa after sm',fcsa,self.debug)
        
        
       
        fcsa = fcsa.unsqueeze(-1)
        describe('after pmatrix add one dim',fcsa,self.debug)
        # fcsa (batch,sqlen,sqlen,fcsa(xi,xj))
        # x (batch,1,sqlen,word_dim)
        ui = fcsa*x.unsqueeze(1) 
        describe('unsqeeze x',x.unsqueeze(1),self.debug)
        describe('ui=pMatrix*x',ui,self.debug)
        ui = torch.sum(ui,1)
        describe('ui after sum dim -1',ui,self.debug)   
        ui = self.PFN(ui)
        describe('ui after PFN',ui,self.debug)   
        return  ui

In [9]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--gpu', default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), type=int)
parser.add_argument('--csa-mode',default='add',type = str)
args = parser.parse_args(args=[])

x = torch.randn(2,5,5).to('cuda:0')
q = torch.randn(2,10).to('cuda:0')
describe('x',x,True)
describe('q',q,True)
model = CSA(args,x.size(-1),q.size(-1)).to('cuda:0')
res = model(x,q)
res


tensor([[[-0.1559, -0.8875, -0.3132,  1.1609,  0.1527],
         [-0.7436,  0.4666,  0.2635, -2.8303,  2.1312],
         [-0.5296, -1.7525, -0.8819,  0.2569, -0.7854],
         [ 2.4432, -1.2736, -1.1742, -1.1947,  0.8393],
         [-0.4537,  1.0839,  0.1807,  0.9723,  0.0870]],

        [[ 0.6672,  0.6182, -0.9743, -0.0678, -0.3681],
         [-0.1335,  0.6805,  0.0808,  0.1792,  0.4225],
         [ 0.8382, -1.2488, -1.1949,  0.7132,  0.4723],
         [ 0.2085,  1.1006,  2.6164,  0.7486, -0.1758],
         [ 0.2328, -0.6261,  0.3748,  1.3817, -1.1102]]], device='cuda:0') torch.Size([2, 5, 5])


tensor([[ 0.4086, -1.5627,  0.1798, -0.6928,  0.4636,  1.0537, -0.2247, -0.8477,
          0.1056,  0.5593],
        [-1.2007, -1.0034, -1.6710,  1.1377,  0.1275, -1.3883,  0.5069, -0.1661,
          0.3861,  0.5068]], device='cuda:0') torch.Size([2, 10])



tensor([[[ 0.2242, -1.4380, -0.2297,  1.6671, -0.2235],
         [ 0.3127,  0.4635,  0.2927, -1.9540,  0.8851],
         [ 1.2697, -1.5485, -0.2690,  0.8934, -0.3456],
         [ 1.7717, -0.2562, -0.4033, -1.2683,  0.1561],
         [-1.6666,  1.0217, -0.1332,  1.0441, -0.2660]],

        [[ 0.6987,  0.7552, -1.8712, -0.2115,  0.6287],
         [-1.3995,  1.4177, -0.5941, -0.2184,  0.7942],
         [ 1.5427, -1.3442, -0.7842,  0.4079,  0.1778],
         [-0.7171,  0.2085,  1.7812, -0.1567, -1.1158],
         [-0.0534, -0.9665,  0.4804,  1.6274, -1.0880]]], device='cuda:0',
       grad_fn=<AddcmulBackward>)