In [3]:
import pandas as pd
import numpy as np

from functions._gini import calculate_gini_impurity
from functions._best_threshold import find_optimal_threshold


In [4]:
data = np.array([4.9, 5.0, 5.5, 5.7, 6.0, 6.2, 6.5, 6.8])
labels = np.array(['A', 'A', 'B', 'A', 'B', 'B', 'B', 'A'])

# Create a DataFrame
df = pd.DataFrame({'feature': data, 'target': labels})
df

Unnamed: 0,feature,target
0,4.9,A
1,5.0,A
2,5.5,B
3,5.7,A
4,6.0,B
5,6.2,B
6,6.5,B
7,6.8,A


In [5]:
size_feature = df["feature"]

optimal_threshold = find_optimal_threshold(size_feature, df['target'])

print(f"Optimal Threshold : {optimal_threshold}")


Optimal Threshold : 5.0


In [6]:
from sklearn.datasets import load_iris

iris_data = load_iris()
iris_df = pd.DataFrame(data=iris_data.data, columns=iris_data.feature_names)
iris_df['target'] = iris_data.target

# Assuming you want to split based on the 'petal length (cm)' feature
feature_name = 'petal length (cm)'
size_feature = iris_df[feature_name]


In [31]:
from sklearn.model_selection import train_test_split

iris = iris_df.sample(frac=1, random_state=42)  # frac=1 means shuffling all rows

X_train, X_test, y_train, y_test = train_test_split(iris, iris[['target']], test_size=0.2, random_state=42)

In [8]:
def numeric_best_splits(data, target):
    print("################Splitting##################")
    lowest_impurity = calculate_gini_impurity(data, target)
    print(lowest_impurity)
    min_key, min_value = min(lowest_impurity.items(), key=lambda x: x[1])
    print(f"lowest impuirty feature is : {min_key}")
    
    threshold = find_optimal_threshold(data[min_key], data[target])
    print(f"{min_key} threshold is : {threshold}")
    
    left_mask = data[min_key] <= threshold
    right_mask = data[min_key] > threshold
    
    left_data = data[left_mask]
    right_data = data[right_mask]
    
    display(left_data)
    display(right_data)
    
    unique_left_values = left_data[target].unique()
    unique_right_values = right_data[target].unique()
    
    if len(unique_left_values) > 1:
        print("left node have impurites")
        numeric_best_splits(left_data, target)
        
    if len(unique_right_values) > 1:
        print("right node have impurites")
        numeric_best_splits(right_data, target)
        
numeric_best_splits(iris_df, 'target')

################Splitting##################
{'target': 0.6666666666666665, 'sepal length (cm)': 0.3194074074074074, 'sepal width (cm)': 0.4676749176749177, 'petal length (cm)': 0.06266666666666668, 'petal width (cm)': 0.06277777777777778}
lowest impuirty feature is : petal length (cm)
petal length (cm) threshold is : 1.9


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
0,5.1,3.5,1.4,0.2,0
1,4.9,3.0,1.4,0.2,0
2,4.7,3.2,1.3,0.2,0
3,4.6,3.1,1.5,0.2,0
4,5.0,3.6,1.4,0.2,0
5,5.4,3.9,1.7,0.4,0
6,4.6,3.4,1.4,0.3,0
7,5.0,3.4,1.5,0.2,0
8,4.4,2.9,1.4,0.2,0
9,4.9,3.1,1.5,0.1,0


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
50,7.0,3.2,4.7,1.4,1
51,6.4,3.2,4.5,1.5,1
52,6.9,3.1,4.9,1.5,1
53,5.5,2.3,4.0,1.3,1
54,6.5,2.8,4.6,1.5,1
...,...,...,...,...,...
145,6.7,3.0,5.2,2.3,2
146,6.3,2.5,5.0,1.9,2
147,6.5,3.0,5.2,2.0,2
148,6.2,3.4,5.4,2.3,2


right node have impurites
################Splitting##################
{'target': 0.5, 'sepal length (cm)': 0.3104047619047619, 'sepal width (cm)': 0.41757936507936505, 'petal length (cm)': 0.094, 'petal width (cm)': 0.09416666666666666}
lowest impuirty feature is : petal length (cm)
petal length (cm) threshold is : 4.7


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
50,7.0,3.2,4.7,1.4,1
51,6.4,3.2,4.5,1.5,1
53,5.5,2.3,4.0,1.3,1
54,6.5,2.8,4.6,1.5,1
55,5.7,2.8,4.5,1.3,1
56,6.3,3.3,4.7,1.6,1
57,4.9,2.4,3.3,1.0,1
58,6.6,2.9,4.6,1.3,1
59,5.2,2.7,3.9,1.4,1
60,5.0,2.0,3.5,1.0,1


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
52,6.9,3.1,4.9,1.5,1
70,5.9,3.2,4.8,1.8,1
72,6.3,2.5,4.9,1.5,1
76,6.8,2.8,4.8,1.4,1
77,6.7,3.0,5.0,1.7,1
83,6.0,2.7,5.1,1.6,1
100,6.3,3.3,6.0,2.5,2
101,5.8,2.7,5.1,1.9,2
102,7.1,3.0,5.9,2.1,2
103,6.3,2.9,5.6,1.8,2


left node have impurites
################Splitting##################
{'target': 0.04345679012345684, 'sepal length (cm)': 0.022222222222222223, 'sepal width (cm)': 0.03333333333333333, 'petal length (cm)': 0.03888888888888889, 'petal width (cm)': 0.0}
lowest impuirty feature is : petal width (cm)
petal width (cm) threshold is : 1.6


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
50,7.0,3.2,4.7,1.4,1
51,6.4,3.2,4.5,1.5,1
53,5.5,2.3,4.0,1.3,1
54,6.5,2.8,4.6,1.5,1
55,5.7,2.8,4.5,1.3,1
56,6.3,3.3,4.7,1.6,1
57,4.9,2.4,3.3,1.0,1
58,6.6,2.9,4.6,1.3,1
59,5.2,2.7,3.9,1.4,1
60,5.0,2.0,3.5,1.0,1


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
106,4.9,2.5,4.5,1.7,2


right node have impurites
################Splitting##################
{'target': 0.19438016528925628, 'sepal length (cm)': 0.1554112554112554, 'sepal width (cm)': 0.18164724164724158, 'petal length (cm)': 0.13909090909090907, 'petal width (cm)': 0.1060606060606061}
lowest impuirty feature is : petal width (cm)
petal width (cm) threshold is : 1.7


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
52,6.9,3.1,4.9,1.5,1
72,6.3,2.5,4.9,1.5,1
76,6.8,2.8,4.8,1.4,1
77,6.7,3.0,5.0,1.7,1
83,6.0,2.7,5.1,1.6,1
119,6.0,2.2,5.0,1.5,2
129,7.2,3.0,5.8,1.6,2
133,6.3,2.8,5.1,1.5,2
134,6.1,2.6,5.6,1.4,2


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
70,5.9,3.2,4.8,1.8,1
100,6.3,3.3,6.0,2.5,2
101,5.8,2.7,5.1,1.9,2
102,7.1,3.0,5.9,2.1,2
103,6.3,2.9,5.6,1.8,2
104,6.5,3.0,5.8,2.2,2
105,7.6,3.0,6.6,2.1,2
107,7.3,2.9,6.3,1.8,2
108,6.7,2.5,5.8,1.8,2
109,7.2,3.6,6.1,2.5,2


left node have impurites
################Splitting##################
{'target': 0.49382716049382713, 'sepal length (cm)': 0.2222222222222222, 'sepal width (cm)': 0.2222222222222222, 'petal length (cm)': 0.2222222222222222, 'petal width (cm)': 0.4444444444444444}
lowest impuirty feature is : sepal length (cm)
sepal length (cm) threshold is : 6.9


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
52,6.9,3.1,4.9,1.5,1
72,6.3,2.5,4.9,1.5,1
76,6.8,2.8,4.8,1.4,1
77,6.7,3.0,5.0,1.7,1
83,6.0,2.7,5.1,1.6,1
119,6.0,2.2,5.0,1.5,2
133,6.3,2.8,5.1,1.5,2
134,6.1,2.6,5.6,1.4,2


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
129,7.2,3.0,5.8,1.6,2


left node have impurites
################Splitting##################
{'target': 0.46875, 'sepal length (cm)': 0.25, 'sepal width (cm)': 0.125, 'petal length (cm)': 0.25, 'petal width (cm)': 0.375}
lowest impuirty feature is : sepal width (cm)
sepal width (cm) threshold is : 2.2


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
119,6.0,2.2,5.0,1.5,2


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
52,6.9,3.1,4.9,1.5,1
72,6.3,2.5,4.9,1.5,1
76,6.8,2.8,4.8,1.4,1
77,6.7,3.0,5.0,1.7,1
83,6.0,2.7,5.1,1.6,1
133,6.3,2.8,5.1,1.5,2
134,6.1,2.6,5.6,1.4,2


right node have impurites
################Splitting##################
{'target': 0.40816326530612246, 'sepal length (cm)': 0.14285714285714285, 'sepal width (cm)': 0.14285714285714285, 'petal length (cm)': 0.14285714285714285, 'petal width (cm)': 0.3333333333333333}
lowest impuirty feature is : sepal length (cm)
sepal length (cm) threshold is : 6.3


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
72,6.3,2.5,4.9,1.5,1
83,6.0,2.7,5.1,1.6,1
133,6.3,2.8,5.1,1.5,2
134,6.1,2.6,5.6,1.4,2


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
52,6.9,3.1,4.9,1.5,1
76,6.8,2.8,4.8,1.4,1
77,6.7,3.0,5.0,1.7,1


left node have impurites
################Splitting##################
{'target': 0.5, 'sepal length (cm)': 0.25, 'sepal width (cm)': 0.0, 'petal length (cm)': 0.25, 'petal width (cm)': 0.25}
lowest impuirty feature is : sepal width (cm)
sepal width (cm) threshold is : 2.5


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
72,6.3,2.5,4.9,1.5,1


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
83,6.0,2.7,5.1,1.6,1
133,6.3,2.8,5.1,1.5,2
134,6.1,2.6,5.6,1.4,2


right node have impurites
################Splitting##################
{'target': 0.4444444444444444, 'sepal length (cm)': 0.0, 'sepal width (cm)': 0.0, 'petal length (cm)': 0.3333333333333333, 'petal width (cm)': 0.0}
lowest impuirty feature is : sepal length (cm)
sepal length (cm) threshold is : 6.0


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
83,6.0,2.7,5.1,1.6,1


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
133,6.3,2.8,5.1,1.5,2
134,6.1,2.6,5.6,1.4,2


right node have impurites
################Splitting##################
{'target': 0.04253308128544431, 'sepal length (cm)': 0.021739130434782608, 'sepal width (cm)': 0.036231884057971, 'petal length (cm)': 0.02898550724637681, 'petal width (cm)': 0.03985507246376815}
lowest impuirty feature is : sepal length (cm)
sepal length (cm) threshold is : 5.9


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
70,5.9,3.2,4.8,1.8,1
101,5.8,2.7,5.1,1.9,2
113,5.7,2.5,5.0,2.0,2
114,5.8,2.8,5.1,2.4,2
121,5.6,2.8,4.9,2.0,2
142,5.8,2.7,5.1,1.9,2
149,5.9,3.0,5.1,1.8,2


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
100,6.3,3.3,6.0,2.5,2
102,7.1,3.0,5.9,2.1,2
103,6.3,2.9,5.6,1.8,2
104,6.5,3.0,5.8,2.2,2
105,7.6,3.0,6.6,2.1,2
107,7.3,2.9,6.3,1.8,2
108,6.7,2.5,5.8,1.8,2
109,7.2,3.6,6.1,2.5,2
110,6.5,3.2,5.1,2.0,2
111,6.4,2.7,5.3,1.9,2


left node have impurites
################Splitting##################
{'target': 0.24489795918367352, 'sepal length (cm)': 0.14285714285714285, 'sepal width (cm)': 0.0, 'petal length (cm)': 0.0, 'petal width (cm)': 0.14285714285714285}
lowest impuirty feature is : sepal width (cm)
sepal width (cm) threshold is : 3.0


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
101,5.8,2.7,5.1,1.9,2
113,5.7,2.5,5.0,2.0,2
114,5.8,2.8,5.1,2.4,2
121,5.6,2.8,4.9,2.0,2
142,5.8,2.7,5.1,1.9,2
149,5.9,3.0,5.1,1.8,2


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
70,5.9,3.2,4.8,1.8,1


In [57]:
def numeric_best_splits(data, target):
    tree = {}  # Initialize the tree dictionary
    
    lowest_impurity = calculate_gini_impurity(data, target)
    
    min_key, min_value = min(lowest_impurity.items(), key=lambda x: x[1])
    tree['type'] = 'node'  # Node type
    tree['feature'] = min_key  # Feature used for splitting
    
    threshold = find_optimal_threshold(data[min_key], data[target])
    tree['threshold'] = threshold  # Threshold value for splitting
    
    left_mask = data[min_key] <= threshold
    right_mask = data[min_key] > threshold
    
    left_data = data[left_mask]
    right_data = data[right_mask]
    
    unique_left_values = left_data[target].unique()
    unique_right_values = right_data[target].unique()
    
    tree['subtrees'] = {}  # Subtree dictionary
    
    if len(unique_left_values) > 1:
        tree['subtrees'][unique_left_values[0]] = numeric_best_splits(left_data, target)  # Recursively build the left subtree
        
    else:
        tree['subtrees'][unique_left_values[0]] = {'type': 'leaf', 'prediction': unique_left_values[0]}  # Leaf node
    
    if len(unique_right_values) > 1:
        tree['subtrees'][unique_right_values[0]] = numeric_best_splits(right_data, target)  # Recursively build the right subtree
        
    else:
        tree['subtrees'][unique_right_values[0]] = {'type': 'leaf', 'prediction': unique_right_values[0]}  # Leaf node
        
    return tree  # Return the constructed tree dictionary

# Call the function and store the resulting tree
tree = numeric_best_splits(df, 'target')
tree

{'type': 'node',
 'feature': 'feature',
 'threshold': 5.0,
 'subtrees': {'A': {'type': 'leaf', 'prediction': 'A'},
  'B': {'type': 'node',
   'feature': 'feature',
   'threshold': 6.5,
   'subtrees': {'B': {'type': 'node',
     'feature': 'feature',
     'threshold': 5.7,
     'subtrees': {'B': {'type': 'leaf', 'prediction': 'B'}}},
    'A': {'type': 'leaf', 'prediction': 'A'}}}}}

In [44]:
// Decision Tree
digraph {
	Root [label="petal length (cm) <= 1.9"]
	Root_0 [label="Prediction: 0"]
	Root -> Root_0 [label=0]
	Root_1 [label="petal length (cm) <= 4.7"]
	Root -> Root_1 [label=1]
	Root_1_1 [label="petal width (cm) <= 1.7"]
	Root_1 -> Root_1_1 [label=1]
	Root_1_1_1 [label="sepal length (cm) <= 5.9"]
	Root_1_1 -> Root_1_1_1 [label=1]
	Root_1_1_1_1 [label="sepal width (cm) <= 3.0"]
	Root_1_1_1 -> Root_1_1_1_1 [label=1]
	Root_1_1_1_2 [label="Prediction: 2"]
	Root_1_1_1 -> Root_1_1_1_2 [label=2]
	Root_1_1_2 [label="Prediction: 2"]
	Root_1_1 -> Root_1_1_2 [label=1]
	Root_1_2 [label="Prediction: 1"]
	Root_1 -> Root_1_2 [label=2]
}


SyntaxError: invalid syntax (2784727860.py, line 1)

In [56]:
def numeric_best_splits(data, target):
    tree = {} 
    
    lowest_impurity = calculate_gini_impurity(data, target)
    min_key, min_value = min(lowest_impurity.items(), key=lambda x: x[1])
    
    tree['type'] = 'node'  
    tree['feature'] = min_key 
    
    threshold = find_optimal_threshold(data[min_key], data[target])
    tree['threshold'] = threshold 
    tree['subtrees'] = {} 
    
    for direction in ['left', 'right']:
        mask = data[min_key] <= threshold if direction == 'left' else data[min_key] > threshold
        subset_data = data[mask]
        
        unique_values = subset_data[target].unique()
        
        if len(unique_values) > 1:
            tree['subtrees'][direction] = numeric_best_splits(subset_data, target) 
        else:
            tree['subtrees'][direction] = {'type': 'leaf', 'prediction': unique_values[0]} 
        
    return tree

tree = numeric_best_splits(X_train, 'target')
tree

{'type': 'node',
 'feature': 'petal width (cm)',
 'threshold': 0.6,
 'subtrees': {'left': {'type': 'leaf', 'prediction': 0},
  'right': {'type': 'node',
   'feature': 'petal width (cm)',
   'threshold': 1.7,
   'subtrees': {'left': {'type': 'node',
     'feature': 'petal length (cm)',
     'threshold': 4.9,
     'subtrees': {'left': {'type': 'node',
       'feature': 'petal width (cm)',
       'threshold': 1.6,
       'subtrees': {'left': {'type': 'leaf', 'prediction': 1},
        'right': {'type': 'leaf', 'prediction': 2}}},
      'right': {'type': 'node',
       'feature': 'sepal width (cm)',
       'threshold': 2.6,
       'subtrees': {'left': {'type': 'leaf', 'prediction': 2},
        'right': {'type': 'node',
         'feature': 'sepal length (cm)',
         'threshold': 6.3,
         'subtrees': {'left': {'type': 'node',
           'feature': 'sepal length (cm)',
           'threshold': 6.0,
           'subtrees': {'left': {'type': 'leaf', 'prediction': 1},
            'right': {

In [47]:
# Define your prediction function
def predict_sample(sample, tree):
    if tree['type'] == 'leaf':
        return tree['prediction']
    
    feature = tree['feature']
    threshold = tree['threshold']
    
    if sample[feature] <= threshold:
        subtree_key = 'left'
    else:
        subtree_key = 'right'
    
    next_subtree = tree['subtrees'][subtree_key]
    return predict_sample(sample, next_subtree)

# Define a function to make predictions for a DataFrame
def predict_dataframe(data, tree):
    predictions = []
    for index, row in data.iterrows():
        sample = row.to_dict()
        prediction = predict_sample(sample, tree)
        predictions.append(prediction)
    return predictions

# Define a function to calculate accuracy
def calculate_accuracy(predictions, true_labels):
    correct_predictions = sum(pred == true_label for pred, true_label in zip(predictions, true_labels))
    accuracy = correct_predictions / len(predictions) * 100
    return accuracy

# Make predictions on X_test
test_predictions = predict_dataframe(X_test, tree)

# Calculate accuracy on X_test
test_accuracy = calculate_accuracy(test_predictions, y_test['target'])

# Create a DataFrame of predictions
predictions_df = pd.DataFrame({'Prediction': test_predictions}, index=X_test.index)

# Concatenate predictions_df with y_test
result_df = pd.concat([y_test, predictions_df], axis=1)

# Display the result DataFrame and accuracy
print(result_df)
print("Accuracy on X_test:", test_accuracy, "%")


     target  Prediction
101       2           2
55        1           1
79        1           1
5         0           0
148       2           2
15        0           0
94        1           1
74        1           1
47        0           0
35        0           0
54        1           1
36        0           0
51        1           1
82        1           1
132       2           2
28        0           0
136       2           2
97        1           1
67        1           1
37        0           0
30        0           0
124       2           2
108       2           2
21        0           0
129       2           1
71        1           1
2         0           0
149       2           1
81        1           1
22        0           0
Accuracy on X_test: 93.33333333333333 %


In [92]:
from graphviz import Digraph

def visualize_tree(tree_dict, parent=None, parent_label=None):
    dot = Digraph(comment='Decision Tree')
    
    if parent is None:
        if tree_dict['type'] == 'node':
            dot.node('Root', f"{tree_dict['feature']} <= {tree_dict['threshold']}")
        else:
            dot.node('Root', f"Prediction: {tree_dict['prediction']}")
        parent = 'Root'
        parent_label = ''
    else:
        dot.node(parent, parent_label)
    
    if 'subtrees' in tree_dict:  # Check if 'subtrees' key is present
        for key, subtree in tree_dict['subtrees'].items():
            if subtree['type'] == 'leaf':
                label = f"Prediction: {subtree['prediction']}"
            else:
                label = f"{subtree['feature']} <= {subtree['threshold']}"
            
            dot.node(f"{parent}_{key}", label)
            dot.edge(parent, f"{parent}_{key}", label=str(key))
            
            # Expand the visualization for each split recursively
            visualize_tree(subtree, parent=f"{parent}_{key}", parent_label=label)
    
    return dot
# Call the function to visualize the tree
dot = visualize_tree(tree)

# Save the visualization to a file (in PDF format)
dot.format = 'pdf'
dot.render("expanded_decision_tree_visualization")


'expanded_decision_tree_visualization.pdf'