The decision tree is built from scratch, however I use some helper functions from other libraries: pandas for reading the data into a dataframe, random for shuffling the data before splitting, scipy.stats for the entropy function, and sklearn.metrics for the F1 score.

In [1]:
import pandas as pd
import random
import scipy.stats
import sklearn.metrics
import matplotlib.pyplot as plt

Reading the data into a dataframe

In [2]:
original_data = pd.read_csv('HTRU_2.csv')

Here I split the data into the pulsars (data_1) and not pulsars (data_0). I then shuffle these individually, and take the first 90% of the shuffled data_0 and data_1, and use this for training data. The remaining 10% is for testing data. I split them individually so I have an equal proportion of positive samples in the training and testing data.

In [3]:
labels = list(range(original_data.shape[0]))

data_0 = original_data.Class == 0
labels_0 = [i for i in data_0.index if data_0[i]]
data_1 = original_data.Class == 1
labels_1 = [i for i in data_1.index if data_1[i]]

random.seed(503)
random.shuffle(labels_0)
random.shuffle(labels_1)

split_point_0 = int(0.9*len(labels_0))
split_point_1 = int(0.9*len(labels_1))

train_labels = labels_0[:split_point_0] + labels_1[:split_point_1]
test_labels = labels_0[split_point_0:] + labels_1[split_point_1:]

random.shuffle(train_labels)
random.shuffle(test_labels)

train_data = original_data.loc[train_labels].copy(deep=True)
test_data = original_data.loc[test_labels].copy(deep=True)

The class Node represents any node in the decision tree which has children

PredictionNode represents a node with no children i.e. it is a leaf node

DecisionTree represents a collection of nodes, beginning with the root node. This class keeps track of how many nodes are in the tree

In [4]:
class DecisionNode:
    def __init__(self, column, threshold, info_gain):
        self.column = column
        self.threshold = threshold
        self.info_gain = info_gain
    def addLeftChild(self, child):
        self.left_child = child
    def addRightChild(self, child):
        self.right_child = child
    def print_node(self):
        print("Column: ", self.column)
        print("Threshold: ", self.threshold)
        print("Info gain: ", self.info_gain)
        
class PredictionNode:
    def __init__(self, value):
        self.value = value
    def prediction(self):
        return self.value
    def print_node(self):
        print("Prediction: ", self.prediction)
        
# class DecisionTree:
#     def __init__(self, root_node):
#         self.root_node = root_node
#         self.count = 1
#     def addNode(self):
#         self.count += 1
#     def countNodes(self):
#         return count        

split_column() creates a temporary column 'temp' based on the value of the provided column. If the value in the provided column is above the threshold value, the 'temp' value is 1, otherwise 0. This function is used in information_gain(), to calculate the information gain obtained from a particular variable split with a particular threshold value

In [5]:
def split_column(column_name, threshold, data):
    data['temp'] = 0
    filt = data[column_name] > threshold
    data.loc[filt,'temp'] = 1
    
def information_gain(entropy, data):
#     if(data.shape[0] < 0.01*original_data.shape[0] or data['Class'].mode()[0] == data['Class'].shape[0]):
    if(data['Class'].mode()[0] == data['Class'].shape[0]):
        return 0
    filt = data['temp'] == 0
    counts_0 = data.loc[filt, 'Class'].value_counts()
    entropy_0 = scipy.stats.entropy(counts_0)
    size_0 = data[filt].shape[0]
    filt = data['temp'] == 1
    counts_1 = data.loc[filt, 'Class'].value_counts()
    entropy_1 = scipy.stats.entropy(counts_1)
    size_1 = data[filt].shape[0]
    orig_size = data.shape[0]

    return entropy - ((size_0/orig_size)*entropy_0 + (size_1/orig_size)*entropy_1)

This takes one variable, finds 10 threshold values and calculates which threshold value produces the largest information gain for the target variable

In [6]:
def max_info_gain_per_variable(column, data):
    current_entropy = scipy.stats.entropy(data.Class.value_counts())
    value_range = data[column].max() - data[column].min()
    step = value_range/10
    max_info_gain = 0
    max_info_gain_threshold = 0
    for i in range(1,10):
        threshold = data[column].min() + i*step
        split_column(column, threshold, data)
        info_gain = information_gain(current_entropy, data)
        if(info_gain > max_info_gain):
            max_info_gain = info_gain
            max_info_gain_threshold = threshold
    return (max_info_gain, max_info_gain_threshold)

This loops over all variables, finding the maximum information gain at any point in the decision tree. If the maximum information gain is below some minimum value, the function returns 0

In [7]:
def max_info_gain_overall(data, entropy, used_columns, min_value = 0.005, vis = False):
    max_info_gain = 0
    max_info_gain_column = ''
    max_info_gain_threshold = 0
    
    columns = [c for c in data.columns[:8] if c not in used_columns]
    for column in columns:
        result = max_info_gain_per_variable(column, data)
        if (result[0] > max_info_gain):
            max_info_gain = result[0]
            max_info_gain_column = column
            max_info_gain_threshold = result[1]
    if(max_info_gain < min_value):
        return 0
    max_info_node = DecisionNode(max_info_gain_column, max_info_gain_threshold, max_info_gain)
    if vis:
        plt.scatter(data[max_info_gain_column], data.Class)
        plt.axvline(x=max_info_gain_threshold)
        plt.show()
    return max_info_node

build_subtree is a recursive function which builds the decision tree in pre-order.
build_ID3_tree provides the root node and makes use of build_subtree to create the rest of the tree structure.
build_subtree also prints a visual representation of the tree structure

In [8]:
def build_subtree(root, data, level, used_columns):
    
    # left node, or false node
    filt = data[root.column] < root.threshold
    data_subset = data.loc[filt].copy(deep=True)
    entropy = scipy.stats.entropy(data_subset.Class.value_counts())
    left_child = max_info_gain_overall(data_subset, entropy, used_columns)    
    if(left_child == 0):
        root.addLeftChild(PredictionNode(data_subset.Class.mode()[0]))
    else:
        root.addLeftChild(left_child) 
        level += 1
        used_columns.append(left_child.column)
        build_subtree(left_child, data_subset, level, used_columns)
        level -= 1
        used_columns.pop()
        
    # right node, or true node
    filt = data[root.column] > root.threshold
    data_subset = data.loc[filt].copy(deep=True)
    entropy = scipy.stats.entropy(data_subset.Class.value_counts())
    right_child = max_info_gain_overall(data_subset, entropy, used_columns)
    if(right_child == 0):
        root.addRightChild(PredictionNode(data_subset.Class.mode()[0]))
    else:
        root.addRightChild(right_child)
        level += 1
        used_columns.append(right_child.column)
        build_subtree(right_child, data_subset, level, used_columns)
        
# this function checks for redundant subtrees i.e. subtrees where every prediction node has the same value
# this function is used by build_ID3_tree to prune the redundant subtrees
def check_subtree(node, values):
    if(type(node) == PredictionNode):
        values.append(node.value)
    else:
        check_subtree(node.left_child, values)
        check_subtree(node.right_child, values)
    return values

def build_ID3_tree(data):
    entropy = scipy.stats.entropy(data.Class.value_counts())
    used_columns = []
    root_node = max_info_gain_overall(data, entropy, used_columns)
    used_columns.append(root_node.column)
    build_subtree(root_node, data, 0, used_columns)
    
    node_stack = []
    node_stack.append(root_node)
    node = root_node
    nodes_to_prune = {}
    
    
    while(len(node_stack) > 0):
        if(type(node) != PredictionNode):
            values = []
            values = check_subtree(node, values)
            if(len(set(values)) == 1):
                nodes_to_prune[node] = set(values).pop()
                node = node_stack.pop()
            else:
                node_stack.append(node.right_child)
                node = node.left_child
                node_stack.append(node)    
        else:
            node = node_stack.pop()        
        
    node_stack = []
    node_stack.append(root_node)
    node = root_node
    while(len(node_stack) > 0):
        if(type(node) != PredictionNode):
            if(node.right_child in nodes_to_prune):
                node.addRightChild(PredictionNode(nodes_to_prune.get(node.right_child)))
            else:
                node_stack.append(node.right_child)
            if(node.left_child in nodes_to_prune):
                node.addLeftChild(PredictionNode(nodes_to_prune.get(node.left_child)))
            else:
                node_stack.append(node.left_child)
            node = node_stack.pop()
        else:
            node = node_stack.pop()
    
    return root_node
        
ID3_tree = build_ID3_tree(train_data)

In [9]:
def visualise_tree(node, level):
    if(type(node.left_child) == PredictionNode):
        print(level*"\t", "IF ", node.column, " < ", node.threshold)
        print((level+1)*"\t", "THEN ", node.left_child.value)
    else:
        print(level*"\t", "IF ", node.column, " < ", node.threshold)
        level += 1
        visualise_tree(node.left_child, level)
        level -= 1
        
    if(type(node.right_child) == PredictionNode):
        print(level*"\t", "ELSE ", node.column, " > ", node.threshold)
        print((level+1)*"\t", "THEN ", node.right_child.value)
    else:
        print(level*"\t", "ELSE ", node.column, " > ", node.threshold)
        level += 1
        visualise_tree(node.right_child, level)
        
visualise_tree(ID3_tree, 0)

 IF  Excess kurtosis of the integrated profile  <  1.1076487870999996
	 THEN  0
 ELSE  Excess kurtosis of the integrated profile  >  1.1076487870999996
	 IF  Skewness of the DM-SNR curve  <  100.04292962330001
		 THEN  1
	 ELSE  Skewness of the DM-SNR curve  >  100.04292962330001
		 IF  Mean of the DM-SNR curve  <  2.5133779262000004
			 THEN  0
		 ELSE  Mean of the DM-SNR curve  >  2.5133779262000004
			 THEN  1


In [10]:
def make_prediction(node, data_input):
    while(type(node) != PredictionNode):
        if data_input[node.column] > node.threshold:
            node = node.right_child
        else:
            node = node.left_child
    return node.prediction()

Finding the accuracy of the ID3_tree on the test data

In [11]:
real_values = [val for val in test_data.Class]
pred_values = []
for i in test_labels:
    pred = make_prediction(ID3_tree, test_data.loc[i])
    pred_values.append(pred)
correct = len([1 for r,p in zip(real_values, pred_values) if (r==p)]) / float(len(pred_values))
print("Test set accuracy: ", correct)

Test set accuracy:  0.9743016759776536


In [12]:
f1_score = sklearn.metrics.f1_score(real_values, pred_values)
print("Test set F1 score: ", f1_score)

Test set F1 score:  0.8516129032258065
