In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import classification_report, confusion_matrix

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

# =========================
# 1. 数据加载
# =========================

train_path = "../liar_dataset/processed_dataset/processed_train.csv"
valid_path = "../liar_dataset/processed_dataset/processed_valid.csv"
test_path  = "../liar_dataset/processed_dataset/processed_test.csv"

train_data = pd.read_csv(train_path)
valid_data = pd.read_csv(valid_path)
test_data  = pd.read_csv(test_path)

# 如果你的label是文本形式，需要先映射数字；如果已经是0~5，可跳过
label_mapping = {
    "pants-fire":   0,
    "false":        1,
    "barely-true":  2,
    "half-true":    3,
    "mostly-true":  4,
    "true":         5
}
# 应对可能是文本标签的情形
if train_data["label"].dtype == object:
    train_data["label"] = train_data["label"].map(label_mapping)
    valid_data["label"] = valid_data["label"].map(label_mapping)
    test_data["label"]  = test_data["label"].map(label_mapping)

# =========================
# 2. Tokenizer & Padding
# =========================

max_words = 5000    # 只考虑常见词
max_len   = 100     # 每条新闻句子长度上限

tokenizer = Tokenizer(num_words=max_words, oov_token="<OOV>")
tokenizer.fit_on_texts(train_data["clean_statement"])

train_seq = tokenizer.texts_to_sequences(train_data["clean_statement"])
valid_seq = tokenizer.texts_to_sequences(valid_data["clean_statement"])
test_seq  = tokenizer.texts_to_sequences(test_data["clean_statement"])

train_seq = pad_sequences(train_seq, maxlen=max_len, padding='post', truncating='post')
valid_seq = pad_sequences(valid_seq, maxlen=max_len, padding='post', truncating='post')
test_seq  = pad_sequences(test_seq, maxlen=max_len, padding='post', truncating='post')

train_labels = train_data["label"].values
valid_labels = valid_data["label"].values
test_labels  = test_data["label"].values

# =========================
# 3. 构建 LSTM 模型（6分类）
# =========================
model_6class = models.Sequential([
    layers.Embedding(input_dim=max_words, output_dim=128),
    layers.Bidirectional(layers.LSTM(128, return_sequences=False)),
    layers.Dropout(0.5),
    layers.BatchNormalization(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.3),
    layers.Dense(6, activation='softmax')  # 对应6分类
])

model_6class.compile(
    optimizer=optimizers.Adam(learning_rate=1e-3),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# =========================
# 4. 模型训练
# =========================
early_stop = callbacks.EarlyStopping(monitor='val_loss', patience=4, restore_best_weights=True)

history_6class = model_6class.fit(
    train_seq, train_labels,
    epochs=20,
    batch_size=128,
    validation_data=(valid_seq, valid_labels),
    callbacks=[early_stop],
    verbose=2
)

# =========================
# 5. 可视化训练过程
# =========================
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history_6class.history['loss'], label='Train Loss')
plt.plot(history_6class.history['val_loss'], label='Val Loss')
plt.title('6-class LSTM - Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history_6class.history['accuracy'], label='Train Acc')
plt.plot(history_6class.history['val_accuracy'], label='Val Acc')
plt.title('6-class LSTM - Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

# =========================
# 6. 模型评估与预测
# =========================
y_pred_prob_6class = model_6class.predict(test_seq)
y_pred_6class = np.argmax(y_pred_prob_6class, axis=1)

print("=== 6-Class Classification Report ===\n")
print(classification_report(test_labels, y_pred_6class))

cm_6class = confusion_matrix(test_labels, y_pred_6class)
plt.figure(figsize=(6, 5))
sns.heatmap(cm_6class, annot=True, fmt="d", cmap="Blues",
            xticklabels=list(label_mapping.keys()),
            yticklabels=list(label_mapping.keys()))
plt.title("Confusion Matrix - 6 Class")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

# =========================
# 7. 导出预测结果
# =========================
result_df_6class = pd.DataFrame({
    "id": test_data["id"],
    "true_label": test_labels,
    "predicted_label": y_pred_6class
})
result_df_6class.to_csv("./predict_result_lstm_6class.csv", index=False)
print("✅ [6分类] 预测结果已保存至: ./predict_result_lstm_6class.csv")
