# 載入 IMDB 資料集

In [1]:
import tensorflow as tf

dataset = tf.keras.utils.get_file(
    fname="aclImdb.tar.gz", 
    origin="http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz", 
    extract=True,
)

Downloading data from http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz


# 資料集處理
    正負評價分開、並且編碼。

In [0]:
import os
import glob
import pandas as pd

def read(fp):
    with open(fp, "r", encoding="utf-8") as f:
        content = f.read()
    return content

def read_data(base):
    pos = glob.glob(os.path.join(base, "pos", "*"))
    neg = glob.glob(os.path.join(base, "neg", "*"))
    df = pd.DataFrame({
        "path":(neg + pos),
        "target":([0]*len(neg) + [1] * len(pos))
    })
    df["content"] = df["path"].apply(read)
    return df

In [3]:
dirname = os.path.dirname(dataset)
base = os.path.join(dirname, "aclImdb", "train")
train_df = read_data(base)
base = os.path.join(dirname, "aclImdb", "test")
test_df = read_data(base)
test_df

Unnamed: 0,path,target,content
0,/root/.keras/datasets/aclImdb/test/neg/11722_4...,0,I saw this film for one reason: the tagline is...
1,/root/.keras/datasets/aclImdb/test/neg/1571_4.txt,0,The film was disappointing. I saw it on Broadw...
2,/root/.keras/datasets/aclImdb/test/neg/8224_1.txt,0,"OK, so I gotta start this review by saying i w..."
3,/root/.keras/datasets/aclImdb/test/neg/10426_4...,0,"This movie narrate the story of John Belushi,b..."
4,/root/.keras/datasets/aclImdb/test/neg/6390_2.txt,0,This movie just arrived to Mexico and since I ...
...,...,...,...
24995,/root/.keras/datasets/aclImdb/test/pos/11181_8...,1,On account of my unfortunately not being able ...
24996,/root/.keras/datasets/aclImdb/test/pos/8188_10...,1,This movie is beautiful in many ways: the plot...
24997,/root/.keras/datasets/aclImdb/test/pos/3098_8.txt,1,One of the most entertaining of all silent com...
24998,/root/.keras/datasets/aclImdb/test/pos/10878_1...,1,"Two years after its initial release, Goldeneye..."


# 文字預處理 (text to number)

In [0]:
# 預處理1. 先把文字化成數字
from tensorflow.keras.preprocessing.text import Tokenizer
# 出現太少的詞, 你可以選擇不看, 只留出現次數最高的2000(num_words)
tok = Tokenizer(num_words=2000)
tok.fit_on_texts(train_df["content"])

In [0]:
# 想要看每個單詞被給的編號: 
# tok.word_index

In [5]:
x_train_seq = tok.texts_to_sequences(train_df["content"])
x_test_seq = tok.texts_to_sequences(test_df["content"])
pd.DataFrame(x_train_seq)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,1671,1672,1673,1674,1675,1676,1677,1678,1679,1680,1681,1682,1683,1684,1685,1686,1687,1688,1689,1690,1691,1692,1693,1694,1695,1696,1697,1698,1699,1700,1701,1702,1703,1704,1705,1706,1707,1708,1709,1710
0,2,104,4,58,219,1636,153,4,11.0,1114.0,765.0,53.0,8.0,48.0,6.0,176.0,866.0,1.0,586.0,30.0,1031.0,1.0,457.0,1.0,959.0,6.0,35.0,750.0,2.0,35.0,1250.0,12.0,1383.0,27.0,125.0,122.0,128.0,36.0,1.0,1453.0,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1,1,436,16,1,6,53,31,1,633.0,133.0,1.0,178.0,32.0,1957.0,348.0,35.0,325.0,588.0,1.0,6.0,30.0,1.0,151.0,8.0,1311.0,60.0,1.0,884.0,135.0,78.0,33.0,178.0,5.0,122.0,3.0,45.0,237.0,167.0,5.0,1139.0,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2,51,10,11,17,36,3,461,10,119.0,21.0,815.0,9.0,20.0,1.0,60.0,368.0,606.0,75.0,2.0,1555.0,13.0,1.0,1359.0,186.0,100.0,146.0,9.0,10.0,194.0,12.0,1.0,13.0,239.0,50.0,71.0,1.0,19.0,407.0,1.0,62.0,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3,3,220,16,1243,1627,155,288,151,2.0,41.0,339.0,288.0,151.0,844.0,5.0,11.0,304.0,310.0,8.0,1570.0,4.0,1.0,12.0,177.0,1.0,8.0,1005.0,4.0,65.0,310.0,172.0,1549.0,1571.0,39.0,35.0,1.0,635.0,1.0,93.0,1.0,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
4,259,10,358,11,176,5,64,86,82.0,81.0,1146.0,48.0,10.0,101.0,23.0,39.0,40.0,1041.0,734.0,105.0,100.0,146.0,11.0,17.0,20.0,19.0,685.0,233.0,311.0,10.0,417.0,5.0,898.0,139.0,177.0,57.0,45.0,9.0,40.0,1520.0,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24995,540,6,3,640,1841,170,4,989,110.0,2.0,170.0,784.0,1540.0,3.0,1407.0,16.0,1303.0,406.0,32.0,208.0,39.0,236.0,222.0,3.0,376.0,208.0,51.0,26.0,1.0,1382.0,2.0,8.0,1.0,418.0,4.0,540.0,11.0,6.0,118.0,26.0,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
24996,10,1038,16,313,34,555,12,11,198.0,13.0,1.0,115.0,4.0,1.0,509.0,58.0,656.0,2.0,10.0,68.0,31.0,9.0,51.0,9.0,83.0,8.0,1.0,175.0,2.0,293.0,172.0,387.0,10.0,801.0,30.0,12.0,55.0,5.0,1.0,198.0,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
24997,320,672,52,49,861,22,97,137,43.0,15.0,3.0,302.0,795.0,30.0,29.0,7.0,7.0,1568.0,53.0,608.0,1862.0,1825.0,5.0,1.0,1311.0,5.0,1272.0,3.0,1370.0,5.0,190.0,38.0,5.0,909.0,739.0,187.0,100.0,3.0,777.0,248.0,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
24998,28,4,58,1636,99,10,89,121,45.0,47.0,6.0,98.0,411.0,8.0,260.0,701.0,34.0,1478.0,107.0,11.0,17.0,42.0,160.0,2.0,616.0,30.0,46.0,385.0,10.0,89.0,121.0,86.0,119.0,81.0,184.0,1.0,179.0,34.0,66.0,1432.0,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


# 句子 NAN 填充

In [6]:
from tensorflow.keras.preprocessing.sequence import pad_sequences

x_train_pad = pad_sequences(x_train_seq, maxlen=256)
x_test_pad = pad_sequences(x_test_seq, maxlen=256)
pd.DataFrame(x_train_pad)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255
0,215,14,335,4,32,520,14,88,199,38,1105,15,2,6,6,866,12,38,113,23,109,177,31,3,335,226,45,35,92,12,59,1257,48,16,3,226,12,59,899,12,...,2,573,993,5,80,91,664,860,1,127,956,6,63,176,2,1481,23,1,52,219,4,126,14,1,105,290,1835,8,91,5,91,308,6,3,19,487,88,1326,1656,140
1,1,357,4,24,110,8,237,167,5,27,43,4,342,93,158,1,25,3,576,156,26,13,45,31,46,576,26,211,3,21,92,513,87,222,54,279,5,513,291,34,...,18,171,237,291,115,570,15,8,727,5,285,131,501,2,237,688,392,3,214,550,1,1229,4,200,137,20,5,410,191,1838,8,434,18,431,1849,5,273,1,20,65
2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...,4,155,2,61,85,1,827,606,240,4,642,10,89,178,5,19,1534,39,99,1224,18,204,107,2,204,107,3,888,349,17,36,1534,12,13,744,208,125,71,11,592
3,1026,1,664,31,1657,708,13,49,9,13,193,18,9,1384,102,2,1181,12,90,22,456,48,571,11,17,29,1,106,939,2,40,295,53,1,596,2,564,1385,23,553,...,2,13,40,31,48,10,216,3,1189,186,19,3,297,29,1,93,7,7,51,22,101,42,29,117,2,188,76,430,1,1938,3,422,610,10,89,27,8,3,368,721
4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...,391,10,479,148,332,61,14,49,14,1,226,22,23,345,2,471,37,5,1289,1,924,1,1177,1,164,2,29,4,1,174,15,1549,232,4,58,110,2,46,946,49
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24995,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...,1586,1964,382,122,14,3,861,33,132,22,188,545,540,14,1,268,243,9,128,188,567,22,36,533,48,235,25,74,45,33,66,40,53,1,878,2,1790,4,11,17
24996,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...,2,329,1,198,8,170,5,388,1,111,45,1,201,307,123,263,1434,20,285,634,27,790,1,83,8,344,5,53,3,1036,1143,4,1,18,13,477,26,6,207,477
24997,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,320,672,52,49,861,22,97,137,43,15,3,302,795,30,29,...,326,626,39,1131,8,98,93,143,21,249,45,1,274,13,978,5,94,9,213,122,14,46,429,4,294,2,42,21,615,401,18,42,431,52,1131,2,42,1117,52,70
24998,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...,64,130,47,6,1843,36,34,293,9,2,26,421,9,7,7,10,101,12,634,350,5,273,29,49,36,1,17,20,11,18,83,5,166,43,86,5,78,12,7,7


# 模型建立

In [7]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Flatten, Dropout, Dense

layers = [
    # 2001 * 64 = 128064
    Embedding(2001, 64, mask_zero=True, input_length=256),
    Flatten(),
    Dense(256, activation="relu"),
    Dropout(0.25),
    Dense(2, activation="softmax")
]
model = Sequential(layers)
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (None, 256, 64)           128064    
_________________________________________________________________
flatten (Flatten)            (None, 16384)             0         
_________________________________________________________________
dense (Dense)                (None, 256)               4194560   
_________________________________________________________________
dropout (Dropout)            (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 514       
Total params: 4,323,138
Trainable params: 4,323,138
Non-trainable params: 0
_________________________________________________________________


# 模型編譯

In [0]:
from tensorflow.keras.losses import SparseCategoricalCrossentropy
model.compile(loss=SparseCategoricalCrossentropy(),
              optimizer="adam",
              metrics=["accuracy"])

# 模型訓練

In [9]:
import numpy as np
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
callbacks = [
    EarlyStopping(patience=3, restore_best_weights=True),
]
model.fit(x_train_pad,
          np.array(train_df["target"]),
          batch_size=200,
          epochs=100,
          validation_split=0.1,
          verbose=2,
          callbacks=callbacks)

Epoch 1/100
113/113 - 15s - loss: 0.5282 - accuracy: 0.7081 - val_loss: 0.4840 - val_accuracy: 0.7960
Epoch 2/100
113/113 - 15s - loss: 0.2111 - accuracy: 0.9176 - val_loss: 0.3553 - val_accuracy: 0.8552
Epoch 3/100
113/113 - 15s - loss: 0.0736 - accuracy: 0.9773 - val_loss: 0.4488 - val_accuracy: 0.8444
Epoch 4/100
113/113 - 15s - loss: 0.0167 - accuracy: 0.9972 - val_loss: 0.6650 - val_accuracy: 0.8304
Epoch 5/100
113/113 - 15s - loss: 0.0042 - accuracy: 0.9996 - val_loss: 0.7824 - val_accuracy: 0.8208


<tensorflow.python.keras.callbacks.History at 0x7f73e5761b38>

# 模型驗證

In [10]:
model.evaluate(x_test_pad, np.array(test_df["target"]))



[0.33653995394706726, 0.858959972858429]

# 預測資料

In [11]:
layers = [
    # 2001 * 64 = 128064
    Embedding(2001, 64, mask_zero=True),
]
embedding = Sequential(layers)
w = model.layers[0].get_weights()
embedding.set_weights(w)
embedding.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      (None, None, 64)          128064    
Total params: 128,064
Trainable params: 128,064
Non-trainable params: 0
_________________________________________________________________


In [12]:
embedding.predict([[1]])

array([[[-0.0180213 ,  0.01985464,  0.02356286, -0.03132242,
          0.00996363, -0.03841089,  0.03119331, -0.01875136,
          0.01503539,  0.0453583 , -0.03884904,  0.00555581,
          0.01781305, -0.01935773, -0.03558391, -0.01774406,
         -0.0171569 ,  0.00849898,  0.02612113,  0.00717588,
          0.02580583,  0.02318492, -0.0097551 ,  0.02422469,
         -0.0281295 ,  0.00059863,  0.01478149,  0.00800881,
         -0.0003116 , -0.03357462, -0.00507779, -0.03138198,
          0.01264676, -0.03298078, -0.02271628,  0.0205271 ,
         -0.02204996,  0.03441257,  0.0160019 , -0.0340315 ,
          0.01385654, -0.0067904 ,  0.01141077, -0.01508256,
          0.00136489, -0.01414987,  0.00065778, -0.02574585,
         -0.04049654, -0.01858634,  0.00316737, -0.02689118,
         -0.01500137,  0.02210342,  0.01340085,  0.01668786,
          0.02157978,  0.00375725,  0.01285978, -0.03685961,
         -0.00817791, -0.01694725, -0.00168056,  0.03617594]]],
      dtype=float32)