# Senior Project - Decision Tree Classification
**Firdaus Bisma Suryakusuma**  
**19/444051/TK/49247**

In [171]:
import pandas
import math

In [172]:
raw_data = pandas.read_csv("drug200.csv", delimiter=",")
print(raw_data)

     Age Sex      BP Cholesterol  Na_to_K   Drug
0     23   F    HIGH        HIGH   25.355  drugY
1     47   M     LOW        HIGH   13.093  drugC
2     47   M     LOW        HIGH   10.114  drugC
3     28   F  NORMAL        HIGH    7.798  drugX
4     61   F     LOW        HIGH   18.043  drugY
..   ...  ..     ...         ...      ...    ...
195   56   F     LOW        HIGH   11.567  drugC
196   16   M     LOW        HIGH   12.006  drugC
197   52   M  NORMAL        HIGH    9.894  drugX
198   23   M  NORMAL      NORMAL   14.020  drugX
199   40   F     LOW      NORMAL   11.349  drugX

[200 rows x 6 columns]


In [173]:
raw_data['Sex'] = raw_data['Sex'].replace(['F', 'M'], [0, 1])
raw_data['BP'] = raw_data['BP'].replace(['LOW', 'NORMAL', 'HIGH'], [0, 1, 2])
raw_data['Cholesterol'] = raw_data['Cholesterol'].replace(['NORMAL', 'HIGH'], [0, 1])
print(raw_data)

     Age  Sex  BP  Cholesterol  Na_to_K   Drug
0     23    0   2            1   25.355  drugY
1     47    1   0            1   13.093  drugC
2     47    1   0            1   10.114  drugC
3     28    0   1            1    7.798  drugX
4     61    0   0            1   18.043  drugY
..   ...  ...  ..          ...      ...    ...
195   56    0   0            1   11.567  drugC
196   16    1   0            1   12.006  drugC
197   52    1   1            1    9.894  drugX
198   23    1   1            0   14.020  drugX
199   40    0   0            0   11.349  drugX

[200 rows x 6 columns]


In [174]:
# split raw_data dataframe into training 70% and testing 30%
training_data = raw_data.sample(frac=0.7)
test_data = raw_data.drop(training_data.index)
print(training_data)
print(test_data)

     Age  Sex  BP  Cholesterol  Na_to_K   Drug
144   39    1   2            1    9.664  drugA
11    34    0   2            0   19.199  drugY
29    45    1   0            1   17.951  drugY
64    60    0   2            1   13.303  drugB
70    70    1   2            1   13.967  drugB
..   ...  ...  ..          ...      ...    ...
39    15    1   1            1    9.084  drugX
81    64    1   1            1    7.761  drugX
71    28    0   1            1   19.675  drugY
173   41    0   0            0   18.739  drugY
170   28    0   1            1   12.879  drugX

[140 rows x 6 columns]
     Age  Sex  BP  Cholesterol  Na_to_K   Drug
0     23    0   2            1   25.355  drugY
2     47    1   0            1   10.114  drugC
3     28    0   1            1    7.798  drugX
4     61    0   0            1   18.043  drugY
5     22    0   1            1    8.607  drugX
9     43    1   0            0   19.368  drugY
15    16    0   2            0   15.516  drugY
16    69    1   0            0   11.

In [175]:
def calculate_gini_index(data, label_column):
  labels = data[label_column].unique()
  probabilities = []
  # Calculate the probability of each class.
  for label in labels:
    probability = data[label_column].value_counts()[label] / len(data)
    probabilities.append(probability)

  # Calculate the gini index.
  gini = 1
  for probability in probabilities:
    gini -= probability**2

  return gini

print(calculate_gini_index(training_data, 'Drug'))

0.7029591836734694


In [176]:
def generate_split_combinations(data, label_column):
  non_label_columns = list(data.columns.values)
  non_label_columns.remove(label_column)

  split_combinations = {
    'column': [],
    'split_point': []
  }
  # Generate all possible split combinations.
  for column in non_label_columns:
    # sort and append the midpoint of each value.
    split_points = sorted(data[column].unique())
    split_points = [split_points[i] + (split_points[i + 1] - split_points[i]) / 2 for i in range(len(split_points) - 1)]
    split_combinations['column'].extend([column] * len(split_points))
    split_combinations['split_point'].extend(split_points)
  
  split_combinations = pandas.DataFrame(split_combinations)

  return split_combinations

def calculate_information_gain(data, label_column, split_column, split_point):
  # Calculate the gini index for the current split.
  gini_current = calculate_gini_index(data, label_column)
  # Calculate the gini index for the left split.
  data_left = data[data[split_column] <= split_point]
  gini_left = calculate_gini_index(data_left, label_column)
  # Calculate the gini index for the right split.
  data_right = data[data[split_column] > split_point]
  gini_right = calculate_gini_index(data_right, label_column)
  # Calculate the information gain.
  information_gain = gini_current - (len(data_left) / len(data)) * gini_left - (len(data_right) / len(data)) * gini_right

  return information_gain

def find_best_split(data, label_column):
  split_combinations = generate_split_combinations(data, label_column)

  # calculate the information gain from each split combination
  information_gain = []
  for _, row in split_combinations.iterrows():
    information_gain.append(calculate_information_gain(data, 'Drug', row['column'], row['split_point']))

  split_combinations = split_combinations.assign(information_gain=information_gain)
  split_combinations.sort_values('information_gain', ascending=False, inplace=True)

  return { 
    'column': split_combinations['column'].iloc[0], 
    'split_point': split_combinations['split_point'].iloc[0],
    'information_gain': split_combinations['information_gain'].iloc[0] 
  }

print(find_best_split(training_data, 'Drug'))

{'column': 'Na_to_K', 'split_point': 14.8285, 'information_gain': 0.32232008592910855}


In [177]:
class Node:
  def __init__(self, children, column, split_value, label):
    self.children = children
    self.column = column
    self.split_value = split_value
    self.label = label

  def predict(self, row):
    if self.label is None:
      if row[self.column] <= self.split_value:
        return self.children[0].predict(row)
      else:
        return self.children[1].predict(row)
    else:
      return self.label

In [178]:
def split_data(best_split, data):
  return {
    'left': data[data[best_split['column']] <= best_split['split_point']],
    'right': data[data[best_split['column']] > best_split['split_point']]
  }

def find_label(data, label_column):
  return data[label_column].value_counts().idxmax()

def construct_decision_tree(data, label_column, depth=0):
  gini = calculate_gini_index(data, label_column)

  if (gini == 0):
    return Node([], None, None, find_label(data, label_column))
  else:
    best_split = find_best_split(data, label_column)
    print('\t'*depth + str(best_split))
    splitted_data = split_data(best_split, data)
    left_child = construct_decision_tree(splitted_data['left'], label_column, depth+1)
    right_child = construct_decision_tree(splitted_data['right'], label_column, depth+1)
    return Node([left_child, right_child], best_split['column'], best_split['split_point'], None)

tree = construct_decision_tree(training_data, 'Drug')
print('Tree constructed!')

# test accuracy
test_data['Prediction'] = test_data.apply(lambda row: tree.predict(row), axis=1)
correct_data = test_data[test_data['Drug'] == test_data['Prediction']]
print('Accuracy:', len(correct_data) / len(test_data))
print(test_data)

{'column': 'Na_to_K', 'split_point': 14.8285, 'information_gain': 0.32232008592910855}
	{'column': 'BP', 'split_point': 1.5, 'information_gain': 0.26628026014693484}
		{'column': 'Cholesterol', 'split_point': 0.5, 'information_gain': 0.14642407057340895}
			{'column': 'BP', 'split_point': 0.5, 'information_gain': 0.4965277777777779}
		{'column': 'Age', 'split_point': 51.5, 'information_gain': 0.48}
Tree constructed!
Accuracy: 0.9833333333333333
     Age  Sex  BP  Cholesterol  Na_to_K   Drug Prediction
0     23    0   2            1   25.355  drugY      drugY
2     47    1   0            1   10.114  drugC      drugC
3     28    0   1            1    7.798  drugX      drugX
4     61    0   0            1   18.043  drugY      drugY
5     22    0   1            1    8.607  drugX      drugX
9     43    1   0            0   19.368  drugY      drugY
15    16    0   2            0   15.516  drugY      drugY
16    69    1   0            0   11.455  drugX      drugX
27    49    0   1            