In [145]:
from random import random
from functools import reduce
from collections import namedtuple
from queue import PriorityQueue

import numpy as np
from tqdm.auto import tqdm

In [146]:
# Definition of problem constants and generation of the sets
PROBLEM_SIZE = 20
NUM_SETS = 250
SETS = tuple(
    np.array([random() < 0.2 for _ in range(PROBLEM_SIZE)]) for _ in range(NUM_SETS)
)

In [147]:
# Define algorithmic state and helper functions
State = namedtuple("State", ["taken", "not_taken"])


def covered(state):
    """
    Returns the characteristic function of the set of covered elements at the current state.
    @param state: the current state
    """
    return reduce(
        np.logical_or,
        [SETS[i] for i in state.taken],
        np.array([False for _ in range(PROBLEM_SIZE)]),
    )


def goal_check(state):
    """
    Checks whether the current state satisfies the goal
    @param state: the current state
    """
    return np.all(covered(state))

In [148]:
class HeuristicAlgorithm:
    def __init__(
        self,
        heuristic,  # a scalar function of the state
        visit_state,  # a function that returns a collection of (neighboring) states from a state
        goal_check,  # a function that checks a state to interrupt execution
        total_cost=lambda _, h: h,  # a scalar function of the state and the heuristic
    ):
        """
        This class implements a heuristic algorithm.

        @param heuristic: a scalar function of the state
        @param visit_state: a function that returns a collection of states from a state
        @param goal_check: a function that checks a state to interrupt execution
        @param total_cost: a scalar function of the state and the heuristic,
            defaults to the heuristic itself
        """
        self.frontier = PriorityQueue()
        self.h = heuristic
        self.visit_state = visit_state
        self.goal_check = goal_check
        self.iterations_count = 0
        self.f = lambda s: total_cost(s, self.h(s))
        self.solution = None

    def get_iterations_count(self):
        return self.iterations_count

    def solve(self, start_state):
        self.frontier.put((self.f(start_state), start_state))
        self.iterations_count = 0
        _, current_state = self.frontier.get()
        while not goal_check(current_state):
            self.iterations_count += 1
            for new_state in self.visit_state(current_state):
                self.frontier.put((self.f(new_state), new_state))
            # TODO handle no solution case
            _, current_state = self.frontier.get()
        self.solution = current_state

    def get_solution(self):
        return self.solution

## Greedy best first


In [149]:
greedy_solver = HeuristicAlgorithm(
    lambda s: PROBLEM_SIZE - sum(covered(s)),
    lambda s: (State(s.taken ^ {i}, s.not_taken ^ {i}) for i in s[1]),
    goal_check,
)

greedy_solver.solve(State(set(), set(range(NUM_SETS))))
print("Solution found in", greedy_solver.get_iterations_count(), "steps.")
print("Number of sets:", len(greedy_solver.get_solution().taken))
print("Solution:", greedy_solver.get_solution().taken)

Solution found in 4 steps.
Number of sets: 4
Solution: {184, 130, 12, 20}


## A\*


In [150]:
def h(s):
    """
    This heuristic is admissible.
    It returns an optimistic estimate h of the number of sets necessary to complete the covering.

    Let n be the cardinality of the set of uncovered elements.
    h is equal to the minimum number of sets that cover the uncovered elements n times (counting
    repetitions).
    """
    already_covered = covered(s)
    if np.all(already_covered):
        return 0
    missing_size = PROBLEM_SIZE - sum(already_covered)
    candidates = sorted(
        (sum(np.logical_and(s, np.logical_not(already_covered))) for s in SETS),
        reverse=True,
    )
    taken = 1
    while sum(candidates[:taken]) < missing_size:
        taken += 1
    return taken

In [151]:
astar_solver = HeuristicAlgorithm(
    h,
    lambda s: (State(s.taken ^ {i}, s.not_taken ^ {i}) for i in s[1]),
    goal_check,
    lambda s, h: len(s.taken) + h,
)

astar_solver.solve(State(set(), set(range(NUM_SETS))))
print("Solution found in", astar_solver.get_iterations_count(), "steps.")
print("Number of sets:", len(astar_solver.get_solution().taken))
print("Solution:", astar_solver.get_solution().taken)

Solution found in 63 steps.
Number of sets: 3
Solution: {195, 213, 239}
