In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score 
from copy import deepcopy

def prune_heuristcly(tree, X_train, y_train, X_test, y_test, max_itr):
    best_accuracy = 0
    best_tree = None
    queue = [(tree, accuracy_score(y_test, tree.predict(X_test)))]
    itr_count = 0 
    while queue and itr_count < max_itr:
        current_tree, current_accuracy = queue.pop(0)
        if current_accuracy > best_accuracy:
            best_accuracy = current_accuracy
            best_tree = deepcopy(current_tree)

        for node_index in range(current_tree.tree_.node_count):
            if current_tree.tree_.children_left[node_index] == current_tree.tree_.children_right[node_index]:
                continue

            current_tree_copy = deepcopy(current_tree)
            prune_subtree(current_tree_copy, node_index)

            pruned_accuracy = accuracy_score(y_test, current_tree_copy.predict(X_test))
            queue.append((current_tree_copy, pruned_accuracy))
            itr_count += 1
            
    return best_accuracy, best_tree

def prune_subtree(tree, node_index):
    tree.tree_.children_left[node_index] = -1
    tree.tree_.children_right[node_index] = -1


In [None]:
max_itr = (((dt_classifier.tree_.max_depth)*(dt_classifier.tree_.max_depth-1)) / 2)*2**(dt_classifier.tree_.max_depth // 2)

In [None]:
best_accuracy, best_pruned_tree = prune_heuristcly(dt_classifier, X_train, y_train, X_test, y_test, max_itr)
print("The best retrieved accuracy from Reduced Error Pruning: ", best_accuracy)