Copyright **`(c)`** 2023 Giovanni Squillero `<giovanni.squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free for personal or classroom use; see [`LICENSE.md`](https://github.com/squillero/computational-intelligence/blob/master/LICENSE.md) for details.  

In [11]:
from random import random
from functools import reduce
from collections import namedtuple
from queue import PriorityQueue, SimpleQueue, LifoQueue
import numpy as np
import time

In [12]:
PROBLEM_SIZE = 20
NUM_SETS = 3500
SETS = tuple(np.array([random() < .3 for _ in range(PROBLEM_SIZE)]) for _ in range(NUM_SETS))
State = namedtuple('State', ['taken', 'not_taken'])

In [13]:
# function that validate the solution of the problemS
def goal_check(state):
    return np.all(reduce(np.logical_or, [SETS[i] for i in state.taken], np.array([False for _ in range(PROBLEM_SIZE)])))

In [14]:
# ensures the problem is solvable
while not goal_check(State(set(range(NUM_SETS)), set())):
    SETS = tuple(np.array([random() < .3 for _ in range(PROBLEM_SIZE)]) for _ in range(NUM_SETS))


In [15]:
# cost function for the astar search that returns the number of sets taken
def cost(state):
    return len(state.taken)

# heuristic for the astar search that returns the number of elements not covered
def heuristic(state):
    return PROBLEM_SIZE - sum(
        reduce(
            np.logical_or,
            [SETS[i] for i in state.taken],
            np.array([False for _ in range(PROBLEM_SIZE)]),
        )
    )

In [16]:
def astar_search(cost, heuristic):
    state = State(set(), set(range(NUM_SETS)))
    frontier = PriorityQueue()
    frontier.put((cost(state) + heuristic(state), State(set(), set(range(NUM_SETS)))))

    total_time = 0
    counter = 0
    _, current_state = frontier.get()
    while not goal_check(current_state):
        counter += 1
        
        start_time = time.time()
        for action in current_state[1]:
            new_state = State(
                current_state.taken ^ {action}, 
                current_state.not_taken ^ {action}
            )
            priority = cost(new_state) + heuristic(new_state)
            frontier.put((priority, new_state))
        end_time = time.time()
        total_time += end_time - start_time
        print(f"Step {counter:,} took {end_time - start_time:.6f} seconds")
        _, current_state = frontier.get()

    print(f"Solved in {counter:,} steps and {total_time:.6f} seconds")
    return current_state.taken

In [17]:
astar_solution = astar_search(cost, heuristic)

Step 1 took 0.806225 seconds
Step 2 took 0.741642 seconds
Solved in 2 steps and 1.547866 seconds


In [18]:
# This is the least efficient but complete way to verify that the solution find by A* is comparable to the optimal solution
# Not viable to verify large problems
def complete_search(cost_only=True):
    optimal_solutions = []
    min_cost = float('inf')

    frontier = SimpleQueue()
    frontier.put(State(set(), set(range(NUM_SETS))))

    while not frontier.empty():
        current_state = frontier.get()
        if goal_check(current_state):
            if(len(current_state.taken) < min_cost):
                min_cost = len(current_state.taken)
                if(current_state.taken not in optimal_solutions) and not cost_only:
                    optimal_solutions.append(current_state.taken)
        for action in current_state.not_taken:
            new_state = State(
                current_state.taken ^ {action}, 
                current_state.not_taken ^ {action}
            )
            if(len(new_state.taken) <= min_cost):
                frontier.put(new_state)
            else: 
                break
    return optimal_solutions, min_cost

In [19]:
def print_sets(sets_indexes):
    for index in sets_indexes:
        set = SETS[index]
        print(index, end=':\t')
        for element in set:
            if element:
                print('* ', end='')
            else: 
                print('_ ', end='')
        print()

In [20]:
#_, min_cost = complete_search()
#print("Optimal solutions has length:", min_cost)
print("A* solution:")
print(astar_solution)
print_sets(astar_solution)
#for i in optimal_solutions: 
#    print(i, end=':\n') 
#    print_sets(i)
#print("All the sets:")
#print_sets(range(NUM_SETS))

A* solution:
{2673, 716}
2673:	* * _ * * * _ _ * * * _ * _ * _ * * _ * 
716:	* * * _ * _ * * _ * _ * _ * * * * _ * _ 
