In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.datasets import imdb

# 设置参数
max_features = 10000  # 词汇表大小
max_len = 200  # 最大序列长度
embedding_dim = 128  # 嵌入维度
gru_units = 64  # GRU单元数

# 加载IMDB数据集
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)

# 填充序列
x_train = pad_sequences(x_train, maxlen=max_len)
x_test = pad_sequences(x_test, maxlen=max_len)

# 构建GRU模型
model = keras.Sequential([
    layers.Embedding(max_features, embedding_dim, input_length=max_len),
    layers.GRU(gru_units, return_sequences=False),
    layers.Dense(1, activation='sigmoid')  # 二分类任务，使用sigmoid激活函数
])

# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 训练模型
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)

# 评估模型
loss, accuracy = model.evaluate(x_test, y_test)
print(f"测试损失: {loss:.4f}, 测试准确率: {accuracy:.4f}")

# 示例：使用模型进行预测
sample_text = ["This movie was fantastic! I loved it."]
sample_sequence = imdb.get_word_index()
sample_sequence = [[sample_sequence.get(word, 0) for word in text.split()] for text in sample_text]
sample_sequence = pad_sequences(sample_sequence, maxlen=max_len)

predictions = model.predict(sample_sequence)
print(f"预测结果: {predictions[0][0]:.4f}")  # 输出预测的概率


Epoch 1/5




[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 55ms/step - accuracy: 0.6978 - loss: 0.5399 - val_accuracy: 0.8388 - val_loss: 0.3779
Epoch 2/5
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 58ms/step - accuracy: 0.8942 - loss: 0.2663 - val_accuracy: 0.8792 - val_loss: 0.3033
Epoch 3/5
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 67ms/step - accuracy: 0.9389 - loss: 0.1650 - val_accuracy: 0.8736 - val_loss: 0.3306
Epoch 4/5
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m56s[0m 89ms/step - accuracy: 0.9638 - loss: 0.1094 - val_accuracy: 0.8688 - val_loss: 0.3591
Epoch 5/5
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 100ms/step - accuracy: 0.9794 - loss: 0.0625 - val_accuracy: 0.8578 - val_loss: 0.3981
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 16ms/step - accuracy: 0.8474 - loss: 0.4343
测试损失: 0.4300, 测试准确率: 0.8488
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m