In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder

# 多分类标签编码处理函数（不进行 fake/real 映射）
def load_and_prepare_multiclass_data(train_path, val_path, test_path):
    train_df = pd.read_csv(train_path)
    val_df = pd.read_csv(val_path)
    test_df = pd.read_csv(test_path)

    label_encoder = LabelEncoder()
    label_encoder.fit(train_df["label"])  # 使用训练集fit编码器

    for df in [train_df, val_df, test_df]:
        df["label"] = label_encoder.transform(df["label"])

    return train_df, val_df, test_df, label_encoder.classes_

from torch.utils.data import DataLoader
import torch

def predict(model, dataset, batch_size=16):
    loader = DataLoader(dataset, batch_size=batch_size)
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    predictions = []

    with torch.no_grad():
        for batch in loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            outputs = model(**inputs)
            preds = torch.argmax(outputs, dim=1)
            predictions.extend(preds.cpu().numpy())

    return predictions

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix

def analyze_predictions(preds, test_df, label_col="label", label_map=None, top_k=5, output_file="misclassified_samples.csv"):
    if label_map is None:
        # 默认label_map为 index 到字符串的映射
        label_map = {i: str(i) for i in sorted(set(test_df[label_col].tolist()))}

    true_labels = test_df[label_col].tolist()

    print("📊 Classification Report:")
    print(classification_report(true_labels, preds, target_names=[label_map[i] for i in sorted(label_map)]))

    # 混淆矩阵
    cm = confusion_matrix(true_labels, preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=[label_map[i] for i in sorted(label_map)],
                yticklabels=[label_map[i] for i in sorted(label_map)])
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.show()

    # 错误样本
    test_df["true_label"] = test_df[label_col]
    test_df["predicted_label"] = preds
    test_df["correct"] = test_df["true_label"] == test_df["predicted_label"]

    mistakes = test_df[~test_df["correct"]].copy()
    mistakes["true_label_name"] = mistakes["true_label"].map(label_map)
    mistakes["predicted_label_name"] = mistakes["predicted_label"].map(label_map)

    print(f"\n❌ Top {top_k} Misclassified Samples:\n")
    for i, row in mistakes.head(top_k).iterrows():
        print(f"[{i}] Statement: {row['clean_statement_bert']}")
        print(f"    ➤ True: {row['true_label_name']} | Pred: {row['predicted_label_name']}")
        print("")

    # 保存到文件
    mistakes[["clean_statement_bert", "true_label_name", "predicted_label_name"]].to_csv(output_file, index=False)
    print(f"\n📝 Misclassified samples saved to: {output_file}")

    return test_df

# ---------- Dataset ----------
class MultiClassFakeNewsDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.texts = dataframe["clean_statement_bert"].tolist()
        self.labels = dataframe["label"].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        encodings = self.tokenizer(
            self.texts[idx],
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )
        return {
            'input_ids': encodings['input_ids'].squeeze(0),
            'attention_mask': encodings['attention_mask'].squeeze(0),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }# ---------- Multi-Class BERT Model ----------
class MultiClassBERTClassifier(nn.Module):
    def __init__(self, model_name, num_labels=6, dropout=0.3):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.pooler_output
        return self.classifier(self.dropout(cls_output))
# ---------- Training Function for Multi-Class ----------
def train_multiclass_model(train_dataset, val_dataset, model, epochs=6, batch_size=16, lr=2e-5, warmup_ratio=0.1, patience=2):
    optimizer = AdamW(model.parameters(), lr=lr)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    loss_fn = nn.CrossEntropyLoss()
    total_steps = len(train_loader) * epochs
    warmup_steps = int(warmup_ratio * total_steps)

    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    train_losses, val_accuracies = [], []
    best_val_acc = 0
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            outputs = model(**inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        train_losses.append(avg_loss)

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for batch in val_loader:
                inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
                labels = batch['labels'].to(device)
                outputs = model(**inputs)
                preds = torch.argmax(outputs, dim=1)
                correct += (preds == labels).sum().item()
                total += len(labels)
        acc = correct / total
        val_accuracies.append(acc)

        print(f"Epoch {epoch+1}: Train loss={avg_loss:.4f}, Val acc={acc:.4f}")

        if acc > best_val_acc:
            best_val_acc = acc
            patience_counter = 0
            torch.save(model.state_dict(), "best_multiclass_model.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping.")
                break

    plt.plot(range(1, len(train_losses)+1), train_losses, label="Train Loss")
    plt.plot(range(1, len(val_accuracies)+1), val_accuracies, label="Validation Acc")
    plt.xlabel("Epoch")
    plt.legend()
    plt.title("Training Curve")
    plt.show()

    return model

from transformers import BertTokenizer

# 加载数据
train_df, val_df, test_df, class_names = load_and_prepare_multiclass_data(
    "/content/processed_train (2).csv",
    "/content/processed_valid (2).csv",
    "/content/processed_test (2).csv"
)

# Tokenizer & Dataset
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
train_set = MultiClassFakeNewsDataset(train_df, tokenizer, max_len=128)
val_set = MultiClassFakeNewsDataset(val_df, tokenizer, max_len=128)

# 模型
model = MultiClassBERTClassifier("bert-base-uncased", num_labels=6)

# 训练
model = train_multiclass_model(train_set, val_set, model, epochs=10, batch_size=16)

# 构建测试集数据集
test_set = MultiClassFakeNewsDataset(test_df, tokenizer, max_len=128)


# 加载最优模型参数（可选）
model.load_state_dict(torch.load("best_multiclass_model.pt"))

# 预测
preds = predict(model, test_set)

# 评估与可视化
analyze_predictions(preds, test_df)