In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
def describe(s,x,t):
    if t:
        print(s+':==============================\n')
        print(x,x.shape)
        print('==============================\n')

In [3]:
class customizedModule(nn.Module):
    def __init(self):
        super(customizedModule,self).__init()
    def customizedLinear(self,in_dim,out_dim,activation=None,dropout=False):
        c1 = nn.Sequential(nn.Linear(in_dim,out_dim))
        nn.init.xavier_uniform_(c1[0].weight)
        nn.init.constant_(c1[0].bias,0)
        
        if activation is not None:
            c1.add_module(str(len(c1)),activation)
        if dropout:
            c1.add_module(str(len(c1)),nn.Dropout(p=self.args.dropout))  
        return c1

# CrossAttention

In [4]:
class CrossAttention(customizedModule):
    def __init__(self,dx,dq,mode):
        super(CrossAttention,self).__init__()
        self.w1 = self.customizedLinear(dx,dx)
        self.w2 = self.customizedLinear(dq,dx)   
        self.w1[0].bias.requires_grad = False
        self.w2[0].bias.requires_grad = False
        
        # bias for add attention
        self.wt = self.customizedLinear(dx,1)
        self.wt[0].bias.requires_grad = False
        self.bsa = nn.Parameter(torch.zeros(dx))  
        # 'mul' or 'add'
        self.mode = mode  
        self.debug = False
    def forward(self,x,q):
        if self.mode is 'mul':     
            # W(1)x W(2)c
            wx = self.w1(x)
            wq = self.w2(q)
            wq = wq.unsqueeze(-2)    
            describe('wx',wx,self.debug)
            describe('wq',wq,self.debug)         
            # <x,q>
            p = wx*wq
            describe('wx * wq',p,self.debug)               
            # p = [a0,a1,a2...]
            p = torch.sum(p,dim=-1,keepdim=True)
            describe('p after sum dim = -1',p,self.debug)        
            # softmax along row       
            p = F.softmax(p,dim=-2)
            describe('p sm(row)',p,self.debug)        
            #p = torch.reshape(p,(p.size(0),-1))
            return p
        
        elif self.mode is 'add':   
            describe('x is',x,self.debug)
            describe('q is',q,self.debug)
            wx = self.w1(x)
            wq = self.w2(q) 
            wq = wq.unsqueeze(-2)
            describe('wx',wx,self.debug)
            describe('wq',wq,self.debug)
            describe('wx+wq',wx+wq,self.debug)
            describe('bsa',self.bsa,self.debug)
            describe('wx+wq+bsa',wx+wq+self.bsa,self.debug)
            p = self.wt(wx+wq+self.bsa)
            describe('wt',p,self.debug)  
            p = F.softmax(p,dim = -2)
            describe('sm',p,self.debug)
            return p
        else:
            raise NotImplementedError('CrossAttention error:<mul or add>')

# position wise feedforward network

In [5]:
class PositionwiseFeedForward(customizedModule):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = self.customizedLinear(d_in, d_hid) # position-wise
        self.w_2 = self.customizedLinear(d_hid, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual
        x = self.layer_norm(x)
        return x

# CSA

In [6]:
class CSA(customizedModule):
    def __init__(self,args,dx,dq):
        super(CSA,self).__init__()
        self.args = args
        self.dx = dx
        self.dq = dq  
        if self.args.csa_mode is 'mul':
            self.crossAttention = CrossAttention(dx,dq,'mul')
        elif self.args.csa_mode is 'add':
            self.crossAttention = CrossAttention(dx,dq,'add')
        else:
            raise NotImplementedError('CSA->CrossAttention error')
        
        self.Wsa1 = self.customizedLinear(dx,dx)
        self.Wsa2 = self.customizedLinear(dx,dx)
        self.Wsa1[0].bias.requires_grad = False
        self.Wsa2[0].bias.requires_grad = False
        self.wsat = self.customizedLinear(dx,1)
        self.bsa1 = nn.Parameter(torch.zeros(dx))  
        self.bsa2 = nn.Parameter(torch.zeros(dx)) 
        
        self.debug = False
        self.PFN = PositionwiseFeedForward(dx,dx)
    def forward(self,x,c):
        # x(batch,seq_len,word_dim) c(batch,word_dim)
        seq_len = x.size(-2)
        p = self.crossAttention(x,c)
        describe('p',p,True)
        h = x*p
        describe('h',h,self.debug)
        # p = (seq_len*seq_len): the attention of xi to xj
        hi = self.Wsa1(h)
        hj = self.Wsa2(h)
        hi = hi.unsqueeze(-2)
        hj = hj.unsqueeze(-3)
        
        #fcsa(xi,xj|c)
        fcsa = hi+hj+self.bsa1
        describe('fcsa',fcsa,self.debug)
        fcsa = self.wsat(fcsa)
        describe('w(fcsa)',fcsa,self.debug)
        fcsa = torch.sigmoid(fcsa)
        describe('sigmoid fcsa',fcsa,self.debug)
        fcsa = fcsa.squeeze()
        describe('squeeze(fcsa)',fcsa,self.debug)     
        
        # mask 對角
        M = Variable(torch.eye(seq_len)).to(self.args.gpu).detach()
        M[M==1]= float('-inf')
        fcsa = fcsa+M
        describe('fcsa+M',fcsa,self.debug)
          
            
        fcsa = F.softmax(fcsa,dim=-1)  
        describe('fcsa after sm',fcsa,self.debug)
        
        
       
        fcsa = fcsa.unsqueeze(-1)
        describe('after pmatrix add one dim',fcsa,self.debug)
        # fcsa (batch,sqlen,sqlen,fcsa(xi,xj))
        # x (batch,1,sqlen,word_dim)
        ui = fcsa*x.unsqueeze(1) 
        describe('unsqeeze x',x.unsqueeze(1),self.debug)
        describe('ui=pMatrix*x',ui,self.debug)
        ui = torch.sum(ui,1)
        describe('ui after sum dim -1',ui,self.debug)   
        ui = self.PFN(ui)
        describe('ui after PFN',ui,self.debug)   
        return  ui

# this is test

In [7]:
from torchtext import data
from torchtext.vocab import GloVe
import torch
import spacy
from torchtext.data import Iterator, BucketIterator

In [8]:
class getHotpotData():
    def __init__(self,args,trainPath,devPath,):
        self.nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'tagger'])   
        self.trainpath= trainPath
        self.devpath= devPath
        
        self.ANSWER  = data.Field(tokenize = self.tokenizer)
        self.QUESTION = data.Field(tokenize = self.tokenizer)
        self.CONTEXT = data.Field(tokenize = self.tokenizer)
        
        fields = {'context':('Context', self.CONTEXT),'answer':('Answer', self.ANSWER),'question':('Question', self.QUESTION)}
        
        self.train = data.TabularDataset(path = self.trainpath,format='csv',fields=fields)
        self.dev = data.TabularDataset(path = self.devpath,format='csv',fields=fields)
        
        self.CONTEXT.build_vocab(self.train, vectors=GloVe(name='6B', dim=300))  
        self.QUESTION.build_vocab(self.train, vectors=GloVe(name='6B', dim=300)) 
        self.ANSWER.build_vocab(self.train)
        
        self.train_iter,self.dev_iter = data.BucketIterator.splits((self.train,self.dev),sort_key=lambda x: len(x.Question),sort_within_batch=True,shuffle=True,batch_size=args.batch_size,device=args.gpu)
       
        print('load hotpot data done')
        
    def tokenizer(self,text):
        return [str(token) for token in self.nlp(text)]
    
    def calculate_block_size(self, B):
        data_lengths = []
        for e in self.train.examples:
            data_lengths.append(len(e.premise))
            data_lengths.append(len(e.hypothesis))

        mean = np.mean(data_lengths)
        std = np.std(data_lengths)
        self.block_size = int((2 * (std * ((2 * np.log(B)) ** (1/2)) + mean)) ** (1/3))
    



# Test model

In [9]:
class CSATransformer(customizedModule):
    def __init__(self, args, data):
        super(CSATransformer, self).__init__()

        self.args = args
        self.word_emb = nn.Embedding(len(data.CONTEXT.vocab.vectors), len(data.CONTEXT.vocab.vectors[0]))
        # initialize word embedding with GloVe
        self.word_emb.weight.data.copy_(data.CONTEXT.vocab.vectors)
        # fine-tune the word embedding
        self.word_emb.weight.requires_grad = True

        self.csa = CSA(args,args.word_dim, args.word_dim)

    def forward(self, batch):
        q = self.word_emb(batch.Question)
        a = self.word_emb(batch.Answer)
        #print(q,q.size)
        a = torch.sum(q,dim = -2)
        x = self.csa(q,a)
        print(x.shape)
        return x

# ARGS

In [10]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', default=2, type=int)
parser.add_argument('--gpu', default=torch.device('cpu' if torch.cuda.is_available() else 'cpu'), type=int)
parser.add_argument('--csa-mode',default='add',type = str)
parser.add_argument('--word-dim',default=300,type = int)
args = parser.parse_args(args=[])
trainpath = 'C:/Users/User/Documents/3.NLP/Dataset/HotpotQA/small/smalltrain100.csv'
devpath = 'C:/Users/User/Documents/3.NLP/Dataset/HotpotQA/small/smalldev100.csv'
mydata = getHotpotData(args,trainpath,devpath)
model = CSATransformer(args,mydata)

print('start')

load hotpot data done
start


In [11]:
iterator = mydata.train_iter
for i, batch in enumerate(iterator):
    #print('i= '+ str(i))
    model.train()
    batch.Question = batch.Question.transpose(0,1)
    batch.Answer = batch.Answer.transpose(0,1)
    batch.Context = batch.Context.transpose(0,1)
    #print(batch.Question)
    #print( batch.Question.shape,batch.Answer.shape,batch.Context.shape)
    if i > 10:
        break
    x = model(batch)



tensor([[[0.0322],
         [0.0224],
         [0.0288],
         [0.0316],
         [0.0340],
         [0.0316],
         [0.0404],
         [0.0400],
         [0.0351],
         [0.0312],
         [0.0359],
         [0.0151],
         [0.0243],
         [0.0340],
         [0.0424],
         [0.0145],
         [0.0212],
         [0.0159],
         [0.0340],
         [0.0316],
         [0.0442],
         [0.0364],
         [0.0340],
         [0.0376],
         [0.0263],
         [0.0340],
         [0.0172],
         [0.0364],
         [0.0404],
         [0.0340],
         [0.0172],
         [0.0461]],

        [[0.0290],
         [0.0425],
         [0.0353],
         [0.0324],
         [0.0384],
         [0.0324],
         [0.0187],
         [0.0303],
         [0.0324],
         [0.0353],
         [0.0324],
         [0.0301],
         [0.0289],
         [0.0690],
         [0.0358],
         [0.0541],
         [0.0280],
         [0.0151],
         [0.0358],
         [0.0305],
         

         [0.0367]]], grad_fn=<SoftmaxBackward>) torch.Size([2, 24, 1])

torch.Size([2, 24, 300])

tensor([[[0.0745],
         [0.0745],
         [0.0411],
         [0.0655],
         [0.1589],
         [0.0802],
         [0.0745],
         [0.0930],
         [0.0745],
         [0.0745],
         [0.0397],
         [0.0679],
         [0.0811]],

        [[0.0747],
         [0.0589],
         [0.0890],
         [0.0675],
         [0.0788],
         [0.0935],
         [0.0788],
         [0.0455],
         [0.0914],
         [0.0788],
         [0.0788],
         [0.0857],
         [0.0788]]], grad_fn=<SoftmaxBackward>) torch.Size([2, 13, 1])

torch.Size([2, 13, 300])

tensor([[[0.0556],
         [0.0662],
         [0.0648],
         [0.0601],
         [0.0338],
         [0.0797],
         [0.0414],
         [0.0680],
         [0.0491],
         [0.0244],
         [0.0586],
         [0.0403],
         [0.0491],
         [0.1090],
         [0.0438],
         [0.0335],
         [0.0586],
    

# Encoder
+ self-attention
+ S2T

In [12]:
#class selfAttention(customizedModule):
#    super(selfAttention,self):
        
#    def forward(x):