## Decision Trees

In [2]:
import csv
import numpy as np
import matplotlib.pyplot as plt

In [3]:
with open("data/gift.csv", "r") as file:
    data = list(csv.reader(file))

In [4]:
data

[['Type', 'Color', 'Pattern', 'Liked'],
 ['Blouse', 'Green', 'Polka', 'Yes'],
 ['Blouse', 'Red', 'Polka', 'Yes'],
 ['Dress', 'Red', 'Solid', 'Yes'],
 ['Dress', 'Green', 'Checkers', 'Yes'],
 ['Blouse', 'Yellow', 'Checkers', 'No'],
 ['Dress', 'Red', 'Solid', 'No'],
 ['Dress', 'Yellow', 'Solid', 'Yes'],
 ['Blouse', 'Green', 'Checkers', 'Yes'],
 ['Blouse', 'Red', 'Polka', 'No'],
 ['Blouse', 'Pink', 'Solid', 'Yes']]

In [5]:
col_names = data[0]
data = data[1:]

In [6]:
data

[['Blouse', 'Green', 'Polka', 'Yes'],
 ['Blouse', 'Red', 'Polka', 'Yes'],
 ['Dress', 'Red', 'Solid', 'Yes'],
 ['Dress', 'Green', 'Checkers', 'Yes'],
 ['Blouse', 'Yellow', 'Checkers', 'No'],
 ['Dress', 'Red', 'Solid', 'No'],
 ['Dress', 'Yellow', 'Solid', 'Yes'],
 ['Blouse', 'Green', 'Checkers', 'Yes'],
 ['Blouse', 'Red', 'Polka', 'No'],
 ['Blouse', 'Pink', 'Solid', 'Yes']]

In [11]:
def unique(data, col:int):
    return set([row[col] for row in data])

In [12]:
unique(data,0)

{'Blouse', 'Dress'}

In [16]:
def question(col, value):
    return lambda row: row[col]==value

In [18]:
question(0,"Dress")(['Blouse', 'Red', 'Polka', 'Yes'])

False

In [19]:
def split(data, question):
    true_row = []
    false_row = []
    for row in data:
        if question(row):
            true_row.append(row)
        else:
            false_row.append(row)
    return true_row, false_row

In [21]:
q = question(0,"Dress")
true_rows, false_rows = split(data,q)

In [22]:
true_rows

[['Dress', 'Red', 'Solid', 'Yes'],
 ['Dress', 'Green', 'Checkers', 'Yes'],
 ['Dress', 'Red', 'Solid', 'No'],
 ['Dress', 'Yellow', 'Solid', 'Yes']]

In [23]:
false_rows

[['Blouse', 'Green', 'Polka', 'Yes'],
 ['Blouse', 'Red', 'Polka', 'Yes'],
 ['Blouse', 'Yellow', 'Checkers', 'No'],
 ['Blouse', 'Green', 'Checkers', 'Yes'],
 ['Blouse', 'Red', 'Polka', 'No'],
 ['Blouse', 'Pink', 'Solid', 'Yes']]

In [24]:
def value_counts(data, col):
    counts = {}
    for row in data:
        if not row[col] in counts:
            counts[row[col]] = 1
        else:
            counts[row[col]] += 1
    return counts

In [28]:
value_counts(data,0)

{'Blouse': 6, 'Dress': 4}

In [40]:
def gini(data,col=-1):
    counts = value_counts(data,col)
    prob = [v/len(data) for v in counts.values()]
    prob = [p**2 for p in prob]
    return 1 - sum(prob)

In [41]:
gini(data,-1)

0.42000000000000004

In [42]:
gini(true_rows,-1)

0.375

In [43]:
gini(false_rows,-1)

0.4444444444444444

In [44]:
(gini(true_rows,-1)+gini(false_rows,-1))/2

0.4097222222222222

In [73]:
def info_gain(data, true_rows, false_rows):
    weight = len(true_rows)/len(data)
    return gini(data) - weight*gini(true_rows) - (1-weight)*gini(false_rows)

In [74]:
info_gain(data,true_rows,false_rows)

0.05333333333333343

In [75]:
true_rows, false_rows = split(data,question(1,"Red"))

In [76]:
info_gain(data,true_rows,false_rows)

0.05333333333333343

In [77]:
true_rows

[['Blouse', 'Red', 'Polka', 'Yes'],
 ['Dress', 'Red', 'Solid', 'Yes'],
 ['Dress', 'Red', 'Solid', 'No'],
 ['Blouse', 'Red', 'Polka', 'No']]

In [78]:
false_rows

[['Blouse', 'Green', 'Polka', 'Yes'],
 ['Dress', 'Green', 'Checkers', 'Yes'],
 ['Blouse', 'Yellow', 'Checkers', 'No'],
 ['Dress', 'Yellow', 'Solid', 'Yes'],
 ['Blouse', 'Green', 'Checkers', 'Yes'],
 ['Blouse', 'Pink', 'Solid', 'Yes']]

In [99]:
def find_best_question(data):
    best_gain = 0
    best_question = None
    for col in range(len(data[0])-1):
        values = unique(data,col)
        for value in values:
            q = question(col,value)
            true,false = split(data,q)
            gain = info_gain(data,true,false)
            if gain > best_gain:
                best_gain = gain
                best_question = q
    return best_gain, best_question

In [100]:
g, q = find_best_question(data)

In [101]:
g

0.07714285714285712

In [102]:
split(data,q)

([['Blouse', 'Green', 'Polka', 'Yes'],
  ['Dress', 'Green', 'Checkers', 'Yes'],
  ['Blouse', 'Green', 'Checkers', 'Yes']],
 [['Blouse', 'Red', 'Polka', 'Yes'],
  ['Dress', 'Red', 'Solid', 'Yes'],
  ['Blouse', 'Yellow', 'Checkers', 'No'],
  ['Dress', 'Red', 'Solid', 'No'],
  ['Dress', 'Yellow', 'Solid', 'Yes'],
  ['Blouse', 'Red', 'Polka', 'No'],
  ['Blouse', 'Pink', 'Solid', 'Yes']])

In [103]:
true, false = split(data,q)

In [104]:
g, q = find_best_question(false)

In [105]:
g

0.10884353741496605

In [106]:
split(data,q)

([['Dress', 'Green', 'Checkers', 'Yes'],
  ['Blouse', 'Yellow', 'Checkers', 'No'],
  ['Blouse', 'Green', 'Checkers', 'Yes']],
 [['Blouse', 'Green', 'Polka', 'Yes'],
  ['Blouse', 'Red', 'Polka', 'Yes'],
  ['Dress', 'Red', 'Solid', 'Yes'],
  ['Dress', 'Red', 'Solid', 'No'],
  ['Dress', 'Yellow', 'Solid', 'Yes'],
  ['Blouse', 'Red', 'Polka', 'No'],
  ['Blouse', 'Pink', 'Solid', 'Yes']])

In [107]:
true, false = split(false,q)

In [108]:
g, q = find_best_question(false)

In [109]:
g

0.1111111111111111

In [110]:
split(data,q)

([['Blouse', 'Red', 'Polka', 'Yes'],
  ['Dress', 'Red', 'Solid', 'Yes'],
  ['Dress', 'Red', 'Solid', 'No'],
  ['Blouse', 'Red', 'Polka', 'No']],
 [['Blouse', 'Green', 'Polka', 'Yes'],
  ['Dress', 'Green', 'Checkers', 'Yes'],
  ['Blouse', 'Yellow', 'Checkers', 'No'],
  ['Dress', 'Yellow', 'Solid', 'Yes'],
  ['Blouse', 'Green', 'Checkers', 'Yes'],
  ['Blouse', 'Pink', 'Solid', 'Yes']])

In [111]:
true, false = split(false,q)

In [112]:
g, q = find_best_question(false)

In [113]:
g

0

In [114]:
print(q)

None


In [115]:
5/6

0.8333333333333334

In [112]:
def unique(data,col_number):
    return set([row[col_number] for row in data])

In [113]:
unique(data,1)

{'Green', 'Pink', 'Red', 'Yellow'}

In [114]:
def value_counts(data, target=-1):
    counts = {}
    for row in data:
        value = row[target]
        counts[value] = counts.get(value,0)+1
    return counts

In [115]:
value_counts(data)

{'Yes': 7, 'No': 3}

In [116]:
def question(col, value):
    return lambda x: x[col]==value

In [117]:
q = question(1,"Green")
q

<function __main__.question.<locals>.<lambda>(x)>

In [118]:
q(['Dress', 'Green', 'Checkers', 'Yes'])

True

In [119]:
q(['Dress', 'Yellow', 'Checkers', 'Yes'])

False

In [120]:
def split(data,question):
    true = []
    false = []
    for row in data:
        if question(row):
            true.append(row)
        else:
            false.append(row)
    return true, false

In [121]:
true, false = split(data,question(0,"Dress"))

In [122]:
true

[['Dress', 'Red', 'Solid', 'Yes'],
 ['Dress', 'Green', 'Checkers', 'Yes'],
 ['Dress', 'Red', 'Solid', 'No'],
 ['Dress', 'Yellow', 'Solid', 'Yes']]

In [123]:
false

[['Blouse', 'Green', 'Polka', 'Yes'],
 ['Blouse', 'Red', 'Polka', 'Yes'],
 ['Blouse', 'Yellow', 'Checkers', 'No'],
 ['Blouse', 'Green', 'Checkers', 'Yes'],
 ['Blouse', 'Red', 'Polka', 'No'],
 ['Blouse', 'Pink', 'Solid', 'Yes']]

### What is the impurity of each division?

In [124]:
def gini(data):
    counts = value_counts(data)
    prob = [v/len(data) for v in counts.values()]
    prob_sq = [p**2 for p in prob]
    return 1-sum(prob_sq)

In [125]:
gini(data)

0.42000000000000004

In [126]:
gini(true)

0.375

In [127]:
gini(false)

0.4444444444444444

### Is this division an impovement?

Information gain is the uncertainty before the division minus the the weighted uncertainty of both leaves.

In [128]:
def info_gain(data, true, false):
    weight = len(true)/len(data)
    return gini(data) - weight*gini(true) - (1-weight)*gini(false)

In [129]:
true, false = split(data,question(1,"Green"))
info_gain(data, true, false)

0.07714285714285712

In [130]:
value_counts(true)

{'Yes': 3}

In [131]:
value_counts(false)

{'Yes': 4, 'No': 3}

In [132]:
true, false = split(data,question(1,"Yellow"))
info_gain(data, true, false)

0.020000000000000018

In [133]:
true, false = split(data,question(1,"Red"))
info_gain(data, true, false)

0.05333333333333343

In [134]:
true, false = split(data,question(1,"Pink"))
info_gain(data, true, false)

0.020000000000000073

### Finding the best split

In [135]:
def best_split(data):
    best_gain = 0  
    best_question = False
    n = len(data[0]) - 1
    for i in range(n):
        values = unique(data,i)
        for val in values:
            q = question(i,val)
            # Trying out the question
            true,false = split(data,q)
            # Checking for no division
            if len(true)==0 or len(false)==0:
                continue
            # Checking for improvement
            gain = info_gain(data,true,false)
            if gain > best_gain:
                best_gain = gain
                best_question = q
                
    return best_gain, best_question

In [136]:
gain, q = best_split(data)

In [137]:
gain

0.07714285714285712

In [138]:
q

<function __main__.question.<locals>.<lambda>(x)>

In [139]:
split(data,q)

([['Blouse', 'Green', 'Polka', 'Yes'],
  ['Dress', 'Green', 'Checkers', 'Yes'],
  ['Blouse', 'Green', 'Checkers', 'Yes']],
 [['Blouse', 'Red', 'Polka', 'Yes'],
  ['Dress', 'Red', 'Solid', 'Yes'],
  ['Blouse', 'Yellow', 'Checkers', 'No'],
  ['Dress', 'Red', 'Solid', 'No'],
  ['Dress', 'Yellow', 'Solid', 'Yes'],
  ['Blouse', 'Red', 'Polka', 'No'],
  ['Blouse', 'Pink', 'Solid', 'Yes']])