In [26]:
import json
from collections import defaultdict

with open('../data/conifer_info.json', 'r') as fh:
    CONIFER_INFO = json.load(fh)
    
array_features = ['bundle_size']
size_features = ['needle_length', 'cone_length', 'diameter', 'height']
attribute_features = ['cone_scale_descriptors']

indistinguishable = defaultdict(set)
for reference_species, reference_info in CONIFER_INFO.items():
    for species, info in CONIFER_INFO.items():
        if species == reference_species: continue
        
        distinguishing_feature = False
        for feature in array_features:
            r_start, r_end = reference_info['features'][feature]
            start, end = info['features'][feature]
            distinguishing_feature |= r_start > end or start > r_end
        for feature in size_features:
            r_start, r_end = reference_info['features'][feature]
            start, end = info['features'][feature]
            distinguishing_feature |= r_start >= end or start >= r_end
        for feature in attribute_features:
            distinguishing_feature |= (
                set(info['features'][feature]) != set(reference_info['features'][feature])
            )
        if not distinguishing_feature:
            indistinguishable[reference_species].add(species)
            
indistinguishable

defaultdict(set, {})

In [75]:
def get_expectation(options, total=None):
    total = sum(options) if not total else total
    return sum(option * (total - option) / total for option in options)

def find_split(ranges):
    ends = defaultdict(int)
    for r in ranges:
        ends[r[1]] += 1
    starts = defaultdict(int)
    for r in ranges:
        starts[r[0]] += 1
        
    events = []
    for key in starts:
        if key in ends:
            events.append((key, 'both'))
        else:
            events.append((key, 'start'))
    for key in ends:
        if key not in starts:
            events.append((key, 'end'))
            
    fuzz = min(
        [
            e1 - e2
            for e1, e2 in zip(
                sorted(e[0] for e in list(events) + [(0, 'n/a')])[1:],
                sorted(e[0] for e in list(events) + [(0, 'n/a')])[:-1]
            ) if e1 != e2
        ]
    ) / 2.
            
        
    total = len(ranges)
    total_pos = 0
    total_neg = len(ranges)
    split_point = fuzz
    expectation = get_expectation([total_pos, total_neg])
    for event, kind in sorted(events):
        if kind in ('end', 'both'):
            total_pos += ends[event]
        if kind == ('start', 'both'):
            total_neg -= starts[event]
        new_expectation = get_expectation([total_pos, total_neg], total)
        if new_expectation < expectation:
            break
        expectation = new_expectation
        split_point = event + fuzz
    return split_point, expectation

def build_questions(species, categorical_questions=[], range_questions=[]):
    ranged_question_scores, categorical_question_scores = {}, {}
    for feature, category_id in categorical_questions:
        options = [0, 0]
        for info in species.values():
            index = 0 if category_id in info['features'][feature] else 1
            options[index] += 1
        categorical_question_scores[(feature, category_id)] = get_expectation(options)
    for feature in range_questions:
        ranges = [info['features'][feature] for info in species.values()]
        split, expectation = find_split(ranges)
        ranged_question_scores[(feature, split)] = expectation
    return categorical_question_scores, ranged_question_scores

build_questions(
    CONIFER_INFO, 
    categorical_questions=[
        ('cone_scale_descriptors', 'raised'),
        ('cone_scale_descriptors', 'keeled'),
        ('cone_scale_descriptors', 'bristled'),
        ('cone_scale_descriptors', 'spine'),
        ('cone_scale_descriptors', 'four_sided')
    ],
    range_questions=[
        'bundle_size',
        'height',
        'needle_length',
        'cone_length',
        'diameter'
    ]
)

({('cone_scale_descriptors', 'raised'): 16.61111111111111,
  ('cone_scale_descriptors', 'keeled'): 17.944444444444443,
  ('cone_scale_descriptors', 'bristled'): 16.0,
  ('cone_scale_descriptors', 'spine'): 13.5,
  ('cone_scale_descriptors', 'four_sided'): 11.277777777777779},
 {('bundle_size', 2.5): 8.0,
  ('height', 21.5): 8.972222222222221,
  ('needle_length', 10.5): 9.0,
  ('cone_length', 9.5): 8.972222222222221,
  ('diameter', 0.6499999999999999): 8.88888888888889})

In [80]:
c_questions=[
    ('cone_scale_descriptors', 'raised'),
    ('cone_scale_descriptors', 'keeled'),
    ('cone_scale_descriptors', 'bristled'),
    ('cone_scale_descriptors', 'spine'),
    ('cone_scale_descriptors', 'four_sided')
]
r_questions=[
    'bundle_size',
    'height',
    'needle_length',
    'cone_length',
    'diameter'
]
species = CONIFER_INFO
while r_questions or c_questions:
    print('species left', len(species))
    c_scores, r_scores = build_questions(species, c_questions, r_questions)
    best_c_score = 0
    best_r_score = 0
    if c_scores:
        best_c_score, best_c = sorted([(score, q) for q, score in c_scores.items()], reverse=True)[0]
    if r_scores:
        best_r_score, best_r = sorted([(score, q) for q, score in r_scores.items()], reverse=True)[0]
    if not best_c_score and not best_r_score:
        break
    if best_r_score > best_c_score:
        print(best_r, best_r_score)
        answer = input()
        if answer == 'y':
            species = {
                name: info
                for name, info in species.items()
                if info['features'][best_r[0]][1] > best_r[1]
            }
        elif answer == 'n':
            species = {
                name: info
                for name, info in species.items()
                if info['features'][best_r[0]][0] < best_r[1]
            }
        r_questions = [q for q in r_questions if q != best_r[0]]
    else:
        print(best_c, best_c_score)
        answer = input()
        if answer == 'y':
            species = {
                name: info
                for name, info in species.items()
                if best_c[1] in info['features'][best_c[0]]
            }
        elif answer == 'n':
            species = {
                name: info
                for name, info in species.items()
                if best_c[1] not in info['features'][best_c[0]]
            }
        c_questions = [q for q in c_questions if q != best_c]
species     

species left 36
('cone_scale_descriptors', 'keeled') 17.944444444444443


 n


species left 17
('cone_scale_descriptors', 'bristled') 7.764705882352941


 y


species left 6
('cone_scale_descriptors', 'four_sided') 3.0


 y


species left 3
('needle_length', 4.5) 0.6666666666666666


 n


species left 3
('height', 9.5) 0.6666666666666666


 y


species left 1


{'Pinus aristata': {'common_names': ['bristlecone pine'],
  'features': {'bundle_size': [5, 5],
   'needle_length': [2, 4],
   'cone_length': [6, 9],
   'diameter': [0.3, 0.8],
   'height': [6, 12],
   'cone_scale_descriptors': ['four_sided', 'bristled']}}}