In [161]:
from random import random
from functools import reduce
from collections import namedtuple, deque
from queue import PriorityQueue, SimpleQueue, LifoQueue
from math import ceil,floor

import numpy as np

In [162]:
PROBLEM_SIZE = 20
NUM_SETS = 40
SETS = tuple(
    np.array([random() < 0.3 for _ in range(PROBLEM_SIZE)])
    for _ in range(NUM_SETS)
)
State = namedtuple('State', ['taken', 'not_taken'])

In [163]:
SETS

(array([False, False, False, False, False,  True, False,  True, False,
        False, False,  True,  True, False, False, False, False, False,
        False, False]),
 array([False, False, False, False, False, False, False,  True, False,
        False, False, False, False, False, False, False,  True, False,
        False, False]),
 array([False,  True,  True, False,  True,  True, False, False, False,
        False, False, False, False,  True, False, False, False, False,
        False, False]),
 array([False, False, False, False,  True, False, False, False, False,
         True, False, False, False, False, False, False, False,  True,
         True, False]),
 array([False, False, False, False, False, False, False,  True, False,
        False,  True, False, False,  True, False,  True,  True, False,
        False, False]),
 array([False, False, False, False,  True, False, False, False, False,
        False, False, False,  True, False, False, False, False,  True,
        False,  True]),
 arr

In [164]:
def goal_check(state):
    return np.all(reduce(
        np.logical_or,
        [SETS[i] for i in state.taken],
        np.array([False for _ in range(PROBLEM_SIZE)]),
    ))

assert goal_check(
    State(set(range(NUM_SETS)), set())
), "Probelm not solvable"

As we know, the A* algorithm is based on addition between the actual cost and the estimated (heuristic) one, which calculus function should be admissible (it should never overstimate the cost to reach the goal), consistent and monotonic.
Therefore, we need to define the functions g() and h().

The actual state is rapresented by the number of sets already taken.

In [171]:
def g(state):
    return len(state.taken)

Due to the fact that we have probability of true less than 0.3, we could calculate the heuristic function as the product between the lower bound of the states not taken and the best probability to have true value (that is 0.3).
For example, if PROBLEM_SIZE is 10, state.not_taken is equal to 7, then h will be 7/0.3 = 2,1 = 2, and this will rapresent the optimistic number of steps to reach the goal.

In [166]:
def h(state):
    return floor(len(state.not_taken)*0.3)

In [167]:
def f(state):
    return g(state) + h(state)

### A* algorithm

In [168]:
frontier = PriorityQueue()
state = State(set(), set(range(NUM_SETS)))
frontier.put((f(state), state))

counter = 0
_, current_state = frontier.get()
while not goal_check(current_state):
    counter += 1
    for action in current_state[1]:
        new_state = State(
            current_state.taken ^ {action},
            current_state.not_taken ^ {action},
        )
        frontier.put((f(new_state), new_state))
    _, current_state = frontier.get()

print(f"Solved in {counter:,} steps ({len(current_state.taken)} tiles)")

Solved in 2,941 steps (4 tiles)


To better understand the efficiency of an informed algorithm such as A*, we compare the results of one uninformed like Breadth-First.

In [169]:
frontier = deque()
state = State(set(), set(range(NUM_SETS)))
frontier.append(state)

counter = 0
current_state = frontier.popleft()
while not goal_check(current_state):
    counter += 1
    for action in current_state[1]:
        new_state = State(
            current_state.taken ^ {action},
            current_state.not_taken ^ {action},
        )
        frontier.append(new_state)
    current_state = frontier.popleft()

print(f"Solved in {counter:,} steps ({len(current_state.taken)} tiles)")

Solved in 188,116 steps (4 tiles)


As we can see, there's an evident difference between number of steps performed by the two algorithms.

In [170]:
current_state

State(taken={32, 2, 20, 13}, not_taken={0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 33, 34, 35, 36, 37, 38, 39})