In [14]:
from math import log2
import operator
import pickle

In [31]:
def cleanDataset(dataset):
    for row in dataset:
        for col in row:
            if col == '?' or col == '':
                print(row)
                dataset.remove(row)

def createDataset(filename):
    with open(filename,'r') as csvfile:
        dataset = [row.strip().split(', ') for row in csvfile.readlines()]
        dataset = [[int(i) if i.isdigit() else i for i in row] for row in dataset]
        # cleanDataset(dataset)
        del(dataset[-1])
        labels=['age','workclass','fnlwgt','education','education-num',
                'marital-status','occupation',
                 'relationship','race','sex','capital-gain','capital-loss','hours-per-week',
                 'native-country']
        labelType = ['continuous', 'uncontinuous', 'continuous',
                      'uncontinuous',
                      'continuous', 'uncontinuous',
                      'uncontinuous', 'uncontinuous', 'uncontinuous',
                      'uncontinuous', 'continuous', 'continuous',
                      'continuous', 'uncontinuous']
        return dataset,labels,labelType
        
#dataset,labels,labelType = createDataset('adult/adult.data')



In [16]:
def calculateEnt(dataset):
    classCount  = {}
    n = len(dataset)
    for vec in dataset:
        classification = vec[-1]
        if classification not in classCount.keys():
            classCount[classification] = 0
        classCount[classification] += 1
    ent = 0.0
    for key in classCount:
        p = classCount[key]/n
        ent += -1*(p*log2(p))
    return ent

In [17]:
def splitDataset(dataset,labelIdx,value):
    newDataset = []
    for vec in dataset:
        if vec[labelIdx] == value:
            tmp = vec[:labelIdx]
            tmp.extend(vec[labelIdx+1:])
            newDataset.append(tmp)
    return newDataset

def splitContinuousDataset(dataset,labelIdx,value):
    biggerDataset = []
    smallerDataset = []
    for vec in dataset:
        tmp = vec[:labelIdx]
        tmp.extend(vec[labelIdx+1:])
        if vec[labelIdx] > value:
            biggerDataset.append(tmp)
        else:
            smallerDataset.append(tmp)
    return biggerDataset,smallerDataset

In [18]:
def calGainRatioUnContinuous(dataset,labelIdx,ent):
    #获取该特征的类别
    featureList = [vec[labelIdx] for vec in dataset]
    uniqueFeature = set(featureList)
    #条件熵
    entc = 0.0
    #特征分裂信息度量
    iv = 0.0
    #计算条件熵和特征分裂信息度量
    for val in uniqueFeature:
        subDataset = splitDataset(dataset,labelIdx,val)
        D = len(subDataset) / len(dataset)
        iv += -1*(D*log2(D))
        entc += D*calculateEnt(subDataset)
    gain = ent - entc
    #全是一个类别split = 1
    if iv == 0 :
        iv = 1
    #计算gainRatio
    gainRadio = gain / iv
    return gainRadio


In [19]:
def calGainRatioContinuous(dataset,labelIdx,ent):
    #获取该连续特征的各种值
    valueList = [vec[labelIdx] for vec in dataset]
    valueList = set(valueList)
    sortedValue = sorted(valueList)
    #获取n-1个分割点
    splitPoint = []
    for i in range(len(sortedValue)-1):
        splitPoint.append((sortedValue[i]+sortedValue[i+1])/2.0)
    #寻找最优分割点(信息增益最大)
    bestGainRatio = 0.0
    bestGain = 0.0
    bestSplitPoint = 0.0
    for i in range(len(splitPoint)):
        entc = 0.0
        iv = 0.0
        biggerDataset,smallerDataset = splitContinuousDataset(dataset,labelIdx,splitPoint[i])
        Db = len(biggerDataset) / len(dataset)
        Ds = len(smallerDataset) / len(dataset)
        entc += Db*calculateEnt(biggerDataset)
        entc += Ds*calculateEnt(smallerDataset)
        gain = ent - entc
        if gain > bestGain:
            bestGain = gain
            iv += -1*(Db*log2(Db))
            iv += -1*(Ds*log2(Ds))
            bestSplitPoint = splitPoint[i]
            # TODO(meguriri): 修正信息增益
            bestGain-=log2(len(valueList)-1)/abs(len(dataset))
            #计算最优信息增益率
            bestGainRatio = bestGain / iv
    return bestGainRatio,bestSplitPoint


In [20]:
#字典转换元组列表
def dict2list(dic):
  keys = dic.keys()
  values = dic.values()
  lst = [(k,v) for k,v in zip(keys,values)]
  return lst

In [21]:
#特征用完后，叶子节点未分类成功，选择出现次数最多的分类
def majority(classList):
  classficationCount = {}
  for i in classList:
    if i not in classficationCount:
      classficationCount[i] = 0
    classficationCount += 1
  sortedClassCount = sorted(dict2list(classficationCount),key= operator.itemgetter(1),reverse=True)
  return sortedClassCount[0][0]

In [23]:
def chooseBestSplit(dataset,labelType):
  #选取的特征是否连续
  isContinuous = False
  #数据集整体的信息熵
  ent = calculateEnt(dataset)
  bestFeatureIdx = -1
  bestGainRatio = 0.0
  bestSplitPoint = 0.0
  #遍历特征
  for featureIdx in range(len(dataset[0])-1):
    if labelType[featureIdx] == 'uncontinuous':
      gainRatio = calGainRatioUnContinuous(dataset,featureIdx,ent)
      if gainRatio > bestGainRatio:
        bestGainRatio = gainRatio
        bestFeatureIdx = featureIdx
        isContinuous = False
    else:
      gainRatio,splitPoint = calGainRatioContinuous(dataset,featureIdx,ent)
      if gainRatio > bestGainRatio:
        bestGainRatio = gainRatio
        bestFeatureIdx = featureIdx
        bestSplitPoint = splitPoint
        isContinuous = True
  return bestFeatureIdx,bestSplitPoint,isContinuous


In [24]:
def createTree(dataset,labels,labelType):
  #递归出口
  #构造叶子节点分类
  classificationList = [vex[-1] for vex in dataset]
  if classificationList.count(classificationList[0]) == len(classificationList):
    return classificationList[0]
  if len(dataset[0]) == 1:
    return majority(classificationList)
  bestFeatureIdx,bestSplitPoint,isContinuous = chooseBestSplit(dataset,labelType)
  bestFeature = labels[bestFeatureIdx]
  tree = {bestFeature:{}}
  del(labels[bestFeatureIdx])
  del(labelType[bestFeatureIdx])
  if isContinuous:
    biggerDataset,smallerDataset = splitContinuousDataset(dataset,bestFeatureIdx,bestSplitPoint)
    subLabels =labels[:]
    subLabelType = labelType[:]
    tree[bestFeature]['>'+str(bestSplitPoint)] = createTree(biggerDataset,subLabels,subLabelType)
    subLabels =labels[:]
    subLabelType = labelType[:]
    tree[bestFeature]['<='+str(bestSplitPoint)] = createTree(smallerDataset,subLabels,subLabelType)
  else:
    # 获取最优特征的全部类别
    featureList = [vex[bestFeatureIdx] for vex in dataset]
    uniqueFeature = set(featureList)
    for feature in uniqueFeature:
      subLabels =labels[:]
      subLabelType = labelType[:]
      subDataset = splitDataset(dataset,bestFeatureIdx,feature)
      tree[bestFeature][feature] = createTree(subDataset,subLabels,subLabelType)
  
  return tree


In [57]:
def classify(tree,data,labels,labelType):
  feature = list(tree.keys())[0]
  dic = tree[feature]
  featureIdx = labels.index(feature)
  classLabel = ''
  
  if labelType[featureIdx] == 'uncontinuous':#非连续特征
    nowFeature = data[featureIdx]
    for key in dic.keys():
      if key == nowFeature:
        #非叶子结点
        if type(dic[key]).__name__ == 'dict':
          classLabel = classify(dic[key],data,labels,labelType)
        #叶子结点
        else:
          classLabel = dic[key]
        break
  else:#连续特征
    nowFeature = float(data[featureIdx])
    firstBranch = list(dic.keys())[0]
    splitPoint = ''
    if str(firstBranch).startswith('>'):
      splitPoint = firstBranch[1:]
    else:
      splitPoint = firstBranch[2:]
    if nowFeature > float(splitPoint):
      #非叶子结点
      if type(dic['>'+str(splitPoint)]).__name__ == 'dict':
        classLabel = classify(dic['>'+str(splitPoint)],data,labels,labelType)
      #叶子结点
      else:
        classLabel = dic['>'+str(splitPoint)]
    else:
      #非叶子结点
      if type(dic['<='+str(splitPoint)]).__name__ == 'dict':
        classLabel = classify(dic['<='+str(splitPoint)],data,labels,labelType)
      #叶子结点
      else:
        classLabel = dic['<='+str(splitPoint)]
  return classLabel 

In [60]:
def test(tree,testFilePath,labels,labelType):
  with open(testFilePath,'r') as csvfile:
    dataset = [row.strip().split(', ') for row in csvfile.readlines() ]
    dataset = [[int(i) if i.isdigit() else i for i in row] for row in dataset]
    #cleanoutdata(dataset)
    del(dataset[0])
    del(dataset[-1])
    #clean(dataset,mydate)
    total = len(dataset)
    correct = 0 
    error = 0 
  for line in dataset:
    result = classify(tree,line,labels,labelType)+'.'
    if result == line[-1]:
      correct += 1
    else:
      print('{} is error;result:{},correct:{}'.format(line,result,line[-1]))
      error += 1
  print('load {} lines data'.format(total))
  print('Correct: {},Error: {},Accuracy: {}'.format(correct,error,correct/total))


In [47]:
def train():
  datasetPath = r'./adult/adult.data'
  dataset,labels,labelType = createDataset(datasetPath)
  nowLabels = labels[:]
  nowLabelType = labelType[:]
  tree = createTree(dataset,nowLabels,nowLabelType)
  print(tree)
  return tree,labels,labelType


In [28]:
def storeTree(tree,fileName):
  f = open(fileName,'wb')
  pickle.dump(tree,f)
  f.close()

def getTree(fileName):
  f = open(fileName,'rb')
  tree = pickle.load(f)
  f.close()
  return tree

In [61]:
def main():
  # tree,labels,labelType = train()
  # storeTree(tree,'./model/tree.txt')
  tree = getTree('./model/tree.txt')
  test(tree,'./adult/adult.test',labels,labelType)

main()

1
2
[38, 'Private', 89814, 'HS-grad', 9, 'Married-civ-spouse', 'Farming-fishing', 'Husband', 'White', 'Male', 0, 0, 50, 'United-States', '<=50K.'] is error;result:.,correct:<=50K.
3
[28, 'Local-gov', 336951, 'Assoc-acdm', 12, 'Married-civ-spouse', 'Protective-serv', 'Husband', 'White', 'Male', 0, 0, 40, 'United-States', '>50K.'] is error;result:<=50K.,correct:>50K.
4
5
6
7
8
[63, 'Self-emp-not-inc', 104626, 'Prof-school', 15, 'Married-civ-spouse', 'Prof-specialty', 'Husband', 'White', 'Male', 3103, 0, 32, 'United-States', '>50K.'] is error;result:.,correct:>50K.
9
10
11
[65, 'Private', 184454, 'HS-grad', 9, 'Married-civ-spouse', 'Machine-op-inspct', 'Husband', 'White', 'Male', 6418, 0, 40, 'United-States', '>50K.'] is error;result:.,correct:>50K.
12
[36, 'Federal-gov', 212465, 'Bachelors', 13, 'Married-civ-spouse', 'Adm-clerical', 'Husband', 'White', 'Male', 0, 0, 40, 'United-States', '<=50K.'] is error;result:.,correct:<=50K.
13
14
[58, '?', 299831, 'HS-grad', 9, 'Married-civ-spouse',

ValueError: could not convert string to float: '50K'