In [1]:
import logging
import numpy as np
import random
from math import inf
from itertools import chain
from typing import Callable
from gx_utils import *

logging.basicConfig(format="%(message)s", level=logging.INFO)

Graph search for the Set Covering problem

In [2]:
class State:
    def __init__(self, data: list):
        self._list = sorted(data.copy())
        self.set_covered = set(chain(*self._list))

    def __hash__(self):
        # logging.debug('__hash__')
        return hash(bytes(chain(*self._list)))

    def __eq__(self, other):
        # logging.debug('__eq__')
        return bytes(self.set_covered) == bytes(other.set_covered)

    def __contains__(self, other):
        # logging.debug('__contains__')
        return set(other) in self.set_covered

    def __le__(self, other):
        return self.set_covered <= other._set

    def __lt__(self, other):
        return self.set_covered < other._set

    def __str__(self):
        return str(chain(*self._list))

    def __repr__(self):
        return repr(self._list)

    def covers(self, other: list):
        return set(other) <= self.set_covered

    # def is_alias(self, l, : list):
    #     return 

    @property
    def data(self):
        return self._list

    def copy_data(self):
        return self._list.copy()

In [3]:
def goal_test(state):
    return state.set_covered == goal

In [4]:
# def is_valid(list_of_lists: list, l: list):
#     return (not set(l) <= set(chain(*list_of_lists))) # If this was just a "<", duplicates are allowed

def possible_actions(state: State):
    return (l for l in all_lists if not state.covers(l))

In [5]:
def result(state, action):
    current_list = state.copy_data()
    current_list.append(action)
    return State(current_list)

In [6]:
def problem(N, seed=None):
    random.seed(seed)
    return [
        list(set(random.randint(0, N - 1) for n in range(random.randint(N // 5, N // 2))))
        for n in range(random.randint(N, N * 5))
    ]

In [7]:
def search_min(
    initial_state: State,
    goal_test: Callable,
    parent_state: dict,
    state_cost: dict,
    priority_function: Callable,
    unit_cost: Callable,
):
    frontier = PriorityQueue()
    parent_state.clear()
    state_cost.clear()

    state = initial_state
    parent_state[state] = None
    state_cost[state] = 0

    min_cost = inf
    min_state = None

    i = 0
    n_frontier = 0

    # covered = set()

    while state is not None:
        logging.debug(f'i= {i}')
        logging.debug(f'curr. state -> {state.data}')

        if goal_test(state):
            logging.debug(f'found a solution: {state.data}')
            # solutions.append(state)
            if state_cost[state] < min_cost:
                logging.debug(f'updating min cost -> {state_cost[state]}')
                min_cost = state_cost[state]
                min_state = state
        else:
            for a in possible_actions(state):
                new_state = result(state, a)
                cost = unit_cost(a)
                # logging.debug(f'entered actions -> {new_state.data}')
                # logging.debug(f'new action covered set: {new_state.set_covered}')
                if new_state not in state_cost and new_state not in frontier:
                    parent_state[new_state] = state
                    state_cost[new_state] = state_cost[state] + cost
                    # covered |= new_state.set_covered
                    # logging.debug(f'global covered set: {covered}')
                    frontier.push(new_state, p=priority_function(new_state))
                    n_frontier += 1
                    logging.debug(f"Added new node ({n_frontier}) to frontier (cost={state_cost[new_state]}) -> {new_state.data}")
                elif new_state in frontier and state_cost[new_state] > state_cost[state] + cost:
                    old_cost = state_cost[new_state]
                    parent_state[new_state] = state
                    state_cost[new_state] = state_cost[state] + cost
                    logging.debug(f"Updated node cost in frontier: {old_cost} -> {state_cost[new_state]}")

        if frontier:
            state = frontier.pop()
        else:
            state = None
        
        i += 1

    logging.debug(f'total nodes in frontier: {n_frontier}')

    path = list()
    s = min_state
    while s:
        path.append(s.copy_data())
        s = parent_state[s]

    logging.info(f'done in {i} iterations')
    print(list(enumerate(reversed(path))))

    return min_state

Breadth-First

In [8]:
logging.getLogger().setLevel(logging.INFO)

# for N in [5, 10, 20, 100, 500, 1000]:
for N in [10]:
    goal = set(range(N))
    initial_state = State(list()) # Empty list as initial state

    all_lists = problem(N, seed=42)

    # # Pre-process (remove duplicates)
    # all_lists = set(tuple(l) for l in all_lists)
    # all_lists = [list(t) for t in all_lists]
    
    # all_lists = [[1, 2], [2, 3], [2, 4], [3], [1, 4], [4], [0, 1], [0]]
    # all_lists = [[1, 2], [2, 3], [2], [3]]

    # solutions = list()

    print(len(all_lists))

    parent_state = dict()
    state_cost = dict()

    min_state = search_min(
        initial_state,
        goal_test=goal_test,
        parent_state=parent_state,
        state_cost=state_cost,
        priority_function=lambda s: len(state_cost),
        unit_cost=lambda a: len(a),
    )

    logging.info(
        f"Found min solution for N={N}: w={sum(len(_) for _ in min_state.data)} (bloat={(sum(len(_) for _ in min_state.data)-N)/N*100:.0f}%)"
    )

50


KeyboardInterrupt: 