In [48]:
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# 读取数据
data = pd.read_csv('data.csv')

# 数据预处理
data['text'] = data['text'].apply(lambda x: ' '.join(x.split()))

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(data['text'], data['category'], test_size=0.2, random_state=42)

# 文本特征提取
vectorizer = CountVectorizer()
tf_transformer = TfidfTransformer()

X_train_counts = vectorizer.fit_transform(X_train)
X_train_tfidf = tf_transformer.fit_transform(X_train_counts)

# 训练分类器
clf = MultinomialNB().fit(X_train_tfidf, y_train)

# 预测测试集
X_test_counts = vectorizer.transform(X_test)
X_test_tfidf = tf_transformer.transform(X_test_counts)
y_pred = clf.predict(X_test_tfidf)

# 评估模型性能
print("分类报告：\n", classification_report(y_test, y_pred))
print("混淆矩阵：\n", confusion_matrix(y_test, y_pred))
print("准确率：\n", accuracy_score(y_test, y_pred))

# 预测新文本
def predict_new_text(text):
    preprocessed_text = ' '.join(text.split())
    text_counts = vectorizer.transform([preprocessed_text])
    text_tfidf = tf_transformer.transform(text_counts)
    predicted_category = clf.predict(text_tfidf)
    return predicted_category[0]

new_text = "你好"
predicted_category = predict_new_text(new_text)
print("预测的分类:", predicted_category)


分类报告：
               precision    recall  f1-score   support

          健康       0.00      0.00      0.00         4
        健康健身       0.49      1.00      0.66        19
          其他       1.00      0.06      0.11        17

    accuracy                           0.50        40
   macro avg       0.50      0.35      0.26        40
weighted avg       0.66      0.50      0.36        40

混淆矩阵：
 [[ 0  4  0]
 [ 0 19  0]
 [ 0 16  1]]
准确率：
 0.5
预测的分类: 健康健身


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
