In [3]:
#setup kaggle cli by using 
#kg config -g -u $USER -p $PASSWoRD -c titanic
#kg download

In [4]:
import pandas as pd

In [5]:
data = pd.read_csv('data/train.csv')

In [6]:
data.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [136]:
X_df = data[['Age', 'Sex', 'PassengerId']]
Y_df = data[ ['Survived', 'PassengerId']]

In [137]:
X_df = X_df.set_index('PassengerId')
Y_df = Y_df.set_index('PassengerId')

In [9]:
X_df.head()
#consider throwing out those that miss AGE (are equal to NaN)

Unnamed: 0_level_0,Age,Sex
PassengerId,Unnamed: 1_level_1,Unnamed: 2_level_1
1,22.0,male
2,38.0,female
3,26.0,female
4,35.0,female
5,35.0,male


In [153]:
Y_df.head()

Unnamed: 0_level_0,Survived
PassengerId,Unnamed: 1_level_1
1,0
2,1
3,1
4,1
5,0


Goal is to create a decision tree based off questions:    is  Age > x and Sex === y

1) create a question class with signature (question_type, x) => question in natural lang


2) create a function that inputs the records (rows) and computes impurity at a node
3 create a function  (node1,node2, question)  => information_gain

then a function that trains the tree

question

partition (question, dataset) => dataset1 dataset2

gini: rows => gini measure

info_gain: (false_rows, left_rows_ gini_impurity at begining row) => info_gain

find best split

 

In [138]:
class Question:

    def __init__(self, question_type, value):
        self.question_type = question_type
       
        self.value = value
        

    def __repr__(self):
        
        if self.question_type =="Age" :
            return "Age >= {} ?".format(self.value)
        elif self.question_type == "Sex":
            return "Sex == {} ?".format(self.value)
        
    def match(self, passenger):
        #passenger will be a DataFrame row
        if self.question_type =="Age" :
            return passenger.Age >= self.value
        elif self.question_type == "Sex":
            return self.value == passenger.Sex

Some demos below to see that it works as intended.

In [139]:
q_age = Question("Age",1)

In [140]:
q_age

Age >= 1 ?

In [141]:
q_sex = Question("Sex", "male")

In [142]:
q_sex

Sex == male ?

In [143]:
passenger = X_df.ix[1]

In [144]:
passenger

Age      22
Sex    male
Name: 1, dtype: object

In [145]:
q_age.match(passenger)

True

In [146]:
q_age2 = Question("Age",23)
q_age2.match(passenger)

False

In [147]:
q_sex.match(passenger)

True

In [148]:
Y_df[Y_df==0].count()

Survived    549
dtype: int64

In [198]:
def gini_impurity_function(Y_df):
    unique_values = Y_df.Survived.unique()
    #will be 0 (died) or 1 (survived)  
    died_count = Y_df[Y_df==0].count()
    survived_count = Y_df[Y_df==1].count()
    total_count = Y_df.count()
    
    result = 1 - (survived_count/total_count)**2 - (died_count/total_count)**2
    return result[0]
    

In [199]:
gini_impurity_function(Y_df)

0.47301295786144276

In [156]:
gini_impurity_function(Y_df[Y_df==1])

Survived    0.0
dtype: float64

Seems like it works.

Now, for a given question, find the rows for which that question is true and also false

In [157]:
def partititioner(question, X_df):
    
    true_passengers = []
    false_passengers = []
    
    
    for passenger in X_df.itertuples():
        if question.match(passenger):
            true_passengers.append(passenger[0])
            #print(question ,passenger)
        else:
            #print("did not match", passenger)
            false_passengers.append(passenger[0])
            
            
    #turn arrays (that contain ID_s into DF's)        
    true_passengers = pd.DataFrame(data={"PassengerId": true_passengers})        
    false_passengers = pd.DataFrame(data={"PassengerId": false_passengers})
    
    true_passengers_df = X_df.merge(true_passengers, how='inner', left_index=True, right_on='PassengerId').set_index('PassengerId')
    false_passengers_df = X_df.merge(false_passengers, how='inner', left_index=True, right_on='PassengerId').set_index('PassengerId')        
    return true_passengers_df, false_passengers_df        

In [158]:
true_passengers_df, false_passengers_df = partititioner(q_age2, X_df)

In [170]:
true_passengers_df.head()

Unnamed: 0_level_0,Age,Sex
PassengerId,Unnamed: 1_level_1,Unnamed: 2_level_1
2,38.0,female
3,26.0,female
4,35.0,female
5,35.0,male
7,54.0,male


891

In [217]:
def find_best_split(X_df, Y_df ):
    
    ages = X_df.Age.unique()
    sexes = X_df.Sex.unique()
    
    gini_impurity = gini_impurity_function(Y_df)
    
    information_gain_age_max = 0
    information_gain_sex_max = 0
    
    for age in ages:
        age_question = Question("Age", age)
        
        true_passengers_df, false_passengers_df = partititioner(age_question, X_df)
        
        Y_df_of_true_passengers = Y_df.merge(true_passengers_df, how='inner', left_index=True, right_index=True)['Survived'].to_frame()
        #now need to do
        Y_df_of_false_passengers = Y_df.merge(false_passengers_df, how='inner', left_index=True, right_index=True)['Survived'].to_frame()
        
        true_passengers_fraction = len(true_passengers_df) / (len(true_passengers_df) + len(false_passengers_df))
        
        gini_impurity_post_split = true_passengers_fraction * gini_impurity_function(Y_df_of_true_passengers ) + (1 - true_passengers_fraction) * gini_impurity_function(Y_df_of_false_passengers )
        
        information_gain_age = gini_impurity - gini_impurity_post_split
        
        if  information_gain_age > information_gain_age_max:
            information_gain_age_max = information_gain_age
            age_question_max = age_question
            
        
        print(information_gain_age)
    print(information_gain_age_max, age_question_max)    
    return information_gain_age_max, age_question_max     
#         for sex in sexes:
#             sex_question = Question("Sex", sex)
#             print(age_question, sex_question)
    #plan here: 
    #1) ask all possible questions
    #2) find out which one gives the best information gain.

In [218]:
X_df.Sex.unique()

array(['male', 'female'], dtype=object)

In [219]:
find_best_split(X_df, Y_df)

0.000539886870539
2.04193593858e-05
0.000566037361193
0.00018713869416
nan
0.000484582550666
0.00129744824017
0.000717524969848
1.78180825775e-05
0.000784957029323
0.000502303532607
1.31652135527e-07
5.95237913587e-05
0.000546411189949
0.000665042797828
0.00016646504483
9.5235927477e-06
0.000190209847497
1.6982503597e-05
5.50301579311e-06
4.13146444009e-05
0.00121399427312
4.25953294111e-05
9.63368945905e-05
3.26527737305e-05
0.00148609503023
1.28507721302e-05
2.64511623493e-06
0.000619202510201
0.0021451892664
0.000500261322772
0.000252010035883
8.01184555743e-05
2.91982573337e-05
5.38593455152e-06
0.000447506783595
1.76137092048e-05
0.000217659378071
0.00282575394146
0.000580973129993
0.000182882673473
0.000442902171182
0.000606272119942
1.47719669039e-05
0.00116235449288
0.000161557727066
8.98472421674e-05
3.79164384556e-06
3.0281585997e-06
0.000381450918423
0.000204448242603
0.000116813757868
5.83030886064e-06
0.000119567726673
0.000248737666751
0.000660763770118
0.000119379704707


(0.004020704223127336, Age >= 0.42 ?)

In [125]:
X_df

Unnamed: 0_level_0,Age,Sex
PassengerId,Unnamed: 1_level_1,Unnamed: 2_level_1
1,22.0,male
2,38.0,female
3,26.0,female
4,35.0,female
5,35.0,male
6,,male
7,54.0,male
8,2.0,male
9,27.0,female
10,14.0,female


0      0
1      1
2      1
3      1
4      0
5      0
6      0
7      0
8      1
9      1
10     1
11     1
12     0
13     0
14     0
15     1
16     0
17     1
18     0
19     1
20     0
21     1
22     1
23     1
24     0
25     1
26     0
27     0
28     1
29     0
      ..
861    0
862    1
863    0
864    0
865    1
866    1
867    0
868    0
869    1
870    0
871    1
872    0
873    0
874    1
875    1
876    0
877    0
878    0
879    1
880    1
881    0
882    0
883    0
884    0
885    0
886    0
887    1
888    0
889    1
890    0
Name: Survived, dtype: int64