使用Bilstm+Attention进行文本分类，依然使用Flyai的框架

定义参数

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--EPOCHS", default=10, type=int, help="train epochs")
parser.add_argument("-b", "--BATCH", default=64, type=int, help="batch size")
args = parser.parse_args()

'''
flyai库中的提供的数据处理方法
传入整个数据训练多少轮，每批次批大小
'''
dataset = Dataset(epochs=args.EPOCHS, batch=args.BATCH)
model = Model(dataset)

# 超参
vocab_size = 20655      # 总词汇量
embedding_dim = 64      # 嵌入层大小
hidden_dim = 128        # Dense层大小
max_seq_len = 34        # 最大句长
num_filters = 256       # 卷积核数目
kernel_size = 5         # 卷积核尺寸
learning_rate = 1e-3    # 学习率
numclass = 3            # 类别数
lstm_units = 64
atten_size = 128

# 传值空间
input_x = tf.placeholder(tf.int32, shape=[None, max_seq_len], name='input_x')
input_y = tf.placeholder(tf.int32, shape=[None], name='input_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')


# define embedding layer
with tf.variable_scope('embedding'):
    # 标准正态分布初始化
    input_embedding = tf.Variable(
        tf.truncated_normal(shape=[vocab_size, embedding_dim], stddev=0.1), name='encoder_embedding')


定义bilstm网络

In [None]:
with tf.name_scope("bilstm"):
    # CNN layer
    x_input_embedded = tf.nn.embedding_lookup(input_embedding, input_x)
    # conv = tf.layers.conv1d(x_input_embedded, num_filters, kernel_size, name='conv')
    # # global max pooling layer
    # pooling = tf.reduce_max(conv, reduction_indices=[1])
    encode_fw = tf.nn.rnn_cell.BasicLSTMCell(lstm_units)
    encode_fw = tf.nn.rnn_cell.DropoutWrapper(cell=encode_fw, output_keep_prob=keep_prob)
    encode_bw = tf.nn.rnn_cell.BasicLSTMCell(lstm_units)
    encode_bw = tf.nn.rnn_cell.DropoutWrapper(cell=encode_bw, output_keep_prob=keep_prob)
    ((encode_fw_output, encode_bw_output),
     (encode_fw_state, encode_bw_state)) = (tf.nn.bidirectional_dynamic_rnn(cell_fw=encode_fw,
                                                                            cell_bw=encode_bw,
                                                                            inputs = x_input_embedded,
                                                                            dtype=tf.float32))

    output = tf.concat((encode_fw_output, encode_bw_output), 2)

定义attention网络

In [None]:
with tf.name_scope('attention'):
    att_in = tf.expand_dims(output, axis=2)
    w_att = tf.Variable(tf.random_normal([1, 1, 2 * lstm_units, atten_size], stddev=0.1))
    b_att = tf.Variable(tf.random_normal([atten_size], stddev=0.1))
    u_att = tf.Variable(tf.random_normal([1, 1, atten_size, 1], stddev=0.1))
    v_att = tf.tanh(tf.nn.conv2d(att_in, w_att, strides=[1, 1, 1, 1], padding='SAME') + b_att)
    betas = tf.nn.conv2d(v_att, u_att, strides=[1, 1, 1, 1], padding='SAME')
    exp_betas = tf.reshape(tf.exp(betas), [-1, max_seq_len])
    alphas = exp_betas / tf.reshape(tf.reduce_mean(exp_betas, 1), [-1, 1])
    last = tf.reduce_sum(output * tf.reshape(alphas, [-1, max_seq_len, 1]), 1)
with tf.name_scope("score"):
    # 全连接层，后面接dropout以及relu激活
    weight = tf.Variable(tf.random_normal([2 * lstm_units, numclass], stddev=0.1))
    bias = tf.Variable(tf.random_normal([numclass], stddev=0.1))
    y_ = (tf.matmul(last, weight) + bias)


    # 分类器

    y_pred_cls = tf.argmax(tf.nn.softmax(y_), 1, name='y_pred')  # 预测类别

训练

In [None]:
with tf.name_scope("score"):
    # 全连接层，后面接dropout以及relu激活
    weight = tf.Variable(tf.random_normal([2 * lstm_units, numclass], stddev=0.1))
    bias = tf.Variable(tf.random_normal([numclass], stddev=0.1))
    y_ = (tf.matmul(last, weight) + bias)


    # 分类器

    y_pred_cls = tf.argmax(tf.nn.softmax(y_), 1, name='y_pred')  # 预测类别

with tf.name_scope("optimize"):
    # 将label进行onehot转化
    one_hot_labels = tf.one_hot(input_y, depth=numclass, dtype=tf.float32)
    # 损失函数，交叉熵
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=y_, labels=one_hot_labels)
    loss = tf.reduce_mean(cross_entropy)
    # 优化器
    train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

with tf.name_scope("accuracy"):
    # 准确率
    correct_pred = tf.equal(tf.argmax(one_hot_labels, 1), y_pred_cls)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='acc')

with tf.name_scope("summary"):
    tf.summary.scalar("loss", loss)
    tf.summary.scalar("accuracy", accuracy)
    merged_summary = tf.summary.merge_all()

best_score = 0
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    train_writer = tf.summary.FileWriter(LOG_PATH, sess.graph)

    # dataset.get_step() 获取数据的总迭代次数
    for step in range(dataset.get_step()):
        x_train, y_train = dataset.next_train_batch()
        # x_val, y_val = dataset.next_validation_batch()

        fetches = [loss, accuracy, train_op]
        feed_dict = {input_x: x_train, input_y: y_train, keep_prob: 0.9}
        loss_, accuracy_, _ = sess.run(fetches, feed_dict=feed_dict)

        summary = sess.run(merged_summary, feed_dict=feed_dict)
        train_writer.add_summary(summary, step)

        cur_step = str(step + 1) + "/" + str(dataset.get_step())
        print('The Current step per total: {} | The Current loss: {} | The Current ACC: {}'.
              format(cur_step, loss_, accuracy_))
        if step % 100 == 0:
            model.save_model(sess, MODEL_PATH, overwrite=True)

查看训练结果

In [None]:
The Current step per total: 11/20 | The Current loss: 0.5396712422370911 | The Current ACC: 0.796875
The Current step per total: 12/20 | The Current loss: 0.5476027727127075 | The Current ACC: 0.796875
The Current step per total: 13/20 | The Current loss: 0.5575822591781616 | The Current ACC: 0.796875
The Current step per total: 14/20 | The Current loss: 0.5402854681015015 | The Current ACC: 0.796875
The Current step per total: 15/20 | The Current loss: 0.5028542280197144 | The Current ACC: 0.8125
The Current step per total: 16/20 | The Current loss: 0.4492661952972412 | The Current ACC: 0.828125
The Current step per total: 17/20 | The Current loss: 0.4068627953529358 | The Current ACC: 0.828125
The Current step per total: 18/20 | The Current loss: 0.3983329236507416 | The Current ACC: 0.859375
The Current step per total: 19/20 | The Current loss: 0.4089481234550476 | The Current ACC: 0.890625
The Current step per total: 20/20 | The Current loss: 0.38450390100479126 | The Current ACC: 0.921875


可以看出结果不错