In [None]:
# default_exp samplers.datatypes

# Datatypes

> Collection of datatypes used by Sampler classes.

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
from anytree import RenderTree, NodeMixin
from anytree.exporter import JsonExporter
from anytree.importer import JsonImporter, DictImporter

import numpy as np
import qsam.math as math

In [None]:
#export
class CountNode(NodeMixin):
    def __init__(self, name, parent=None, children=None, counts=0, 
                 ckey=None, is_deterministic=False, is_fail=False):
        super(CountNode, self).__init__()
        
        self.name = name
        self.ckey = ckey
        self.counts = counts
        self.is_deterministic = is_deterministic
        self.is_fail = is_fail
        
        self.parent = parent
        if children:
            self.children = children
            
    @property
    def rate(self):
        assert not self.ckey and not self.is_root
        if self.parent.counts < 2: return 0.5
        else: return self.counts / self.parent.counts

    @property
    def var(self, var_fn=math.Wilson_var):
        assert not self.ckey and not self.is_root
        if self.parent.counts < 2: return 1.0
        else: return var_fn(self.rate, self.parent.counts)
    
    def __str__(self):
        if self.is_root or self.ckey:
            return f'{self.name} ({self.counts})'
        else:
            return f'{self.name} ({self.counts}/{self.parent.counts}, {self.var:.2e})'
    
class SampleTree:
    def __init__(self, root=None):
        self.root = root
        
    def save(self, fname):
        exporter = JsonExporter(indent=2, sort_keys=False)
        with open(fname, 'w') as f:
            exporter.write(self.root, f)

    def load(self, fname):
        dictimporter = DictImporter(CountNode)
        importer = JsonImporter(dictimporter)
        with open(fname, 'r') as f:
            self.root = importer.read(f)
        return self
    
    def add(self, name, parent=None, **kwargs):
        if parent == None:
            if self.root == None:
                self.root = CountNode(name, **kwargs)
            return self.root
        else:
            child_match = [n for n in parent.children if n.name==name]
            if child_match: return child_match[0]
            else: return CountNode(name, parent=parent, **kwargs)
        
    def detach(self, node):
        node.parent = None
        
    def rate(self, const):
        p = 0
        for leaf in self.root.leaves:
            if not leaf.is_root and not leaf.parent.is_deterministic and leaf.is_fail: 
                prod = 1
                for n in leaf.path[1:]:
                    if n.ckey: prod *= const[n.ckey[0]][n.ckey[1]]
                    else: prod *= n.rate
                p += prod
        return p
    
    def var(self, const):
        var = 0
        for node in [n for n in self.root.descendants if not n.ckey]:
            if not node.parent.is_deterministic:
                twig = 1.0
                for n in node.path[1:-1]:
                    if n.ckey: twig *= const[n.ckey[0]][n.ckey[1]]
                    else: twig *= n.rate
                                    
                subtree = 0
                for leaf in node.leaves:
                    if not leaf.is_fail:
                        continue
                    prod = 1
                    for n in leaf.iter_path_reverse():
                        if n == node: break
                        elif n.ckey: prod *= const[n.ckey[0]][n.ckey[1]]
                        else: prod *= n.rate
                    subtree += prod
                var += node.var * twig**2 * subtree**2
        return var
    
    def delta(self, const):
        path_sum = 0.0
        for leaf in self.root.leaves:
            if not leaf.is_root:
                prod = 1.0
                for n in leaf.path[1:]:
                    if n.ckey: prod *= const[n.ckey[0]][n.ckey[1]]
                    else: prod *= n.rate
                path_sum += prod
        return 1.0 - path_sum
    
    def __str__(self):
        return '\n'.join([f'{pre}{node}' for pre, _, node in RenderTree(self.root)])