In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
iris = pd.read_csv("iris.csv")

In [None]:
def gini_impurity(df,targ='class'):
    if len(df) == 0:
        return 0
    gini = 0.0
    for cls in df[targ].unique():
        gin = (df[targ]==cls).sum()/len(df)
        gini += gin*(1-gin)
    return gini

In [None]:
def split_gini(df,attrib,split,targ='class'):
    ln = sum(df[attrib] < split)/len(df)
    gini = gini_impurity(df[df[attrib] < split],targ)
    gini_inv = gini_impurity(df[df[attrib] >= split],targ)
    return ln*gini+(1-ln)*gini_inv

In [None]:
def best_split(df,f,targ='class'):
    best_gini = 1.0
    best_sp = 0
    for sp in df[f]:
        gini = split_gini(df,f,sp,targ)
        if gini < best_gini:
            best_gini = gini
            best_sp = sp
    return best_sp,best_gini

In [None]:
class Node:
    def __init__(self,parent=None):
        self.parent = parent
    
    def type(self):
        return 'node'

In [None]:
class SplitNode(Node):
    def __init__(self,attrib,val,parent=None,left=None,right=None):
        Node.__init__(self,parent)
        self.left = left
        self.right = right
        self.attrib = attrib
        self.val = val
    
    def type(self):
        return 'split'

In [None]:
class LeafNode(Node):
    def __init__(self,cls,parent=None):
        Node.__init__(self,parent)
        self.cls = cls
        
    def type(self):
        return 'leaf'

In [None]:
def best_split_all(df,targ='class'):
    best_gini = 1.0
    best_feat = df.columns[0]
    best_sp = 0
    for f in df.columns.drop('class'):
        bs,bg = best_split(df,f,targ)
        if bg < best_gini:
            best_gini = bg
            best_sp = bs
            best_feat = f
    return SplitNode(best_feat,best_sp),df[df[best_feat] < best_sp],df[df[best_feat] >= best_sp],best_gini

In [None]:
def build_tree(df,parent=None,targ='class'):
    if len(df) == 0:
        return Node(parent)
    if len(df.columns) == 1 or len(df[targ].unique()) == 1:
        return LeafNode(df[targ].value_counts().idxmax(),parent)
    root,ldf,rdf,gin = best_split_all(df,targ)
    root.parent = parent
    root.left = build_tree(ldf,root,targ)
    root.right = build_tree(rdf,root,targ)
    return root

In [None]:
tree = build_tree(iris)

In [None]:
def print_tree(tree,level=0):
    print('\t'*level,end="")
    if tree.type() == 'node':
        print("Empty")
    elif tree.type() == 'leaf':
        print("leaf:",tree.cls)
    else:
        print("split",tree.attrib,"with",tree.val)
        print('\t'*level,end="")
        print("left:")
        print_tree(tree.left,level+1)
        print('\t'*level,end="")
        print("right:")
        print_tree(tree.right,level+1)

In [None]:
%%capture ptree
print_tree(tree)

In [None]:
with open("cart_undrop.txt","w") as t:
    t.write(ptree.stdout)

In [None]:
right = 0 
wrong = 0
for i in iris.index:
    row = iris.loc[i]
    node = tree
    while node.type() == 'split':
        if row[node.attrib] < node.val:
            node = node.left
        else:
            node = node.right
    if node.type() == 'leaf':
        if node.cls != row['class']:
            wrong += 1
            print("Wrong at",i)
        else:
            right += 1

In [None]:
print(right,wrong)