# Decision tree

In this program, we take a matrix of items and responses to questions about them, and construct a decision tree based on that matrix, that allows to identify the item by just a few responses.
The algorithm uses the heuristic that the most optimal way is to ask questions that leave roughly half items available.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random

In [2]:
questions = [
    'red?',
    'sweet?',
    'yellow?',
    'peel?',
]

In [3]:
# 0 -- no knowledge (unacceptable)
# 1 -- we know no
# 2 -- we know yes
# 3 -- both yes or/and no
answers = {
    'banana': (1,  2,  2,  2),
    'apple' : (2,  2,  1,  1),
    'berry' : (2,  1,  1,  1),
    'orange': (1,  3,  1,  2),
    'grape' : (1,  1,  1,  2),
    'lemon' : (1,  1,  2,  2),
}

In [4]:
is_yes = lambda x: (x & 2) != 0
is_no  = lambda x: (x & 1) != 0

In [5]:
print(questions)
print(answers)

['red?', 'sweet?', 'yellow?', 'peel?']
{'banana': (1, 2, 2, 2), 'apple': (2, 2, 1, 1), 'berry': (2, 1, 1, 1), 'orange': (1, 3, 1, 2), 'grape': (1, 1, 1, 2), 'lemon': (1, 1, 2, 2)}


In [6]:
yeses = [ set(k for k,v in answers.items() if is_yes(v[i])) for i in range(len(questions)) ]
noes  = [ set(k for k,v in answers.items() if is_no (v[i])) for i in range(len(questions)) ]
yes_lens = [len(s) for s in yeses]
no_lens  = [len(s) for s in noes ]
min_lens = [abs(a-b) for a, b in zip(yes_lens, no_lens)]
# return (list(i for i in range(len(min_lens)) if min_lens[i] == min(min_lens)))[0]

print(yes_lens)
print(no_lens)
print(min_lens)
print()

[2, 3, 2, 4]
[4, 4, 4, 2]
[2, 1, 2, 2]



In [7]:
def eliminate_questions(set_elim, questions, answers):
    nq = len(questions)
    questions = [questions[i] for i in range(nq) if i not in set_elim]
    answers = {k: tuple(v[i] for i in range(nq) if i not in set_elim) \
                   for k,v in answers.items()}
    return questions, answers

In [8]:
# return sets where the ith question is answered yes or no
def split_fruits(i, answers):
    return set(k for k,v in answers.items() if is_yes(v[i])), \
           set(k for k,v in answers.items() if is_no (v[i]))

In [9]:
def compare_answer_lists(A, B):
    return all(a == b or a == -1 or b == -1 for a, b in zip(A, B))
(compare_answer_lists((0,0,1,0), (0,0,1,0)),
compare_answer_lists((0,0,1,0), (0,0,1,1)),
compare_answer_lists((0,0,-1,0), (0,0,1,0)))

(True, False, True)

In [10]:
def whatever(questions, answers):
    
#     print('calling whatever with {} questions and {} answers'.format(len(questions), len(answers)))
#     print(questions)
#     print(answers)
        
    answers_list = list(answers.values())
    if len(answers) == 1 or all(compare_answer_lists(answers_list[i], answers_list[0]) for i in range(len(answers))): 
#         print('i have found {}'.format(set(answers.keys())))
        return set(answers.keys())

    to_remove = set()
    
    for i in range(len(questions)):
        branch_yes, branch_no = split_fruits(i, answers)

        if len(branch_yes) == 0 or len(branch_no) == 0: 
            to_remove.add(i)

    if len(to_remove) > 0:
        questions, answers = eliminate_questions(to_remove, questions, answers)
#         print('cleaned up')
#         print(questions)
#         print(answers)
    
    # we decide which question to branch from.
    # in a better than random heuristic, we should have equal amount of hard yes and noes
    ibranch = random.randrange(len(questions))
    branch_yes, branch_no = split_fruits(ibranch, answers)
    answers_yes = {k: answers[k] for k in branch_yes}
    answers_no  = {k: answers[k] for k in branch_no }
    
    # we should assume yes/no wherever a 'do not know' occurs
    # or we can just remove this question and do not wait for it to be removed by
    # a recursive call. both are equivalent
    questions_yes, answers_yes = eliminate_questions([ibranch], questions, answers_yes)
    questions_no , answers_no  = eliminate_questions([ibranch], questions, answers_no )
    
    returnv = (questions[ibranch], whatever(questions_yes, answers_yes), 
            whatever(questions_no, answers_no))
#     print('trying to return {}'.format(returnv))
    return returnv
    
    
result = whatever(questions, answers)

In [11]:
def print_result(result, ind=''):
    question, result_yes, result_no = result
    
    print(ind + question)
    
    print(ind + '|  yes:')
    
    if type(result_yes) == tuple:
        print_result(result_yes, ind=ind+'|    ')
    else:
        print(ind + '|---\u2192 {}'.format(result_yes))
        
    print(ind + '|  no:')
    
    if type(result_no) == tuple:
        print_result(result_no, ind=ind+'|    ')
    else:
        print(ind + '|---\u2192 {}'.format(result_no))
        
print_result(result)

yellow?
|  yes:
|    sweet?
|    |  yes:
|    |---→ {'banana'}
|    |  no:
|    |---→ {'lemon'}
|  no:
|    red?
|    |  yes:
|    |    sweet?
|    |    |  yes:
|    |    |---→ {'apple'}
|    |    |  no:
|    |    |---→ {'berry'}
|    |  no:
|    |    sweet?
|    |    |  yes:
|    |    |---→ {'orange'}
|    |    |  no:
|    |    |---→ {'orange', 'grape'}


In [12]:
print(questions)
print(answers)

['red?', 'sweet?', 'yellow?', 'peel?']
{'banana': (1, 2, 2, 2), 'apple': (2, 2, 1, 1), 'berry': (2, 1, 1, 1), 'orange': (1, 3, 1, 2), 'grape': (1, 1, 1, 2), 'lemon': (1, 1, 2, 2)}
