In [1]:
import numpy as np
import networkx as nx
import random

In [2]:
from search_env import SearchEnv

In [46]:
class RandomGraphSearchEnv(SearchEnv):
    """
    Search environment on a random undirected graph.
    Nodes are integers from 0 to N-1.
    Actions are moving to a neighboring node.
    """
    def __init__(self, N=10, d=3, start_node=0, goal_node=None, seed=None):
        self.N = N
        self.d = d
        self.seed = seed
        self.rng = np.random.default_rng(seed)
        self.graph = self._generate_random_graph(N, d)
        if goal_node is None:
            goal_node = N - 1
        super().__init__(start_node, goal_node)
        self.state = start_node

    def _generate_random_graph(self, N, d):
        """Generate a random undirected graph with N nodes and max degree d. Assign random positive weights to each edge."""
        G = nx.Graph()
        G.add_nodes_from(range(N))
        for node in range(N):
            attempts = 0
            # Try to add edges until degree d or no more possible neighbors
            while G.degree[node] < d:
                potential = list(set(range(N)) - {node} - set(G.neighbors(node)))
                if not potential:
                    # No more possible neighbors, break to avoid infinite loop
                    break
                neighbor = self.rng.choice(potential)
                if G.degree[neighbor] < d:
                    weight = float(self.rng.integers(1, 11))  # Random weight between 1 and 10
                    # print(node,neighbor,weight)
                    G.add_edge(node, neighbor, weight=weight)
                else:
                    # If neighbor is at max degree, remove from potential and try again
                    potential.remove(neighbor)
                    if not potential:
                        break
                attempts += 1
                # Safety: break if too many attempts (should not happen, but just in case)
                if attempts > N * d:
                    break
        return G

    def get_reachable_states(self, state=None):
        """
        Return a list of (neighbor, weight) tuples reachable from the given state.
        """
        if state is None:
            state = self.state
        return [n for n in self.graph.neighbors(state)]

    def step(self, action):
        # Action is the next node to move to
        neighbors = self.get_reachable_states(self.state)
        if action not in neighbors:
            reward = -1.0  # Penalty for invalid move
            done = False
            info = {'invalid': True}
            return self.state, reward, done, info
        reward = self.cost(self.state, action)  # This correctly gets the edge weight  # Reward is the edge weight
        self.state = action
        done = self.state == self.goal_state
        info = {'reachable': self.get_reachable_states(self.state)}
        return self.state, reward, done, info

    def reset(self):
        self.state = self.start_state
        info = {'reachable': self.get_reachable_states(self.state)}
        return self.state, 0, False, info

    def render(self, mode='human'):
        print(f"Current node: {self.state}, Goal node: {self.goal_state}")
        print(f"Neighbors: {self.get_reachable_states(self.state)}")

    def cost(self, from_node, to_node):
        """Return the cost (edge weight) of moving from from_node to to_node."""
        edge_data = self.graph.get_edge_data(from_node, to_node)
        return edge_data['weight'] if edge_data and 'weight' in edge_data else 1.0
