In [37]:
import numpy as np
import math
from random import random
from functools import reduce
from queue import PriorityQueue, LifoQueue
from collections import namedtuple

In [63]:
PROBLEM_SIZE = 6
NUM_SETS = 4
SETS = tuple(
    np.array([random() < 0.4 for _ in range(PROBLEM_SIZE)])
    for _ in range(NUM_SETS)
)
State = namedtuple('State', ['taken', 'not_taken'])

In [64]:
for i in SETS:
        print(i)

[False False False  True False False]
[ True False  True False False False]
[ True False False False False  True]
[False  True False False  True  True]


In [76]:
def goal_check(state):
    return np.all(reduce(
        np.logical_or,
        [SETS[i] for i in state.taken],
        np.array([False for _ in range(PROBLEM_SIZE)]),
    ))

# OLD H that wasn't suitable for A*
# def h(state):
#     # Gives an estimation on how far the current frontier is from the goal state
#     return PROBLEM_SIZE - sum(
#         reduce(
#             np.logical_or,
#             [SETS[i] for i in state.taken],
#             np.array([False for _ in range(PROBLEM_SIZE)]),
#         ))

def h(state):
    # This function works as follow:
    # 1. it looks among the not_taken sets and find which one has the highest number N of TRUE tiles
    # 2. it calculates how many FALSE tiles we still have to cover (let's call it M)
    # 3. it returns N/M, which is an the number of sets we need to take (at least, optimistically) to reach the goal state
    # 4. example: i still miss 5 tiles M=5, we find a set among not taken with N=2 => it returns ceil(5/2)=3, because
    #    optimistaclly we will just need 3 tiles to solve the problem 
    _sorted = sorted(state.not_taken, key=lambda i: sum(SETS[i]), reverse=True)

    return math.ceil((PROBLEM_SIZE - sum(
        reduce(
            np.logical_or,
            [SETS[i] for i in state.taken],
            np.array([False for _ in range(PROBLEM_SIZE)]),
        )))
        /
        sum(SETS[_sorted[0]]))

def g(state):
    # Gives 34the actual distance from the start state (in terms of number of node)
    return len(state.taken)

def f(state):
    return g(state) + h(state)

In [44]:
assert goal_check(State(set(range(NUM_SETS)), set())), "Problem is not solvable"

In [75]:
frontier = PriorityQueue()
state = State(set(), set(range(NUM_SETS)))
frontier.put((f(state), state))

counter = 0
_, current_state = frontier.get()
while not goal_check(current_state):
    counter += 1
    for action in current_state.not_taken:
        new_state = State(
            current_state.taken ^ {action},
            current_state.not_taken ^ {action},
        )
        frontier.put((f(new_state), new_state))
    _, current_state = frontier.get()

print(
    f"Solved in {counter:,} steps ({len(current_state.taken)} sets)"
)

print(
    f"The current sate is: {current_state}"
)


{0, 1, 2, 3}
TAKEN: set()
NOT_TAKEN: {0, 1, 2, 3}
[3, 1, 2, 0]
3
____________________
{1, 2, 3}
TAKEN: {0}
NOT_TAKEN: {1, 2, 3}
[3, 1, 2]
3
____________________
{0, 2, 3}
TAKEN: {1}
NOT_TAKEN: {0, 2, 3}
[3, 2, 0]
3
____________________
{0, 1, 3}
TAKEN: {2}
NOT_TAKEN: {0, 1, 3}
[3, 1, 0]
3
____________________
{0, 1, 2}
TAKEN: {3}
NOT_TAKEN: {0, 1, 2}
[1, 2, 0]
2
____________________
{2, 3}
TAKEN: {0, 1}
NOT_TAKEN: {2, 3}
[3, 2]
3
____________________
{1, 3}
TAKEN: {0, 2}
NOT_TAKEN: {1, 3}
[3, 1]
3
____________________
{1, 2}
TAKEN: {0, 3}
NOT_TAKEN: {1, 2}
[1, 2]
2
____________________
{1, 3}
TAKEN: {0, 2}
NOT_TAKEN: {1, 3}
[3, 1]
3
____________________
{0, 3}
TAKEN: {1, 2}
NOT_TAKEN: {0, 3}
[3, 0]
3
____________________
{0, 1}
TAKEN: {2, 3}
NOT_TAKEN: {0, 1}
[1, 0]
2
____________________
{1, 2}
TAKEN: {0, 3}
NOT_TAKEN: {1, 2}
[1, 2]
2
____________________
{0, 2}
TAKEN: {1, 3}
NOT_TAKEN: {0, 2}
[2, 0]
2
____________________
{0, 1}
TAKEN: {2, 3}
NOT_TAKEN: {0, 1}
[1, 0]
2
______________