In [8]:


import numpy as np
from collections import defaultdict, deque
from graphs import graph2 as graph
class ProbabilisticGraphSampler:
    def __init__(self, graph):
        """
        Initialize the sampler with a graph.
        :param graph: 
            keys : random variables
            values : (parents, expression).
                      Example:
                      {
                          "z": ([], lambda: np.random.binomial(1, 0.5)),
                          "y": (["z"], lambda values: np.random.normal(-1.0 if values["z"] == 0 else 1.0, 1.0))
                      }
        """
        self.graph = graph
        self.values = {}
        self.sorted_nodes = self.topological_sort()

    def topological_sort(self):
        """
        Perform a topological sort of the graph.
        :return: A list of nodes in topological order.
        """
        # Build the dependency graph
        in_degree = defaultdict(int)  # Count of incoming edges for each node
        adj_list = defaultdict(list)  # Adjacency list for graph traversal

        for node, (parents, _) in self.graph.items():
            for parent in parents:
                in_degree[node] += 1
                adj_list[parent].append(node)

        # Collect nodes with no incoming edges
        queue = deque([node for node in self.graph if in_degree[node] == 0])
        sorted_nodes = []

        while queue:
            node = queue.popleft()
            sorted_nodes.append(node)

            # Reduce in-degree for child nodes
            for neighbor in adj_list[node]:
                in_degree[neighbor] -= 1
                if in_degree[neighbor] == 0:
                    queue.append(neighbor)

        if len(sorted_nodes) != len(self.graph):
            print(sorted_nodes)
            raise ValueError("Graph contains a cycle!")

        return sorted_nodes

    def sample_trace(self):
        """
        Sample a single trace by evaluating all nodes in topological order.
        :return: A dictionary of sampled values for all nodes.
        """
        self.values.clear()  # Clear previous sampled values
        for node in self.sorted_nodes:
            parents, expression = self.graph[node]
            # Gather parent values
            parent_values = {p: self.values[p] for p in parents}
            # Evaluate the current node
            self.values[node] = expression(parent_values)
        return self.values




#print(graph2)
# Create the sampler
sampler = ProbabilisticGraphSampler(graph)

# Sample traces
for _ in range(5):
    trace = sampler.sample_trace()
    print(trace)



{'r': 0.03212986157243387, 'z_1': 1.6169434188685357, 'z_2': -0.43688081957052116, 'z_3': -3.114472407723614, 'z_4': -1.087661879939931, 'z_5': 0.016773834588778153}
{'r': 0.2075049017947947, 'z_1': -0.642786145217531, 'z_2': 1.2967908065394425, 'z_3': 0.41705980571356704, 'z_4': 0.05892492997936558, 'z_5': -0.8167970034538782}
{'r': -0.18477519709753007, 'z_1': 0.944832628558011, 'z_2': -0.10772325725632813, 'z_3': -0.6313882345047575, 'z_4': 0.178671210279216, 'z_5': 0.0039034525474052006}
{'r': 0.44941642041969865, 'z_1': -0.9987338875660379, 'z_2': 0.38860095095794156, 'z_3': -1.1361635806030577, 'z_4': 3.6639238347154137, 'z_5': 0.10796488861953465}
{'r': 0.985679511484372, 'z_1': 0.07302813980309728, 'z_2': 2.2653150524192256, 'z_3': -1.0715420226916788, 'z_4': -1.6210852205791428, 'z_5': 3.808257867143737}
