# Domain Sorting 
## Author: Damien 

For this implementation, of domain sorting, we assume that the domain and all sub-domains are hyperrectangles.

In [69]:
import numpy as np
from anytree import NodeMixin, RenderTree

In [147]:
class Domain_Tree:
    """
    Author:         Damien Beecroft
    Domain_Tree:    A super of the Domain class that stores important functions for working with
                    the domain decomposition tree.
    """
    def __init__(self):
        pass

    def find_interior_pts(self): # find which of the parent's points are in the current domain
        parent_pts = self.parent.pts # get the parent's points
        verts = self.vertices
        mask = [((verts[0] <= pt) & (pt <= verts[1])).all() for pt in parent_pts] # find which of the parent's points are in the current domain
        domain_pts = parent_pts[mask]
        return domain_pts
    
class Domain(Domain_Tree,NodeMixin):
    """
    Author:     Damien Beecroft
    Domain:     A class that defines properties of the relevant domains used for tracking where each
                neural network has support.
    """
    def __init__(self,vertices,parent=None,children=None,pts=None,root=False):
        super(Domain,self).__init__()
        self.vertices = vertices # two opposite vertices that define the n-dimensional box
        self.parent = parent # parent domain of the current domain

        if parent: # set level
            self.lvl = parent.lvl + 1
        else:
            self.lvl = 0

        if children: # set children
            self.children = children

        if root: # determine points on the interior of the domain
            self.pts = pts # if this is the root node, just assign points
        else:
            self.pts = self.find_interior_pts() # otherwise, find which points are in the domain of the current domain

In [148]:
dom0 = np.array([0.,1.])
dom10 = np.array([0.,0.6])
dom11 = np.array([0.4,1.])
dom20 = np.array([0.,0.35])
dom21 = np.array([0.25,0.6])

In [149]:
D0 = Domain(dom0,pts=np.array([[0.2],[0.6],[0.9],[0.4],[0.8]]),root=True)
D10 = Domain(dom10,parent=D0)
D11 = Domain(dom10,parent=D0)
D20 = Domain(dom20,parent=D10)
D21 = Domain(dom20,parent=D10)

In [158]:
print(D10.children)
print(D10.vertices)
print(D10.pts)

(<__main__.Domain object at 0x0000024A3E42D910>, <__main__.Domain object at 0x0000024A3E49D450>)
[0.  0.6]
[[0.2]
 [0.6]
 [0.4]]
