In [1]:
import tensorflow as tf
from attrdict import AttrDict
tf.__version__



'1.8.0'

读取训练和测试文件

In [2]:
import re
import tarfile
import random

class ReviewReader:
    
    TOKEN_REGEX = re.compile(r'[A-Za-z]+|[!?.:,()]')
    
    def __init__(self):
        filepath="aclImdb_v1.tar.gz"
        self.train_data = []
        self.test_data = []
        self.max_length = 0
        
        with tarfile.open(filepath) as archive:
            for filename in archive.getnames():
                if filename.startswith('aclImdb/train/pos/'):
                    self.train_data.append((self._read(archive, filename), True))
                elif filename.startswith('aclImdb/train/neg/'):
                    self.train_data.append((self._read(archive, filename), False))
                elif filename.startswith('aclImdb/test/pos/'):
                    self.test_data.append((self._read(archive, filename), True))
                elif filename.startswith('aclImdb/test/neg/'):
                    self.test_data.append((self._read(archive, filename), False))
        random.shuffle(self.train_data)
        random.shuffle(self.test_data)
    
    def _read(self, archive, filename):
        with archive.extractfile(filename) as file_:
            data = file_.read().decode('utf-8')
            data = type(self).TOKEN_REGEX.findall(data)
            data = [x.lower() for x in data]
            if len(data) > self.max_length:
                self.max_length = len(data)
            return data
    
    

    

In [3]:
rr = ReviewReader()
test_data = rr.test_data
train_data = rr.train_data

print("test.shape:", len(test_data))
print("train.shape:", len(train_data))
print("max_len:", rr.max_length)

test.shape: 25000
train.shape: 25000
max_len: 2758


读取embedding

In [4]:
from load_word2vec import load_embedding
import numpy as np
import logging

class Embedding:
    def __init__(self, vocab_path="", embedding_path= "", length=1500):
        vocab, index_vocab, embed = load_embedding()
        self._length = length
        self._embedding = embed
        self._vocab = vocab

    def __call__(self, sequence):
        data = np.zeros((self._length, self._embedding.shape[1]))
        indices = [self._vocab.get(x, 0) for x in sequence]
        # print(indices)
        logging.info("indices:", indices)
        embedded = self._embedding[indices, :]
        data[:len(sequence)] = embedded
        return data

    @property
    def dimensions(self):
        return self._embedding.shape[1]

定义运算的图


In [5]:

embed = None
try :
    embed = Embedding(length = rr.max_length)
except Exception as ex:
    print(dir(ex))
    print(ex.message.decode("utf8"))



b'3000000 300\n'
3000000 300


In [6]:
print(embed)
data_batch = np.zeros((3, rr.max_length , embed.dimensions), dtype="float32")
label_batch = np.zeros((3, 2))
index = 0
for sentence, label in train_data[:3]:
    data = embed(sentence)
    data_batch[index] = data
    print("train:", label)
    label_batch[index] = [1, 0] if label else [0, 1]
    index += 1
    
    print(data)

print("data_batch:", data_batch)
print("label_batch:", label_batch)

<__main__.Embedding object at 0x111e5eeb8>
train: False
[[ 0.20800781  0.03979492  0.25       ... -0.08789062 -0.00418091
  -0.09179688]
 [ 0.03515625 -0.08300781 -0.07177734 ...  0.01965332  0.09863281
  -0.04760742]
 [-0.03442383  0.10351562  0.02160645 ...  0.07324219  0.03320312
   0.03833008]
 ...
 [ 0.          0.          0.         ...  0.          0.
   0.        ]
 [ 0.          0.          0.         ...  0.          0.
   0.        ]
 [ 0.          0.          0.         ...  0.          0.
   0.        ]]
train: False
[[ 0.109375    0.140625   -0.03173828 ...  0.00765991  0.12011719
  -0.1796875 ]
 [ 0.04394531 -0.03344727 -0.29882812 ... -0.03271484 -0.02514648
   0.41015625]
 [ 0.23730469 -0.10107422 -0.07275391 ... -0.10107422 -0.1015625
   0.01501465]
 ...
 [ 0.          0.          0.         ...  0.          0.
   0.        ]
 [ 0.          0.          0.         ...  0.          0.
   0.        ]
 [ 0.          0.          0.         ...  0.          0.
   0.       

获取序列的长度

In [7]:
def data_input_length(data_input):
    tmp_abs = tf.abs(data_input)
    tmp_max = tf.reduce_max(tmp_abs, axis = 2)
    tmp_sign = tf.sign(tmp_max)
    temp_len = tf.reduce_sum(tmp_sign, axis = 1)
    tmp_len = tf.cast(temp_len, tf.int32)
    return tmp_len

def test_data_input_length():
    """
    测试求序列长度
    """
    temp_graph = tf.Graph()
    with temp_graph.as_default():
        data_input = tf.placeholder(tf.float32, [None, length, embedding_dimensions], name="data")
        data_input_len =  data_input_length(data_input)
        sess = tf.Session()
        le = sess.run(data_input_len, feed_dict={data_input:data_batch})
        print(le)
        
# test_data_input_length()

从output中获取最后一个有效输出

In [8]:
def last_output(output, seq_len):
    output_shape = tf.shape(output)
    batch_size = output_shape[0]
    max_len = output_shape[1]
    output_size = output_shape[2]
    index = tf.range(0, batch_size) * max_len + (seq_len - 1)
    flat = tf.reshape(output, [-1, output_size])
    last_out = tf.gather(flat, index)
    return last_out

In [9]:
def my_optimizer(loss):
    # 计算梯度，后向传播
    optimizer = tf.train.GradientDescentOptimizer(0.01, name="my_optimizer")
#     grads = optimizer.compute_gradients(loss)
#     app_grads = optimizer.apply_gradients(grads)
    app_grads = optimizer.minimize(loss)
    return app_grads

In [10]:
def my_error(target, prediction):
    target_argmax = tf.argmax(target, 1)
    
    predic_argmax = tf.argmax(prediction, 1, name="predic_argmax")
    mistakes = tf.not_equal(target_argmax, predic_argmax)
    print(mistakes)
    return tf.reduce_mean(tf.cast(mistakes, tf.float32))

def test_error(target, prediction):
    print('target:', target)
    with tf.Graph().as_default():
        data_input = tf.placeholder(tf.float32, [None, 2])
        target_argmax = tf.argmax(data_input, 1)
        sess = tf.Session()
        ret = sess.run(target_argmax, feed_dict={data_input:target})
        print(ret)
test_error(label_batch, 0)

target: [[0. 1.]
 [0. 1.]
 [0. 1.]]
[1 1 1]


定义好数据流图

In [11]:
main_graph = tf.Graph()

class_num = 2
length = rr.max_length  # 每个句子的最大长度
embedding_dimensions = embed.dimensions   # 每个word embedding的维数
hidden_size = 150


with main_graph.as_default():


    data_input = tf.placeholder(tf.float32, [None, length, embedding_dimensions], name="input_data")
    data_label = tf.placeholder(tf.float32, [None, 2], name = "input_target")
    data_input_len = data_input_length(data_input)

    rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size, name="rnn_cell")
    output, state = tf.nn.dynamic_rnn(
        rnn_cell,
        data_input,
        dtype = tf.float32,

    )
    
    W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev = 0.01), name="weight")
    bias = tf.Variable(tf.truncated_normal([class_num], stddev = 0.01), name="bias")
    
    last_out = last_output(output, data_input_len)
    class_output = tf.matmul(last_out, W) + bias 
    prediction = tf.nn.softmax(class_output)
    cross_entropy = -tf.reduce_sum(data_label * tf.log(prediction))
    optimizer = my_optimizer(cross_entropy)
    mode_error = my_error(data_label, prediction)
    initia = tf.global_variables_initializer()
    


    
    








Tensor("NotEqual:0", shape=(?,), dtype=bool)


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


训练

In [12]:


def batch_reader(batch_size):
    iterator = iter(train_data)
    while True:
        data_batch = np.zeros((batch_size, rr.max_length , embed.dimensions), dtype="float32")
        label_batch = np.zeros((batch_size, 2))
        
        for index in range(batch_size):
            text, label =next(iterator)
            data_batch[index] = embed(text)
            label_batch[index] = [1, 0] if label else [0, 1]
    
    
        yield data_batch, label_batch
        
        

In [13]:
import random
def model_test_data():
    test_count = 2000
    test_batch = np.zeros((test_count, rr.max_length , embed.dimensions), dtype="float32")
    test_label_batch = np.zeros((test_count, 2))
    tmp_test = random.sample(test_data, test_count)
    for index in range(test_count):
        (text, label)= tmp_test[index]
        test_batch[index] = embed(text)
        test_label_batch[index] = [1, 0] if label else [0, 1]
        
    return test_batch, test_label_batch

In [14]:
session = tf.Session(graph=main_graph)
session.run(initia)

batch_size = 20


    

for index, data_train in enumerate(batch_reader(batch_size)):
    # print("label:", data_train[1])
    op_data, err = session.run([optimizer, mode_error], feed_dict={data_input:data_train[0], data_label:data_train[1]})
    print("{} error:{}".format(index, err ))
    if err < 1e-6:
        break

t_b, t_lb = model_test_data()
test_error = session.run([ mode_error], feed_dict={data_input:t_b, data_label:t_lb})
        
print("test error:",test_error )


writer = tf.summary.FileWriter("mygraph", main_graph)
writer.close()




0 error:0.44999998807907104
1 error:0.3499999940395355
2 error:0.4000000059604645
3 error:0.550000011920929
4 error:0.4000000059604645
5 error:0.6000000238418579
6 error:0.5
7 error:0.550000011920929
8 error:0.44999998807907104
9 error:0.4000000059604645
10 error:0.44999998807907104
11 error:0.44999998807907104
12 error:0.75
13 error:0.5
14 error:0.5
15 error:0.20000000298023224
16 error:0.44999998807907104
17 error:0.6000000238418579
18 error:0.6499999761581421
19 error:0.5
20 error:0.4000000059604645
21 error:0.30000001192092896
22 error:0.550000011920929
23 error:0.6000000238418579
24 error:0.6000000238418579
25 error:0.4000000059604645
26 error:0.6000000238418579
27 error:0.25
28 error:0.5
29 error:0.6000000238418579
30 error:0.3499999940395355
31 error:0.5
32 error:0.5
33 error:0.30000001192092896
34 error:0.44999998807907104
35 error:0.8500000238418579
36 error:0.5
37 error:0.6499999761581421
38 error:0.30000001192092896
39 error:0.6000000238418579
40 error:0.699999988079071
41 e

318 error:0.30000001192092896
319 error:0.6499999761581421
320 error:0.3499999940395355
321 error:0.5
322 error:0.4000000059604645
323 error:0.4000000059604645
324 error:0.5
325 error:0.3499999940395355
326 error:0.4000000059604645
327 error:0.4000000059604645
328 error:0.30000001192092896
329 error:0.6000000238418579
330 error:0.6499999761581421
331 error:0.5
332 error:0.4000000059604645
333 error:0.550000011920929
334 error:0.4000000059604645
335 error:0.6000000238418579
336 error:0.44999998807907104
337 error:0.30000001192092896
338 error:0.6000000238418579
339 error:0.550000011920929
340 error:0.5
341 error:0.44999998807907104
342 error:0.5
343 error:0.3499999940395355
344 error:0.550000011920929
345 error:0.25
346 error:0.4000000059604645
347 error:0.5
348 error:0.30000001192092896
349 error:0.44999998807907104
350 error:0.44999998807907104
351 error:0.4000000059604645
352 error:0.44999998807907104
353 error:0.3499999940395355
354 error:0.44999998807907104
355 error:0.600000023841

631 error:0.4000000059604645
632 error:0.4000000059604645
633 error:0.699999988079071
634 error:0.550000011920929
635 error:0.6000000238418579
636 error:0.44999998807907104
637 error:0.5
638 error:0.25
639 error:0.6000000238418579
640 error:0.44999998807907104
641 error:0.44999998807907104
642 error:0.5
643 error:0.5
644 error:0.6000000238418579
645 error:0.5
646 error:0.6499999761581421
647 error:0.44999998807907104
648 error:0.6000000238418579
649 error:0.6499999761581421
650 error:0.5
651 error:0.44999998807907104
652 error:0.5
653 error:0.30000001192092896
654 error:0.5
655 error:0.550000011920929
656 error:0.4000000059604645
657 error:0.5
658 error:0.5
659 error:0.550000011920929
660 error:0.44999998807907104
661 error:0.550000011920929
662 error:0.6000000238418579
663 error:0.5
664 error:0.44999998807907104
665 error:0.4000000059604645
666 error:0.44999998807907104
667 error:0.6000000238418579
668 error:0.6499999761581421
669 error:0.550000011920929
670 error:0.550000011920929
67

943 error:0.550000011920929
944 error:0.5
945 error:0.6000000238418579
946 error:0.6000000238418579
947 error:0.4000000059604645
948 error:0.25
949 error:0.3499999940395355
950 error:0.699999988079071
951 error:0.6499999761581421
952 error:0.6000000238418579
953 error:0.6000000238418579
954 error:0.5
955 error:0.5
956 error:0.44999998807907104
957 error:0.550000011920929
958 error:0.44999998807907104
959 error:0.550000011920929
960 error:0.44999998807907104
961 error:0.44999998807907104
962 error:0.4000000059604645
963 error:0.699999988079071
964 error:0.5
965 error:0.44999998807907104
966 error:0.550000011920929
967 error:0.6499999761581421
968 error:0.5
969 error:0.44999998807907104
970 error:0.3499999940395355
971 error:0.30000001192092896
972 error:0.3499999940395355
973 error:0.6000000238418579
974 error:0.6000000238418579
975 error:0.550000011920929
976 error:0.550000011920929
977 error:0.6499999761581421
978 error:0.44999998807907104
979 error:0.44999998807907104
980 error:0.300

  if __name__ == '__main__':


test error: [0.4975]
