In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# 線性轉換層wx+b 並初始化 
+ customizedModule(in_dim,out_dim,activation,dropout)

In [2]:
class customizedModule(nn.Module):
    def __init__(self):
        super(customizedModule, self).__init__()

    # linear transformation (w/ initialization) + activation + dropout
    def customizedLinear(self, in_dim, out_dim, activation=None, dropout=False):
        cl = nn.Sequential(nn.Linear(in_dim, out_dim))
        nn.init.xavier_uniform(cl[0].weight)
        nn.init.constant(cl[0].bias, 0)

        if activation is not None:
            cl.add_module(str(len(cl)), activation)
        if dropout:
            cl.add_module(str(len(cl)), nn.Dropout(p=self.args.dropout))

        return cl

# S2T
+ s2tSA(args,hidden_size)

In [3]:
class S2T(customizedModule):
    def _init(self,args,hidden_size):
        super(self,S2T).__init__()
        self.args = args
        self.s2t_W1 = customizedLinear(args,hidden_size,activation = nn.ReLU())
        self.s2t_Wt = customizedLinear(args,hidden_size)
            
    def forward(self,x):
        """
        source2token self-attention module
        :param x: (batch, (block_num), seq_len, hidden_size)
        :return: s: (batch, (block_num), hidden_size)
        """
        # (batch, (block_num), seq_len, word_dim)
        f = self.s2t_W1(x)
        f = self.s2t_Wt(f)
        
        # 沿著sequence length 做 softmax ， 算出每個token的 Pi
        f = f.softmax(f,dim=-2)
        s = torch.sum(f,dim = -2)
        return s

# Masked block self-attention

## 1. Inter-block
![](https://i.imgur.com/8hup1mf.png)
## 2. Feature fusion gate
![](https://i.imgur.com/Xzmq1WF.png)
## 3. Masked self attention
![](https://i.imgur.com/uFbeQ4y.png)

data = tensor:narrow(dim, index, size)
–表示取出tensor中第dim维上索引从index开始到index+size-1的所有元素存放在data中
举例：

In [4]:
class mBloSA(customizedModule):
    def __init__(self,args,mask):
        super(self,mBloSA).__init__()
        self.args = args
        self.mask = mask
    
    # 初始化 mblosa msa
    def init_msa(self):
        self.m_W1 = customizedLinear(self.args.word_dim, self.args.word_dim)
        self.m_W2 = customizedLinear(self.args.word_dim, self.args.word_dim)
        self.m_b = nn.Parameter(torch.zeros(self.args.word_dim))
        
        # 變成單純的W而已
        self.m_W1[0].bias.requires_grad = False
        self.m_W2[0].bisa.requires_grad = False
        
        # 預設是5啦
        self.c = nn.Parameter(torch.Tensor([self.args.c]),requires_grad = False)
        
    def init_mBloSA(self):
        # inter-block
        self.g_W1 =  customizedLinear(self.args.word_dim, self.args.word_dim)
        self.g_W2 =  customizedLinear(self.args.word_dim, self.args.word_dim)
        self.g_b = nn.Parameter(torch.zeros(self.args.word_dim))
    
        # 變成單純的W而已
        self.g_W1[0].bias.requires_grad = False
        self.g_W2 [0].bisa.requires_grad = False
        
        #Feature fusion gate
        self.f_W1 = customizedLinear(self.args.word_dim*3, self.args.word_dim,)
        self.f_W2 = customizedLinear(self.args.word_dim*3, self.args.word_dim)
    
    def msa(self):
        """
            masked self-attention module
            :param x: (batch, (block_num), seq_len, word_dim)
            :return: s: (batch, (block_num), seq_len, word_dim)
        """
        seq_len = x.size(-2)
        
        # batch,block_num,suqlen,1,word_dim
        xi = self.m_W1(x).unqueeze(-2)
         # batch,block_num,1,suqlen,word_dim
        xj = self.m_W2(x).unqueeze(-3)
           
        # forward masking
        # sequencelen * sequencelen
        # triu(上三角) detach() 中斷更新
        M = Variable(torch.ones((seq_len,seq_len))).to(self.args.gpu).triu().detach()
        M[M==1] = float('-inf')
        
        # 1 ,seq_len ,seq_len ,1
        M = M.contiguous().view(1,M.size(0),M.size(1),1)
        
        # :param x: (batch, (block_num), seq_len, word_dim)
        #  pad      (batch, 1, seq_len, word_dim)
        # 這是用來捕block的
        pad = torch.zeros(x.size(0),1,x.size(-2),x.size(-1))
        pad = Variable(pad).to(self.args.gpu).detach()
        
        #
        if len(x.size() == 4):
            M = M.unsqueeze(1)
            pad = torch.stack([pad]*x.size(1),dim=1)
        
        # (batch,block_num,seq_len,seq_len,word_dim) 
        f = self.c * F.tanh((x_i+x_j+self.m_b)/self.c)
        
        if f.size(-2) > 1:
            if self.mask == 'fw':
                M = M.transpose(-2, -3)
                f = F.softmax((f + M).narrow(-3, 0, f.size(-3) - 1), dim=-2)
                f = torch.cat([f, pad], dim=-3)
            elif self.mask == 'bw':
                f = F.softmax((f + M).narrow(-3, 1, f.size(-3) - 1), dim=-2)
                f = torch.cat([pad, f], dim=-3)
            else:
                raise NotImplementedError('only fw or bw mask is allowed!')
        else:
            f = pad
        
        # (batch,block_num,seq_len,word_dim)
        s = torch.sum(f*x.unsqueeze(-2),dim=-2)
        
        return s
    
    def forward(self,x):
        r = self.args.r
        n = x.size(1)
        m = n // r
        # padding for the same length of each block
        pad_len = (r - n % r) % r
        if pad_len:
            pad = Variable(torch.zeros(x.size(0), pad_len, x.size(2))).to(self.args.gpu).detach()
            x = torch.cat([x, pad], dim=1)

        # --- Intra-block self-attention ---
        # (batch, block_num(m), seq_len(r), word_dim)
        x = torch.stack([x.narrow(1, i, r) for i in range(0, x.size(1), r)], dim=1)
        # (batch, block_num(m), seq_len(r), word_dim)
        h = self.mSA(x)
        # (batch, block_num(m), word_dim)
        v = self.s2tSA(h)

        # --- Inter-block self-attention ---
        # (batch, m, word_dim)
        o = self.mSA(v)
        # (batch, m, word_dim)
        G = F.sigmoid(self.g_W1(o) + self.g_W2(v) + self.g_b)
        # (batch, m, word_dim)
        e = G * o + (1 - G) * v

        # --- Context fusion ---
        # (batch, n, word_dim)
        E = torch.cat([torch.stack([e.select(1, i)] * r, dim=1) for i in range(e.size(1))], dim=1).narrow(1, 0, n)
        x = x.view(x.size(0), -1, x.size(-1)).narrow(1, 0, n)
        h = h.view(h.size(0), -1, h.size(-1)).narrow(1, 0, n)

        # (batch, n, word_dim * 3) -> (batch, n, word_dim)
        fusion = self.f_W1(torch.cat([x, h, E], dim=2))
        G = F.sigmoid(self.f_W2(torch.cat([x, h, E], dim=2)))
        # (batch, n, word_dim)
        u = G * fusion + (1 - G) * x

        return u

In [5]:
class BiBloSAN(customizedModule):
    def __init__(self, args):
        super(BiBloSAN, self).__init__()

        self.args = args

        self.mBloSA_fw = mBloSA(self.args, 'fw')
        self.mBloSA_bw = mBloSA(self.args, 'bw')

        # two untied fully connected layers
        self.fc_fw = self.customizedLinear(self.args.word_dim, self.args.word_dim, activation=nn.ReLU())
        self.fc_bw = self.customizedLinear(self.args.word_dim, self.args.word_dim, activation=nn.ReLU())

        self.s2tSA = s2tSA(self.args, self.args.word_dim * 2)

    def forward(self, x):
        input_fw = self.fc_fw(x)
        input_bw = self.fc_bw(x)

        # (batch, block ,seq_len, word_dim)
        u_fw = self.mBloSA_fw(input_fw)
        u_bw = self.mBloSA_bw(input_bw)
        
        u_bi = torch.cat([u_fw, u_bw], dim=2)
        # (batch, seq_len, word_dim * 2) -> (batch, word_dim * 2)
        u_bi = self.s2tSA(torch.cat([u_fw, u_bw], dim=2))
        return u_bi