In [None]:
#A brute force implementation of Algorithm1

In [4]:
import pandas as pd
import numpy as np
from random import sample,seed

In [5]:
df = pd.DataFrame(pd.read_csv('../data/compas-binary.csv'))
df

seed(10)

idx = sample(range(df.shape[0]),df.shape[0])

In [6]:
x = df.as_matrix()[:,:13]

y = df.as_matrix()[:,13]

In [7]:
# calculate the bound with dp=all leaves except leaf l
def bound(tree, x, y, l, lamb = 0.02):
    return (sum(tree.get_cap()[:l]+tree.get_cap()[l+1:])-
            sum(tree.get_ncc()[:l]+tree.get_ncc()[l+1:]))/len(y) + lamb*len(tree.get_prefix())

# calculate the risk
def Risk(tree, x, y, lamb = 0.02):
    return 1-sum(tree.get_ncc())/len(y) + + lamb*len(tree.get_prefix())

In [8]:
def calcul(prefix, x, y):
    """
    Function for calculating the predictions, number of data captured,
    and number of data correctly captured by the leaves.
    """
    prediction = []
    num_captured = []
    num_captured_correct = []
    for i in range(len(prefix)):
        tag = np.array([True]*len(y))
        for j in range(len(prefix[i])):
            rule_index = abs(prefix[i][j])-1
            rule_label = int(prefix[i][j]>0)
            tag = (x[:,rule_index] == rule_label)*tag
            
        # the y's of these data captured by leaf i
        y_leaf = y[tag]
        
        num_cap = len(y_leaf)
        num_captured.append(num_cap)
        
        if len(y_leaf)>0:
            pred = int(y_leaf.sum()/len(y_leaf) > 0.5)
            prediction.append(pred)
            num_cap_cor = sum(y_leaf == pred)
            num_captured_correct.append(num_cap_cor)
        else:
            prediction.append(0)
            num_captured_correct.append(0)
        
    return prediction, num_captured, num_captured_correct

In [9]:
class CacheTree:
    """
    A tree data structure.
    prefix: a 2-d tuple to record the prefixes of leaves
    prediction: a list to record the predictions of leaves
    num_captured: a list to record number of data captured by the leaves
    num_captured_correct: a list to record number of data correctly captured by the leaves
    """
    def __init__(self, prefix=None, 
                 prediction=None,
                 num_captured=None,
                 num_captured_correct=None,
                 x=None, y=None):
        self.prefix = prefix
        self.prediction = prediction
        self.num_captured = num_captured
        self.num_captured_correct = num_captured_correct
        if prediction==None:
            self.prediction, self.num_captured, self.num_captured_correct = calcul(self.prefix, x, y)
            
    def get_prefix(self):
        return self.prefix
    
    def get_pred(self):
        return self.prediction
    
    def get_cap(self):
        return self.num_captured
    
    def get_ncc(self):
        return self.num_captured_correct

In [10]:
class Eliminate:
    """
    A data structure to record and identify
    whether a tree has been visited/pruned
    """
    def __init__(self, elim_dict = None, 
                 eliminated = None):
        self.elim_dict = {}
        self.eliminated = []
        
    def eliminate(self, prefix):
        self.elim_dict[prefix] = 1
    
    def eliminate_children(self, dp):
        self.eliminated.append(dp)
        
    def is_eliminated(self, prefix):
        if prefix in self.elim_dict.keys():
            #print("Eliminated!")
            return True
        
        for i in range(len(self.eliminated)):
            if all([r in prefix for r in self.eliminated[i]]):
                #print("Eliminated!")
                return True
        
        return False

In [11]:
import time

In [12]:
def algorithm1(x, y, lamb):
    """
    A brute forece implementation of Algorithm1
    """
    
    d_c = None
    R_c = 1

    nrule = x.shape[1]

    # initialize the queue to include all trees of just one split
    queue = [CacheTree(prefix = ((-r,),(r,)), x = x, y = y) for r in range(1, nrule+1)]

    E = Eliminate()

    while (queue):
        tree = queue.pop(0)
        d = tree.get_prefix()
        if E.is_eliminated(d):
            continue
        else:
            E.eliminate(d)
        R = Risk(tree,x,y,lamb)
        if R<R_c:
            d_c = d
            R_c = R

        #print("d",d)
        #print("R",R)
        #print("R_c",R_c)
        #print(tree.get_pred())

        for i in range(len(d)):
            d0 = d[i]
            dp = d[:i]+d[i+1:]
            #print("bound", bound(tree,x,y,i))
            #print("R_c", R_c)
            if bound(tree,x,y,i,lamb)<R_c:
                for j in range(1, nrule+1):
                    if (j not in d0)and(-j not in d0):
                        l1 = d0+(-j,)
                        l2 = d0+(j,)
                        t = dp+(l1, l2)
                        #print("t",t)
                        if (not E.is_eliminated(t)):
                            pred_l, cap_l, corr_l = calcul((l1,l2),x,y)
                            queue.append(
                                CacheTree(prefix = t,
                                          prediction = tree.get_pred()[:i]+tree.get_pred()[i+1:]+pred_l,
                                          num_captured = tree.get_cap()[:i]+tree.get_cap()[i+1:]+cap_l,
                                          num_captured_correct = tree.get_ncc()[:i]+tree.get_ncc()[i+1:]+corr_l)
                            )
            else:
                E.eliminate_children(dp)
                #print("Children Elim")

    print(d_c)
    print("R_c", R_c)

In [27]:
def algorithm1_2splits(x, y, lamb):
    """
    A brute forece implementation of Algorithm1
    """
    
    d_c = None
    R_c = 1

    nrule = x.shape[1]

    # initialize the queue to include all trees of just one split
    queue = [CacheTree(prefix = ((-r,),(r,)), x = x, y = y) for r in range(1, nrule+1)]

    E = Eliminate()

    while (queue):
        tree = queue.pop(0)
        d = tree.get_prefix()
        if E.is_eliminated(d):
            continue
        else:
            E.eliminate(d)
        R = Risk(tree,x,y,lamb)
        if R<R_c:
            d_c = d
            R_c = R

        print("d",d)
        #print("R",R)
        #print("R_c",R_c)
        #print(tree.get_pred())
        
        V = []
        for i in range(len(d)):
            d0 = d[i]
            dp = d[:i]+d[i+1:]
            #print("bound", bound(tree,x,y,i))
            #print("R_c", R_c)
            if bound(tree,x,y,i,lamb)<R_c:
                for j in range(1, nrule+1):
                    if (j not in d0)and(-j not in d0):
                        l1 = d0+(-j,)
                        l2 = d0+(j,)
                        t = dp+(l1, l2)
                        #print("t",t)
                        if (not E.is_eliminated(t)):
                            pred_l, cap_l, corr_l = calcul((l1,l2),x,y)
                            queue.append(
                                CacheTree(prefix = t,
                                          prediction = tree.get_pred()[:i]+tree.get_pred()[i+1:]+pred_l,
                                          num_captured = tree.get_cap()[:i]+tree.get_cap()[i+1:]+cap_l,
                                          num_captured_correct = tree.get_ncc()[:i]+tree.get_ncc()[i+1:]+corr_l)
                            )
            else:
                E.eliminate_children(dp)
                if len(V)>0:
                    for j in range(1, nrule+1):
                        if (j not in d0)and(-j not in d0):
                            l1 = d0+(-j,)
                            l2 = d0+(j,)
                            t = dp+(l1, l2)
                            
                            pred_l, cap_l, corr_l = calcul((l1,l2),x,y)
                            t_prediction = tree.get_pred()[:i]+tree.get_pred()[i+1:]+pred_l
                            t_num_captured = tree.get_cap()[:i]+tree.get_cap()[i+1:]+cap_l
                            t_num_captured_correct = tree.get_ncc()[:i]+tree.get_ncc()[i+1:]+corr_l
                            
                            for iv in range(len(V)):
                                lv = V[iv]
                                idx_lv = t.index(lv)
                                
                                for jv in range(1, nrule+1):
                                    print("jv, lv",jv, lv)
                                    if (jv not in lv)and(-jv not in lv):
                                        lv1 = lv+(-jv,)
                                        lv2 = lv+(jv,)
                                        t1 = t[:idx_lv]+t[idx_lv+1:]+(lv1, lv2)
                                        print("t1",t1)
                                        if (not E.is_eliminated(t1)):
                                            pred_lv, cap_lv, corr_lv = calcul((lv1,lv2),x,y)
                                            queue.append(
                                                CacheTree(prefix = t1,
                                                          prediction = t_prediction[:idx_lv]+t_prediction[idx_lv+1:]+pred_lv,
                                                          num_captured = t_num_captured[:idx_lv]+t_num_captured[idx_lv+1:]+cap_lv,
                                                          num_captured_correct = t_num_captured_correct[:idx_lv]+t_num_captured_correct[idx_lv+1:]+corr_lv)
                                            )
                V.append(d0)
                #print("Children Elim")

    print(d_c)
    print("R_c", R_c)

In [1]:
aaaaaa = ((1,2),(3,4))


In [3]:
aaaaaa.index((1,2))

0

In [14]:
rule_idx = sample(range(13),13)

In [32]:
%%time
# algorithm1_2splits, 3 rules, 25 data, lambda = 0.04

algorithm1_2splits(x[:,rule_idx[:3]][idx[:25]],y[idx[:25]],lamb=0.04)

d ((-1,), (1,))
d ((-2,), (2,))
d ((-3,), (3,))
d ((1,), (-1, -2), (-1, 2))
d ((1,), (-1, -3), (-1, 3))
d ((-1,), (1, -2), (1, 2))
d ((-1,), (1, -3), (1, 3))
d ((2,), (-2, -1), (-2, 1))
d ((2,), (-2, -3), (-2, 3))
d ((-2,), (2, -1), (2, 1))
d ((-2,), (2, -3), (2, 3))
d ((3,), (-3, -1), (-3, 1))
d ((3,), (-3, -2), (-3, 2))
d ((-1, -2), (-1, 2), (1, -2), (1, 2))
d ((-1, -2), (-1, 2), (1, -3), (1, 3))
d ((1,), (-1, 2), (-1, -2, -3), (-1, -2, 3))
d ((1,), (-1, -2), (-1, 2, -3), (-1, 2, 3))
d ((-1, -3), (-1, 3), (1, -2), (1, 2))
jv, lv 1 (-1, 3)
jv, lv 2 (-1, 3)
t1 ((-1, -3), (1, -2), (1, 2, -3), (1, 2, 3), (-1, 3, -2), (-1, 3, 2))
jv, lv 3 (-1, 3)
d ((-1, -3), (-1, 3), (1, -3), (1, 3))
jv, lv 1 (-1, 3)
jv, lv 2 (-1, 3)
t1 ((-1, -3), (1, -3), (1, 3, -2), (1, 3, 2), (-1, 3, -2), (-1, 3, 2))
jv, lv 3 (-1, 3)
d ((1,), (-1, 3), (-1, -3, -2), (-1, -3, 2))
d ((-1,), (1, 2), (1, -2, -3), (1, -2, 3))
d ((-1,), (1, 3), (1, -3, -2), (1, -3, 2))
d ((-2, -1), (-2, 1), (2, -1), (2, 1))
d ((-2, -1), (-2,

In [15]:
%%time
# 3 rules, 25 data, lambda = 0.04

algorithm1(x[:,rule_idx[:3]][idx[:25]],y[idx[:25]],lamb=0.04)

((-1,), (1,))
R_c 0.48000000000000004
CPU times: user 20 ms, sys: 4 ms, total: 24 ms
Wall time: 23.3 ms


In [14]:
%%time
# 4 rules, 25 data, lambda = 0.04

algorithm1(x[:,rule_idx[:4]][idx[:25]],y[idx[:25]],lamb=0.04)

((-1,), (1,))
R_c 0.48000000000000004
CPU times: user 392 ms, sys: 0 ns, total: 392 ms
Wall time: 394 ms


In [15]:
%%time
# 5 rules, 25 data, lambda = 0.04

algorithm1(x[:,rule_idx[:5]][idx[:25]],y[idx[:25]],lamb=0.04)

((-1,), (1,))
R_c 0.48000000000000004
CPU times: user 4.64 s, sys: 0 ns, total: 4.64 s
Wall time: 4.64 s


In [16]:
%%time
# 6 rules, 25 data, lambda = 0.04

algorithm1(x[:,rule_idx[:6]][idx[:25]],y[idx[:25]],lamb=0.04)

((-6,), (6,))
R_c 0.36000000000000004
CPU times: user 12.4 s, sys: 12 ms, total: 12.4 s
Wall time: 12.5 s


In [17]:
%%time
# 7 rules, 25 data, lambda = 0.04

algorithm1(x[:,rule_idx[:7]][idx[:25]],y[idx[:25]],lamb=0.04)

((-6,), (6,))
R_c 0.36000000000000004
CPU times: user 2min 29s, sys: 36 ms, total: 2min 29s
Wall time: 2min 30s


In [18]:
%%time
# 8 rules, 25 data, lambda = 0.04

algorithm1(x[:,rule_idx[:8]][idx[:25]],y[idx[:25]],lamb=0.04)

((-6,), (6,))
R_c 0.36000000000000004
CPU times: user 22min 48s, sys: 544 ms, total: 22min 48s
Wall time: 22min 49s


In [19]:
%%time
# 9 rules, 25 data, lambda = 0.04

algorithm1(x[:,rule_idx[:9]][idx[:25]],y[idx[:25]],lamb=0.04)

KeyboardInterrupt: 

In [20]:
%%time
# all 13 rules, 25 data, lambda = 0.04

algorithm1(x[idx[:25]],y[idx[:25]],lamb=0.04)

KeyboardInterrupt: 

In [21]:
%%time
# 6 rules, 25 data, lambda = 0.03

algorithm1(x[:,rule_idx[:6]][idx[:25]],y[idx[:25]],lamb=0.03)

((-6,), (6,))
R_c 0.34
CPU times: user 15.2 s, sys: 4 ms, total: 15.2 s
Wall time: 15.2 s


In [22]:
%%time
# 6 rules, 25 data, lambda = 0.02

algorithm1(x[:,rule_idx[:6]][idx[:25]],y[idx[:25]],lamb=0.02)

((-5, -6), (-5, 6), (5, 1), (5, -1, -2), (5, -1, 2))
R_c 0.29999999999999993
CPU times: user 27 s, sys: 0 ns, total: 27 s
Wall time: 27 s


In [23]:
%%time
# 6 rules, 25 data, lambda = 0.01

algorithm1(x[:,rule_idx[:6]][idx[:25]],y[idx[:25]],lamb=0.01)

((-5, -6), (-5, 6), (5, 1), (5, -1, -2), (5, -1, 2))
R_c 0.24999999999999994
CPU times: user 31.4 s, sys: 20 ms, total: 31.4 s
Wall time: 31.4 s


In [16]:
def algorithm1_nobound(x, y, lamb):
    """
    A brute forece implementation of Algorithm1 without the bound of Theorem1
    """
    
    d_c = None
    R_c = 1

    nrule = x.shape[1]

    # initialize the queue to include all trees of just one split
    queue = [CacheTree(prefix = ((-r,),(r,)), x = x, y = y) for r in range(1, nrule+1)]

    E = Eliminate()

    while (queue):
        tree = queue.pop(0)
        d = tree.get_prefix()
        if E.is_eliminated(d):
            continue
        else:
            E.eliminate(d)
        R = Risk(tree,x,y,lamb)
        if R<R_c:
            d_c = d
            R_c = R

        #print("d",d)
        #print("R",R)
        #print("R_c",R_c)
        #print(tree.get_pred())

        for i in range(len(d)):
            d0 = d[i]
            dp = d[:i]+d[i+1:]
            #print("bound", bound(tree,x,y,i))
            #print("R_c", R_c)
            #if bound(tree,x,y,i,lamb)<R_c:
            for j in range(1, nrule+1):
                if (j not in d0)and(-j not in d0):
                    l1 = d0+(-j,)
                    l2 = d0+(j,)
                    t = dp+(l1, l2)
                    #print("t",t)
                    if (not E.is_eliminated(t)):
                        pred_l, cap_l, corr_l = calcul((l1,l2),x,y)
                        queue.append(
                            CacheTree(prefix = t,
                                      prediction = tree.get_pred()[:i]+tree.get_pred()[i+1:]+pred_l,
                                      num_captured = tree.get_cap()[:i]+tree.get_cap()[i+1:]+cap_l,
                                      num_captured_correct = tree.get_ncc()[:i]+tree.get_ncc()[i+1:]+corr_l)
                        )
            #else:
                #E.eliminate_children(dp)
                #print("Children Elim")

    print(d_c)
    print("R_c", R_c)

In [33]:
%%time
# algorithm1 without the bound, 3 rules, 25 data, lambda = 0.04

algorithm1_nobound(x[:,rule_idx[:3]][idx[:25]],y[idx[:25]],lamb=0.04)

((-1,), (1,))
R_c 0.48000000000000004
CPU times: user 212 ms, sys: 0 ns, total: 212 ms
Wall time: 212 ms


In [29]:
%%time
# algorithm1 without the bound, 4 rules, 25 data, lambda = 0.04

algorithm1_nobound(x[:,rule_idx[:4]][idx[:25]],y[idx[:25]],lamb=0.04)

KeyboardInterrupt: 