In [2]:
import numpy as np
import train

In [231]:
class Prune:
    def __init__(self, training_data, validation_data):
        """Class to prune decision trees
        """
        dt = train.DecisionTreeTrain()
        self.tree, self.tree_depth = dt.decision_tree_learning(training_data)

    def get_optimum_pruned_tree(self):
        pass

    def prune_next(self, tree, classified_nodes, node_to_prune):
        """Method to prune and return updated tree:
        
        Arg:
            tree (dict): original or pruned tree of self.tree
            node_to_prune (dict): dictionary representing the node to prune
            classified_nodes (dict): nodes indexed by id

        Returns:
            pruned_tree (dict):
        """
        majority_value = max([node_to_prune["left"]["value"], node_to_prune["right"]["value"]])
        id_path=[node_to_prune["id"]]
        parent_id = node_to_prune["parent_id"]
        # Trace back parents of nodes step by step back to top of tree
        while parent_id != 0:
            id_path.insert(parent_id)
            parent_id = classified_nodes[str(parent_id)]["parent_id"]

        lr_path = []
        left = tree["left"]
        right = tree["right"]
        # Construct left - right path from top of tree to location of node to prune
        for item in id_path:
            if left == classified_nodes[str(item)]["tree"].pop("id"):
                lr_path.append("left")
            elif right == classified_nodes[str(item)]["tree"].pop("id"):
                lr_path.append("right")
            else:
                pass
        # Update node that need to be pruned
        for direction in lr_path:
            tree=tree[direction]
        
        return

    def get_list_to_prune(self, tree):
        """Method to identify node to prune based on criterias:
            1. Node directly connects to two leaves
            2. Node is at deepest level of tree

        Returns:
            classified_nodes (dict): dictionary of nodes with properties indexed by id
            list_to_prune (dict): the node to be pruned sorted by priority
        """
        classified_nodes = self.classify_nodes(tree)
        all_leaf_nodes = []

        for id, node in classified_nodes.items():
            if node["type"] == "leaf_node":
                all_leaf_nodes.append(node)

        list_to_prune = sorted(all_leaf_nodes, key=lambda k: k["tree"]["depth"])
        return (classified_nodes, list_to_prune)

    def classify_nodes(self, tree):
        """Method to classify_nodes, giving it an ID, identify leaf nodes and depth of nodes

        Returns:
            classified_nodes (dict): dictionary index by an id and sub dictionary of decision tree below the node
        """
        id = 0
        tree["id"] = id
        tree["parent_id"] = None
        nodes = [tree]
        classified_nodes = {}
        # First item in classified_nodes should be the entire tree initialised below
        properties_0 = {}
        properties_0["id"] = id
        properties_0["tree"] = tree
        properties_0["type"] = "origin"
        classified_nodes[str(id)] = properties_0

        # Iterate through items in nodes to classify them until empty   
        while nodes != []:
            add_nodes, properties = self.get_sub_tree_properties(nodes[0],id)
            for property in properties:
                if property == None:
                    pass
                else:
                    property
                    property["parent_id"] = nodes[0]["id"]
                    classified_nodes[str(property["id"])] = property
                    id += 1
            nodes.remove(nodes[0])
            for item in add_nodes:
                if item == None:
                    pass
                else:
                    nodes.append(item)
        return classified_nodes
        
    def get_sub_tree_properties(self, tree, id):
        """Method to get the properties of sub trees (brances) of a tree

        Args:
            tree (dict): dictionary representing part of decision tree
            previous_id: id for highest level node in the input tree

        Returns:
            sub_trees (set): list of left branch (dict) and right branch (dict) of a tree
            properties (set): list of properties (dict) for left and right branch
        """   
        left_key = "left"
        right_key = "right"
        sub_tree_left = tree[left_key]
        sub_tree_right = tree[right_key]
        branches = [sub_tree_left, sub_tree_right]
        properties = []
        for branch in branches:
            # When tree == None, this branch is already a leaf, not a node anymore
            if branch != None:
                node_properties={}
                id += 1
                # Give the ID to node itself as well so that can track parents later
                branch["id"] = id
                node_properties["id"] = id
                node_properties["tree"] = branch
                # If both branch are None, this node is a leaf
                if branch[left_key] == None and branch[right_key] == None:
                    node_properties["type"] = "leaf"
                # If node is connected to 2 leaves, it's a leaf_node
                elif (branch[left_key][left_key] == None and
                      branch[left_key][right_key] == None and
                      branch[right_key][left_key] == None and
                      branch[right_key][right_key] == None):
                    node_properties["type"] = "leaf_node"
                # Otherwise its a normal branch
                else:
                    node_properties["type"] = "branch_node"

                properties.append(node_properties)
            else:
                properties.append(None)
        
        return (branches, properties)
        
        
                     


In [232]:
data = np.loadtxt("..\wifi_db\clean_dataset.txt")

In [233]:
data

array([[-64., -56., -61., ..., -82., -81.,   1.],
       [-68., -57., -61., ..., -85., -85.,   1.],
       [-63., -60., -60., ..., -85., -84.,   1.],
       ...,
       [-62., -59., -46., ..., -87., -88.,   4.],
       [-62., -58., -52., ..., -90., -85.,   4.],
       [-59., -50., -45., ..., -88., -87.,   4.]])

In [234]:
dt = train.DecisionTreeTrain()

In [235]:
tree, depth = dt.decision_tree_learning(data)

In [236]:
pr = Prune(data, data)

In [183]:
pr.get_list_to_prune()

({'0': {'id': 0,
   'tree': {'attribute': 1.0,
    'value': -55.0,
    'left': {'attribute': 5.0,
     'value': -60.0,
     'left': {'attribute': 4.0,
      'value': -56.0,
      'left': {'attribute': 3.0,
       'value': -56.0,
       'left': {'attribute': None,
        'value': 1.0,
        'left': None,
        'right': None,
        'depth': 4},
       'right': {'attribute': 7.0,
        'value': -86.0,
        'left': {'attribute': 5.0,
         'value': -63.0,
         'left': {'attribute': 6.0,
          'value': -86.0,
          'left': {'attribute': 1.0,
           'value': -60.0,
           'left': {'attribute': None,
            'value': 4.0,
            'left': None,
            'right': None,
            'depth': 8},
           'right': {'attribute': None,
            'value': 3.0,
            'left': None,
            'right': None,
            'depth': 8},
           'depth': 8},
          'right': {'attribute': None,
           'value': 1.0,
           'left': None,
   

In [237]:
pr.classify_nodes(tree)

{'0': {'id': 0,
  'tree': {'attribute': 1.0,
   'value': -55.0,
   'left': {'attribute': 5.0,
    'value': -60.0,
    'left': {'attribute': 4.0,
     'value': -56.0,
     'left': {'attribute': 3.0,
      'value': -56.0,
      'left': {'attribute': None,
       'value': 1.0,
       'left': None,
       'right': None,
       'depth': 4,
       'id': 15},
      'right': {'attribute': 7.0,
       'value': -86.0,
       'left': {'attribute': 5.0,
        'value': -63.0,
        'left': {'attribute': 6.0,
         'value': -86.0,
         'left': {'attribute': 1.0,
          'value': -60.0,
          'left': {'attribute': None,
           'value': 4.0,
           'left': None,
           'right': None,
           'depth': 8,
           'id': 53},
          'right': {'attribute': None,
           'value': 3.0,
           'left': None,
           'right': None,
           'depth': 8,
           'id': 54},
          'depth': 8,
          'id': 47},
         'right': {'attribute': None,
        

In [243]:
a = {}
b = {}
a["first"] = 1
a["second"] = 2

b["first"] = 1
b["second"] = 2


In [245]:
tree = tree["left"]