# Quality metrics

There are two different pruning methods:
- validation,
- direct.

The first group works on trees that is already built. The direct method works while building the tree. In both cases we need to set a testing data set to validate the accuracy.

In [1]:
%store -r labels
%store -r data_set

test_labels = [1,1,-1,-1,1,1,1,-1]
test_data_set = [[1,1,2,2],[3,2,1,2],[2,3,1,2],
                [2,2,1,2],[1,3,2,2],[2,1,1,2],
                [3,1,2,1],[2,1,2,2]]

## Validation pruning - Reduced Error Pruning

This method checks the tree after it's build for leafs that does not impact on the accuracy or impact on the accuracy by reducing it.

Let's build the tree first.

In [16]:
import math
import numpy as np
import pydot
import copy
from math import log

class BinaryLeaf:

    def __init__(self, elements, labels, ids):
        self.L = None
        self.R = None
        self.elements = elements
        self.split_feature = None
        self.split_value = None
        self.labels = labels
        self.completed = False
        self.ids = ids
        self.validated = False

    def set_R(self, Rleaf):
        self.R = Rleaf

    def set_L(self, Lleaf):
        self.L = Lleaf

    def set_elements(self, elements):
        self.elements = elements

    def get_elements(self):
        return self.elements

    def set_p(self, threshold):
        self.p = threshold

    def get_L(self):
        return self.L

    def get_R(self):
        return self.R

    def set_completed(self):
        self.completed = True

    def is_completed(self):
        return self.completed

    def get_labels(self):
        return self.labels

    def set_split(self, feature):
        self.split_feature = feature

    def get_split(self):
        return self.split_feature

    def set_split_value(self, value):
        self.split_value = value

    def get_split_value(self):
        return self.split_value

    def set_validated(self):
        self.validated = True

    def is_validated(self):
        return self.validated

    def set_ids(self, ids):
        self.ids = ids

    def get_ids(self):
        return self.ids
    
labels_count = len(np.unique(labels))

ids = list(range(len(data_set)))
root = BinaryLeaf(data_set, labels, ids)
current_node = root    

def get_unique_labels(labels):
    return np.unique(np.array(labels)).tolist()

def get_unique_values(elements):
    features_number = len(elements[0])
    unique = []
    for i in range(features_number):
        features_list = []
        for j in range(len(elements)):
            features_list.append(elements[j][i])
        unique.append(np.unique(np.array(features_list)))
    return unique

def is_leaf_completed(node):
    if node.is_completed():
        if node.get_L() != None and not node.get_L().is_completed():
            return node.get_L()
        elif node.get_R() != None and not node.get_R().is_completed():
            return node.get_R()
        elif node.get_L() == None and node.get_R() == None:
            return None
        elif node.get_L().is_completed() or node.get_R().is_completed():
            new_node = is_leaf_completed(node.get_L())
            if new_node == None:
                return is_leaf_completed(node.get_R())
            else:
                return new_node
        else:
            return None
    return node

def find_leaf_not_completed(root):
    return is_leaf_completed(root)

def get_split_candidates(unique_values):
    split_list = []
    for i in range(len(unique_values)):
        current_list = []
        temp_list = copy.deepcopy(unique_values)
        current_list.append(temp_list[i])
        del temp_list[i]
        current_list.append(temp_list)
        split_list.append(current_list)
    return split_list


def get_number_of_labels_for_value(elements, column_id, label):
    count = 0
    if not isinstance(elements, list):
        elements_list = [elements]
    else:
        elements_list = elements

    column_elements = get_node_elements_column(column_id)

    for i in range(len(elements_list)):
        for j in range(len(column_elements)):
            if column_elements[j] == elements_list[i]:
                if current_node.labels[j] == label:
                    count = count + 1
    return count

def get_node_elements_column(column_id):
    return np.array(current_node.elements)[..., column_id].tolist()

def count_number_of_elements(elements, column_id):
    count = 0
    if isinstance(elements, list):
        column_elements = get_node_elements_column(column_id)
        for i in range(len(elements)):
            count = count + column_elements.count(elements[i])
    else:
        count = count + get_node_elements_column(column_id).count(elements)
    return count

def calculate_omega(elements, column_id):
    t_l = count_number_of_elements(elements[0], column_id)
    t_r = count_number_of_elements(elements[1], column_id)
    p_l = t_l * 1.0 / len(current_node.elements) * 1.0
    p_r = t_r * 1.0 / len(current_node.elements) * 1.0

    sum_p = 0
    labels = get_unique_labels(current_node.labels)
    for i in range(labels_count):
        p_class_t_l = (get_number_of_labels_for_value(elements[0], column_id, labels[i]) * 1.0) / (
                count_number_of_elements(elements[0], column_id) * 1.0)
        p_class_t_r = (get_number_of_labels_for_value(elements[1], column_id, labels[i]) * 1.0) / (
                count_number_of_elements(elements[1], column_id) * 1.0)
        sum_p = sum_p + math.fabs(p_class_t_l - p_class_t_r)
    return 2.0 * p_l * p_r * sum_p

def check_completed(labels, elements):
    ratio = len(get_unique_labels(labels))
    if ratio == 1:
        return True
    elements = sorted(elements)
    duplicated = [elements[i] for i in range(len(elements)) if i == 0 or elements[i] != elements[i - 1]]
    if len(duplicated) == 1:
        return True
    return False

def split_node(current_node, value, split_id, split_history):
    left_leaf = []
    left_leaf_labels = []
    left_leaf_ids = []
    right_leaf = []
    right_leaf_labels = []
    right_leaf_ids = []
    for i in range(len(current_node.elements)):
        if current_node.elements[i][split_id] == value:
            left_leaf.append(current_node.elements[i])
            left_leaf_labels.append(current_node.labels[i])
            left_leaf_ids.append(current_node.ids[i])
        else:
            right_leaf.append(current_node.elements[i])
            right_leaf_labels.append(current_node.labels[i])
            right_leaf_ids.append(current_node.ids[i])
    if len(right_leaf_labels) == 0 or len(left_leaf_labels) == 0:
        current_node.set_completed()
        return current_node, split_history
    split_history.append([str(current_node.ids), str(left_leaf_ids)])
    split_history.append([str(current_node.ids), str(right_leaf_ids)])
    current_node.set_L(BinaryLeaf(left_leaf, left_leaf_labels, left_leaf_ids))
    current_node.set_R(BinaryLeaf(right_leaf, right_leaf_labels, right_leaf_ids))
    current_node.set_split(split_id)
    current_node.set_completed()
    if check_completed(left_leaf_labels, left_leaf):
        current_node.L.set_completed()
    if check_completed(right_leaf_labels, right_leaf):
        current_node.R.set_completed()
    return current_node, split_history

def get_current_node():
    return find_leaf_not_completed()

def build(root_node):
    current_node = root_node
    stop_criterion = False
    split_history = []
    while stop_criterion == False:
        unique_values = get_unique_values(current_node.get_elements())
        max_unique_id = 0
        max_split_id = 0
        max_value = 0
        for i in range(len(unique_values)):
            if len(unique_values[i]) == 1:
                continue
            split_candidates = get_split_candidates(unique_values[i].tolist())
            for j in range(len(split_candidates)):
                current_value = calculate_omega(split_candidates[j], i)
                if max_value < current_value:
                    max_unique_id = i
                    max_split_id = j
                    max_value = current_value
        current_node, split_history = split_node(current_node, unique_values[max_unique_id][max_split_id], max_unique_id, split_history)
        new_node = find_leaf_not_completed(root_node)
        if new_node != None:
            current_node = new_node
        else:
            stop_criterion = True
    return root_node, split_history

In [24]:
cart_tree, split_history_cart = build(current_node)

The current level methods returns the leafs of a given node:

In [25]:
def get_current_level(node):
    if type(node) is not list:
        return [node]
    level = []
    for leaf in node:
        if leaf.get_R() != None:
            level.append(leaf.get_R())
        if leaf.get_L() != None:
            level.append(leaf.get_L())
    return level

Accuracy is calcualated on the tree that is temporarly pruned (changed) to check if the accuracy is greater or less compared to the full version of the tree.

In [26]:
def get_accuracy(cart_tree, test_data_set, test_labels):
    predictions = []
    for sample in test_data_set:
        current_node = cart_tree

        while current_node.get_R() != None or current_node.get_L() != None:
            split_feature = current_node.get_split()
            split_value = current_node.get_split_value()

            if sample[split_feature] == split_value:
                current_node = current_node.get_L()
            else:
                current_node = current_node.get_R()

        prediction = int(np.sign(np.sum(current_node.get_labels())))

        if prediction == 0:
            prediction = -1
        predictions.append(prediction)

    accuracy = np.sum(np.array(predictions) == np.array(test_labels))

    return predictions, accuracy / len(test_labels)


The validation method goes through the tree and cut/prune nodes on a given level. Next, it check the accuracy change with such a pruned tree.

In [27]:
def validate_rep(cart_tree, test_data_set, test_labels):
    old_prediction, old_accuracy = get_accuracy(cart_tree, data_set, labels)
    print("Train accuracy: "+ str(old_accuracy))

    old_accuracy = 0.0

    level = [cart_tree]
    levels = [level]

    while level != []:
        level = get_current_level(levels[-1])
        if level != []:
            levels.append(level)

    for i, level in enumerate(levels):
        print("level ", i)

        for j, leaf in enumerate(level):
            print(" leaf ", j, ", ", leaf.ids)

            if leaf.get_L() != None:

                right_child = leaf.get_R()
                left_child = leaf.get_L()

                leaf.set_R(None)
                leaf.set_L(None)

                prediction, accuracy = get_accuracy(cart_tree, test_data_set, test_labels)

                if i != 0:
                    print("Leaf: " + str(leaf.ids)+": post prunning accuracy is greater or equal than pre prunning accuracy: " + str(accuracy) + ">=" + str(old_accuracy))
                else:
                    if accuracy < old_accuracy:
                        leaf.set_R(right_child)
                        leaf.set_L(left_child)

                        prediction, old_accuracy = get_accuracy(cart_tree, test_data_set, test_labels)

                    else:
                        old_accuracy = accuracy
                        leaf.set_completed()


In [28]:
validate_rep(cart_tree, test_data_set, test_labels)

Train accuracy: 0.6
level  0
 leaf  0 ,  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
level  1
 leaf  0 ,  [0, 2, 3, 6, 7, 8, 10, 11]
Leaf: [0, 2, 3, 6, 7, 8, 10, 11]: post prunning accuracy is greater or equal than pre prunning accuracy: 0.625>=0.625
 leaf  1 ,  [1, 4, 5, 9, 12, 13, 14]
Leaf: [1, 4, 5, 9, 12, 13, 14]: post prunning accuracy is greater or equal than pre prunning accuracy: 0.625>=0.625
level  2
 leaf  0 ,  [0, 6, 7, 8]
 leaf  1 ,  [2, 3, 10, 11]
Leaf: [2, 3, 10, 11]: post prunning accuracy is greater or equal than pre prunning accuracy: 0.625>=0.625
 leaf  2 ,  [1, 4]
 leaf  3 ,  [5, 9, 12, 13, 14]
Leaf: [5, 9, 12, 13, 14]: post prunning accuracy is greater or equal than pre prunning accuracy: 0.625>=0.625
level  3
 leaf  0 ,  [2, 3]
 leaf  1 ,  [10, 11]
 leaf  2 ,  [5, 9, 13]
 leaf  3 ,  [12, 14]
