<a href="https://colab.research.google.com/github/profteachkids/StemUnleashed/blob/main/CART.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import pandas as pd

In [None]:
data_gen_param = [['a',10, [1., 3., 'ee']],
       ['b',8, [1., 3., 'ff']],
       ['c',5,[1.5, 3., 'ee']]]

In [None]:
rand=np.random.RandomState(123)
labels=[]
xs=[]
dfs=[]
for [label, n, avgs] in data_gen_param:
    labels.append(label)
    df=pd.DataFrame()
    for i,avg in enumerate(avgs):
        if type(avg) is float:
            df[f'P{i}']=rand.normal(avg,1.,size=n)
        else:
            df[f'P{i}']=avg
    dfs.append(df)

df=pd.concat(dfs, keys=labels).droplevel(1).reset_index().rename(columns={'index':'label'})
df[['label','P2']]=df[['label','P2']].astype('category')
features=df.iloc[:,1:]
ilabels=df.iloc[:,0].cat.codes.values


In [None]:
class Node():
    def __init__(self,depth):
        self.leftNode=None
        self.rightNode=None
        self.feature=None
        self.split_value=None
        self.depth=depth
        self.leafCounts=None

    def __str__(self):
        indent='\n' + '   '*self.depth
        if self.leafCounts is None:
            if features[self.feature].dtype.name=='category':
                leftstr=indent + f'{self.feature} is {self.split_value}'
                rightstr=indent + f'{self.feature} not {self.split_value}'
            else:
                leftstr=indent + f'{self.feature} < {self.split_value}'
                rightstr=indent + f'{self.feature} > {self.split_value}'
            left = leftstr + f'{self.leftNode}'
            right = rightstr + f'{self.rightNode}'
            return left + right
        else:
            countstr=''
            for ilabel,count in zip(*self.leafCounts):
                countstr+=f'{labels[ilabel]}:{count} '
            return indent + f'counts: {countstr}\n'




In [None]:
def find_split(idxs):
    nidxs=idxs.size
    nl=0
    bins,counts=np.unique(ilabels[idxs],return_counts=True)
    parent_counts=np.zeros(np.max(bins)+1)
    np.add.at(parent_counts,bins,counts)

    best_gini=1-np.sum((parent_counts/nidxs)**2)

    for feature_col in features.columns:
        feature=features[feature_col].values[idxs]
        if features[feature_col].dtype.name=='category':
            categories, category_counts = np.unique(feature,return_counts=True)
            if len(categories)<2:  #if less than 2, can't split
                continue
            for ncat, category in enumerate(categories):
                rcounts=np.zeros_like(parent_counts)
                iscategory = (feature == category)
                bins, counts= np.unique(ilabels[idxs[iscategory]], return_counts=True)
                np.add.at(rcounts,bins,counts)
                lcounts = parent_counts-rcounts
                nr, nl = np.sum(rcounts), np.sum(lcounts)
                gini = nl/nidxs*(1-np.sum((lcounts/nl)**2)) + nr/nidxs*(1-np.sum((rcounts/nr)**2))
                if gini < best_gini:
                    best_gini=gini
                    idxsL = idxs[iscategory]
                    idxsR = idxs[np.logical_not(iscategory)]
                    best_feature=feature_col
                    best_feature_split_value = category

        else:
            rcounts=parent_counts.copy()
            lcounts=np.zeros_like(rcounts)

            sorted_feature_idx=np.argsort(feature)
            sorted_ilabels=ilabels[idxs[sorted_feature_idx]]
            gini=np.full(nidxs,best_gini)
            for nl,i in enumerate(sorted_ilabels[:-1],1):
                lcounts[i]+=1
                rcounts[i]-=1
                nr=nidxs-nl
                xl=nl/nidxs
                gini[nl-1]=xl*(1-np.sum((lcounts/nl)**2))+ (1-xl)*(1-np.sum((rcounts/nr)**2))
            amin_gini = np.argmin(gini)
            if gini[amin_gini]<best_gini:
                best_gini=gini[amin_gini]
                idxsL= idxs[sorted_feature_idx[:amin_gini+1]]
                idxsR= idxs[sorted_feature_idx[amin_gini+1:]]
                best_feature=feature_col
                best_feature_split_value=(feature[sorted_feature_idx[amin_gini]]+feature[sorted_feature_idx[amin_gini+1]])/2

    return best_feature,best_feature_split_value, idxsL, idxsR



In [None]:
depth=0
root=Node(depth)
stack=[(df.index.values,root,depth)]

while len(stack)>0:
    idxs,parent,depth=stack.pop()
    parent.depth=depth
    bins,counts=np.unique(ilabels[idxs],return_counts=True)

    if bins.size==1: #only one label, don't split.
        parent.leafCounts=(bins,counts)
        continue

    if depth>4:
        parent.leafCounts=(bins,counts)
        continue

    feature, split_value, idxsL, idxsR = find_split(idxs)
    parent.feature=feature
    parent.split_value=split_value


    if len(idxsL)==0:
        continue

    parent.leftNode=Node(depth+1)
    stack.append((idxsL, parent.leftNode, depth+1))
    parent.rightNode=Node(depth+1)
    stack.append((idxsR, parent.rightNode, depth+1))



In [None]:
print(root)


P2 is ee
   P0 < 0.5957755140673854
      counts: a:6 

   P0 > 0.5957755140673854
      P0 < 2.2270794849039097
         P1 < 4.032597744264673
            P0 < 1.8904863851954383
               counts: c:4 

            P0 > 1.8904863851954383
               counts: a:1 c:1 

         P1 > 4.032597744264673
            counts: a:1 

      P0 > 2.2270794849039097
         counts: a:2 

P2 not ee
   counts: b:8 

