# Biblosa

source: https://github.com/zhaoguangxiang/BiBloSA-pytorch/blob/master/model/model.py

# function reference
+ https://kknews.cc/zh-tw/code/yegrgnb.html

# First, customize your Linear function

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
class customizedModule(nn.Module):
    def __init__(self):
        super(customizedModule,self).__init__()
        
    def customizedLinear(self,in_dim,out_dim,activation=None,dropout = False):
        c1 = nn.Sequential(nn.Linear(in_dim,out_dim))
        
        # initialize the weight & bias
        nn.init.xavier_normal_(c1[0].weight)
        nn.init.constant(c1[0].bias,0)
        
        if activation is not None:
            c1.add_module(str(len(c1)),activation)
        if dropout:
            c1.add_module(str(len(c1)),nn.Dropout(p=self.args.dropout))
        return c1     

# Build source to token

In [3]:
class s2tSA(customizedModule):
    def __init__(self, args, hidden_size):
        super(s2tSA, self).__init__()

        self.args = args
        self.s2t_W1 = self.customizedLinear(hidden_size, hidden_size, activation = nn.ReLU())
        self.s2t_W = self.customizedLinear(hidden_size, hidden_size)

    def forward(self, x):
        """
        source2token self-attention module
        :param x: (batch, (block_num), seq_len, hidden_size)
        :return: s: (batch, (block_num), hidden_size)
        """

        # (batch, (block_num), seq_len, word_dim)
        f = self.s2t_W1(x)
        f = self.s2t_W(f)
        f = F.softmax(f, dim=-2)
        # (batch, (block_num), word_dim)
        s = torch.sum(f * x, dim=-2)
        return s

In [4]:
args = '123'
data = torch.randn(5,10)
print(data)
model = s2tSA(args,10)
B = model(data)
print(B)

tensor([[-2.4584,  2.0789,  0.5668, -0.6074,  1.6688,  0.0831, -0.0713,  0.1733,
          0.2433, -0.3455],
        [ 1.1394, -1.6121,  1.3858,  0.0028, -1.2394, -0.0347,  0.9353,  0.1221,
         -0.1670,  2.3652],
        [ 0.3302, -1.1033, -0.1094, -1.3479,  0.7760, -0.5299, -0.7292,  0.2932,
         -0.0578,  0.2370],
        [-0.2266, -0.1741, -1.0779, -2.5518, -1.4624, -1.3721, -1.2072,  0.1465,
         -0.1631,  0.9194],
        [ 0.4715,  1.6250, -0.5335,  1.1893, -0.1670,  0.2464,  0.8059, -2.1331,
          1.1223,  1.4092]])


  # Remove the CWD from sys.path while we load stuff.


tensor([ 0.2086,  0.2824, -0.5255, -1.1239,  0.0892, -0.4929, -0.0658, -0.1166,
         0.0715,  0.6283], grad_fn=<SumBackward1>)


# Build mask block self attention

In [5]:
class mBloSA(customizedModule):
    def __init__(self,args,mask):
        super(mBloSA,self).__init__()
        self.args = args
        self.mask = mask
        
        self.s2tSA = s2tSA(self.args,self.args.word_dim)
        self.init_mSA()
        self.init_mBloSA()
    
    def init_mSA(self):
        self.m_W1 = self.customizeLinear(self.args.word_dim,self.args.word_dim)
        self.m_W2 = self.customizeLinear(self.args.word_dim,self.args.word_dim)
        self.m_b = nn.Parameter(torch.zeros(self.args.word_dim))
        
        self.m_W1[0].bias.requires_grad = False
        self.m_W2[0].bias.requires_grad = False
        
        self.c = nn.Parameter(torch.Tensor([self.args.c],requires_grad=False))
        
    def init_mBloSA(self):
        self.g_w1 = self.customizeLinear(self.args.word_dim,self.args.word_dim)
        self.g_w2 = self.customizeLinear(self.args.word_dim,self.args.word_dim)
        self.g_b = nn.Parameter(torch.zeros(self.args.word_dim))
        
        self.g_w1.bias.requires_grad = False
        self.g_w2.bias.requires_grad = False
        
        self.f_w1 = self.customizeLinear(self.args.word_dim*3,self.args.word_dim,activation=nn.ReLU())
        self.f_w2 = self.customizeLinear(self.args.word_dim*3,self.args.word_dim)
        
    def mSA(self,x):
        
        """
        masked self-attention module
        :param x: (batch, (block_num), seq_len, word_dim)
        :return: s: (batch, (block_num), seq_len, word_dim)
        """

        # 算有幾個token近來 : n
        seq_len = x.size(-2)
        
        # (batch, (block_num), seq_len, 1, word_dim)
        x_i = self.m_W1(x).unsqeeze(-2)
        # (batch, (block_num), 1, seq_len, word_dim)
        x_j = self.m_W2(x).unsqeeze(-3)
        
        # triu()是上三角 ， detach()中斷反向傳播
        M = Variable(torch.ones((seq_len, seq_len))).cuda(self.args.gpu).triu().detach()
        M[M==1] = float('-inf')
        
        # CASE 1 - x: (batch, seq_len, word_dim)
        # (1, seq_len, seq_len, 1)
        M = M.contiguous().view(1,M.size(0),M.size(1),1)
        
        # (batch, 1, seq_len, word_dim)
        # padding to deal with nan
        pad = torch.zeros(x.size(0),1,x.size(-2),x.size(-1))
        pad = Variable(pad).cuda(self.args.gpu).detach()
        
         # CASE 2 - x: (batch, block_num, seq_len, word_dim)
        if len(x.size()) == 4:
            M = M.unsqueeze(1)
            pad = torch.stack([pad] * x.size(1), dim=1)
        
        f = self.c * F.tanh((x_i + x_j + self.m_b) /self.c)
        
        # fw or bw masking
        if f.size(-2) > 1:
            if self.mask == 'fw':
                M = M.transpose(-2, -3)
                f = F.softmax((f + M).narrow(-3, 0, f.size(-3) - 1), dim=-2)
                f = torch.cat([f, pad], dim=-3)
            elif self.mask == 'bw':
                f = F.softmax((f + M).narrow(-3, 1, f.size(-3) - 1), dim=-2)
                f = torch.cat([pad, f], dim=-3)
            else:
                raise NotImplementedError('only fw or bw mask is allowed!')
        else:
            f = pad
            
        # (batch, (block_num), seq_len, word_dim)
        s = torch.sum(f * x.unsqueeze(-2), dim=-2)
        return s  
    
    def forward(self,x):
        """
        masked block self-attention module
        :param x: (batch, seq_len, word_dim)
        :param M: (seq_len, seq_len)
        :return: (batch, seq_len, word_dim)
        """
        
        r = self.args.r
        n = x.size(1)
        m = n // r
        
        pad_len = (r - n % r) % r
        if pad_len:
            pad = Variable(torch.zeros(x.size(0),pad_len,x.size(2))).cuda(self.args.gpu).detach()
            # pagging at sequence length
            x = torch.cat([x,pad] , dim=2)
         
        # --- Intra-block self-attention ---
        # (batch, block_num(m), seq_len(r), word_dim)
        x = torch.stack([x.narrow(1,i,r) for i in range(0,x.size(1),r)],dim=1)
        h = self.mSA(x)
        v = self.s2tSA(h)
        
        o = self.mSA(v)
        G = F.sigmoid(self.g_w1(0) + self.g_w2(v) + self.g_b)
        e = G * O + (1-G)*V
        
        E = torch.cat([torch.stack([e.select(1, i)] * r, dim=1) for i in range(e.size(1))], dim=1).narrow(1, 0, n)
        
        # -1 代表自動推算
        x = x.view(x.size(0),-1,x.size(-1)).narrow(1,0,n)
        h = h.view(h.size(0), -1, h.size(-1)).narrow(1, 0, n)
        
        # fusion layer
        fusion = self.f_W1(torch.cat([x, h, E], dim=2))
        G = F.sigmoid(self.f_W2(torch.cat([x, h, E], dim=2)))
        # (batch, n, word_dim)
        u = G * fusion + (1 - G) * x

        return u
        
        
        
        
        

# Bulid Biblosan

In [6]:
class BiBloSAN(customizedModule):
    def __init__(self,args):
        super(BiBloSAN,self).__init__()
        self.args = args
        
        self.mBloSA_fw = mBloSA(args,'fw')
        self.mBloSA_bw = mBloSA(args,'bw')

        # two untied fully connected layers
        self.fc_fw = self.customizedLinear(self.args.word_dim, self.args.word_dim, activation=nn.ReLU())
        self.fc_bw = self.customizedLinear(self.args.word_dim, self.args.word_dim, activation=nn.ReLU())

        self.s2tSA = s2tSA(self.args, self.args.word_dim * 2)
   
    def forward(self,x):
        
        ufw = self.fc_fw(x)      
        ubw = self.fc_bw(x)
        
        ufw = self.mBloSA_fw(ufw) 
        ubw = self.fc_bw(ubw)
        
        # cat at word dimention
        ubi = torch.cat([ufw,ubw],dim = -1)
        s = self.s2tSA(ubi)
        
        return s
        
        
        
        