In [29]:
import torch
import torch.nn as nn
import numpy as np

import pandas as pd
import joblib

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from transformers import BertTokenizer, BertConfig, BertModel, Trainer, TrainingArguments, BertForSequenceClassification


In [30]:
class CustomBertForSequenceClassification(BertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size + 2, config.num_labels)  # 增加了2个特征

    def forward(self, input_ids=None, attention_mask=None, month=None, hour=None, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output  # BERT模型的pooler_output
        
        # 将时间特征拼接到pooled_output中
        time_features = torch.stack((month, hour), dim=1).float()  # 创建时间特征张量，假设 month 和 hour 的形状都是 [batch_size]
        pooled_output = torch.cat((pooled_output, time_features), dim=1)  # 在最后一个维度上拼接
        
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        outputs = (logits,) + outputs[2:]  # 将 logits 与 BERT 模型的其他输出组合在一起
        
        if labels is not None:
            if self.num_labels == 1:
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs
        
        return outputs

In [31]:
# 加载BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 加载大类和小类编码器
major_encoder = joblib.load('major_encoder.pkl')
minor_encoder = joblib.load('minor_encoder.pkl')

# 初始化变量
num_major_labels = None
num_minor_labels = None
num_labels = None

# 打开并读取文件内容
with open('labels_info.txt', 'r') as f:
    lines = f.readlines()

# 解析文件内容并赋值给变量
for line in lines:
    if line.startswith("Number of major labels:"):
        num_major_labels = int(line.split(": ")[1].strip())
    elif line.startswith("Number of minor labels:"):
        num_minor_labels = int(line.split(": ")[1].strip())
    elif line.startswith("Total number of labels:"):
        num_labels = int(line.split(": ")[1].strip())

# 加载自定义模型权重
config = BertConfig.from_pretrained('bert-base-chinese', num_labels=num_labels)
model = CustomBertForSequenceClassification.from_pretrained('bert-base-chinese', config=config)  # 多标签分类，输出类别数量需适当调整
model.load_state_dict(torch.load('model.pth'))  # 加载训练好的模型权重
model.eval()

# 读取CSV文件的前5行数据
csv_file = "data/data_cleaned_enhanced.csv"  # 替换成你的CSV文件路径
df = pd.read_csv(csv_file, header=0)  # 读取前5行数据

def preprocess_data(data):
    # 提取日期和时间特征
    data['date'] = pd.to_datetime(data['date'])
    data['time'] = pd.to_datetime(data['time'], format='%H:%M:%S').dt.time
    data['month'] = data['date'].dt.month
    data['hour'] = data['time'].apply(lambda x: x.hour)

    # 对标签进行编码
    data['major_label_encoded'] = major_encoder.transform(data['bjlbmc'])
    data['minor_label_encoded'] = minor_encoder.transform(data['bjlxmc'])

    return data

# 执行推理
def predict(input_ids, attention_mask, month, hour, true_major_labels, true_minor_labels):
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, month=month, hour=hour)
        logits = outputs[0]

    major_preds = np.argmax(logits[:, :num_major_labels].cpu().numpy(), axis=1)
    minor_preds = np.argmax(logits[:, num_major_labels:].cpu().numpy(), axis=1)

    major_labels = major_encoder.inverse_transform(major_preds)
    minor_labels = minor_encoder.inverse_transform(minor_preds)

    # 计算准确率
    major_accuracy = accuracy_score(true_major_labels, major_preds)
    minor_accuracy = accuracy_score(true_minor_labels, minor_preds)

    return major_labels, minor_labels, major_accuracy, minor_accuracy

# 预处理数据
df = preprocess_data(df)

# 划分训练集和测试集
train_texts, test_texts, train_major, test_major, train_minor, test_minor, train_month, test_month, train_hour, test_hour = train_test_split(
    df['content'], df['major_label_encoded'], df['minor_label_encoded'], df['month'], df['hour'], test_size=0.2, random_state=42)

# 文本处理，tokenize
tokenized_test = tokenizer(list(test_texts), padding='max_length', truncation=True, max_length=64, return_tensors='pt')

# 转换为张量
input_ids = tokenized_test['input_ids'].to(device)
attention_mask = tokenized_test['attention_mask'].to(device)
month = torch.tensor(test_month.values).to(device)
hour = torch.tensor(test_hour.values).to(device)

# 执行推理过程
major_labels, minor_labels, major_accuracy, minor_accuracy = predict(input_ids, attention_mask, month, hour, test_major.values, test_minor.values)

# 输出预测结果和准确率
for i in range(len(test_texts)):
    print(f"Sample {i+1}:")
    print(f"Predicted Major Label: {major_labels[i]}")
    print(f"Predicted Minor Label: {minor_labels[i]}")
    print()

print(f"Major Label Accuracy: {major_accuracy}")
print(f"Minor Label Accuracy: {minor_accuracy}")

Some weights of CustomBertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Sample 1:
Predicted Major Label: 治安警情
Predicted Minor Label: 盗窃

Sample 2:
Predicted Major Label: 治安警情
Predicted Minor Label: 盗窃

Sample 3:
Predicted Major Label: 群众求助
Predicted Minor Label: 开锁求助

Sample 4:
Predicted Major Label: 刑事案件
Predicted Minor Label: 诈骗

Sample 5:
Predicted Major Label: 治安警情
Predicted Minor Label: 抢夺

Sample 6:
Predicted Major Label: 社会联动
Predicted Minor Label: 安全生产监督

Sample 7:
Predicted Major Label: 社会联动
Predicted Minor Label: 其它社会联动

Sample 8:
Predicted Major Label: 群众求助
Predicted Minor Label: 自杀求助

Sample 9:
Predicted Major Label: 治安警情
Predicted Minor Label: 其它治安警情

Sample 10:
Predicted Major Label: 群众求助
Predicted Minor Label: 走失求助

Sample 11:
Predicted Major Label: 交通警情
Predicted Minor Label: 交通设施

Sample 12:
Predicted Major Label: 交通警情
Predicted Minor Label: 其它交通管理

Sample 13:
Predicted Major Label: 刑事案件
Predicted Minor Label: 诈骗

Sample 14:
Predicted Major Label: 交通警情
Predicted Minor Label: 交通事故

Sample 15:
Predicted Major Label: 群体事件
Predicted Minor Labe