In [25]:
from csv import reader

In [26]:
# Load a CSV file
def load_csv(filename):
    file = open(filename, "r")
    lines = reader(file)
    dataset = list(lines)
    return dataset
 
# Convert string column to float
def str_column_to_float(dataset, column):
    for row in dataset:
        row[column] = float(row[column].strip())


In [27]:
def gini_index(groups, classes):
    # count all samples at split point
    n_instances = float(sum([len(group) for group in groups]))
    # sum weighted Gini index for each group
    gini = 0.0
    
    for group in groups:
        size = float(len(group))
        # avoid divide by zero
        if size == 0:
            continue
        score = 0.0
        
        # score the group based on the score for each class
        for class_val in classes:
            p = [row[-1] for row in group].count(class_val) / size
            score += p * p
        
        # weight the group score by its relative size
        gini += (1.0 - score) * (size / n_instances)
    
    return gini


In [28]:
# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
    left, right = list(), list()
    
    for row in dataset:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
    
    return left, right


In [29]:
# Select the best split point for a dataset
def get_split(dataset):
    class_values = list(set(row[-1] for row in dataset))
    b_index, b_value, b_score, b_groups = 999, 999, 999, None
    
    for index in range(len(dataset[0])-1):
        
        for row in dataset:
            groups = test_split(index, row[index], dataset)
            gini = gini_index(groups, class_values)
            print('X%d < %.3f Gini=%.3f' % ((index+1), row[index], gini))
            
            if gini < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], gini, groups
    
    return {'index':b_index, 'value':b_value, 'groups':b_groups}


In [30]:
def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)


In [31]:
def split(node, max_depth, min_size, depth):
    
    left, right = node['groups']
    del(node['groups'])
    
    # check for a no split
    if not left or not right:
        node['left'] = node['right'] = to_terminal(left + right)
        return
    
    # check for max depth
    if depth >= max_depth:
        node['left'], node['right'] = to_terminal(left), to_terminal(right)
        return
    
    # process left child
    if len(left) <= min_size:
        node['left'] = to_terminal(left)
    else:
        node['left'] = get_split(left)
        split(node['left'], max_depth, min_size, depth+1)
    
    # process right child
    if len(right) <= min_size:
        node['right'] = to_terminal(right)
    else:
        node['right'] = get_split(right)
        split(node['right'], max_depth, min_size, depth + 1)


In [32]:
# Build a decision tree
def build_tree(train, max_depth, min_size):
    root = get_split(train)
    split(root, max_depth, min_size, 1)
    return root


In [33]:
# Print a decision tree
def predict(node, row):
    
    if row[node['index']] < node['value']:
        if isinstance(node['left'], dict):
            return predict(node['left'], row)
        else:
            return node['left']
    
    else:
        if isinstance(node['right'], dict):
            return predict(node['right'], row)
        else:
            return node['right']


In [34]:
dataset = load_csv('haberman.csv')
del(dataset[-1])

for i in range(len(dataset[0])):
    str_column_to_float(dataset, i)


In [36]:
def decision_tree(train, test, max_depth, min_size):
    tree = build_tree(train, max_depth, min_size)
    predictions = list()
    for row in test:
        prediction = predict(tree, row)
        predictions.append(prediction)
    return(predictions)

In [37]:
train = dataset
test = dataset
predicted_classes = decision_tree(train, test, 10, 2)

X1 < 30.000 Gini=0.389
X1 < 30.000 Gini=0.389
X1 < 30.000 Gini=0.389
X1 < 31.000 Gini=0.388
X1 < 31.000 Gini=0.388
X1 < 33.000 Gini=0.387
X1 < 33.000 Gini=0.387
X1 < 34.000 Gini=0.386
X1 < 34.000 Gini=0.386
X1 < 34.000 Gini=0.386
X1 < 34.000 Gini=0.386
X1 < 34.000 Gini=0.386
X1 < 34.000 Gini=0.386
X1 < 34.000 Gini=0.386
X1 < 35.000 Gini=0.388
X1 < 35.000 Gini=0.388
X1 < 36.000 Gini=0.387
X1 < 36.000 Gini=0.387
X1 < 37.000 Gini=0.386
X1 < 37.000 Gini=0.386
X1 < 37.000 Gini=0.386
X1 < 37.000 Gini=0.386
X1 < 37.000 Gini=0.386
X1 < 37.000 Gini=0.386
X1 < 38.000 Gini=0.384
X1 < 38.000 Gini=0.384
X1 < 38.000 Gini=0.384
X1 < 38.000 Gini=0.384
X1 < 38.000 Gini=0.384
X1 < 38.000 Gini=0.384
X1 < 38.000 Gini=0.384
X1 < 38.000 Gini=0.384
X1 < 38.000 Gini=0.384
X1 < 38.000 Gini=0.384
X1 < 39.000 Gini=0.381
X1 < 39.000 Gini=0.381
X1 < 39.000 Gini=0.381
X1 < 39.000 Gini=0.381
X1 < 39.000 Gini=0.381
X1 < 39.000 Gini=0.381
X1 < 40.000 Gini=0.381
X1 < 40.000 Gini=0.381
X1 < 40.000 Gini=0.381
X1 < 41.000

X3 < 2.000 Gini=0.355
X3 < 3.000 Gini=0.351
X3 < 4.000 Gini=0.354
X3 < 0.000 Gini=0.389
X3 < 4.000 Gini=0.354
X3 < 0.000 Gini=0.389
X3 < 4.000 Gini=0.354
X3 < 5.000 Gini=0.348
X3 < 0.000 Gini=0.389
X3 < 1.000 Gini=0.364
X3 < 0.000 Gini=0.389
X3 < 0.000 Gini=0.389
X3 < 0.000 Gini=0.389
X3 < 4.000 Gini=0.354
X3 < 1.000 Gini=0.364
X3 < 3.000 Gini=0.351
X3 < 9.000 Gini=0.353
X3 < 24.000 Gini=0.388
X3 < 12.000 Gini=0.367
X3 < 1.000 Gini=0.364
X3 < 1.000 Gini=0.364
X3 < 2.000 Gini=0.355
X3 < 1.000 Gini=0.364
X3 < 0.000 Gini=0.389
X3 < 11.000 Gini=0.360
X3 < 23.000 Gini=0.382
X3 < 5.000 Gini=0.348
X3 < 7.000 Gini=0.359
X3 < 7.000 Gini=0.359
X3 < 3.000 Gini=0.351
X3 < 0.000 Gini=0.389
X3 < 46.000 Gini=0.389
X3 < 0.000 Gini=0.389
X3 < 7.000 Gini=0.359
X3 < 19.000 Gini=0.376
X3 < 1.000 Gini=0.364
X3 < 0.000 Gini=0.389
X3 < 6.000 Gini=0.356
X3 < 15.000 Gini=0.374
X3 < 1.000 Gini=0.364
X3 < 0.000 Gini=0.389
X3 < 1.000 Gini=0.364
X3 < 18.000 Gini=0.378
X3 < 0.000 Gini=0.389
X3 < 3.000 Gini=0.351
X3

X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 1.000 Gini=0.259
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 2.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 2.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 1.000 Gini=0.259
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 1.000 Gini=0.259
X3 < 0.000 Gini=0.260
X3 < 1.000 Gini=0.259
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 2.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000 Gini=0.260
X3 < 0.000

X2 < 58.000 Gini=0.480
X2 < 58.000 Gini=0.480
X2 < 58.000 Gini=0.480
X3 < 0.000 Gini=0.480
X3 < 0.000 Gini=0.480
X3 < 0.000 Gini=0.480
X3 < 0.000 Gini=0.480
X3 < 0.000 Gini=0.480
X3 < 0.000 Gini=0.480
X3 < 0.000 Gini=0.480
X3 < 1.000 Gini=0.444
X3 < 0.000 Gini=0.480
X3 < 0.000 Gini=0.480
X1 < 62.000 Gini=0.494
X1 < 62.000 Gini=0.494
X1 < 64.000 Gini=0.492
X1 < 65.000 Gini=0.481
X1 < 65.000 Gini=0.481
X1 < 66.000 Gini=0.489
X1 < 66.000 Gini=0.489
X1 < 66.000 Gini=0.489
X1 < 70.000 Gini=0.417
X2 < 58.000 Gini=0.494
X2 < 58.000 Gini=0.494
X2 < 58.000 Gini=0.494
X2 < 58.000 Gini=0.494
X2 < 58.000 Gini=0.494
X2 < 58.000 Gini=0.494
X2 < 58.000 Gini=0.494
X2 < 58.000 Gini=0.494
X2 < 58.000 Gini=0.494
X3 < 0.000 Gini=0.494
X3 < 0.000 Gini=0.494
X3 < 0.000 Gini=0.494
X3 < 0.000 Gini=0.494
X3 < 0.000 Gini=0.494
X3 < 0.000 Gini=0.494
X3 < 0.000 Gini=0.494
X3 < 1.000 Gini=0.444
X3 < 0.000 Gini=0.494
X1 < 62.000 Gini=0.469
X1 < 62.000 Gini=0.469
X1 < 64.000 Gini=0.458
X1 < 65.000 Gini=0.467
X1 < 65

X3 < 25.000 Gini=0.218
X3 < 13.000 Gini=0.245
X3 < 19.000 Gini=0.231
X1 < 46.000 Gini=0.000
X1 < 47.000 Gini=0.000
X1 < 48.000 Gini=0.000
X1 < 48.000 Gini=0.000
X1 < 49.000 Gini=0.000
X1 < 50.000 Gini=0.000
X1 < 51.000 Gini=0.000
X1 < 53.000 Gini=0.000
X1 < 53.000 Gini=0.000
X1 < 53.000 Gini=0.000
X2 < 65.000 Gini=0.000
X2 < 63.000 Gini=0.000
X2 < 58.000 Gini=0.000
X2 < 58.000 Gini=0.000
X2 < 64.000 Gini=0.000
X2 < 63.000 Gini=0.000
X2 < 59.000 Gini=0.000
X2 < 60.000 Gini=0.000
X2 < 63.000 Gini=0.000
X2 < 65.000 Gini=0.000
X3 < 20.000 Gini=0.000
X3 < 23.000 Gini=0.000
X3 < 11.000 Gini=0.000
X3 < 11.000 Gini=0.000
X3 < 10.000 Gini=0.000
X3 < 13.000 Gini=0.000
X3 < 13.000 Gini=0.000
X3 < 9.000 Gini=0.000
X3 < 24.000 Gini=0.000
X3 < 12.000 Gini=0.000
X1 < 54.000 Gini=0.397
X1 < 54.000 Gini=0.397
X1 < 54.000 Gini=0.397
X1 < 56.000 Gini=0.394
X1 < 57.000 Gini=0.396
X1 < 57.000 Gini=0.396
X1 < 59.000 Gini=0.388
X1 < 60.000 Gini=0.396
X1 < 60.000 Gini=0.396
X1 < 62.000 Gini=0.364
X1 < 62.000 

In [38]:
def accuracy(actual, predicted_classes) :
    correct = 0
    
    for row_ind in range(len(actual)) :
        if test[row_ind][-1] == predicted_classes[row_ind] :
            correct += 1
    
    return(correct * 100 / len(actual))

In [39]:
print('Accuracy : ' + str(accuracy(test, predicted_classes)) + '%')

Accuracy : 92.48366013071896%
