In [1]:
import numpy as np                                                                                                                                                                                                                                                          
import math

In [2]:
maxGroup = 6	#最大分组数
minInfoThreshold = 0.5		#停止划分的最小熵
result = dict()		#保存划分结果

In [3]:
data = np.genfromtxt('price-label.csv',skip_header=1,delimiter=',')

In [4]:
#计算按照数据指定数据分组后的香农熵
def calEntropy(data):
    numData = len(data)
    labelCounts = {}
    for feature in data:
        #获得标签,这里只有0或者1
        oneLabel = feature[-1]#设置字典中，标签的默认值
        #如果标签步骤新定义的字典里则创建该标签
        labelCounts.setdefault(oneLabel,0)
        #该类标签下含有数据的个数
        labelCounts[oneLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        #同类标签出现的概率,某一标签出现的次数除以所有标签的数量
        #求熵，以2为底，取对数
        prob = float(labelCounts[key]) / numData
        shannonEnt -= prob * math.log(prob, 2)
    return shannonEnt
shannonEnt = calEntropy(data)
shannonEnt

0.9709505944546686

In [5]:
def split(data):
    #inf为正无穷大
    minEntropy = np.inf
    #记录最终分割索引
    index = -1
    #按照第一列对数据进行排序
    sortData = data[np.argsort(data[:,0])]
    #初始化最终分割数据后的熵
    lastE1,lastE2 = -1,-1
    #返回的数据结构，包含数据和对应的熵
    S1 = dict()
    S2 = dict()
    for i in range(len(sortData)):
        #分割数据集
        splitData1,splitData2 = sortData[: i + 1],sortData[i + 1 :]#计算信息熵
        entropy1,entropy2 = (
            calEntropy(splitData1),
            calEntropy(splitData2),
        ) #计算调和平均熵
        entropy = entropy1 * len(splitData1) / len(sortData) + entropy2 * len(splitData2) / len(sortData)
        #如果调和平均熵小于最小值
        if entropy < minEntropy:
            minEntropy = entropy
            index = i
            lastE1 = entropy1
            lastE2 = entropy2
    S1["entropy"] = lastE1
    S1["data"] = sortData[: index + 1]
    S1["entropy"] = lastE2
    S2["data"] = sortData[index + 1 :]
    return S1,S2,minEntropy
S1,S2,minEntropy = split(data)

In [6]:
#需要遍历的key
needSplitKey = [0]
#将整个数据作为一组
result.setdefault(0,{})
result[0]["entropy"] = np.inf
result[0]["data"] = data
group = 1
for key in needSplitKey:
    S1,S2,entropy = split(result[key]["data"])
    #如果满足条件，熵大于“最小熵”，并且分组数小于“最大分组数”
    if entropy > minInfoThreshold and group < maxGroup:
        result[key] = S1
        newKey = max(result.keys()) + 1
        result[newKey] = S2
        needSplitKey.extend([key])
        needSplitKey.extend([newKey])
        group += 1
    else:
        break

#打印结果
print("result is {}".format(result))

result is {0: {'entropy': 0.0, 'data': array([[  9.,   0.],
       [ 10.,   1.],
       [ 23.,   0.],
       [ 56.,   1.],
       [ 63.,   0.],
       [ 87.,   1.],
       [ 88.,   1.],
       [ 97.,   0.],
       [121.,   1.]])}, 1: {'data': array([[3.420e+02, 1.000e+00],
       [4.530e+02, 1.000e+00],
       [5.610e+02, 1.000e+00],
       [5.920e+02, 1.000e+00],
       [6.410e+02, 1.000e+00],
       [7.640e+02, 0.000e+00],
       [2.323e+03, 0.000e+00],
       [2.398e+03, 1.000e+00],
       [2.764e+03, 1.000e+00]])}, 2: {'data': array([[129.,   0.],
       [222.,   0.]])}}


In [7]:
# 下载数据，保存在 20newsgroups 的文件夹中
from sklearn.datasets import fetch_20newsgroups
path = '20newsgroups'
data = fetch_20newsgroups(data_home=path, # 文件下载的路径
                   subset='train', # 加载那一部分数据集 train/test
                   categories=['alt.atheism','comp.graphics'], # 选取哪一类数据集[类别列表]，默认20类，这里只用两类
                   shuffle=True,  # 将数据集随机排序
                   random_state=42, # 随机数生成器
                   remove=(), # ('headers','footers','quotes') 去除部分文本
                   download_if_missing=True # 如果没有下载过，重新下载
                   )

In [8]:
print(data)

       '20newsgroups/20news_home/20news-bydate-train/comp.graphics/38654',
       '20newsgroups/20news_home/20news-bydate-train/comp.graphics/38240',
       ...,
       '20newsgroups/20news_home/20news-bydate-train/alt.atheism/51296',
       '20newsgroups/20news_home/20news-bydate-train/alt.atheism/51262',
       '20newsgroups/20news_home/20news-bydate-train/comp.graphics/38316'],


In [12]:
d = data.data[0]
def pro(data):
    data = data.lower()
    data.replace('.', ' ')
    data.replace(',', ' ')
    data.replace('\n', ' ')
    result = data.split(' ')
    return result
table = []
for line in data.data:
    table += pro(line)
table = set(table)
table

{'',
 'rotation',
 '<1993apr15.163317.20805@cs.nott.ac.uk>',
 'determinant',
 ':-)\n\ngood',
 'not)',
 'come\nfrom?\n\nwe',
 '25286\ndenver',
 '6011',
 'white\nmandlebrot',
 'compensate',
 'jpl',
 'widening',
 'revelation.',
 'evidence\n>',
 '1.4\n\n',
 'behaviour',
 'a.a.\n\njohn\nthe',
 'call,',
 'works,',
 'windows!',
 'utopia,',
 'atoms,',
 '1988\n\n-',
 'kandolf)',
 'frame\n',
 'believers,\nbut',
 'lifestyle"',
 'necessary?"\n\nyes,',
 '"life,',
 'geoffm\n',
 '>"objectively"',
 'algorithms!\narticle-i.d.:',
 'incidents',
 'cusps,',
 '(jim',
 'significance".\n\n>',
 'exist.\n>\t3)',
 'reexamined...\n\nkeith\n',
 'kuoppala)\ngeoff.arnold@east.sun.com',
 'reputation.',
 'have\nnot',
 'yours',
 'scientific\n:p>fact',
 '"multiply',
 'doctrine',
 'mouths',
 'an\neffort',
 'writes:\n>benedikt',
 'traits',
 '80303.',
 'reduction:',
 ">it's",
 'religion.',
 'levels.\n\nmark\n\n[although',
 'books\nan',
 'photographs',
 'program\nthat',
 'occam',
 'arising',
 'atoms',
 'haberj@informatik.tu