In [None]:
!pip install requirements

In [None]:
import random
import sys

from loguru import logger

In [None]:
logger.remove()
logger.add(
    sys.stderr,
    enqueue=True,
    backtrace=True,
    diagnose=True,
    colorize=True,
    level="INFO",
)

In [None]:
from enum import IntEnum


class SchemaType(IntEnum):
    partnership = 1
    feature = 2
    value_prop = 3
    customer_seg = 4
    channel = 5
    market = 6
    # adjust the schema type as needed

    # get last element
    @classmethod
    def last(cls):
        return list(cls)[-1]
    
    # get first element
    @classmethod
    def first(cls):
        return list(cls)[0]

In [None]:
from threading import Lock


class Node:
    def __init__(self, weight: float, _type: SchemaType, name: str) -> None:
        self.weight = weight
        self._value: float = 0
        self.__value_lock = Lock()
        self.value_lock_ext = Lock()
        self.type: SchemaType = _type
        self.name: str = name

    @property
    def value(self) -> float:
        with self.__value_lock:
            return self._value

    @value.setter
    def value(self, value: float) -> None:
        with self.__value_lock:
            self._value = value

    def __repr__(self) -> str:
        # include the weight, value, and type of the node
        # and list the names of the nodes it is connected to
        return (
            "Node("
            f"{self.name=}, {self.weight=:.4f}, {self.value=:.4f}, {self.type=}"
            ")"
        )

    def __str__(self) -> str:
        return (
            "Node("
            f"{self.name!r}, {self.weight:.4f}, {self.value:.4f}, {self.type}"
            ")"
        )

    def __eq__(self, other) -> bool:
        # compare the edges and weight of the nodes
        if not isinstance(other, Node):
            return NotImplemented("Cannot compare Node with non-Node type")

        return (
            self.weight == other.weight
            and self.type == other.type
            and self.name == other.name
            and self.value == other.value
        )

    def __hash__(self) -> int:
        return hash((self.weight, self.type, self.name))


In [None]:
class Edge:
    def __init__(self, node_a, node_b, weight: float) -> None:
        """Initializes an edge between two nodes.
            Directed edge from node_a to node_b with a weight.

        Args:
            node_a (Node): Starting node
            node_b (Node): Ending node
            weight (float): Weight of the edge
        """
        self.node_a = node_a
        self.node_b = node_b
        self.weight = weight

    def __repr__(self) -> str:
        return f"Edge({self.node_a.name} -> {self.node_b.name}, {self.weight=:.4f})"

    def __str__(self) -> str:
        return f"Edge({self.node_a.name} -> {self.node_b.name}, {self.weight})"

    def __eq__(self, other) -> bool:
        # NOTE: This is a directed edge, so the order of the nodes matters
        if not isinstance(other, Edge):
            return NotImplemented("Cannot compare Edge with non-Edge type")

        return (
            self.node_a == other.node_a
            and self.node_b == other.node_b
            and self.weight == other.weight
        )

    def __hash__(self) -> int:
        return hash((self.node_a, self.node_b, self.weight))


In [None]:
class Graph:
    def __init__(self) -> None:
        self.nodes: list[Node] = []
        self.edges: dict[Node, list[Edge]] = dict()

    def add_node(self, node: Node) -> None:
        self.nodes.append(node)

    def add_edge(self, node_a: Node, node_b: Node, weight: float) -> None:
        if self.edges.get(node_a):
            self.edges[node_a].append(Edge(node_a, node_b, weight))
        else:
            self.edges[node_a] = [Edge(node_a, node_b, weight)]
        if self.edges.get(node_b):
            self.edges[node_b].append(Edge(node_b, node_a, weight))
        else:
            self.edges[node_b] = [Edge(node_b, node_a, weight)]

    def remove_node(self, node: Node) -> None:
        self.nodes.remove(node)

        for edge in self.edges.pop(node):
            self.remove_edge(edge)

    def remove_edge(self, edge: Edge) -> None:
        try:
            self.edges[edge.node_a].remove(edge)
        except ValueError:
            pass
        try:
            self.edges[edge.node_b].remove(edge)
        except ValueError:
            pass

    def get_adjacent_edges(self, node: Node) -> list[Edge]:
        return self.edges[node]

    def __str__(self) -> str:
        # include the weight, value, and type of the node and list the names of the nodes it is connected to
        return (
            "Graph(\n"
            + ",\n".join(
                [
                    f"{node!s} -> {', '.join([edge.node_b.name for edge in self.edges[node]])}"
                    for node in self.nodes
                ]
            )
            + "\n)"
        )

    def __repr__(self) -> str:
        # include the weight, value, and type of the node and list the names of the nodes it is connected to
        return (
            "Graph(\n"
            + ",\n".join(
                [
                    f"{node!r} -> {', '.join([edge.node_b.name for edge in self.edges[node]])}"
                    for node in self.nodes
                ]
            )
            + "\n)"
        )

In [None]:
from concurrent.futures import ThreadPoolExecutor
from queue import PriorityQueue
from loguru import logger


class GraphOps:
    def __init__(
        self,
        graph: Graph,
        spread_schema: SchemaType,
        decay_factor: float = 0.8,
        threshold: float = 0.6,
        max_steps: int = 10,
        self_activation: bool = False,
        orderly_spreading: bool = False,
        directional_spreading: bool = True,
        stop_at_leafs: bool = True,
        num_processes: int = 4,
    ):
        self.__graph: Graph = graph
        self.__spread_schema: SchemaType = spread_schema
        # in case of a fully connected graph
        self.__activation_queue = PriorityQueue(maxsize=len(self.__graph.nodes) ** 2)
        self.decay_factor = decay_factor
        self.__threshold = threshold
        self.__max_steps = max_steps
        self.__self_activation = self_activation
        self.__orderly_spreading = orderly_spreading
        self.__directional_spreading = directional_spreading
        self.__stop_at_leafs = stop_at_leafs
        self.__n_proc = num_processes

    def _activate_node(
        self,
        node: Node,
        parent_node: Node,
        pre_value: float,
    ):
        with node.value_lock_ext:
            node.value = parent_node.value * pre_value
            if node.value > self.__threshold:
                node.value = min(node.value, 1.0)
            else:
                logger.trace(
                    f"Node {node.name} calculated value {node.value:.4f} "
                    "is below the threshold"
                )
                node.value = 0

    def _is_traversable(
        self, node_a: Node, node_b: Node, priority_value: float
    ) -> bool:
        """Check if the node a can be activated from node b.

        Args:
            node_a (Node): Source node
            node_b (Node): Destination node
        """
        # parameter-based checks
        if not self.__self_activation and node_a.type == node_b.type:
            logger.trace(
                f"Node {node_a.name} and {node_b.name} are of the same type. Skipping."
            )
            return False
        if self.__stop_at_leafs and node_b.type in (
            self.__spread_schema.last(),
            self.__spread_schema.first(),
        ):
            logger.trace(f"Node {node_b.name} is a leaf node. Skipping.")
            return False
        if self.__directional_spreading and node_a.type + 1 != node_b.type:
            logger.trace(
                f"Node {node_a.name} and {node_b.name} are not adjacent in the schema. Skipping."
            )
            return False
        if self.__orderly_spreading and node_a.type <= node_b.type:
            logger.trace(
                f"Node {node_a.name} is not of a lower type than {node_b.name}. Skipping."
            )
            return

        # regular checks
        if node_b.value != 0:
            logger.trace(f"Node {node_b.name} is already activated. Skipping.")
            return False
        if priority_value < self.__threshold:
            logger.trace(
                f"Priority value {priority_value:.4f} is below the threshold. Skipping."
            )
            return False

        # TODO: add more checks

        return True

    def activate_priority_queue(self):
        # pop the traversal pair with the highest value from the priority queue
        # activate the node (i.e., calculate the value of the node:
        # (`decay_factor` ^ `step_num`) * `edge_weight` * `node_a.value`
        # or "priority value" *  `node_a.value`
        # )
        with ThreadPoolExecutor(self.__n_proc) as pool:
            results = []
            while not self.__activation_queue.empty():
                step_num, priority_value, node_a, node_b = self.__activation_queue.get()
                priority_value *= -1
                logger.debug(
                    f"Activating {node_b.name} from {node_a.name} on step "
                    f"{step_num} with priority value {priority_value:.4f}"
                )

                results.append(
                    pool.submit(
                        self._activate_node,
                        node=node_b,
                        parent_node=node_a,
                        pre_value=priority_value,
                    )
                )
            for result in results:
                try:
                    result.result()
                except Exception:
                    logger.exception("Error in activating the node")

    def _traverse(
        self,
        node: Node,
        steps_left: int,
    ):
        """Traverse the graph starting from the given node.

        Args:
            node (Node): Starting node
            steps_left (int): Number of steps left to traverse
        """
        if steps_left == 0:
            logger.warning(f"Reached max steps at the {node.name}")
            return

        cur_step = self.__max_steps - steps_left

        for edge in self.__graph.get_adjacent_edges(node):

            priority_value = (self.decay_factor ** (cur_step)) * edge.weight

            if not self._is_traversable(node, edge.node_b, priority_value):
                continue

            # populate the queue with the edge weight and the node
            logger.trace(
                f"Pushing {node.name} -> {edge.node_b.name} to the queue: {priority_value}"
            )
            self.__activation_queue.put(
                (
                    # queue is sorted by cur_step, then by priority_value
                    # so we want the lowest step and the highest priority value to be at the top
                    cur_step,
                    priority_value * -1,
                    node,
                    edge.node_b,
                )
            )

            self._traverse(edge.node_b, steps_left - 1)

    def create_priority_queue(self):
        # find all nodes that have non-zero value, sort them by value

        with ThreadPoolExecutor(self.__n_proc) as pool:
            results = []
            for node in self.__graph.nodes:
                if node.value != 0:
                    logger.trace(f"Starting traversal from node {node.name}")
                    results.append(
                        pool.submit(
                            self._traverse,
                            node=node,
                            steps_left=self.__max_steps - 1,
                        )
                    )
            for result in results:
                try:
                    result.result()
                except Exception:
                    logger.exception("Error in traversing the graph")


In [None]:
def generate_random_graph(size: int) -> Graph:
    g = Graph()
    for i in range(size):
        n = Node(
            random.randint(1, 10) / 10,
            SchemaType.first() + (i % len(SchemaType)),
            f"n{i}",
        )
        if random.random() > 0.5:
            n.value = 1.0
        g.add_node(n)
        if i > 1:
            g.add_edge(n, g.nodes[-2], weight=random.random())
            g.add_edge(random.choice(g.nodes[:-2]), n, weight=random.random())

    return g


In [None]:
g = generate_random_graph(30)

logger.info(str(g))

graph_ops = GraphOps(
    g,
    spread_schema=SchemaType,
    decay_factor=0.9,
    threshold=0.3,
    max_steps=10,
)

logger.info("Starting the graph operations")
graph_ops.create_priority_queue()
logger.info("Finished traversing the graph")
graph_ops.activate_priority_queue()
logger.info("Finished the graph operations")

logger.info(str(g))
