In [1]:
import math
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import itertools
import string
import seaborn as sns
import scipy as sp
from toydown import GeoUnit, ToyDown

In [2]:
def create_tree_from_leaves(leaf_dict):
    """ Given a dictionary, where the keys are the names of leaf nodes (labeled by their path)
        and the corresponding value is the associated attribute counts, this function returns
        the list of GeoUnits that defines the corresponding tree.
    """
    nodes = leaf_dict.copy()
    h = len(list(leaf_dict.keys())[0])
    n = len(list(leaf_dict.values())[0])
    
    for i in range(2, h+1):
        level_names = list(set(list(map(lambda s: s[:-(i-1)], leaf_dict.keys()))))
        level_counts = [np.zeros(n)]*len(level_names)
        for node in level_names:
            nodes[node] = np.array([v for k, v in leaf_dict.items() if k.startswith(node)]).sum(axis=0)
        
    return [GeoUnit(k, k[:-1], v) if k != "1" else GeoUnit(k, None, v) for k, v in nodes.items()]

In [9]:
leaves = {   '111': np.array([100, 40, 60]),
             '112': np.array([100, 50, 50]),
             '113': np.array([100, 70, 30]),
             '114': np.array([100, 50, 50]),
             '121': np.array([100, 40, 60]),
             '122': np.array([100, 30, 70]),
             '123': np.array([100, 40, 60]),
             '124': np.array([100, 60, 40]),
             '131': np.array([100, 70, 30]),
             '132': np.array([100, 50, 50]),
             '133': np.array([100, 40, 60]),
             '134': np.array([100, 50, 50]),
             '141': np.array([100, 40, 60]),
             '142': np.array([100, 60, 40]),
             '143': np.array([100, 40, 60]),
             '144': np.array([100, 30, 70]),}

In [6]:
leaf_names = ["1" + "".join(a) for a in itertools.product(string.hexdigits[1:4+1], repeat=3-1)]

In [10]:
def leaf_counts(leaf_names, leaves):
    num_leaves = len(leaf_names)
    counts = np.zeros((num_leaves,3))
    for i,l in enumerate(leaf_names):
        counts[i] = leaves[l]
    return counts

In [15]:
geounits = create_tree_from_leaves(leaves)
geounits.reverse()
eps = 0.2
eps_split = [1/3, 1/3, 1/3]

In [16]:
model = ToyDown(geounits, 3, eps, eps_split)
model.show()

1
├── 11
│   ├── 111
│   ├── 112
│   ├── 113
│   └── 114
├── 12
│   ├── 121
│   ├── 122
│   ├── 123
│   └── 124
├── 13
│   ├── 131
│   ├── 132
│   ├── 133
│   └── 134
└── 14
    ├── 141
    ├── 142
    ├── 143
    └── 144



In [17]:
cons_0_diff = lambda n: [{'type': 'eq', 'fun': lambda x, i=i:  x[i] - np.sum([x[j] for j in range(i+1,i+3)])} 
                         for i in np.arange(n*3, step=3)]

In [20]:
def toydown_noise(leaves, model, cons=cons_0_diff, n_leaves=None):
    n = n_leaves if n_leaves else len(leaves)
    noised_counts = np.zeros((n,3))
    model.noise_and_adjust(node_cons=cons)
    for i,l in enumerate(leaves):
        noised_counts[i] = model.get_node(l).data.adjusted
    return noised_counts

In [35]:
grid = np.array([["111","112","121","122"],["114","113","124","123"],["141","142","131","132"],["144","143","134","133"]])

In [41]:
f = np.vectorize(lambda x: leaves[x])
f(grid)

ValueError: setting an array element with a sequence.

In [26]:
true_count = leaf_counts(leaf_names, leaves)

In [27]:
noise_count = toydown_noise(leaf_names, model)

In [28]:
true_count.sum(axis=0)

array([1600.,  760.,  840.])

In [29]:
noise_count.sum(axis=0)

array([1599.99926942,  759.99850219,  840.00076723])

In [30]:
noise_count

array([[ 92.38999482,  42.21323626,  50.17675856],
       [109.03651913,  49.35990062,  59.67661851],
       [106.18318349,  72.21323626,  33.96994723],
       [ 92.38999482,  52.21323626,  40.17675856],
       [ 67.94232443,  54.14498872,  13.79733571],
       [129.24008371,  19.227581  , 110.01250271],
       [100.16042633,  26.81672028,  73.34370605],
       [110.6572657 ,  67.81039902,  42.84686668],
       [ 92.95354599,  79.33665113,  13.61689487],
       [102.60592628,  52.43340793,  50.17251835],
       [113.45687252,  28.8333246 ,  84.62354792],
       [ 86.98313103,  47.3958212 ,  39.58730983],
       [ 95.9610242 ,  42.24568743,  53.71533677],
       [106.40179383,  64.71419207,  41.68760176],
       [ 93.98217217,  47.90693171,  46.07524046],
       [ 99.65501096,  13.13318771,  86.52182325]])