In [79]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_normal_
from constants import D_MODEL, STACKED_NUM,DK, DV, H, P_DROP, D_FF, MAX_SEQUENCE_LENGTH, MAX_NUM_WORDS, EMBEDDING_DIM
# environment
with_gpu = torch.cuda.is_available()
# with_gpu = False
device = torch.device("cuda:0" if with_gpu else "cpu")

def positional_encoding(pos):
    assert D_MODEL % 2 == 0
    pos = torch.tensor(pos, dtype=torch.float32, requires_grad=False)
    pe = torch.zeros([1,D_MODEL], dtype=torch.float32, requires_grad=False)
    for i in range(D_MODEL//2):
        a = torch.tensor(10000, dtype=torch.float32, requires_grad=False)
        b = torch.tensor(2.*i/float(D_MODEL), dtype=torch.float32, requires_grad=False)
        c = pos / torch.pow(a, b)
        pe[0, 2*i] = torch.sin(c)
        pe[0, 2*i+1] = torch.cos(c)
    return pe
def get_pos_mat(length):
    if length > MAX_SEQUENCE_LENGTH:
        print('sequence length reach PE_MAT_CACHE. %d ' % length)
        ret = torch.cat([positional_encoding(i) for i in range(length)], dim=0).to(device)
        ret.requires_grad = False
        global PE_CACHE_MATRIX
        PE_CACHE_MATRIX = ret
        return ret
    else:
        return PE_CACHE_MATRIX[:length]
    
PE_CACHE_MATRIX = torch.cat([positional_encoding(i) for i in range(0,MAX_SEQUENCE_LENGTH)], dim=0).to(device)
PE_CACHE_MATRIX.requires_grad = False

# construct neuron network

def scaled_dot_attention(Q, K, V, mask=None):
    assert Q.size()[-1] == K.size()[-1]
    assert len(Q.size()) == 3 and len(K.size()) == 3 and len(V.size()) == 3
    dk = torch.tensor(K.size()[-1], dtype=torch.float32, requires_grad=False).to(device)
    out = torch.matmul(Q,K.permute(0,2,1)) / torch.sqrt(dk) 
    if mask is not None:
        out = out.masked_fill_(mask, -float('inf'))
        
    return torch.matmul(F.softmax(out, dim=-1), V)
                            
class Transformer(nn.Module):

    def __init__(self, layer_num, dk, dv, dm, h):
        super(Transformer, self).__init__()
        
        self.emb = Word_Embedding()
        
        self.emb_drop = nn.Dropout(P_DROP)
        
        self.encoder = Stack_Encoder(layer_num, dk, dv, dm, h)
        self.decoder = Stack_Decoder(layer_num, dk, dv, dm, h)
        
        self.summary_linear = nn.Linear(dm, 1)
        self.summary_weight = nn.Parameter(torch.FloatTensor(1, dm))
        torch.nn.init.xavier_uniform_(self.summary_weight)
        
        self.output_linear = nn.Linear(3*dm, 1)

    def forward(self, Q, K, Q_fea, K_fea):
        batch, K_len = K.size()
#         encoder
        K = self.emb(K)
#         print(K.size(), get_pos_mat(MAX_SEQUENCE_LENGTH).size())

        K = K + get_pos_mat(K_len)
        K = self.emb_drop(K)
        
        en_out = self.encoder(K)
        
#         decoder
        batch, Q_len = Q.size()
        Q = self.emb(Q)
        
        Q = Q + get_pos_mat(Q_len)
        Q = self.emb_drop(Q)
        
        de_out = self.decoder(Q, en_out)
        
        
        summary_weight = F.softmax(self.summary_linear(de_out).squeeze(-1), dim=1).unsqueeze(1)
        
        summary = torch.matmul(summary_weight, de_out).squeeze(1)
        print summary.shape,Q_fea.shape,'lii2'
        
        x = torch.cat([summary, Q_fea, K_fea], dim=-1)
        out = self.output_linear(x)
        out = torch.sigmoid(out)
        

        return out

class Word_Embedding(nn.Module):
    def __init__(self):
        super(Word_Embedding, self).__init__()
        self.emb = nn.Embedding(MAX_NUM_WORDS, EMBEDDING_DIM, padding_idx=0)
        self.emb.weight.requires_grad_(False)
        
        self.linear = nn.Linear(EMBEDDING_DIM, D_MODEL, bias=False)
        


    def forward(self, x):
        x = self.emb(x)
        x = self.linear(x)
        return x
    
class Stack_Encoder(nn.Module):
    """
    Stacked Encoder
    """
    def __init__(self, layer_num, dk, dv, dm, h):
        super(Stack_Encoder, self).__init__()
        self.encoders = nn.ModuleList([Encoder(dk, dv, dm, h) for i in range(layer_num)])

    def forward(self, K):
        # ModuleList can act as an iterable, or be indexed using ints
        for lay in self.encoders:
            K = lay(K)
        return K               
class Encoder(nn.Module):
    def __init__(self, dk, dv, dm, h):
        super(Encoder, self).__init__()
#         attention residual block
        self.multi_head_attention_layer = Multi_Head_attention_layer(dk, dv, dm, h)
        self.attention_norm_lay = nn.LayerNorm([dm,])
        self.att_drop = nn.Dropout(P_DROP)
#         feed forward residual block
        self.fcn = PositionwiseFeedForward(D_MODEL, D_FF)
        self.linear_drop = nn.Dropout(P_DROP)
        self.ff_norm_lay = nn.LayerNorm([dm, ])
        

    def forward(self, K):
#         attention
        attention_out = self.multi_head_attention_layer(K, K, K)
        attention_out = self.att_drop(attention_out)
        att_out = self.attention_norm_lay(K + attention_out)
#         feed forward
        linear_out = self.fcn(att_out)
        linear_out = self.linear_drop(linear_out)
        out = self.ff_norm_lay(att_out + linear_out)
        out = att_out + linear_out
    
        return out
class Stack_Decoder(nn.Module):
    """
    Stacked Encoder
    """
    def __init__(self, layer_num, dk, dv, dm, h):
        super(Stack_Decoder, self).__init__()
        self.decoders = nn.ModuleList([Decoder(dk, dv, dm, h) for i in range(layer_num)])
        
        
    def forward(self, Q, encoder_out):
        # ModuleList can act as an iterable, or be indexed using ints
        bat, Q_len, d = Q.size()
        for lay in self.decoders:
            Q = lay(Q, encoder_out, mask=None)
        return Q           

class Decoder(nn.Module):
    def __init__(self, dk, dv, dm, h):
        super(Decoder, self).__init__()
#         query attention residual block
        self.Q_attention_lay = Multi_Head_attention_layer(dk, dv, dm, h)
        self.Q_attention_norm_lay = nn.LayerNorm([dm, ])
        self.Q_att_drop = nn.Dropout(P_DROP)
    
#         query key attention residual block
        self.QK_attention_lay = Multi_Head_attention_layer(dk, dv, dm, h)
        self.QK_attention_norm_lay = nn.LayerNorm([dm, ])
        self.QK_att_drop = nn.Dropout(P_DROP)
        
    
#         feed forward residual block
        self.fcn = PositionwiseFeedForward(D_MODEL, D_FF)
        self.ff_norm_lay = nn.LayerNorm([dm, ])
        self.linear_drop = nn.Dropout(P_DROP)
        

    def forward(self, Q, encoder_out, mask):
#         query attention
        Q_attention_out = self.Q_attention_lay(Q, Q, Q, mask)
        Q_attention_out = self.Q_att_drop(Q_attention_out)
        Q_att_out = self.Q_attention_norm_lay(Q + Q_attention_out)
#         query key attention
        QK_attention_out = self.QK_attention_lay(Q_att_out, encoder_out, encoder_out)
        QK_attention_out = self.QK_att_drop(QK_attention_out)
        QK_att_out = self.QK_attention_norm_lay(Q_att_out + QK_attention_out)
        
#         feed forward
        linear_out = self.fcn(QK_att_out)
        out = self.ff_norm_lay(QK_att_out + linear_out)
        return out

class Multi_Head_attention_layer(nn.Module):
    def __init__(self, dk, dv, dm, h):
        super(Multi_Head_attention_layer, self).__init__()
        self.Q_linears = nn.ModuleList([nn.Linear(dm, dk) for i in range(h)])
        self.K_linears = nn.ModuleList([nn.Linear(dm, dk) for i in range(h)])
        self.V_linears = nn.ModuleList([nn.Linear(dm, dv) for i in range(h)])
        self.output_linear = nn.Linear(h*dv, dm)
                            

    def forward(self, Q_input, K_input, V_input, mask=None):
        buf = []
        for Q_linear, K_linear, V_linear in zip(self.Q_linears, self.K_linears, self.V_linears):
            Q = Q_linear(Q_input)
            K = K_linear(K_input)
            V = V_linear(V_input)
            buf.append(scaled_dot_attention(Q, K, V, mask))
            
        buf = torch.cat(buf,dim=-1)
        out = self.output_linear(buf)
        
        return out      
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionwiseFeedForward, self).__init__()
        self.cnn1 = nn.Conv1d(d_model, d_ff, 1)
        self.cnn2 = nn.Conv1d(d_ff, d_model, 1)
                            

    def forward(self, x):
        bat,seq_len,_ = x.size()
        x = x.permute(0,2,1)
        x = self.cnn1(x)
        x = F.relu(x)
        x = self.cnn2(x)
        x = x.permute(0,2,1)
        
        return x      
    
# encoder = Stack_Encoder(6, 64,64,20,8)
# # print net
bat = 7
Q = torch.randint(10000,[bat, 13,], dtype=torch.long).to(device)
V = torch.randint(10000,[bat, 19,], dtype=torch.long).to(device)
print Q.shape
Q_fea = torch.rand([bat, D_MODEL,]).to(device)
K_fea = torch.rand([bat, D_MODEL,]).to(device)
net = Transformer(STACKED_NUM, DK, DV, D_MODEL, H).to(device)
print Q.shape, Q_fea.shape
print(Q.dtype)
o = net(Q, V, Q_fea, K_fea)
# print t
print(o.size())
# print o
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(net))

torch.Size([7, 13])
torch.Size([7, 13]) torch.Size([7, 128])
torch.int64
sequence length reach PE_MAT_CACHE. 19 
torch.Size([7, 128]) torch.Size([7, 128]) lii2
torch.Size([7, 1])
4263554


In [65]:
def scaled_dot_attention(Q, K, V, mask=None):
    assert Q.size()[-1] == K.size()[-1]
    dk = torch.tensor(K.size()[-1], dtype=torch.float32, requires_grad=False)
    out = torch.matmul(Q,K.t()) / torch.sqrt(dk) 
    if mask is not None:
        out = out.masked_fill_(mask, -float('inf'))
        
    return torch.matmul(F.softmax(out, dim=-1), V)
def f2(Q, K, V, mask=None):
    assert Q.size()[-1] == K.size()[-1]
    assert len(Q.size()) == 3
    dk = torch.tensor(K.size()[-1], dtype=torch.float32, requires_grad=False)
    print Q.shape, K.shape,'qk'
    out = torch.matmul(Q,K.permute(0,2,1)) / torch.sqrt(dk) 
    print torch.matmul(Q,K.permute(0,2,1)).shape,'out'
    if mask is not None:
        out = out.masked_fill_(mask, -float('inf'))
        
    return torch.matmul(F.softmax(out, dim=-1), V)            

bat = 7
Q = torch.rand([bat, 3, 64,])
K = torch.rand([bat, 5, 64,])
R = f2(Q,K,K)
print R.shape
for i in range(bat):
    q = Q[i,:,:]
    k = K[i,:,:]
    r2 = scaled_dot_attention(q,k,k)
    r = R[i,:,:]
#     print r - r2


# print r.shape
# print r2.shape
# print r == r2.squeeze(0)

torch.Size([7, 3, 64]) torch.Size([7, 5, 64]) qk
torch.Size([7, 3, 5]) out
torch.Size([7, 3, 64])


In [75]:
a = torch.arange(24).view(2,4,3).type(torch.float)
b = torch.arange(56).view(2,7,4).type(torch.float)
print a,a.shape
print b,b.shape
c = torch.matmul(b,a)
print c.shape
print c
# print b,b.shape



tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.],
         [ 6.,  7.,  8.],
         [ 9., 10., 11.]],

        [[12., 13., 14.],
         [15., 16., 17.],
         [18., 19., 20.],
         [21., 22., 23.]]]) torch.Size([2, 4, 3])
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.]],

        [[28., 29., 30., 31.],
         [32., 33., 34., 35.],
         [36., 37., 38., 39.],
         [40., 41., 42., 43.],
         [44., 45., 46., 47.],
         [48., 49., 50., 51.],
         [52., 53., 54., 55.]]]) torch.Size([2, 7, 4])
torch.Size([2, 7, 3])
tensor([[[  42.,   48.,   54.],
         [ 114.,  136.,  158.],
         [ 186.,  224.,  262.],
         [ 258.,  312.,  366.],
         [ 330.,  400.,  470.],
         [ 402.,  488.,  574.],
         [ 474.,  576.,  678.]],

        [[1962., 2080., 2198.],
         [2226., 2360., 24