In [1]:
#1.加载所需模块
from sklearn import datasets            #Sklearn数据模块
from sklearn import metrics             #Sklearn衡量模块
from sklearn import feature_extraction
from sklearn import model_selection as ms     #Sklearn划分训练和测试模块
from sklearn import naive_bayes        #Sklearn朴素贝叶斯模块
import os
import numpy as np
import pandas as pd
import cv2                              #cv2模块

In [2]:
#2.定义函数
# 读取单个文件中的相关内容函数
def read_single_file(filename):
    past_header, lines = False, []
    if os.path.isfile(filename):
        f = open(filename, encoding="latin-1")
        for line in f:
            if past_header:
                lines.append(line)
            elif line == '\n':
                past_header = True
        f.close()
    content = '\n'.join(lines)
    return filename, content

# 读取文件夹下所有文件中的相关内容函数
def read_files(path):
    for root, dirnames, filenames in os.walk(path):
        for filename in filenames:
            filepath = os.path.join(root, filename)
            yield read_single_file(filepath)

# 创建一个pandas 的DataFrame
def build_data_frame(extractdir, classification):
    rows = []
    index = []
    for file_name, text in read_files(extractdir):
        rows.append({'text': text, 'class': classification})
        index.append(file_name)
    data_frame = pd.DataFrame(rows, index=index)
    return data_frame

In [3]:
#3.加载数据
#3.1创建DataFrame数据,指定数据集路径
HAM = 0   #非垃圾邮件
SPAM = 1   #垃圾邮件
datadir = '/home/retoo/Desktop/实验/数据集/2.机器学习/7-1 data'
sources = [
    ('beck-s.tar.gz', HAM),
    ('farmer-d.tar.gz', HAM),
    ('kaminski-v.tar.gz', HAM),
    ('kitchen-l.tar.gz', HAM),
    ('lokay-m.tar.gz', HAM),
    ('williams-w3.tar.gz', HAM),
    ('BG.tar.gz', SPAM),
    ('GP.tar.gz', SPAM),
    ('SH.tar.gz', SPAM)
]

data1=pd.DataFrame({
    'model': ['Normal Bayes', 'Multinomial Bayes', 'Bernoulli Bayes'],
    'class': [
        'cv2.ml.NormalBayesClassifier_create()',
        'sklearn.naive_bayes.MultinomialNB()',
        'sklearn.naive_bayes.BernoulliNB()'
    ]
})

#3.2创建DataFrame数据,并读取文件夹中内容构成包含text和class的数据集
data = pd.DataFrame({'text': [], 'class': []})
for source, classification in sources:
    extractdir = '%s/%s' % (datadir, source[:-7])
    data = data.append(build_data_frame(extractdir, classification))

#3.3将读取上来的数据分为数据信息X和结果信息y
counts = feature_extraction.text.CountVectorizer()
X = counts.fit_transform(data['text'].values)
print('数据的特征信息的格式:',X.shape)
y = data['class'].values
print('结果信息的格式:',y.shape)

数据的特征信息的格式: (52050, 643186)
结果信息的格式: (52050,)


In [4]:
#4.将数据划分为训练数据和测试数据
X_train, X_test, y_train, y_test = ms.train_test_split(
     X, y, test_size=0.2, random_state=42)
print('训练数据的特征信息的格式:',X_train.shape)
print('训练结果信息的格式:',y_train.shape)
print('测试数据的特征信息的格式:',X_test.shape)
print('测试结果信息的格式:',y_test.shape)

训练数据的特征信息的格式: (41640, 643186)
训练结果信息的格式: (41640,)
测试数据的特征信息的格式: (10410, 643186)
测试结果信息的格式: (10410,)


In [6]:
#6.使用完整的数据集进行训练
model_naive = naive_bayes.MultinomialNB()
model_naive.fit(X_train, y_train)
# 在训练集上的准确率
score1=model_naive.score(X_train, y_train)
print('在训练集上的准确率:',score1)
#在测试集上的准确率
score2=model_naive.score(X_test, y_test)
print('在测试集上的准确率:',score2)

在训练集上的准确率: 0.9497838616714698
在测试集上的准确率: 0.94447646493756


In [7]:
#7.使用n-gram提升结果
counts = feature_extraction.text.CountVectorizer(
    ngram_range=(1, 2)
)
X = counts.fit_transform(data['text'].values)
X_train, X_test, y_train, y_test = ms.train_test_split(
    X, y, test_size=0.2, random_state=42
)
model_naive = naive_bayes.MultinomialNB()
model_naive.fit(X_train, y_train)
print("使用n-gram后朴素贝叶斯分类器的准确率:",model_naive.score(X_test, y_test))

使用n-gram后朴素贝叶斯分类器的准确率: 0.9701248799231508


In [8]:
#8.使用TD-IDF提升结果
tfidf = feature_extraction.text.TfidfTransformer()
X_new = tfidf.fit_transform(X)
X_train, X_test, y_train, y_test = ms.train_test_split(
    X_new, y, test_size=0.2, random_state=42
)
model_naive = naive_bayes.MultinomialNB()
model_naive.fit(X_train, y_train)
print("使用TD-IDF后朴素贝叶斯分类器的准确率:",model_naive.score(X_test, y_test))
metrics.confusion_matrix(y_test, model_naive.predict(X_test))

使用TD-IDF后朴素贝叶斯分类器的准确率: 0.9907780979827089


array([[3717,   91],
       [   5, 6597]])