In [2]:
import torch
import requests
import random
from collections import Counter

# iris data 다운로드
iris_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
r = requests.get(iris_url)
with open('iris.data', 'wb') as f:
    f.write(r.content)

In [6]:
vectors = []
answers = []

with open('iris.data', 'r') as f:
    for line in f:
        items = line.strip().split(",")
        if len(items) == 5:
            x = tuple(float(i) for i in items[:4])
            y = items[4]
            vectors.append(x)
            answers.append(y)

zipped = list(zip(vectors, answers))
random.shuffle(zipped)
train_size = int(len(vectors) * 0.8)
train_x, train_y = zip(*zipped[:train_size])
test_x, test_y = zip(*zipped[train_size:])

In [54]:
def gini_score(items):
    counter = Counter(items)
    total_count = sum(counter.values())
    gini = 1
    for item, count in counter.items():
        gini -= (count / total_count)**2
    
    return gini

def find_split_point_of_a_field(pairs):
    S = sorted(pairs, key = lambda x: x[0])

    min_gini, min_sp = 99, -1

    for i in range(1, len(S)):
        if S[i][0] != S[i-1][0]:
            gini_left = gini_score(x[1] for x in S[:i])
            gini_right = gini_score(x[1] for x in S[i:])
            gini = gini_left * i / len(S) + gini_right * (len(S)-i) / len(S)
            sp = (S[i][0] + S[i-1][0])/2
            if min_gini > gini:
                min_gini, min_sp = gini, sp
                
    return min_sp, min_gini

def split_data(X, Y):
    num_fields = len(X[0])
    min_gini, min_sp, min_fid = 99, -1, -1
    for fid in range(num_fields):
        sp, gini = find_split_point_of_a_field(zip((x[fid] for x in X), Y))
        if min_gini > gini:
            min_gini, min_sp, min_fid = gini, sp, fid

    node = {}
    node['fid'] = min_fid
    node['sp'] = min_sp
    node['gini'] = min_gini
    node['left'] = tuple(zip(*((x, y) for x, y in zip(X, Y) if x[min_fid] < min_sp)))
    node['right'] = tuple(zip(*((x, y) for x, y in zip(X, Y) if x[min_fid] >= min_sp)))

    return node

def decision_tree(X, Y, threshold):
    original_gini = gini_score(Y)
    node = split_data(X, Y)
    node['gini']

    if original_gini - node['gini'] <= threshold:
        counter = Counter(Y)
        total_count = sum(counter.values())
        ans, c = counter.most_common(1)[0]
        return (ans, c/total_count)
    else:
        XL, YL = node['left']
        node['left'] = decision_tree(XL, YL, threshold)
        XR, YR = node['right']
        node['right'] = decision_tree(XR, YR, threshold)
        return node

def predict(x, tree):
    if 'fid' not in tree:
        return tree
    
    if x[tree['fid']] < tree['sp']:
        return predict(x, tree['left'])
    else:
        return predict(x, tree['right'])

tree = decision_tree(train_x, train_y, 0.1)

train_accuracy = 0
for x, y in zip(train_x, train_y):
    if predict(x, tree)[0] == y:
        train_accuracy += 1
train_accuracy /= len(train_x)
print("train_accuracy:", train_accuracy)

test_accuracy = 0
for x, y in zip(test_x, test_y):
    if predict(x, tree)[0] == y:
        test_accuracy += 1
test_accuracy /= len(test_x)
print("test_accuracy:", test_accuracy)

train_accuracy: 0.95
test_accuracy: 0.9666666666666667
