In [81]:
# import networkx as nx
import rustworkx as rx

from dataclasses import dataclass, field
from typing import Iterable, Optional

from functools import reduce
from itertools import chain

from enum import Enum

class Type(Enum):
    COMMUNE = 0
    EPCI = 1
    DEP = 2
    REGION = 4
    PAYS = 5

@dataclass(frozen=True)
class Entity:
    name: str
    atomic: bool = True
    type: Type = Type.COMMUNE
    geo_bound: None = None
    es_code: Optional[str] = None
    tree_id: Optional[int] = field(default=None, compare=False)

    def __repr__(self) -> str:
        return self.name

    def contains(self, other, tree: rx.PyDiGraph) -> bool:
        return (self == other) or (self in rx.ancestors(tree, other))
    
        
    def __and__(self, other):
        if other is None:
            return None
        if self in other:
            return self
        if self.is_disjoint(other):
            return None
        return reduce(lambda x, y: x | y, [other & child for child in self.entities])



lyon = Entity("Lyon")
marseille = Entity("Marseille")
paris = Entity("Paris")
nogent = Entity("Nogent")
pantin = Entity("Pantin")
villeurbane = Entity("Villeurbane")
sté = Entity("Saint Etienne")

metropole = Entity("Grand Lyon", False, Type.EPCI)

sud = Entity("Sud", False, Type.REGION)
idf = Entity("Île-de-France", False, Type.REGION)
rhone = Entity("Rhône", False, Type.DEP)

france = Entity("France", False, Type.PAYS)



def build_tree() -> rx.PyDiGraph:
    print("BUILDING TREE : this is a very long operation")
    entities = (france, sud, idf, rhone, metropole, nogent, pantin, paris, marseille, sté, villeurbane, lyon)

    tree= rx.PyDiGraph()
    entities_indices = tree.add_nodes_from(entities)
    mapper = {o : idx for o, idx in zip(entities, entities_indices)}
    edges = [
        (france, idf),
        (france, sud),
        
        (idf, nogent),
        (idf, pantin),
        (idf, paris),

        (sud, marseille),
        (sud, rhone),

        (rhone, metropole),
        (rhone, sté),

        (metropole, villeurbane),
        (metropole, lyon),
        ]

    tree.add_edges_from([
        (mapper[parent], mapper[child], None) for parent, child in edges
    ])

    return tree

tree = build_tree()


BUILDING TREE : this is a very long operation


In [82]:
t = tree.get_node_data(0)
print(t.tree_id)
t.tree_id = 0
print(tree.get_node_data(0))
assert False

None


FrozenInstanceError: cannot assign to field 'tree_id'

In [71]:
list(tree.successor_indices(0))

[]

In [70]:
tree.predecessors(3)

[Sud]

In [67]:
from __future__ import annotations
from typing import Callable

from perfect_hash import generate_hash, Format


class Territory:
    tree: Optional[rx.DiGraph] = None
    perfect_hash_fct: Optional[Callable[[str], int]] = None

    @classmethod
    def assign_tree(cls, tree):
        cls.tree = tree
        # create perfect hash table
        elements: list[Entity] = [tree.get_node_data(i) for i in tree.node_indices()]
        for i, e in enumerate(elements):
            e.tree_id  = i

        names = [e.name for e in elements]

        f1, f2, G = generate_hash(names)

        fmt = Format()
        NG = len(G)
        NS = len(f1.salt)
        S1 = fmt(f1.salt)
        S2 = fmt(f2.salt)

        def hash_f(key, T):
            return sum(ord(T[i % NS]) * ord(c) for i, c in enumerate(key)) % NG

        def perfect_hash(key):
            return (G[hash_f(key, S1)] + G[hash_f(key, S2)]) % NG

        cls.perfect_hash_fct = perfect_hash        

        for name in names:
            i = cls.perfect_hash_fct(name)
            assert name == tree.get_node_data(i).name


    @classmethod
    def hash(cls, name: str) -> int:
        return cls.perfect_hash_fct(name)


    @staticmethod
    def contains(a: int, b: int, tree: rx.PyDiGraph) -> bool:
        return (a == b) or (a in rx.ancestors(tree, b))


    @classmethod
    def minimize(cls, node: int, items: Iterable[int]) -> set[int]:
        """evaluate complexity of this method

        Args:
            node (Entity): _description_
            items (Iterable[Entity]): _description_

        Returns:
            set[Entity]: _description_
        """
        if len(items) == 0:
            return set()
        if node in items:
            return {node}
        children = set(cls.tree.successor_indices(node))
        if children == set(items):
            return {node}

        gen = (cls.minimize(child, tuple(item for item in items if cls.contains(child, item, cls.tree))) for child in children)
        # print(type(iter(gen)))
        union =  set.union(*gen)

        if union == children:
            return {node}
        return union
    

    @classmethod
    def union(csl, *others):
        return reduce(lambda x, y: x + y, iter(others))


    @classmethod
    def intersection(csl, *others):
        return reduce(lambda x, y: x & y, iter(others))


    @classmethod
    def _sub(cls, a: Entity, b: Entity) -> set[Entity]:
        if a == b:
            return set()
        if a in rx.ancestors(cls.tree, b):
            return set.union(*(cls._sub(child, b) for child in cls.tree.successors(a)))
        return {a}
    

    @classmethod
    def _and(cls, a: Entity, b: Entity) -> set[Entity]:
        if a == b:
            return {a}
        # if a in b
        if a in rx.ancestors(cls.tree, b):
            return {b}
        # if b in a
        if b in rx.ancestors(cls.tree, a):
            return {a}
        return set()


    def __init__(self, *args: Iterable[Entity]) -> None:
        if self.tree is None:
            raise Exception('Tree is not initialized')
        entities = set(args)
        if entities:
            # root = next(tree.get_node_data(i) for i in tree.node_indices() if tree.in_degree(i) == 0)
            root_index = next(i for i in tree.node_indices() if tree.in_degree(i) == 0)
            entities_idxs = [self.hash(e.name) for e in entities]
            #  guarantee the Territory is always represented in minimal form
            self.entities = {self.tree.get_node_data(i) for i in self.minimize(root_index, entities_idxs)}
            # print(self.entities)
        else:
            self.entities = set()


    def __eq__(self, value: Territory) -> bool:
        return self.entities == value.entities


    def __add__(self, other: Territory) -> Territory:
        return Territory(
            *(self.entities | other.entities)
        )
    

    def is_contained(self, other: Territory | Entity) -> bool:
        for entity in self.entities:
            parents = rx.ancestors(self.tree, entity) | {entity}
            if isinstance(other, Entity):
                if other not in parents:
                    return False
            else:
                if not any(other_entity in parents for other_entity in other.entities):
                    return False
        return True
    

    def __contains__(self, other: Territory | Entity) -> bool:
        if isinstance(other, Entity):
            other = Territory(other)
        return other.is_contained(self)
    

    def is_disjoint(self, other: Territory) -> bool:
        pass


    def __or__(self, other: Territory | Entity) -> Territory:
        if not self.entities:
            entities = tuple()
        else:
            entities = self.entities
        if isinstance(other, Entity):
            return Territory(*chain(entities, [other]))
        if other.entities is not None:
            return Territory(*chain(entities, other.entities))
        return self
    


    def __and__(self, other: Territory | Entity) -> Territory:
        if isinstance(other, Entity):
            return  Territory(*chain(*(self._and(child, other) for child in self.entities)))
        if (not other.entities) or (not self.entities):
            return Entity()
        if self in other:
            return self

        return Territory.union(*(self & child for child in other.entities))
     

    def __sub__(self, other: Territory | Entity) -> Territory:
        if isinstance(other, Entity):
            return Territory(*chain(*(self._sub(child, other) for child in self.entities)))
        if (not other.entities) or (not self.entities):
            return self
        if self in other:
            return Territory()

        return Territory.intersection(*(self - child for child in other.entities))


    def __repr__(self) -> str:
        if self.entities:
            return '|'.join(str(e) for e in self.entities)
        return '{}'


    def to_es_filter(self) -> list[str]:
        return [{"match_phrase" : {"tu_zone" : e.es_code}} for e in self.entities]

    
    @classmethod
    def from_name(cls, *args: Iterable[str]):
        pass

Territory.assign_tree(tree=build_tree())

# a = Territory(sté, marseille)

BUILDING TREE : this is a very long operation


In [68]:
# tests
Territory.assign_tree(tree=build_tree())


a = Territory(sté, marseille)
b = Territory(lyon, france)
c = Territory(paris, nogent, pantin, lyon, lyon, metropole)
d = Territory(lyon, villeurbane, marseille)
e = Territory(rhone, idf)
f = Territory(idf, marseille, metropole)

assert b == Territory(france, )
assert d + a == Territory(sud, )
assert a + d == Territory(sud, )
assert c + a == Territory(idf, sud)
assert d + c == Territory(metropole, marseille, idf )

assert a in a
assert b in b
assert c in c
assert d in d
assert a in b
assert a in c + a
assert a not in d

assert a | d == Territory(sud, )
assert c | d == Territory(idf, marseille, metropole)
assert d | c == Territory(idf, marseille, metropole)


assert a & b == a
assert a & d == Territory(marseille, )
assert e & f == Territory(idf, metropole)
assert f & e == Territory(idf, metropole)


assert a - b == Territory()
assert b - a == Territory(metropole, idf)

BUILDING TREE : this is a very long operation


TypeError: argument 'node': 'Entity' object cannot be interpreted as an integer