In [1]:
from sklearn.datasets import load_breast_cancer
import numpy as np
from tqdm import tqdm


## load data
From sklearn, we know:
212(M, 0), 357(B, 1)  
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_breast_cancer.html

In [2]:
data = load_breast_cancer()

# some basic info about data
x, labels = data.data, data.target
print(x.shape, labels.shape)
print(data.target_names)
print(data.feature_names)


(569, 30) (569,)
['malignant' 'benign']
['mean radius' 'mean texture' 'mean perimeter' 'mean area'
 'mean smoothness' 'mean compactness' 'mean concavity'
 'mean concave points' 'mean symmetry' 'mean fractal dimension'
 'radius error' 'texture error' 'perimeter error' 'area error'
 'smoothness error' 'compactness error' 'concavity error'
 'concave points error' 'symmetry error' 'fractal dimension error'
 'worst radius' 'worst texture' 'worst perimeter' 'worst area'
 'worst smoothness' 'worst compactness' 'worst concavity'
 'worst concave points' 'worst symmetry' 'worst fractal dimension']


In [3]:
import pandas as pd

In [4]:
# reorganzie to pandas
df_data = {}
for i, feature in enumerate(data.feature_names):
    df_data[feature] = list(x[:, i])
    
df_data['label'] = list(labels)
df = pd.DataFrame(df_data)

df.head()


Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,label
0,17.99,10.38,122.8,1001.0,0.1184,0.2776,0.3001,0.1471,0.2419,0.07871,...,17.33,184.6,2019.0,0.1622,0.6656,0.7119,0.2654,0.4601,0.1189,0
1,20.57,17.77,132.9,1326.0,0.08474,0.07864,0.0869,0.07017,0.1812,0.05667,...,23.41,158.8,1956.0,0.1238,0.1866,0.2416,0.186,0.275,0.08902,0
2,19.69,21.25,130.0,1203.0,0.1096,0.1599,0.1974,0.1279,0.2069,0.05999,...,25.53,152.5,1709.0,0.1444,0.4245,0.4504,0.243,0.3613,0.08758,0
3,11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,...,26.5,98.87,567.7,0.2098,0.8663,0.6869,0.2575,0.6638,0.173,0
4,20.29,14.34,135.1,1297.0,0.1003,0.1328,0.198,0.1043,0.1809,0.05883,...,16.67,152.2,1575.0,0.1374,0.205,0.4,0.1625,0.2364,0.07678,0


In [5]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 569 entries, 0 to 568
Data columns (total 31 columns):
mean radius                569 non-null float64
mean texture               569 non-null float64
mean perimeter             569 non-null float64
mean area                  569 non-null float64
mean smoothness            569 non-null float64
mean compactness           569 non-null float64
mean concavity             569 non-null float64
mean concave points        569 non-null float64
mean symmetry              569 non-null float64
mean fractal dimension     569 non-null float64
radius error               569 non-null float64
texture error              569 non-null float64
perimeter error            569 non-null float64
area error                 569 non-null float64
smoothness error           569 non-null float64
compactness error          569 non-null float64
concavity error            569 non-null float64
concave points error       569 non-null float64
symmetry error             569 

## Build CART tree

In [6]:
# add label to the matrix X as last col
labels = np.expand_dims(labels, 1)
data = np.c_[x, labels]
print(data.shape)

(569, 31)


In [7]:
def gini_loss(data, col, sep_value):
    """
    Calculate gini loss, return float value.
    
    Args:
      data, data set;
      col, the feature column index;
      sep_value, the speration value to devide the data into two groups;
      
    Returns:
      loss, float gini value;
    """
    assert col<31, "Invalid feature col index"
    # count symbol, eg. d1_m means d1-less than spe_value and the lable is benign
    n_data = data.shape[0]
    d1_m, d1_b, d2_m, d2_b = 0, 0, 0, 0
    d1 = data[list(data[:, col]<sep_value)]
    d2 = data[list(data[:, col]>=sep_value)]
    
    # count
    n_d1, n_d2 = d1.shape[0], d2.shape[0]
    d1_m, d1_b = sum(d1[:, -1]<0.5), sum(d1[:, -1]>0.5)
    d2_m, d2_b = sum(d2[:, -1]<0.5), sum(d2[:, -1]>0.5)
#     print(n_data, n_d1, n_d2, d1_m, d1_b, d2_m, d2_b)
    
    # gini function, deal with zero division
    if (n_d1 == 0):
        loss = 2*(d2_m/n_d2*d2_b/n_d2)    # n_d2 equals n_data, since n_data=n_d1+n_d2
    elif (n_d2 == 0):
        loss = 2*(d1_m/n_d1*d1_b/n_d1)
    else:
        loss = n_d1/n_data*2*(d1_m/n_d1*d1_b/n_d1) + n_d2/n_data*2*(d2_m/n_d2*d2_b/n_d2)
    return loss, d1_m>d1_b, d2_m>d2_b

if __name__ == '__main__':
    # test, where I choose label col as selected feature
    index = 23
    test_sep = (max(data[:, index])+min(data[:,index]))/2
#     test_sep = 868.2
    print('sep_point:', test_sep)
    print('gini_loss:', gini_loss(data, index, test_sep)[0])
    

sep_point: 868.2
gini_loss: 0.14502085890681446


In [8]:
n_features = 30
n_stop_samples = 2    # here, I select last number of samples more than 2 as stop condition

# recursive stop condtion
def searchBestSegPoint(data, n_stop_samples):
    n_samples_p = data.shape[0]     # number of samples at TreeNode p
    if (n_samples_p<=n_stop_samples):
        return
    
    best_seg_point = [1.0, 0, 0]    # record 3 values: gini_loss, index, seg_value
    for i in tqdm(range(n_features)):
        features = np.sort(data[:, i])    # sort from low to high
        for j in range(n_samples_p-1):
            seg_value = (features[j] + features[j+1])/2+1e-5
#             print(i, j, seg_value, data.shape)
            loss, left_label, right_label = gini_loss(data, i, seg_value)
            # update best_seg_point
            if (loss < best_seg_point[0]):
                best_seg_point[0] = loss
                best_seg_point[1] = i
                best_seg_point[2] = seg_value
                
    return best_seg_point

# test
if __name__ == '__main__':
    test = searchBestSegPoint(data, n_stop_samples)
    print(test)

100%|██████████| 30/30 [00:49<00:00,  1.64s/it]

[0.1423191809182917, 20, 16.79501]





In [9]:
def split2subData(data, best_seg_point):
    """
    According to the best seg point, split data to two sub-data at node p.
    """
    _, index, seg_value = best_seg_point
    left_data = data[list(data[:, index]<seg_value)]
    right_data = data[list(data[:, index]>seg_value)]
    return left_data, right_data

# test
# if __name__ == '__main__':
#     left, right = split2subData(data, a)
#     print(left.shape, right.shape)

In [10]:
# define tree node
class TreeNode:
    
    def __init__(self, best_seg_point):
        self.feature_col = best_seg_point[1]
        self.seg_val = best_seg_point[2]
        self.label = 0
        self.left_node = None
        self.right_node = None

# if __name__ == '__main__':
#     best_seg_point = searchBestSegPoint(data[:2, :], n_stop_samples)
#     if best_seg_point:
#         root = TreeNode(best_seg_point)
#     else:
#         print('None')
        

In [11]:
def build_CART(data, n_stop_samples, depth):
    # the tree limist to 3 layer depth
    if depth > 2:
        return None
    
    best_seg_point = searchBestSegPoint(data, n_stop_samples)
    # mark label
    label = 0 if (sum(data[:, -1]<0.5)>sum(data[:, -1]>0.5)) else 1
    
    if best_seg_point:
        gini_loss = best_seg_point[0]
        if gini_loss>5e-3:
            node = TreeNode(best_seg_point)
            node.label = label
            # split to two data-set
            left_data, right_data = split2subData(data, best_seg_point)
            print(gini_loss, left_data.shape, right_data.shape)
            # recursive
            depth += 1
            node.left_node = build_CART(left_data, n_stop_samples, depth)
            node.right_node = build_CART(right_data, n_stop_samples, depth)
        else:
            return None

    else:
        return None
    
    return node

# if __name__ == '__main__':
#     root = build_CART(tt, 10, 0)
    
        

In [13]:
def test(val_set, decision_tree):
    n_samples = val_set.shape[0]
    pred = np.zeros(n_samples)    # predict array
    target = val_set[:, -1].astype('int')
    
    for i in range(n_samples):
        node = decision_tree
        sample = val_set[i, :]
        while(node):
            pred[i] = node.label
            if sample[node.feature_col]<node.seg_val:
                node = node.left_node
            else:
                node = node.right_node
    acc = sum(pred==target)/n_samples
    return acc

# test
# if __name__ == '__main__':
#     print('Test_scores:', test(data, root))

In [15]:
# five-fold cross-validation
k_fold = 5
n_samples = data.shape[0]
n_sample_fold = n_samples//k_fold
acc_list = []

for i in range(k_fold):
    begin_index, end_index = i*n_sample_fold, min((i+1)*n_sample_fold, n_samples)
    val_set = data[begin_index: end_index]
    train_set = np.r_[data[:begin_index], data[end_index:]]
    print('Loop:', i, begin_index, end_index, val_set.shape[0], train_set.shape[0])
    
#     assert val_set.shape[0]+ train_set.shape[0] == n_samples, print('Error:', i, val_set.shape[0], train_set.shape[0])
    
    # TODO, main program
    root = build_CART(train_set, 10, 0)
    acc = test(val_set, root)
    acc_list.append(acc)

  0%|          | 0/30 [00:00<?, ?it/s]

Loop: 0 0 113 113 456


100%|██████████| 30/30 [00:31<00:00,  1.03s/it]
  0%|          | 0/30 [00:00<?, ?it/s]

0.10253899599214104 (333, 31) (123, 31)


100%|██████████| 30/30 [00:16<00:00,  1.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.08045082119156194 (324, 31) (9, 31)


100%|██████████| 30/30 [00:15<00:00,  1.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.0591272614333411 (318, 31) (6, 31)


100%|██████████| 30/30 [00:15<00:00,  1.97it/s]
  7%|▋         | 2/30 [00:00<00:02, 12.50it/s]

0.04795411630016207 (282, 31) (36, 31)


100%|██████████| 30/30 [00:02<00:00, 12.58it/s]
  7%|▋         | 2/30 [00:00<00:02, 13.13it/s]

0.01084010840108401 (3, 31) (120, 31)


100%|██████████| 30/30 [00:02<00:00, 13.17it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Loop: 1 113 226 113 456


100%|██████████| 30/30 [00:31<00:00,  1.05s/it]
  0%|          | 0/30 [00:00<?, ?it/s]

0.12292225094542479 (303, 31) (153, 31)


100%|██████████| 30/30 [00:13<00:00,  2.15it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.07772566730357246 (285, 31) (18, 31)


100%|██████████| 30/30 [00:12<00:00,  2.44it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.04450398724082935 (275, 31) (10, 31)


100%|██████████| 30/30 [00:11<00:00,  2.59it/s]
100%|██████████| 30/30 [00:00<00:00, 442.02it/s]
  3%|▎         | 1/30 [00:00<00:03,  8.39it/s]

0.0231129476584022 (264, 31) (11, 31)


100%|██████████| 30/30 [00:03<00:00,  8.15it/s]
  3%|▎         | 1/30 [00:00<00:03,  9.21it/s]

0.06584395996160702 (10, 31) (143, 31)


100%|██████████| 30/30 [00:03<00:00,  9.29it/s]
  3%|▎         | 1/30 [00:00<00:03,  9.27it/s]

0.027578055747069834 (1, 31) (142, 31)


100%|██████████| 30/30 [00:03<00:00,  9.36it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.013984616921386474 (1, 31) (141, 31)
Loop: 2 226 339 113 456


100%|██████████| 30/30 [00:31<00:00,  1.04s/it]
  0%|          | 0/30 [00:00<?, ?it/s]

0.14437771771737454 (305, 31) (151, 31)


100%|██████████| 30/30 [00:14<00:00,  2.08it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.0979370931053416 (286, 31) (19, 31)


100%|██████████| 30/30 [00:12<00:00,  2.39it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.07463604943443652 (224, 31) (62, 31)


100%|██████████| 30/30 [00:07<00:00,  3.87it/s]
 17%|█▋        | 5/30 [00:00<00:00, 46.65it/s]

0.008035714285714285 (214, 31) (10, 31)


100%|██████████| 30/30 [00:00<00:00, 46.20it/s]
100%|██████████| 30/30 [00:00<00:00, 408.62it/s]
100%|██████████| 30/30 [00:00<00:00, 581.03it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.19441244239631336 (48, 31) (14, 31)
0.07894736842105263 (4, 31) (15, 31)


100%|██████████| 30/30 [00:03<00:00,  8.39it/s]
  3%|▎         | 1/30 [00:00<00:03,  8.37it/s]

0.038924178943100414 (4, 31) (147, 31)


100%|██████████| 30/30 [00:03<00:00,  8.58it/s]
  3%|▎         | 1/30 [00:00<00:03,  8.70it/s]

0.026838132513279284 (1, 31) (146, 31)


100%|██████████| 30/30 [00:03<00:00,  8.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.013604156825696741 (1, 31) (145, 31)
Loop: 3 339 452 113 456


100%|██████████| 30/30 [00:31<00:00,  1.05s/it]
  0%|          | 0/30 [00:00<?, ?it/s]

0.14392413631729373 (268, 31) (188, 31)


100%|██████████| 30/30 [00:10<00:00,  2.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.06145163510510688 (253, 31) (15, 31)


100%|██████████| 30/30 [00:09<00:00,  3.07it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.031118639814291985 (252, 31) (1, 31)


100%|██████████| 30/30 [00:09<00:00,  3.05it/s]
100%|██████████| 30/30 [00:00<00:00, 616.52it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.028813667367884235 (249, 31) (3, 31)
0.1111111111111111 (6, 31) (9, 31)


100%|██████████| 30/30 [00:05<00:00,  5.47it/s]
100%|██████████| 30/30 [00:00<00:00, 220.54it/s]
100%|██████████| 30/30 [00:00<00:00, 592.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.13237302321483635 (27, 31) (161, 31)
0.2518518518518519 (15, 31) (12, 31)


100%|██████████| 30/30 [00:00<00:00, 893.94it/s]
100%|██████████| 30/30 [00:04<00:00,  7.44it/s]
  3%|▎         | 1/30 [00:00<00:03,  7.67it/s]

0.048439392163756395 (2, 31) (159, 31)


100%|██████████| 30/30 [00:03<00:00,  7.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.03608021991753093 (8, 31) (151, 31)
Loop: 4 452 565 113 456


100%|██████████| 30/30 [00:31<00:00,  1.05s/it]
  0%|          | 0/30 [00:00<?, ?it/s]

0.13744573327024961 (265, 31) (191, 31)


100%|██████████| 30/30 [00:10<00:00,  2.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.058497907642610675 (257, 31) (8, 31)


100%|██████████| 30/30 [00:10<00:00,  2.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.047476993391390276 (252, 31) (5, 31)


100%|██████████| 30/30 [00:09<00:00,  3.06it/s]
  3%|▎         | 1/30 [00:00<00:05,  5.40it/s]

0.031240118889521277 (251, 31) (1, 31)


100%|██████████| 30/30 [00:05<00:00,  5.29it/s]
 60%|██████    | 18/30 [00:00<00:00, 88.30it/s]

0.11963366974586122 (44, 31) (147, 31)


100%|██████████| 30/30 [00:00<00:00, 87.55it/s]
100%|██████████| 30/30 [00:00<00:00, 438.26it/s]
100%|██████████| 30/30 [00:00<00:00, 231.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

0.19755244755244755 (18, 31) (26, 31)


100%|██████████| 30/30 [00:03<00:00,  8.74it/s]


In [17]:
print('5-fold:', acc_list)
print('Final cross_validation acc:', sum(acc_list)/len(acc_list))

5-fold: [0.7876106194690266, 0.8584070796460177, 0.9646017699115044, 0.9380530973451328, 0.8938053097345132]
Final cross_validation acc: 0.888495575221239


## BenchMark comparison

Applying sklearn.tree.DecisionTreeClassifier I also build a decision tree 

In [None]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score, train_test_split


In [None]:
# define CART tree
clf = DecisionTreeClassifier()

# cross-validataion
scores = cross_val_score(clf, x, labels, cv=5)
print(scores)
print('The avg acc is: {}'.format(sum(scores)/len(scores)))

In [None]:
# 可视化展示
import pydotplus
from sklearn.tree import export_graphviz
from IPython.display import Image

# fit data
train_x, test_x, train_y, test_y = train_test_split(x, labels, test_size=0.2)
clf = DecisionTreeClassifier(criterion='gini',  max_depth=3)
clf = clf.fit(train_x, train_y)
score = clf.score(test_x, test_y)
print('Test scores: {}'.format(score))

dot_Data = export_graphviz(clf, out_file=None)
graph_clf = pydotplus.graph_from_dot_data(dot_Data)
Image(graph_clf.create_png())

In [None]:
def first_order_traverse(node, i):
    if node:
        print(i, node.feature_col, node.seg_val, node.label)
        first_order_traverse(node.left_node, i+1)
        first_order_traverse(node.right_node, i+2)

first_order_traverse(root, 0)

In [None]:
root.left_node.feature_col