In [None]:
import tensorflow as tf
from sklearn.externals import joblib
import keras
import sys
sys.path.append('..')
from utils.word2vec_fast import *

In [None]:
# 加载词向量
wv = Word2VecFast.load_word2vec_format(file_path='../data/chip2018/word_embedding.txt', word_shape=300)
print('word_embedding shape: ', wv.word_shape())
word_size, embedding_size = wv.word_embeddings().shape[0], wv.word_embeddings().shape[1]

In [None]:
# 加载数据集
dataset = joblib.load(filename='../data/chip2018/chip2018.data2')
X_train, Y_train, X_test, Y_test = dataset['X_train'], dataset['Y_train'], dataset['X_test'], dataset['Y_test']
print('X_train shape:', X_train.shape)
Y_train = keras.utils.to_categorical(Y_train, 2)
Y_test = keras.utils.to_categorical(Y_test, 2)

In [None]:
sequence_length = 45 # 句子最大长度
num_filters = 128
num_classes = 2
num_epoch = 5
batch_size = 500

In [None]:
input_x1 = tf.placeholder(tf.int32, [None, sequence_length], name="input_x1")
input_x2 = tf.placeholder(tf.int32, [None, sequence_length], name="input_x2")
input_y = tf.placeholder(tf.float32, [None, num_classes], name="input_y")
dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")

In [None]:
# 嵌入层
WL_word2vec = tf.constant(wv.word_embeddings(), dtype=tf.float32)

In [None]:
WL_word_embedding_x1 = tf.nn.embedding_lookup(WL_word2vec, input_x1)
WL_word_embedding_expanded_x1 = tf.expand_dims(WL_word_embedding_x1, -1)
print(WL_word_embedding_expanded_x1)

WL_word_embedding_x2 = tf.nn.embedding_lookup(WL_word2vec, input_x2)
WL_word_embedding_expanded_x2 = tf.expand_dims(WL_word_embedding_x2, -1)
print(WL_word_embedding_expanded_x2)

In [None]:
def cnn(x_):
    W = tf.Variable(tf.truncated_normal([3, embedding_size, 1, num_filters], stddev=0.1), dtype=tf.float32)
    b = tf.Variable(tf.constant(0.1, shape=[num_filters]))
    conv = tf.nn.conv2d(input=x_, filter=W, strides=[1, 1, 1, 1], padding="VALID")
    h = tf.nn.relu(tf.nn.bias_add(conv, b))
    print(h)
    # Max-pooling over the outputs
    pooled = tf.nn.max_pool(h, ksize=[1, sequence_length - 2, 1, 1], strides=[1, 1, 1, 1], padding='VALID')
    return pooled

In [None]:
pool1 = cnn(WL_word_embedding_expanded_x1)
pool2 = cnn(WL_word_embedding_expanded_x2)

h_pool_flat1 = tf.reshape(pool1, [-1, num_filters])
h_pool_flat2 = tf.reshape(pool2, [-1, num_filters])

h_drop1 = tf.nn.dropout(h_pool_flat1, dropout_keep_prob)
h_drop2 = tf.nn.dropout(h_pool_flat2, dropout_keep_prob)

# output = tf.abs(h_drop1 - h_drop2)
# print(output)
output = tf.nn.l2_normalize(h_drop1) * tf.nn.l2_normalize(h_drop2)

In [None]:
W = tf.Variable(tf.truncated_normal([num_filters, num_classes], stddev=0.1))
b = tf.Variable(tf.constant(0.1, shape=[num_classes]))
logits = tf.nn.xw_plus_b(output, W, b)
predictions = tf.nn.softmax(logits)

correct_predict = tf.equal(tf.argmax(predictions, axis=1), tf.argmax(input_y, axis=1))
acc_op = tf.reduce_mean(tf.cast(correct_predict, tf.float32))

In [None]:
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=input_y))
train_op = tf.train.AdamOptimizer(learning_rate=0.0005).minimize(loss_op)
init_op = tf.global_variables_initializer()

In [None]:
def next_batch(i):
    a = i * batch_size
    b = (i+1) * batch_size
    return (X_train[a:b], Y_train[a:b])

In [None]:
total_batch = int(X_train.shape[0]/batch_size)

In [None]:
with tf.Session() as sess:
    sess.run(init_op)
    
    for epoch in range(num_epoch):
        for i in range(total_batch):
            batch_xs, batch_ys = next_batch(i)
            sess.run(train_op, feed_dict={input_x1: batch_xs[:,0:45], input_x2: batch_xs[:,45:], input_y: batch_ys, dropout_keep_prob: 0.8})
            c, acc = sess.run([loss_op, acc_op], feed_dict={input_x1: batch_xs[:,0:45], input_x2: batch_xs[:,45:], input_y: batch_ys, dropout_keep_prob: 0.8})
            c1, acc1 = sess.run([loss_op, acc_op], feed_dict={input_x1: X_test[:,0:45], input_x2: X_test[:,45:], input_y: Y_test, dropout_keep_prob: 0.8})
        print("Epoch:", '%04d' % (epoch+1), "loss=", "{:.9f}".format(c), "accuracy=", "{:.9f}".format(acc), "test_loss=", "{:.9f}".format(c1), "test_accuracy=", "{:.9f}".format(acc1))
    print("Optimization Finished!")