In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.datasets import imdb

# 设置参数
max_features = 10000  # 词汇表大小
max_len = 200  # 最大序列长度
embedding_dim = 128  # 嵌入维度
filters = 64  # 卷积核数量
kernel_size = 5  # 卷积核大小

# 加载IMDB数据集
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)

# 填充序列
x_train = pad_sequences(x_train, maxlen=max_len)
x_test = pad_sequences(x_test, maxlen=max_len)

# 构建CNN模型
model = keras.Sequential([
    layers.Embedding(max_features, embedding_dim, input_length=max_len),
    layers.Conv1D(filters, kernel_size, activation='relu'),
    layers.MaxPooling1D(pool_size=2),
    layers.Conv1D(filters * 2, kernel_size, activation='relu'),
    layers.MaxPooling1D(pool_size=2),
    layers.GlobalMaxPooling1D(),
    layers.Dense(1, activation='sigmoid')  # 二分类任务，使用sigmoid激活函数
])

# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 训练模型
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)

# 评估模型
loss, accuracy = model.evaluate(x_test, y_test)
print(f"测试损失: {loss:.4f}, 测试准确率: {accuracy:.4f}")

# 示例：使用模型进行预测
sample_text = ["This movie was fantastic! I loved it."]
sample_sequence = imdb.get_word_index()
sample_sequence = [[sample_sequence.get(word, 0) for word in text.split()] for text in sample_text]
sample_sequence = pad_sequences(sample_sequence, maxlen=max_len)

predictions = model.predict(sample_sequence)
print(f"预测结果: {predictions[0][0]:.4f}")  # 输出预测的概率


Epoch 1/5




[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 32ms/step - accuracy: 0.6790 - loss: 0.5471 - val_accuracy: 0.8612 - val_loss: 0.3240
Epoch 2/5
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 32ms/step - accuracy: 0.9173 - loss: 0.2120 - val_accuracy: 0.8768 - val_loss: 0.2914
Epoch 3/5
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 29ms/step - accuracy: 0.9684 - loss: 0.0956 - val_accuracy: 0.8744 - val_loss: 0.3618
Epoch 4/5
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 29ms/step - accuracy: 0.9912 - loss: 0.0308 - val_accuracy: 0.8692 - val_loss: 0.4936
Epoch 5/5
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 27ms/step - accuracy: 0.9978 - loss: 0.0108 - val_accuracy: 0.8658 - val_loss: 0.6396
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 7ms/step - accuracy: 0.8537 - loss: 0.7046
测试损失: 0.6933, 测试准确率: 0.8551
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s