In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from HelperMethods import *

In [2]:
this = %pwd

In [3]:
PATH = (f'{this}/').replace('NoteBook/','')+'Data/'

In [4]:
column_names = ["sex", "length", "diameter", "height", "whole weight", 
                "shucked weight", "viscera weight", "shell weight", "rings"]
df = pd.read_csv(PATH + "abalone.data", names=column_names)

In [5]:
df.shape

(4177, 9)

In [6]:
df.head(50)

Unnamed: 0,sex,length,diameter,height,whole weight,shucked weight,viscera weight,shell weight,rings
0,M,0.455,0.365,0.095,0.514,0.2245,0.101,0.15,15
1,M,0.35,0.265,0.09,0.2255,0.0995,0.0485,0.07,7
2,F,0.53,0.42,0.135,0.677,0.2565,0.1415,0.21,9
3,M,0.44,0.365,0.125,0.516,0.2155,0.114,0.155,10
4,I,0.33,0.255,0.08,0.205,0.0895,0.0395,0.055,7
5,I,0.425,0.3,0.095,0.3515,0.141,0.0775,0.12,8
6,F,0.53,0.415,0.15,0.7775,0.237,0.1415,0.33,20
7,F,0.545,0.425,0.125,0.768,0.294,0.1495,0.26,16
8,M,0.475,0.37,0.125,0.5095,0.2165,0.1125,0.165,9
9,F,0.55,0.44,0.15,0.8945,0.3145,0.151,0.32,19


In [7]:
X = df.drop('rings', axis='columns')
y = df['rings']

In [8]:
d = {'M': 1, 'F': 2, 'I': 3}
X['sex'].replace(d,inplace = True)

X.head(10)

Unnamed: 0,sex,length,diameter,height,whole weight,shucked weight,viscera weight,shell weight
0,1,0.455,0.365,0.095,0.514,0.2245,0.101,0.15
1,1,0.35,0.265,0.09,0.2255,0.0995,0.0485,0.07
2,2,0.53,0.42,0.135,0.677,0.2565,0.1415,0.21
3,1,0.44,0.365,0.125,0.516,0.2155,0.114,0.155
4,3,0.33,0.255,0.08,0.205,0.0895,0.0395,0.055
5,3,0.425,0.3,0.095,0.3515,0.141,0.0775,0.12
6,2,0.53,0.415,0.15,0.7775,0.237,0.1415,0.33
7,2,0.545,0.425,0.125,0.768,0.294,0.1495,0.26
8,1,0.475,0.37,0.125,0.5095,0.2165,0.1125,0.165
9,2,0.55,0.44,0.15,0.8945,0.3145,0.151,0.32


In [9]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
print(X_test.values[:, 0])

[1. 1. 2. ... 1. 1. 2.]


In [10]:
def value_count(x, threshold):
    result = {}
    result[0] = 0
    result[1] = 0
    for value in x:
        if value <= threshold:
            result[0] += 1
        else:
            result[1]+=1
    return result

def target_value_count(y):
    result = {}
    for value in y:
        if value not in result:
            result[value] = 1
        else:
            result[value]+=1
    return result
    
    
def entropy(y, val_type, threshold = None):
    if val_type == "target":
        result = target_value_count(y)
        entropy = 0
        for value in result.values():
            p = value/len(y)
            entropy -= p*log2(p)
        return entropy
    else:
        result= value_count(y, threshold)
        entropy = 0
        for idx in range(2):
            p = result[idx]/len(y)
            entropy -= p*log2(p)
        return entropy
    

In [11]:
ja = X_train.index
print(ja.values)

[3823 3956 3623 ... 3092 3772  860]


# text

In [12]:
def information_gain(attribute, target, method):
    target_entropy = entropy(target, "target")
    #print(target_entropy)
    info_gains = []
    for attr in attribute:
        threshold = np.mean(attribute[attr])
        le_idx = np.where(attribute[attr] <= threshold)
        g_idx = np.where(attribute[attr] > threshold)
        y_le = target.values[le_idx]
        y_g = target.values[g_idx]
        # count hvor mange av hver ring som er i <= og >
        # ta count/len * log2 len/count
        le_entropy = entropy(y_le, "target")
        g_entropy = entropy(y_g, "target")
        
        
        attr_entropy = (len(y_le)/len(attribute[attr]))*le_entropy + (len(y_g)/len(attribute[attr]))*g_entropy
        information_gain = target_entropy - attr_entropy
        info_gains.append(information_gain)
    return np.argmax(info_gains)    

In [41]:
def make_tree(X, y, n, impurity_measure):
    if len(X) == 0:
        return
    else:
        x_copy = X.copy()
        y_copy = y.copy()
        
        top_ig = information_gain(x_copy, y_copy, "entropy")
        threshold = np.mean(x_copy[x_copy.columns[top_ig]])
        le_idx = np.where(x_copy[x_copy.columns[top_ig]] <= threshold)
        g_idx = np.where(x_copy[x_copy.columns[top_ig]] > threshold)
    
        n.category = top_ig
        n.data = threshold
    
        left_child = mnode()
        right_child = mnode()
        left_child.parent = n
        right_child.parent = n
        
        left_child.data = le_idx
        right_child.data = g_idx
        
        #node.children[0] = left_child
        #node.children[1] = right_child
        
        n.add_child(1, le_idx, left_child)
        n.add_child(2, g_idx, right_child)
        for child in n.children:
            X_copy = pd.DataFrame(x_copy.values[child.data])
            Y_copy = pd.Series(y_copy.values[child.data])
            
            if(len(X_copy) == 1):
                child.isLeaf = True
                child.data = Y_copy[X_copy.index]
            elif len(np.unique(Y_copy.values)) == 1:
                child.isLeaf = True
                child.data = Y_copy.sample(n = 1)
            elif len(target_value_count(X_copy)) == 1:
                child.isleaf = True
                child.data = Y_copy.value_counts().argmax()
            else:
                learn(X_copy, Y_copy, child, "entropy") 

In [25]:
class mnode(object):
    
    def __init__(self):
        self.data = None
        self.parent = None
        self.children = []
        self.category = None
        self.isLeaf = False
        self.category = None
    
    def add_child(self, name, threshold, child):
        child.data = threshold
        self.children.append(child)
    

In [26]:
n = mnode()
learn(X_train, y_train, n, "entropy")
print(n.data)

0.24035089399744528


In [27]:
def printer(n):
    print(n.data)
    for child in n.children:
        if child.children != None:
            print(child.data)
            print(child.category)
            printer(child)        

In [28]:
printer(n)

0.24035089399744528
0.13220343137254903
7
0.13220343137254903
0.2765394242803504
2
0.2765394242803504
0.22159663865546197
2
0.22159663865546197
0.024089403973509907
7
0.024089403973509907
0.018020833333333323
5
0.018020833333333323
0.008675675675675681
7
0.008675675675675681
0.007176470588235295
5
0.007176470588235295
0.10222222222222223
2
0.10222222222222223
0.11833333333333335
1
0.11833333333333335
0    1
dtype: int64
None
0    1
dtype: int64
0.14
1
0.14
0    3
dtype: int64
None
0    3
dtype: int64
0    2
dtype: int64
None
0    2
dtype: int64
0.005916666666666666
5
0.005916666666666666
0    4
dtype: int64
None
0    4
dtype: int64
2.3333333333333335
0
2.3333333333333335
0    3
dtype: int64
None
0    3
dtype: int64
0.11499999999999999
2
0.11499999999999999
0    4
dtype: int64
None
0    4
dtype: int64
0    3
dtype: int64
None
0    3
dtype: int64
0.0065625
7
0.0065625
2    4
dtype: int64
None
2    4
dtype: int64
0.12375
2
0.12375
0    5
dtype: int64
None
0    5
dtype: int64
0.04500000000

dtype: int64
0    5
dtype: int64
None
0    5
dtype: int64
0.07333333333333335
3
0.07333333333333335
0    7
dtype: int64
None
0    7
dtype: int64
0    5
dtype: int64
None
0    5
dtype: int64
0.15292857142857144
4
0.15292857142857144
0.3075
1
0.3075
0.2975
1
0.2975
0    9
dtype: int64
None
0    9
dtype: int64
0    6
dtype: int64
None
0    6
dtype: int64
1    6
dtype: int64
None
1    6
dtype: int64
0.32499999999999996
1
0.32499999999999996
0.3125
1
0.3125
0    7
dtype: int64
None
0    7
dtype: int64
0    13
dtype: int64
None
0    13
dtype: int64
0    6
dtype: int64
None
0    6
dtype: int64
0.08131578947368422
3
0.08131578947368422
0.04195
6
0.04195
0.07549999999999998
3
0.07549999999999998
0.074375
3
0.074375
0    7
dtype: int64
None
0    7
dtype: int64
0.17614285714285713
4
0.17614285714285713
0.3333333333333333
1
0.3333333333333333
0.325
1
0.325
0    7
dtype: int64
None
0    7
dtype: int64
0    6
dtype: int64
None
0    6
dtype: int64
0    6
dtype: int64
None
0    6
dtype: int64
0.18175


0    12
dtype: int64
None
0    12
dtype: int64
0    7
dtype: int64
None
0    7
dtype: int64
0    12
dtype: int64
None
0    12
dtype: int64
1    10
dtype: int64
None
1    10
dtype: int64
0.14970000000000003
5
0.14970000000000003
0.10977777777777778
7
0.10977777777777778
0.305
2
0.305
1    8
dtype: int64
None
1    8
dtype: int64
0    13
dtype: int64
None
0    13
dtype: int64
0.35033333333333333
4
0.35033333333333333
0.41333333333333333
1
0.41333333333333333
0    10
dtype: int64
None
0    10
dtype: int64
0.42
1
0.42
0    9
dtype: int64
None
0    9
dtype: int64
0    10
dtype: int64
None
0    10
dtype: int64
0.42333333333333334
1
0.42333333333333334
0    12
dtype: int64
None
0    12
dtype: int64
0.4325
1
0.4325
0    11
dtype: int64
None
0    11
dtype: int64
0    8
dtype: int64
None
0    8
dtype: int64
0.07491666666666667
6
0.07491666666666667
0    9
dtype: int64
None
0    9
dtype: int64
0.4275
1
0.4275
0.4125
1
0.4125
0    10
dtype: int64
None
0    10
dtype: int64
0    12
dtype: int64
None


dtype: int64
None
0    6
dtype: int64
0    8
dtype: int64
None
0    8
dtype: int64
0.43374999999999997
1
0.43374999999999997
1    6
dtype: int64
None
1    6
dtype: int64
0.4425
1
0.4425
0    6
dtype: int64
None
0    6
dtype: int64
0    7
dtype: int64
None
0    7
dtype: int64
0.11423701298701301
7
0.11423701298701301
0.3791184210526317
4
0.3791184210526317
0.34685555555555564
4
0.34685555555555564
0.1015625
3
0.1015625
0.106375
7
0.106375
0.1393
5
0.1393
0.0645
6
0.0645
1    8
dtype: int64
None
1    8
dtype: int64
0    6
dtype: int64
None
0    6
dtype: int64
0.425
1
0.425
0    10
dtype: int64
None
0    10
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0.4166666666666667
1
0.4166666666666667
0.4075
1
0.4075
0    6
dtype: int64
None
0    6
dtype: int64
0    7
dtype: int64
None
0    7
dtype: int64
0    7
dtype: int64
None
0    7
dtype: int64
0.064875
6
0.064875
0.3225
2
0.3225
0    10
dtype: int64
None
0    10
dtype: int64
0.05466666666666667
6
0.05466666666666667
0    7
dtype: 

dtype: int64
0    13
dtype: int64
None
0    13
dtype: int64
1    7
dtype: int64
None
1    7
dtype: int64
0.12
3
0.12
0.1175
3
0.1175
0    13
dtype: int64
None
0    13
dtype: int64
0.17900000000000002
5
0.17900000000000002
1    9
dtype: int64
None
1    9
dtype: int64
0    7
dtype: int64
None
0    7
dtype: int64
0.4525
1
0.4525
0    11
dtype: int64
None
0    11
dtype: int64
0    16
dtype: int64
None
0    16
dtype: int64
0.45
1
0.45
0.11785714285714285
3
0.11785714285714285
0.35333333333333333
2
0.35333333333333333
0    12
dtype: int64
None
0    12
dtype: int64
0.3575
2
0.3575
0    8
dtype: int64
None
0    8
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0.43875
1
0.43875
0.4275
1
0.4275
0    10
dtype: int64
None
0    10
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0.503
4
0.503
0    13
dtype: int64
None
0    13
dtype: int64
0    10
dtype: int64
None
0    10
dtype: int64
0.369
2
0.369
0    8
dtype: int64
None
0    8
dtype: int64
0.4583333333333333
1
0.4583333333333

2    8
dtype: int64
None
2    8
dtype: int64
0.5225
1
0.5225
0    9
dtype: int64
None
0    9
dtype: int64
0    8
dtype: int64
None
0    8
dtype: int64
0.1646881188118812
7
0.1646881188118812
2.0454545454545454
0
2.0454545454545454
0.149625
7
0.149625
0.4877083333333334
1
0.4877083333333334
0.31675000000000003
5
0.31675000000000003
1.2
0
1.2
0.135875
6
0.135875
1    7
dtype: int64
None
1    7
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0    8
dtype: int64
None
0    8
dtype: int64
3    8
dtype: int64
None
3    8
dtype: int64
0.12285714285714287
3
0.12285714285714287
0.5984999999999999
4
0.5984999999999999
0.3833333333333333
2
0.3833333333333333
1    7
dtype: int64
None
1    7
dtype: int64
0.5033333333333333
1
0.5033333333333333
0    9
dtype: int64
None
0    9
dtype: int64
0    8
dtype: int64
None
0    8
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0.14049999999999999
6
0.14049999999999999
1.3333333333333333
0
1.3333333333333333
0    8
dtype: int64
None
0    8
d

0
1.5
0    9
dtype: int64
None
0    9
dtype: int64
0    8
dtype: int64
None
0    8
dtype: int64
0.1641764705882353
6
0.1641764705882353
0.14830000000000002
6
0.14830000000000002
1.6666666666666667
0
1.6666666666666667
0.495
1
0.495
0    9
dtype: int64
None
0    9
dtype: int64
0    7
dtype: int64
None
0    7
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
2.0
0
2.0
3    7
dtype: int64
None
3    7
dtype: int64
0.7316666666666666
4
0.7316666666666666
1    8
dtype: int64
None
1    8
dtype: int64
0    7
dtype: int64
None
0    7
dtype: int64
0.5371428571428571
1
0.5371428571428571
0.52875
1
0.52875
0    9
dtype: int64
None
0    9
dtype: int64
0    8
dtype: int64
None
0    8
dtype: int64
1    9
dtype: int64
None
1    9
dtype: int64
0.33663810741687983
5
0.33663810741687983
2.0878048780487806
0
2.0878048780487806
0.26672388059701496
5
0.26672388059701496
0.21047540983606564
7
0.21047540983606564
0.1976818181818182
7
0.1976818181818182
0.19217499999999996
7
0.19217499999999996
0.3745


dtype: int64
0    11
dtype: int64
None
0    11
dtype: int64
1    10
dtype: int64
None
1    10
dtype: int64
0.13125
3
0.13125
0.126
3
0.126
2    9
dtype: int64
None
2    9
dtype: int64
0    8
dtype: int64
None
0    8
dtype: int64
0.5766666666666667
1
0.5766666666666667
1    10
dtype: int64
None
1    10
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0.43526881720430133
2
0.43526881720430133
0.2157058823529412
7
0.2157058823529412
0.1385714285714285
3
0.1385714285714285
0.7734999999999999
4
0.7734999999999999
2.2142857142857144
0
2.2142857142857144
0.3674166666666667
5
0.3674166666666667
0    10
dtype: int64
None
0    10
dtype: int64
0.12875
3
0.12875
0    9
dtype: int64
None
0    9
dtype: int64
2    8
dtype: int64
None
2    8
dtype: int64
0.424375
2
0.424375
0.52
1
0.52
0    13
dtype: int64
None
0    13
dtype: int64
0    8
dtype: int64
None
0    8
dtype: int64
0.35966666666666675
5
0.35966666666666675
0.43
2
0.43
0.20425
7
0.20425
0    9
dtype: int64
None
0    9
dtype: int64
0

dtype: int64
0.4108437500000001
5
0.4108437500000001
0.4508
2
0.4508
0.8883823529411764
4
0.8883823529411764
2.2
0
2.2
0.18591666666666665
6
0.18591666666666665
0.379125
5
0.379125
0    12
dtype: int64
None
0    12
dtype: int64
1    10
dtype: int64
None
1    10
dtype: int64
1    11
dtype: int64
None
1    11
dtype: int64
0.15000000000000002
3
0.15000000000000002
0.565
1
0.565
0    10
dtype: int64
None
0    10
dtype: int64
0    8
dtype: int64
None
0    8
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0.39207142857142857
5
0.39207142857142857
1    9
dtype: int64
None
1    9
dtype: int64
1.6666666666666667
0
1.6666666666666667
0    10
dtype: int64
None
0    10
dtype: int64
0.5674999999999999
1
0.5674999999999999
0    15
dtype: int64
None
0    15
dtype: int64
0    12
dtype: int64
None
0    12
dtype: int64
0.3863750000000001
5
0.3863750000000001
2.3333333333333335
0
2.3333333333333335
0    9
dtype: int64
None
0    9
dtype: int64
0.57
1
0.57
0    10
dtype: int64
None
0    10
dtype:

0.47666666666666657
2
0.47666666666666657
0.471875
2
0.471875
0.1525
3
0.1525
1    9
dtype: int64
None
1    9
dtype: int64
0.58
1
0.58
0    10
dtype: int64
None
0    10
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
1.5
0
1.5
0.6066666666666666
1
0.6066666666666666
0.1525
3
0.1525
0    10
dtype: int64
None
0    10
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0    10
dtype: int64
None
0    10
dtype: int64
0.62125
1
0.62125
0.6025
1
0.6025
0    8
dtype: int64
None
0    8
dtype: int64
0    14
dtype: int64
None
0    14
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0.15156249999999996
3
0.15156249999999996
0.28250000000000003
7
0.28250000000000003
0.6224999999999999
1
0.6224999999999999
1    8
dtype: int64
None
1    8
dtype: int64
0    10
dtype: int64
None
0    10
dtype: int64
1.5714285714285714
0
1.5714285714285714
0    9
dtype: int64
None
0    9
dtype: int64
1.1720000000000002
4
1.1720000000000002
0.615
1
0.6

0.6575
0    11
dtype: int64
None
0    11
dtype: int64
0    10
dtype: int64
None
0    10
dtype: int64
0    11
dtype: int64
None
0    11
dtype: int64
0.35308333333333325
7
0.35308333333333325
1.5
0
1.5
0    7
dtype: int64
None
0    7
dtype: int64
0    11
dtype: int64
None
0    11
dtype: int64
0.64625
1
0.64625
0.6275
1
0.6275
0    10
dtype: int64
None
0    10
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0    10
dtype: int64
None
0    10
dtype: int64
1.3314833333333334
4
1.3314833333333334
0.48117647058823537
2
0.48117647058823537
1.5
0
1.5
0.565
1
0.565
0    13
dtype: int64
None
0    13
dtype: int64
0    19
dtype: int64
None
0    19
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
0.1792307692307692
3
0.1792307692307692
1.2570833333333333
4
1.2570833333333333
0    12
dtype: int64
None
0    12
dtype: int64
1.5
0
1.5
1    9
dtype: int64
None
1    9
dtype: int64
1    10
dtype: int64
None
1    10
dtype: int64
0.5185714285714286
5
0.5185714285714286
0.4883333333333333
2


None
0    20
dtype: int64
0.48272727272727267
7
0.48272727272727267
0.437
7
0.437
0.17833333333333332
3
0.17833333333333332
0    15
dtype: int64
None
0    15
dtype: int64
1    19
dtype: int64
None
1    19
dtype: int64
1    12
dtype: int64
None
1    12
dtype: int64
0.6366666666666667
1
0.6366666666666667
0.6050000000000001
1
0.6050000000000001
0    27
dtype: int64
None
0    27
dtype: int64
0.6325000000000001
1
0.6325000000000001
0    17
dtype: int64
None
0    17
dtype: int64
0    18
dtype: int64
None
0    18
dtype: int64
0.4871666666666667
5
0.4871666666666667
0    15
dtype: int64
None
0    15
dtype: int64
0    12
dtype: int64
None
0    12
dtype: int64
0.2871842105263158
6
0.2871842105263158
0.4565454545454545
5
0.4565454545454545
0.611
1
0.611
0.5825
1
0.5825
0    8
dtype: int64
None
0    8
dtype: int64
0    13
dtype: int64
None
0    13
dtype: int64
0.63
1
0.63
0.6225
1
0.6225
0    23
dtype: int64
None
0    23
dtype: int64
0    17
dtype: int64
None
0    17
dtype: int64
0    25
dtype: i

0    11
dtype: int64
None
0    11
dtype: int64
0.6525000000000001
1
0.6525000000000001
0    11
dtype: int64
None
0    11
dtype: int64
0    12
dtype: int64
None
0    12
dtype: int64
0.7220781250000001
5
0.7220781250000001
1.549058823529412
4
1.549058823529412
0.6525000000000001
1
0.6525000000000001
0.6880000000000001
5
0.6880000000000001
1.5
0
1.5
0    11
dtype: int64
None
0    11
dtype: int64
0.635
1
0.635
0    11
dtype: int64
None
0    11
dtype: int64
0    9
dtype: int64
None
0    9
dtype: int64
1.6666666666666667
0
1.6666666666666667
0    9
dtype: int64
None
0    9
dtype: int64
0    13
dtype: int64
None
0    13
dtype: int64
1.6666666666666667
0
1.6666666666666667
0    10
dtype: int64
None
0    10
dtype: int64
0    11
dtype: int64
None
0    11
dtype: int64
1.5714285714285714
0
1.5714285714285714
0.51625
2
0.51625
0.6475
1
0.6475
0    14
dtype: int64
None
0    14
dtype: int64
0    11
dtype: int64
None
0    11
dtype: int64
0.66
1
0.66
0    8
dtype: int64
None
0    8
dtype: int64
0    13

In [98]:
def predict_row(x,node):
    while len(x) > 0:
        while node.isLeaf == False: 
            attr_var = x[node.category] 
            if attr_var <= node.data:
                child_node = node.children[0]
            elif attr_var > node.data:
                child_node = node.children[1]
            if child_node.isLeaf: 
                return child_node.data.values.item()
            node = child_node

In [99]:
def predict(X, node):
    counter = 0
    copy = X.copy()
    result = {}
    for i in X.values: 
        result[counter] = (predict_row(i,node))
        counter += 1
    return result

In [100]:
pred = predict(X_test, n)
print(pred)

{0: 9, 1: 8, 2: 9, 3: 11, 4: 13, 5: 9, 6: 9, 7: 8, 8: 5, 9: 8, 10: 15, 11: 9, 12: 13, 13: 7, 14: 8, 15: 20, 16: 13, 17: 12, 18: 11, 19: 7, 20: 24, 21: 9, 22: 9, 23: 11, 24: 8, 25: 7, 26: 9, 27: 9, 28: 8, 29: 11, 30: 14, 31: 7, 32: 7, 33: 15, 34: 9, 35: 6, 36: 4, 37: 1, 38: 11, 39: 8, 40: 11, 41: 24, 42: 8, 43: 11, 44: 9, 45: 12, 46: 9, 47: 8, 48: 8, 49: 5, 50: 5, 51: 25, 52: 9, 53: 6, 54: 5, 55: 13, 56: 10, 57: 17, 58: 6, 59: 9, 60: 11, 61: 21, 62: 8, 63: 7, 64: 8, 65: 9, 66: 4, 67: 7, 68: 9, 69: 12, 70: 10, 71: 10, 72: 8, 73: 8, 74: 9, 75: 11, 76: 16, 77: 11, 78: 15, 79: 8, 80: 14, 81: 14, 82: 12, 83: 15, 84: 5, 85: 10, 86: 9, 87: 15, 88: 9, 89: 11, 90: 11, 91: 8, 92: 10, 93: 11, 94: 8, 95: 11, 96: 7, 97: 8, 98: 13, 99: 5, 100: 12, 101: 17, 102: 10, 103: 9, 104: 7, 105: 11, 106: 11, 107: 14, 108: 10, 109: 9, 110: 9, 111: 9, 112: 19, 113: 10, 114: 8, 115: 4, 116: 8, 117: 12, 118: 12, 119: 9, 120: 9, 121: 10, 122: 15, 123: 8, 124: 12, 125: 11, 126: 11, 127: 9, 128: 11, 129: 12, 130: 14,

In [82]:
def accuracy(y_true, y_pred): 
        result = 0
        for idx, y_ in enumerate(y_pred.values()): 
            if y_ == y_true.values[idx]: result+=1
        return (result/len(y_true))


In [83]:
accuracy(y_test, pred)

0.19808612440191387

In [84]:
def get_leaf(node):
    if node.isLeaf:
        return(node)
    else: 
        for child in node.children:
            return find_leaf(child)

In [104]:
def prune(X_pruning, y_pruning, current_node, prune_accuracy):
    print("1")
    if current_node.isLeaf: return
    for child in current_node.children:
        print(child.data)
        current_acc = prune_accuracy
        
        if child.isLeaf == True:
            y_ = predict(X_pruning, current_node)
            current_acc = accuracy(y_pruning, y_)
            dominant_class = (y_pruning.value_counts().argmax())/len(y_pruning)
            if dominant_class >= current_acc:
                current_node.isLeaf = True
                current_node.data = y_pruning.value_counts().argmax()
                current_node.children = []
                prune(X_pruning, y_pruning, current_node.parent, current_acc)
        else:
            print("jada")
            prune(X_pruning, y_pruning, child, current_acc)

In [105]:
def learn_new(X, y, impurity_measure='entropy', pruning=False):
    root = mnode()
    make_tree(X, y, root, impurity_measure)
    acc = predict(X, root)
    X_pruning, y_pruning = [],[]
    if pruning == True: 
        X, X_pruning, y, y_pruning = train_test_split(X,y, test_size=0.25, random_state=42)
        prune(X_pruning, y_pruning, root, acc)
    
    return root

In [106]:
n = learn_new(X_train, y_train, 'entropy', True)

1
0.13220343137254903
jada
1
0.2765394242803504
jada
1
0.22159663865546197
jada
1
0.024089403973509907
jada
1
0.018020833333333323
jada
1
0.008675675675675681
jada
1
0.007176470588235295
jada
1
0.10222222222222223
jada
1
0.11833333333333335
jada
1
0    1
dtype: int64
1
9
0.005916666666666666
jada
1
2    4
dtype: int64
2.3333333333333335
jada
1
0    3
dtype: int64
1
2    4
dtype: int64


The current behaviour of 'Series.argmax' is deprecated, use 'idxmax'
instead.
The behavior of 'argmax' will be corrected to return the positional
maximum in the future. For now, use 'series.values.argmax' or
'np.argmax(np.array(values))' to get the position of the maximum
row.
  # This is added back by InteractiveShellApp.init_path()
The current behaviour of 'Series.argmax' is deprecated, use 'idxmax'
instead.
The behavior of 'argmax' will be corrected to return the positional
maximum in the future. For now, use 'series.values.argmax' or
'np.argmax(np.array(values))' to get the position of the maximum
row.
  


AttributeError: 'numpy.int64' object has no attribute 'values'