In [None]:
# default_exp datatypes.counttree

# Count Tree

> A tree data structure to count circuit and subset selections during protocol sampling.

This data structure is based on the Python package `anytree` (https://github.com/c0fec0de/anytree)  which has been extended by the classes `Constant`, `Variable` and `CountTree`.

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
#export
import collections
import qsam.math as math
import numpy as np
import json
import pydot
from IPython.display import SVG, display, HTML
import base64
from anytree import NodeMixin, RenderTree, PreOrderIter

In [None]:
#export            
class Node(NodeMixin):
    
    def __init__(self, name, parent=None, children=None):
        self.name = name
        self.parent = parent
        self.id = id(self)
        if children:
            self.children = children

In [None]:
#export
class Variable(Node):
    
    def __init__(self, name, parent=None, counts=0, marked=False, **kwargs):
        super().__init__(name, parent)
        self.counts = counts
        self.marked = marked
       
    @property
    def rate(self):
        if self.parent.counts == 0: return 0
        else: return self.counts / self.parent.counts
    
    def __str__(self):
        return f"{self.name} ({self.counts}/{self.parent.counts})"

In [None]:
#export
class Constant(Node):
    
    def __init__(self, name, parent=None, counts=0, c_hash=None, w_idx=None, is_deterministic=False, **kwargs):
        super().__init__(name, parent)
        self.counts = counts
        self.is_deterministic = is_deterministic
        self.c_hash = c_hash
        self.w_idx = w_idx
    
    @property
    def variance(self):
        """Variance of children (Variable) nodes."""
        if self.is_root or self.counts == 0: return 0
        else: return math.Wilson_var(self.children[0].rate, self.counts)
    
    def __str__(self):
        return f"{self.name} ({self.counts}, {self.variance:.2e})"

In [None]:
#export 
class CountTree:
    
    def __init__(self, constants=None, min_path_weight=1):
        self.constants = constants
        self.root = None
        self.min_path_weight = min_path_weight
        
    def add(self, name, nodetype=None, parent=None, **kwargs):
        if parent is None:
            if self.root is None:
                self.root = Constant(name, **kwargs)
            return self.root
        
        childmatch = [node for node in parent.children if node.name == name]
        if childmatch:
            return childmatch[0]
        else:
            return nodetype(name, parent=parent, **kwargs)
        
    def detach(self, node):
        node.parent = None
        
    def multiplier(self, node):
        if node.is_root: return 1
        elif isinstance(node, Variable): return node.rate
        elif isinstance(node, Constant): return self.constants[node.c_hash][node.w_idx]
        else: raise TypeError(f"Node of type {node.__class__} not recognized.")
        
    @property
    def norm(self):
        acc = 0.0
        for leaf in self.root.leaves:
            if not leaf.is_root:
                prod = 1.0
                for node in leaf.path:
                    prod *= self.multiplier(node)
                acc += prod
        return acc
    
    @property
    def delta2(self):
        return 1.0 - self.norm

    @property
    def delta(self):
        acc = 0
        variables = [node for node in self.root.descendants if isinstance(node, Variable) and not node.is_leaf] + [self.root]
        for node in variables:
            if all([isinstance(child, Variable) for child in node.children]): continue
            if node.is_root: twig = 1
            else: twig = self.twig(node)
            child_delta = 1
            for child in node.children:
                child_delta -= self.constants[child.c_hash][child.w_idx]
            acc += twig * child_delta
        return acc
    
    @property
    def rate(self):
        acc = 0.0
        for leaf in self.root.leaves:
            if leaf.marked:
                prod = 1.0
                for node in leaf.path:
                    prod *= self.multiplier(node)
                acc += prod
        return acc
    
    def twig(self, node):
        twig = 1.0
        for node in node.path:
            twig *= self.multiplier(node)
        return twig
    
    def subtree(self, node, all_leaves=False):
        path_sum = 0.0
        for leaf in node.leaves:
            # if self.is_invariant_path(leaf.parent): 
            #     continue
            if leaf.marked or all_leaves:
                prod = 1.0
                for n in leaf.iter_path_reverse():
                    if n == node:
                        break
                    else:
                        prod *= self.multiplier(n)
                path_sum += prod
        return path_sum
    
    def subtreediff(self, node, all_leaves):
        assert( isinstance(node, Constant) )
        subtrees = []
        for child in node.children:
            subtree = self.subtree(child, all_leaves)
            subtrees.append(subtree)
        if len(subtrees) == 1: return subtrees[0]
        elif len(subtrees) == 2: return subtrees[0] - subtrees[1]
        else: raise Exception(f"Node {node.name} must have <= 2 children.")
            
    def partial_variance(self, node, all_leaves=False):
        assert isinstance(node, Constant)
        subtreediff = self.subtreediff(node, all_leaves)
        return node.variance * subtreediff**2
    
    def is_invariant_path(self, node):
        if all([n.is_leaf for n in node.children]):
            path_weight = sum([sum(n.name) for n in node.path[1:] if isinstance(n,Constant)])
            return path_weight < self.min_path_weight
        else:
            return False
        
    @property
    def variance(self):
        acc = 0.0
        constants = [node for node in self.root.descendants if isinstance(node, Constant)]
        for node in constants:
            
            if node.is_deterministic or self.is_invariant_path(node):
                continue

            twig = self.twig(node)

            if len(node.children) == 1 and node.children[0].is_leaf:
                acc += twig**2 * node.variance
            else:
                acc += twig**2 * self.partial_variance(node)
        return acc

    @property
    def norm_variance(self):
        acc = 0.0
        constants = [node for node in self.root.descendants if isinstance(node, Constant)]
        for node in constants:
            
            if node.is_deterministic or self.is_invariant_path(node):
                continue
            elif all([n.is_leaf for n in node.children]): 
                continue
                
            twig = self.twig(node)
            acc += twig**2 * self.partial_variance(node, all_leaves=True)
        return acc
    
    def __str__(self):
        return '\n'.join([f'{pre}{node}' for pre, _, node in RenderTree(self.root)])
    
    def save(self, path):
        data = self.to_dict()
        data["constants"] = {k: v.tolist() if isinstance(v, np.ndarray) else v for k,v in self.constants.items()}
        with open(path, 'w') as fp:
            json.dump(data, fp, default=lambda x: x.item() if isinstance(x, np.generic) else x)
    
    def load(self, path):
        with open(path, 'r') as fp:
            data = json.load(fp)
        self.constants = data.pop("constants")
        self.from_dict(data)
    
    def to_dict(self, node=None):
        if node == None: node = self.root
        data = dict( list(self._iter_node_attr(node)) )
        data["node_cls"] = node.__class__.__name__
        children = list( self.to_dict(child) for child in node.children )
        if children:
            data["children"] = children
        return data
    
    def from_dict(self, data):
        self.root = self.__import(data)
        
    @staticmethod
    def __import(data, parent=None):
        assert isinstance(data, dict)
        assert "parent" not in data
        children = data.pop("children", [])
        node_name = data.pop("name")
        node_cls = globals()[ data.pop("node_cls") ]
        node = node_cls(node_name, parent=parent, **data)
        for child in children:
            CountTree.__import(child, parent=node)
        return node
    
    @staticmethod
    def _iter_node_attr(node):
        for k, v in node.__dict__.items():
            if k in ("_NodeMixin__children", "_NodeMixin__parent"):
                continue
            yield k, v
            
    def draw(self, save_path=None, scale_percent=100):
        
        root = self.root
        
        def gen_node(node):
            if isinstance(node, Variable):
                color = "#ff9999" if node.marked else "white"
                return pydot.Node(hex(node.id), label=str(node.name), style="filled", fillcolor=color, shape="ellipse")
            if isinstance(node, Constant):
                return pydot.Node(hex(node.id), label=str(node.name), style="filled", fillcolor="white", shape="box")
            
        def edgeattrfunc(node, child):
            weight = 0 if root.counts == 0 else 10.0 * child.counts / root.counts 
            weight = min(weight, 5.0) if weight > 5.0 else max(weight, 0.2)
            return {"penwidth": str(weight)}
        
        G = pydot.Dot(graph_type="digraph")
        
        for node in PreOrderIter(self.root):
            nodeA = gen_node(node)
            G.add_node(nodeA)
            for child in node.children:
                nodeB = gen_node(child)
                edge = pydot.Edge(nodeA,nodeB,**edgeattrfunc(node,child))
                G.add_edge(edge)
                
        def svg_to_fixed_width_html_image(svg, width=f"{scale_percent}%"):
            b64 = base64.b64encode(svg).decode("utf=8")
            text = f'<img width="{width}" src="data:image/svg+xml;base64,{b64}" >'
            return HTML(text)

        if save_path is not None:
            G.write_png(save_path)
        else:
            return svg_to_fixed_width_html_image(G.create_svg())

In [None]:
constants = {'h1': 
                 {
                     0: 0.9,
                     1: 0.01,
                     2: 0.001
                 }
            }

tree = CountTree(constants)

a = tree.add('ghz', counts=100)
b1 = tree.add((0,), Constant, a, ckey=('h1',0), is_deterministic=True, counts=50)
b2 = tree.add((1,), Constant, a, ckey=('h1',1), counts=40)
b3 = tree.add((2,), Constant, a, ckey=('h1',2), counts=10)

c1 = tree.add("None", Variable, b1, counts=50)
c2 = tree.add("None", Variable, b2, counts=20)
d1 = tree.add('ghz', Variable, b2, counts=20)
d2 = tree.add("None", Variable, b3, counts=10)

e1 = tree.add((0,), Constant, d1, ckey=('h1',0), is_deterministic=True, counts=10)
e2 = tree.add((1,), Constant, d1, ckey=('h1',1), counts=7)
e3 = tree.add((2,), Constant, d1, ckey=('h1',2), counts=3)

f1 = tree.add("None", Variable, e1, counts=10)
f2 = tree.add("None", Variable, e2, counts=5)
f2a = tree.add("FAIL", Variable, e2, marked=True, counts=2)
f3 = tree.add("None", Variable, e3, counts=3)

print(tree)
# tree.draw()

ghz (100, 0.00e+00)
├── (0,) (50, 1.27e-03)
│   └── None (50/50)
├── (1,) (40, 2.19e-02)
│   ├── None (20/40)
│   └── ghz (20/40)
│       ├── (0,) (10, 1.93e-02)
│       │   └── None (10/10)
│       ├── (1,) (7, 7.81e-02)
│       │   ├── None (5/7)
│       │   └── FAIL (2/7)
│       └── (2,) (3, 7.88e-02)
│           └── None (3/3)
└── (2,) (10, 1.93e-02)
    └── None (10/10)


In [None]:
tree.save('test.json')
t2 = CountTree()
t2.load('test.json')

In [None]:
print(t2)
print(t2.constants)

ghz (100, 0.00e+00)
├── [0] (50, 1.27e-03)
│   └── None (50/50)
├── [1] (40, 2.19e-02)
│   ├── None (20/40)
│   └── ghz (20/40)
│       ├── [0] (10, 1.93e-02)
│       │   └── None (10/10)
│       ├── [1] (7, 7.81e-02)
│       │   ├── None (5/7)
│       │   └── FAIL (2/7)
│       └── [2] (3, 7.88e-02)
│           └── None (3/3)
└── [2] (10, 1.93e-02)
    └── None (10/10)
{'h1': {'0': 0.9, '1': 0.01, '2': 0.001}}


In [None]:
print(tree.norm, tree.delta)

KeyError: None

In [None]:
print(tree.rate)

In [None]:
print(tree.variance)

In [None]:
print(tree.norm_variance)

In [None]:
tree = CountTree()

a = tree.add('ghz')
b1 = tree.add('A0', Constant, a, is_deterministic=True)
b2 = tree.add('A1', Constant, a)
b3 = tree.add('A2', Constant, a)

c1 = tree.add("1-p1", Variable, b1)
c2 = tree.add("1-p2", Variable, b2)
d1 = tree.add('p2', Variable, b2)
d2 = tree.add("1-p3", Variable, b3)

g1 = tree.add('B0', Constant, c2, is_deterministic=True)
g2 = tree.add('B1', Constant, c2)

h1 = tree.add("1-p4", Variable, g1)
h2 = tree.add("p5", Variable, g2)

e1 = tree.add('C0', Constant, d1, is_deterministic=True)
e2 = tree.add('C1', Constant, d1)
e3 = tree.add('C2', Constant, d1)

f1 = tree.add("1-p6", Variable, e1)
f2 = tree.add("1-p7", Variable, e2)
f2a = tree.add("p7", Variable, e2, marked=True)
f3 = tree.add("p8", Variable, e3)

k1 = tree.add('F1', Constant, f3)
k2 = tree.add('p13', Variable, k1, marked=True)

i1 = tree.add('E0', Constant, f2a, is_deterministic=True)
i2 = tree.add('E1', Constant, f2a)

j1 = tree.add('D0', Constant, h2, is_deterministic=True)
j2 = tree.add('D1', Constant, h2)

l1 = tree.add('1-p9', Variable, j1)
l2 = tree.add('1-p10', Variable, j2)
l3 = tree.add('p10', Variable, j2, marked=True)

m1 = tree.add('1-p11', Variable, i1)
m2 = tree.add('1-p12', Variable, i2)
m3 = tree.add('p12', Variable, i2, marked=True)

print(tree)

In [None]:
tree.norm

In [None]:
tree.rate

In [None]:
tree.variance

In [None]:
tree.norm_variance