In [None]:
#| default_exp sampler.tree

This data structure is based on the Python package `anytree` (https://github.com/c0fec0de/anytree)  which has been extended by the classes `SubsetCountNode`, `CircuitCountNode` and `CountTree`.

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export

from anytree import NodeMixin, RenderTree, PreOrderIter
import qsample.math as math
import pydot
import dill as pickle

from fastcore.test import *

In [None]:
#| export

def draw_tree(tree, path=None):
    """Generate and return PNG image of `CountTree` 

    To display the image in command line call .show() on the returned
    PIL image object.

    Parameters
    ----------
    path : str or None
        File path to save png image to, if None only display image

    Returns
    -------
    PNG image
        Image of `CountTree`
    """

    def gen_node(node):
        """Generate pydot.Node object from `SubsetCountNode` of `CircuitCountNode`
        
        Parameters
        ----------
        node : SubsetCountNode or CircuitCountNode
            `CountNode` to generate pydot.Node object from

        Returns
        -------
        pydot.Node
            Graphical representation of node
        """
        label=f"{node.name} : {node.count}"
        if isinstance(node, CircuitCountNode):
            color = "#ff0000" if node in tree.marked_leaves else "black"
            return pydot.Node(hex(id(node)), label=label, style="filled", color=color, fillcolor="white", shape="ellipse")
        if isinstance(node, SubsetCountNode):
            return pydot.Node(hex(id(node)), label=label, style="filled", fillcolor="white", shape="box")

    def edgeattrfunc(node, child):
        """
        Parameters
        ----------
        node : SubsetCountNode or CircuitCountNode
            Parent node of a (node, child) pair
        child : 
            Child node of a (node, child) pair

        Returns
        -------
        dict
            Specification of the thickness of the edge between `node` and `child`
        """
        weight = 0 if tree.root.count == 0 else 10.0 * child.count / tree.root.count 
        weight = min(weight, 5.0) if weight > 5.0 else max(weight, 0.2)
        return {"penwidth": str(weight)}

    G = pydot.Dot(graph_type="digraph")

    for node in PreOrderIter(tree.root):
        nodeA = gen_node(node)
        G.add_node(nodeA)
        for child in node.children:
            nodeB = gen_node(child)
            edge = pydot.Edge(nodeA,nodeB,**edgeattrfunc(node,child))
            G.add_edge(edge)

    if path is not None:
        G.write_png(path)

    from PIL import Image
    from io import BytesIO
    return Image.open(BytesIO(G.create_png()))

In [None]:
#| export   

class CountNode(NodeMixin):
    """Class to represent a tree node of a `CountTree`.
    
    A node in a tree is a uniquely identifiable object containing references
    to at most one parent and possibly many children. The root node has no 
    parent and leaf nodes have not children.
    
    The `CountNode` class complements the common tree node by a `count` attribute
    which represents the number of times a node has been visited during sampling.
    
    Attributes
    ----------
    name : str
        The (not necessarily unique) name of the node
    count : int
        Node visit counter variable
    invariant : bool
        If true, influences calculation of variance
    parent : Node, optional
        Reference to parent node object
    children : list of Node, optional
        List of references to children Node objects
    """
    
    def __init__(self, name, count=0, invariant=False, parent=None, children=None):
        """
        Parameters
        ----------
        name : str
            The (not necessarily unique) name of the node
        count : int
            Initial value of visit counter variable
        parent : Node
            Reference to parent node object
        children : list of Node
            List of references to children Node objects
        """
        self.name = name
        self.parent = parent
        self.count = count
        self.invariant = invariant
        if children:
            self.children = children
        
    def __str__(self):
        return f"{self.name} ({self.count})"
    
class SubsetCountNode(CountNode):
    """Subclass of `CountNode` to represent a subset node
    
    Attributes
    ----------
    circuit_id : str
        Reference to corresponding circuit
    """
    
    def __init__(self, name, circuit_id=None, invariant=False, count=0, parent=None, children=None):
        """
        Parameters
        ----------
        circuit_id: str or None
            Reference to corresponding circuit
        invariant : bool
            If true, measurement outcome deterministic and successive circuit node 
            has no variance.
        name : str
            The (not necessarily unique) name of the node
        count : int
            Initial value of visit counter variable
        parent : Node
            Reference to parent node object
        children : list of Node
            List of references to children Node objects
        """
        super().__init__(name, count, invariant, parent, children)
        self.circuit_id = circuit_id
        
    
class CircuitCountNode(CountNode):
    """Subclass of Node to represent a circuit node.
    
    CircuitCountNode represent differentiable rates in a CountTree.
    Differentiability refers to the partial derivatives in the 
    propagation of uncertainty formula for the lower and upper bound
    of the logical failure rate.
    """
        
    @property
    def rate(self):
        """Method to calculate transition rate from parent node to this node.
        
        We consider `parent.count` coin flips of a Bernoulli random variable X
        for which we would like to determine the rate p that a transition from
        parent node to this node (X=1) took place, i.e.:
        
        .. math:: p(X=1) = self.count / self.parent.count
        
        Returns
        -------
        float
            Value of transition rate in range [0,1]
        """
        if self.is_root: 
            return 1.0
        elif self.parent.count == 0: 
            return 0.0
        else: 
            return self.count / self.parent.count
        
    @property
    def variance(self):
        """Method to calculate the variance of the transition rate between
        the parent node and this node. 
        
        As the variance of the Wald interval results in unrealistic variances
        at small sample sizes, we use the variance of the Wilson interval instead.
        
        Returns
        -------
        float
            Value of variance of transition rate
        """
        if self.is_root:
            return 0.0
        elif self.invariant or self.parent.invariant:
            return 0.0
        else:
            return math.Wilson_var(self.rate, self.parent.count)
    
    def __str__(self):
        return f"{self.name} ({self.count}, {self.variance:.2e})"

In [None]:
#| export            

class CountTree:
    """Class to represent a tree of `CircuitCountNode` and `SubsetCountNode` nodes.
    
    Attributes
    ----------
    constants : list
        List of constant values corresponding to weight subsets of circuits
    root : CircuitCountNode, default: None
        Root node of the tree
    marked_leaves : list
        Leaf nodes marked by user
    """
    
    def __init__(self, constants):
        """
        Parameters
        ----------
        constants : list
            List of constant values corresponding to weight subsets of circuits
        """
        self.constants = constants
        self.root = None
        self.marked_leaves = set()
        
    def add(self, node_type, name, parent=None, **kwargs):
        """Add node of `node_type` and name `name` as child of `parent`.
        
        Parameters
        ----------
        name : str
            Name of node to add
        node_type : CircuitCountNode or SubsetCountNode
            Type of node to add
        parent : CircuitCountNode or SubsetCountNode, default: None
            Parent of node to add, only root node has no parent
            
        Returns
        -------
        CircuitCountNode or SubsetCountNode
            Reference to added node
        """
        if parent is None:
            if self.root is None:
                self.root = CircuitCountNode(name, **kwargs)
            return self.root
        else:
            child_match = [node for node in parent.children if node.name == name]
            if child_match:
                # tree node already exists
                return child_match[0]
            else:
                # create new tree node
                return node_type(name=name, parent=parent, **kwargs)
            
    def remove(self, node):
        children = set(node.parent.children)
        children.remove(node)
        node.parent.children = tuple(children)
        node.parent = None
        if node.is_leaf and node in self.marked_leaves:
            self.marked_leaves.remove(node)
    
    @property
    def root_leaf_rate(self):
        """Sum of rates marked leaf CircuitCountNode node.count / root.count.
        
        Returns
        -------
        float
            Direct MC estimate of logical failure rate
        """
        return sum([node.count / self.root.count for node in self.marked_leaves])
    
    @property
    def root_leaf_variance(self):
        """Variance of root leaf rate
        
        Returns
        -------
        float
            Variance of direct MC estimate of logical failure rate
        """
        p = self.root_leaf_rate
        return math.Wilson_var(p, self.root.count)
    
    def path_weight(self, end_node):
        """Calculate total path weight from `root` to `end_node`.
        
        Parameters
        ----------
        end_node : SubsetCountNode
            End node of a path starting from `root`
        """
        path_weight = sum([sum(n.name) for n in end_node.path if not isinstance(n,CircuitCountNode)])
        return path_weight
    
    
    def __get_node_value(self, node):
        """Lookup the value of a node.
        
        For `CircuitCountNode` return its `rate`, for `SubsetCountNode` return a corresponding
        value from the `constants` dict.
        
        Parameters
        ----------
        node : SubsetCountNode or CircuitCountNode
            Node for which value is returned
            
        Raises
        ------
        TypeError
            If `node` has different type than `CircuitCountNode` or `SubsetCountNode`
        """
        if isinstance(node, CircuitCountNode):
            return node.rate
        elif isinstance(node, SubsetCountNode):
            return self.constants[node.circuit_id][node.name]
        else:
            raise TypeError(f"Unknown node type: {type(node)}")
            
    def __get_path_product(self, start_node, end_node):
        """Calculate product of node values from `start_node` to `end_node`
        
        Parameters
        ----------
        start_node : SubsetCountNode or CircuitCountNode
            Start node of the path
        end_node : SubsetCountNode or CircuitCountNode
            End node of the path
        """
        if start_node == end_node:
            return 1
        else:
            prod = 1
            for node in end_node.iter_path_reverse():
                if node == start_node:
                    # exclude start_node value
                    break
                else:
                    prod *= self.__get_node_value(node)
            return prod
    
    def path_sum(self, start_node, mode):
        """Calculate sum of all paths from `start_node` to its leaves
        
        Parameters
        ----------
        start_node : CircuitCountNode or SubsetCountNode
            Node from which to start calculating the sum of paths
        mode : int
            Paths to consider, 0: Only not marked paths, 1: only marked paths, 2: both
            
        Returns
        -------
        float
            Sum of paths
            
        Raises
        ------
        ValueError
            If mode different than 0,1,2
        """
        acc = 0
        if mode == 0:
            # filter leaves not overlapping with marked leaves
            leaves = set(start_node.leaves).difference(self.marked_leaves)
        elif mode == 1:
            # filter leaves overlap with marked leaves
            leaves = set(start_node.leaves).intersection(self.marked_leaves)
        elif mode == 2:
            # all leaves
            leaves = start_node.leaves
        else:
            raise ValueError(f"Unknown mode {mode}. Known modes: 0,1,2")
            
        for leaf in leaves:
            acc += self.__get_path_product(start_node=start_node, end_node=leaf)
        return acc
    
    def __partial_derivative(self, node, mode):
        """Calculates partial derivative of `node` in `CountTree`.
        
        We calculate the partial derivative as:
        
        .. math:: twig * subtree,
        
        where twig is the path product leading to but excluding `node`'s rate and 
        subtree is the sum of all paths products starting from the children of `node.
        If `node` has a sibling its rate is 1-qi and therefore contributes a negative
        subtree.
        
        Parameters
        ----------
        node : CircuitCountNode
            Node with respect to which partial derivative is calculated
        mode : int
            Paths to consider, 0: Only not marked paths, 1: only marked paths
            
        Raises
        ------
        Exception
            If a `CircuitCountNode` (circuit) node has more than one sibling.
        """
        if node.invariant or node.parent.invariant:
            return 0
        
        twig = self.__get_path_product(start_node=self.root, end_node=node.parent)
        subtree = self.path_sum(start_node=node, mode=mode)
        if len(node.siblings) == 1:
            subtree -= self.path_sum(start_node=node.siblings[0], mode=mode)
        elif len(node.siblings) > 1:
            raise Exception("CircuitCountNode nodes can only have at most 1 sibling.")

        return twig * subtree
        
    def uncertainty_propagated_variance(self, mode):
        """Variance of `CountTree` by propagation of uncertainty.
        
        The formula for Gaussian propagation of uncertainty (ignoring the covariance terms) is:
        
        .. math:: Var(p_L) = \sum_n{\frac{\partial p_L}{\partial q_n}^2 V_n}
        
        We can ignore covariance contributions as the random variables underlying 
        transition rates along a path depend on the occurance of a certain outcome
        of the random variables underlying the transition rates of their parents. 
        Therefore, there is no covariance amongst the rates along a path.
        
        Parameters
        ----------
        mode : int
            Paths to consider, 0: Only not marked paths, 1: only marked paths
            
        Raises
        ------
        ValueError
            If mode different than 0,1
        """
        if mode == 0:
            # all CircuitCountNode nodes, exclude marked leaves
            variables = [n for n in self.root.descendants if isinstance(n, CircuitCountNode) and set(n.leaves).difference(self.marked_leaves)]
        elif mode == 1:
            # all CircuitCountNode nodes, exclude not marked leaves
            variables = [n for n in self.root.descendants if isinstance(n, CircuitCountNode) and set(n.leaves).intersection(self.marked_leaves)]
        else:
            raise ValueError(f"Unknown mode {mode}. Known modes: 0,1")
        
        # Rate dependence constraint: Keep only one sibling, other is accounted for in subtree difference
        parent_list = [v.parent for v in variables]
        unique_ids = [parent_list.index(x) for x in set(parent_list)]
        variables = [variables[i] for i in unique_ids]
        
        acc = 0
        for node in variables:
            acc += (self.__partial_derivative(node, mode=mode))**2 * node.variance
        return acc

    def __str__(self):
        return '\n'.join([f'{pre}{node}' for pre, _, node in RenderTree(self.root)])
    
    def draw(self, path=None):
        return draw_tree(self, path)

Test creating a 1-level tree.

In [None]:
# Test 1-level tree

constants = {0: {(0,): 0.8, (1,): 0.1, (2,): 0.05}}
tree = CountTree(constants)
root = tree.add(name='root', node_type=CircuitCountNode, count=100)
root0 = tree.add(name=(0,), circuit_id=0, node_type=SubsetCountNode, parent=root, count=40)
root1 = tree.add(name=(1,), circuit_id=0, node_type=SubsetCountNode, parent=root, count=30)
root2 = tree.add(name=(2,), circuit_id=0, node_type=SubsetCountNode, parent=root, count=30)
none_0 = tree.add(name='None', node_type=CircuitCountNode, parent=root0, count=40)
none_0.invariant = True
fail_1 = tree.add(name='fail', node_type=CircuitCountNode, parent=root1, count=20)
none_1 = tree.add(name='None', node_type=CircuitCountNode, parent=root1, count=10)
fail_2 = tree.add(name='fail', node_type=CircuitCountNode, parent=root2, count=5)
none_2 = tree.add(name='None', node_type=CircuitCountNode, parent=root2, count=25)

tree.marked_leaves = set([fail_1, fail_2])
print(tree)

root (100, 0.00e+00)
├── (0,) (40)
│   └── None (40, 0.00e+00)
├── (1,) (30)
│   ├── fail (20, 7.20e-03)
│   └── None (10, 7.20e-03)
└── (2,) (30)
    ├── fail (5, 4.60e-03)
    └── None (25, 4.60e-03)


Test numerics of 1-level tree.

In [None]:
# Test numerics 1-level tree

test_eq(tree.root_leaf_rate, 20/100 + 5/100)
test_eq(tree.path_sum(tree.root, mode=1), 0.1 * 20/30 + 0.05 * 5/30)
test_close(tree.path_sum(tree.root, mode=2), 0.8 + 0.1 + 0.05, eps=1e-05)
test_close(tree.uncertainty_propagated_variance(mode=1), 0.1**2 * fail_1.variance + 0.05**2 * fail_2.variance, eps=1e-05)
test_close(tree.uncertainty_propagated_variance(mode=0), 0.1**2 * none_1.variance + 0.05**2 * none_2.variance, eps=1e-05)

Test creation of 2-level tree.

In [None]:
# Test 2-level tree

constants = {0: {(0,): 0.8, (1,): 0.1}, 
             1: {(0,): 0.7, (1,): 0.2},
             2: {(0,): 1.0}}

tree = CountTree(constants)
root = tree.add(name='c0', node_type=CircuitCountNode, count=100)
root0 = tree.add(name=(0,), circuit_id=0, node_type=SubsetCountNode, parent=root, count=70)
root1 = tree.add(name=(1,), circuit_id=0, node_type=SubsetCountNode, parent=root, count=30)
c1_0 = tree.add(name='c1', node_type=CircuitCountNode, count=50, parent=root0)
c1_0.invariant = True
c1_1 = tree.add(name='c1', node_type=CircuitCountNode, count=20, parent=root1)
c2_0 = tree.add(name='c2', node_type=CircuitCountNode, count=20, parent=root0)
c2_0.invariant = True
c2_1 = tree.add(name='c2', node_type=CircuitCountNode, count=10, parent=root1)

c1_0_0 = tree.add(name=(0,), circuit_id=1, node_type=SubsetCountNode, parent=c1_0, count=40)
c1_0_1 = tree.add(name=(1,), circuit_id=1, node_type=SubsetCountNode, parent=c1_0, count=10)
c1_1_0 = tree.add(name=(0,), circuit_id=1, node_type=SubsetCountNode, parent=c1_1, count=15)
c1_1_1 = tree.add(name=(1,), circuit_id=1, node_type=SubsetCountNode, parent=c1_1, count=5)

c2_0_0 = tree.add(name=(0,), circuit_id=2, node_type=SubsetCountNode, parent=c2_0, count=20)
c2_1_0 = tree.add(name=(0,), circuit_id=2, node_type=SubsetCountNode, parent=c2_1, count=10)

none_c1_0_0 = tree.add(name='None', node_type=CircuitCountNode, parent=c1_0_0, count=40)
none_c1_0_0.invariant = True
none_c1_0_1 = tree.add(name='None', node_type=CircuitCountNode, parent=c1_0_1, count=5)
fail_c1_0_1 = tree.add(name='fail', node_type=CircuitCountNode, parent=c1_0_1, count=5)
none_c1_1_0 = tree.add(name='None', node_type=CircuitCountNode, parent=c1_1_0, count=15)
none_c1_1_1 = tree.add(name='None', node_type=CircuitCountNode, parent=c1_1_1, count=3)
fail_c1_1_1 = tree.add(name='fail', node_type=CircuitCountNode, parent=c1_1_1, count=2)

none_c2_0_0 = tree.add(name='None', node_type=CircuitCountNode, parent=c2_0_0, count=20)
none_c2_0_0.invariant = True
fail_c2_1_0 = tree.add(name='fail', node_type=CircuitCountNode, parent=c2_1_0, count=10)

tree.marked_leaves = set([fail_c1_0_1, fail_c1_1_1, fail_c2_1_0])
print(tree)

c0 (100, 0.00e+00)
├── (0,) (70)
│   ├── c1 (50, 0.00e+00)
│   │   ├── (0,) (40)
│   │   │   └── None (40, 0.00e+00)
│   │   └── (1,) (10)
│   │       ├── None (5, 2.27e-02)
│   │       └── fail (5, 2.27e-02)
│   └── c2 (20, 0.00e+00)
│       └── (0,) (20)
│           └── None (20, 0.00e+00)
└── (1,) (30)
    ├── c1 (20, 7.20e-03)
    │   ├── (0,) (15)
    │   │   └── None (15, 9.77e-04)
    │   └── (1,) (5)
    │       ├── None (3, 4.03e-02)
    │       └── fail (2, 4.03e-02)
    └── c2 (10, 7.20e-03)
        └── (0,) (10)
            └── fail (10, 2.07e-03)


Test numerics of 2-level tree.

In [None]:
# Test numerics 2-level tree

test_eq(tree.root_leaf_rate, 5/100 + 2/100 + 10/100)
test_eq(tree.path_sum(tree.root, mode=1), 0.8 * 50/70 * 0.2 * 5/10 + 0.1 * (20/30 * 0.2 * 2/5 + 10/30 * 1.0 * 1))
test_close(tree.path_sum(tree.root, mode=2), 0.8 * (50/70 * 0.9 + 20/70) + 0.1 * (20/30 * 0.9 + 10/30), eps=1e-05)
var1 = (0.8 * 0.2 * 5/10)**2 * c1_0.variance + (0.8 * 50/70 * 0.2)**2 * fail_c1_0_1.variance + (0.1 * 0.2 * 2/5 - 0.1 * 1)**2 * c1_1.variance + (0.1 * 0.2 * 20/30)**2 * fail_c1_1_1.variance
test_close(tree.uncertainty_propagated_variance(mode=1), var1, eps=1e-05)
var2 = (0.8 * (0.7 * 1 + 0.2 * 5/10 - 1))**2 * c1_0.variance + (0.8 * 50/70 * 0.2)**2 * fail_c1_0_1.variance + (0.1 * (0.7 + 0.2*3/5))**2 * c1_1.variance + (0.1 * 20/30 * 0.2)**2 * fail_c1_1_1.variance
test_close(tree.uncertainty_propagated_variance(mode=0), var2, eps=1e-05)