# Keras Attention

Attention层：
输入是Q,K,V  
`K.batch_dot`方法维度需要注意  
Attention中各操作维度已标注

In [1]:
#! -*- coding: utf-8 -*-

from keras import backend as K
from keras.engine.topology import Layer

class Position_Embedding(Layer):
    
    def __init__(self, size=None, mode='sum', **kwargs):
        self.size = size #必须为偶数
        self.mode = mode
        super(Position_Embedding, self).__init__(**kwargs)
        
    def call(self, x):
        if (self.size == None) or (self.mode == 'sum'):
            self.size = int(x.shape[-1])
        batch_size,seq_len = K.shape(x)[0],K.shape(x)[1]
        position_j = 1. / K.pow(10000., \
                                 2 * K.arange(self.size / 2, dtype='float32' \
                               ) / self.size)
        position_j = K.expand_dims(position_j, 0)
        position_i = K.cumsum(K.ones_like(x[:,:,0]), 1)-1 #K.arange不支持变长，只好用这种方法生成
        position_i = K.expand_dims(position_i, 2)
        position_ij = K.dot(position_i, position_j)
        position_ij = K.concatenate([K.cos(position_ij), K.sin(position_ij)], 2)
        if self.mode == 'sum':
            return position_ij + x
        elif self.mode == 'concat':
            return K.concatenate([position_ij, x], 2)
        
    def compute_output_shape(self, input_shape):
        if self.mode == 'sum':
            return input_shape
        elif self.mode == 'concat':
            return (input_shape[0], input_shape[1], input_shape[2]+self.size)


class Attention(Layer):

    def __init__(self, nb_head, size_per_head, mask_right=False, **kwargs):
        self.nb_head = nb_head
        self.size_per_head = size_per_head
        self.output_dim = nb_head*size_per_head
        self.mask_right = mask_right
        super(Attention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.WQ = self.add_weight(name='WQ', 
                                  shape=(input_shape[0][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.WK = self.add_weight(name='WK', 
                                  shape=(input_shape[1][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.WV = self.add_weight(name='WV', 
                                  shape=(input_shape[2][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)
        super(Attention, self).build(input_shape)
        
    def Mask(self, inputs, seq_len, mode='mul'):
        """
        inputs [B T E]
        """
        if seq_len == None:
            return inputs
        else:
            mask = K.one_hot(seq_len[:,0], K.shape(inputs)[1])
            mask = 1 - K.cumsum(mask, 1)
            for _ in range(len(inputs.shape)-2):
                mask = K.expand_dims(mask, 2)
            if mode == 'mul':
                return inputs * mask
            if mode == 'add':
                return inputs - (1 - mask) * 1e12
                
    def call(self, x):
        #如果只传入Q_seq,K_seq,V_seq，那么就不做Mask
        #如果同时传入Q_seq,K_seq,V_seq,Q_len,V_len，那么对多余部分做Mask
        if len(x) == 3:
            Q_seq,K_seq,V_seq = x
            Q_len,V_len = None,None
        elif len(x) == 5:
            Q_seq,K_seq,V_seq,Q_len,V_len = x
        #对Q、K、V做线性变换
        Q_seq = K.dot(Q_seq, self.WQ) # [B,T,E] * [E, E] = [B T E]
        Q_seq = K.reshape(Q_seq, (-1, K.shape(Q_seq)[1], self.nb_head, self.size_per_head)) # [B T 8 16]
        Q_seq = K.permute_dimensions(Q_seq, (0,2,1,3)) # [B 8 T 16]
        K_seq = K.dot(K_seq, self.WK)
        K_seq = K.reshape(K_seq, (-1, K.shape(K_seq)[1], self.nb_head, self.size_per_head))
        K_seq = K.permute_dimensions(K_seq, (0,2,1,3))
        V_seq = K.dot(V_seq, self.WV)
        V_seq = K.reshape(V_seq, (-1, K.shape(V_seq)[1], self.nb_head, self.size_per_head))
        V_seq = K.permute_dimensions(V_seq, (0,2,1,3))
        #计算内积，然后mask，然后softmax
        A = K.batch_dot(Q_seq, K_seq, axes=[3,3]) / self.size_per_head**0.5  # [B 8 T 16] batch_dot = [B 8 T T]
        A = K.permute_dimensions(A, (0,3,2,1)) # [B T T 8]
        A = self.Mask(A, V_len, 'add')
        A = K.permute_dimensions(A, (0,3,2,1)) # [B 8 T T]
        if self.mask_right:
            ones = K.ones_like(A[:1, :1])
            mask = (ones - K.tf.matrix_band_part(ones, -1, 0)) * 1e12
            A = A - mask
        A = K.softmax(A)
        #输出并mask
        O_seq = K.batch_dot(A, V_seq, axes=[3,2]) # [B 8 T T] * [B 8 T 16] = [B 8 T 16]
        O_seq = K.permute_dimensions(O_seq, (0,2,1,3)) # [B T 8 16]
        O_seq = K.reshape(O_seq, (-1, K.shape(O_seq)[1], self.output_dim)) # [B T E]
        O_seq = self.Mask(O_seq, Q_len, 'mul')
        return O_seq
        
    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[0][1], self.output_dim)

Using TensorFlow backend.


In [None]:
from __future__ import print_function
from keras.preprocessing import sequence
from keras.datasets import imdb

max_features = 20000
maxlen = 80
batch_size = 32

print('Loading data...')
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print(len(x_train), 'train sequences')
print(len(x_test), 'test sequences')

print('Pad sequences (samples x time)')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)

from keras.models import Model
from keras.layers import *

# 添加自己实现的Attention(Q,V)
使用权重$w$, $Q\cdot W = A\in \mathbb{R}^{T\times E}$，然后使用$A*V=O$  

这时候是没有将维度缩进，如果需要实现multi-head，可以加一层映射，将multi-head的结果拼接后降维

In [28]:
class My_Attention(Layer):

    def __init__(self, nb_head, size_per_head, mask_right=False, **kwargs):
        self.nb_head = nb_head
        self.size_per_head = size_per_head
        self.output_dim = nb_head*size_per_head
        self.mask_right = mask_right
        super(My_Attention, self).__init__(**kwargs)
        
    def build(self, input_shape):
        self.WW = self.add_weight(name='WW', 
                                  shape=(128, 80),# W = [E, T]
                                  initializer='glorot_uniform',
                                  trainable=True)
        super(My_Attention, self).build(input_shape)
        
    def Mask(self, inputs, seq_len, mode='mul'):
        """
        inputs [B T E]
        """
        if seq_len == None:
            return inputs
        else:
            mask = K.one_hot(seq_len[:,0], K.shape(inputs)[1])
            mask = 1 - K.cumsum(mask, 1)
            for _ in range(len(inputs.shape)-2):
                mask = K.expand_dims(mask, 2)
            if mode == 'mul':
                return inputs * mask
            if mode == 'add':
                return inputs - (1 - mask) * 1e12
                
    def call(self, x):
        #如果只传入Q_seq,K_seq,V_seq，那么就不做Mask
        #如果同时传入Q_seq,K_seq,V_seq,Q_len,V_len，那么对多余部分做Mask
        if len(x) == 3:
            Q_seq,K_seq,V_seq = x
            Q_len,V_len = None,None
        elif len(x) == 5:
            Q_seq,K_seq,V_seq,Q_len,V_len = x
        
        A = K.dot(self.WW, Q_seq) # [E T] [B T E] = [E,B,E]
        A = K.permute_dimensions(A, (1,0,2))
        A = K.softmax(A)
        O_seq = K.batch_dot(V_seq, A, axes=[2,2]) # []
        return O_seq
        
    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[0][1], self.output_dim)

In [None]:
B=10
T=5
E=12
inp = K.random_normal([B,T,E])
w = K.random_normal([E,T])

A = K.dot(w, inp) # [E,B,E]
A = K.permute_dimensions(A, (1,0,2))
A = K.softmax(A)
o = K.batch_dot(inp, A, axes=[2,2])
np.shape(o)

In [16]:
%%time
S_inputs = Input(shape=(None,), dtype='int32') 
embeddings = Embedding(max_features, 128)(S_inputs) # embeddings [batch_size, max_lens, embedding_dim(128)]
# embeddings = Position_Embedding()(embeddings) # 增加Position_Embedding能轻微提高准确率
O_seq = Attention(8,16)([embeddings,embeddings,embeddings])
O_seq = GlobalAveragePooling1D()(O_seq)
O_seq = Dropout(0.5)(O_seq)
outputs = Dense(1, activation='sigmoid')(O_seq)

model = Model(inputs=S_inputs, outputs=outputs)
# try using different optimizers and different optimizer configs
model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

print('Train...')
model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=1,
          validation_data=(x_test, y_test))

Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
CPU times: user 4min 51s, sys: 12 s, total: 5min 3s
Wall time: 55.6 s


In [35]:
%%time
S_inputs = Input(shape=(None,), dtype='int32')
embeddings = Embedding(max_features, 128)(S_inputs) # embeddings [batch_size, max_lens, embedding_dim(128)]
# embeddings = Position_Embedding()(embeddings) # 增加Position_Embedding能轻微提高准确率
O_seq = My_Attention(8,16)([embeddings,embeddings,embeddings])
O_seq = GlobalAveragePooling1D()(O_seq)
O_seq = Dropout(0.5)(O_seq)
outputs = Dense(1, activation='sigmoid')(O_seq)

all_his = {}
all_his['val_loss']=[]
all_his['val_acc']=[]
all_his['loss']=[]
all_his['acc']=[]
for i in range(10):
    model = Model(inputs=S_inputs, outputs=outputs)
    # try using different optimizers and different optimizer configs
    model.compile(loss='binary_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])

    print('Train...')
    his = model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=1,
              validation_data=(x_test, y_test))
    
    all_his['val_loss'].append(his.history['val_loss'])
    all_his['val_acc'].append(his.history['val_acc'])
    all_his['loss'].append(his.history['loss'])
    all_his['acc'].append(his.history['acc'])
    print('----------')
for k,v in all_his.items():
    print(k,':',np.mean(v))

Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
val_loss : 0.44497904971361163
val_acc : 0.8315439999999998
loss : 0.20958064949715136
acc : 0.915724
CPU times: user 31min 8s, sys: 1min 37s, total: 32min 45s
Wall time: 8min 13s


In [36]:
%%time
S_inputs = Input(shape=(None,), dtype='int32')
embeddings = Embedding(max_features, 128)(S_inputs) # embeddings [batch_size, max_lens, embedding_dim(128)]
# embeddings = Position_Embedding()(embeddings) # 增加Position_Embedding能轻微提高准确率
O_seq = Attention(8,16)([embeddings,embeddings,embeddings])
O_seq = GlobalAveragePooling1D()(O_seq)
O_seq = Dropout(0.5)(O_seq)
outputs = Dense(1, activation='sigmoid')(O_seq)

all_his = {}
all_his['val_loss']=[]
all_his['val_acc']=[]
all_his['loss']=[]
all_his['acc']=[]
for i in range(10):
    model = Model(inputs=S_inputs, outputs=outputs)
    # try using different optimizers and different optimizer configs
    model.compile(loss='binary_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])

    print('Train...')
    his = model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=1,
              validation_data=(x_test, y_test))
    
    all_his['val_loss'].append(his.history['val_loss'])
    all_his['val_acc'].append(his.history['val_acc'])
    all_his['loss'].append(his.history['loss'])
    all_his['acc'].append(his.history['acc'])
    print('----------')
for k,v in all_his.items():
    print(k,':',np.mean(v))

Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/1
----------
val_loss : 0.831772185710907
val_acc : 0.799308
loss : 0.13848755510523247
acc : 0.94378
CPU times: user 45min 25s, sys: 2min 18s, total: 47min 44s
Wall time: 11min 27s


In [38]:
((11*60+27) - (8*60+13)) / (11*60+27)

0.2823871906841339