In [23]:
import time
from collections import namedtuple
from functools import reduce
from queue import SimpleQueue, PriorityQueue
from random import random

import numpy as np

In [24]:
PROBLEM_SIZE = 10
NUM_SETS = 20
SETS = tuple(np.array([random() < .3 for _ in range(PROBLEM_SIZE)]) for _ in range(NUM_SETS))
State = namedtuple('State', ['taken', 'not_taken'])
SETS

(array([ True, False, False,  True,  True,  True,  True, False, False,
         True]),
 array([False, False,  True,  True,  True, False,  True, False, False,
        False]),
 array([False, False, False, False, False,  True,  True, False, False,
        False]),
 array([False,  True, False, False, False, False, False, False, False,
         True]),
 array([False, False,  True, False, False, False,  True, False,  True,
        False]),
 array([ True,  True, False,  True,  True,  True, False, False, False,
        False]),
 array([False, False,  True, False,  True, False, False, False,  True,
        False]),
 array([False, False, False,  True, False, False, False,  True, False,
        False]),
 array([ True, False, False, False, False, False, False, False, False,
         True]),
 array([False,  True, False,  True, False, False, False, False, False,
        False]),
 array([False, False, False, False, False,  True,  True, False, False,
         True]),
 array([False,  True, False, Fal

In [25]:
def current_cover(state):
    return reduce(np.logical_or, [SETS[i] for i in state.taken], np.array([False for _ in range(PROBLEM_SIZE)]))


def goal_check(state):
    return np.all(current_cover(state))


def actual_cost(current_state):
    return len(current_state.taken)

In [26]:
assert goal_check(State(set(range(NUM_SETS)), set())), "Problem not solvable"

## A* with cover-weighted sets

In [27]:
cell_score = np.sum(SETS, axis=0)
set_costs = np.sort([np.dot(cell_score, SETS[i]) for i in range(NUM_SETS)])
least_covering_set_score = set_costs[:1]

i = 2
prev_i = 1
while least_covering_set_score == 0:
    least_covering_set_score = set_costs[prev_i:i]
    
goal_score = np.sum(cell_score)

def expand_sets(state, taken=False):
    if taken:
        return [SETS[i] for i in state.taken]
    else:
        return [SETS[i] for i in state.not_taken]
    

def distance(state):
    set_score = np.dot(cell_score, current_cover(state))
    return np.floor((goal_score - set_score) / least_covering_set_score)[0]


def a_star(state, current_state):
    return actual_cost(current_state) + distance(state)


In [28]:
frontier = PriorityQueue()
state = State(set(), set(range(NUM_SETS)))
frontier.put((a_star(state, state), state))

start_time = time.time()
counter = 0
_, current_state = frontier.get()
while not goal_check(current_state):
    counter += 1
    for action in current_state.not_taken:
        # Skip empty sets
        if sum(current_cover(State(set() ^ {action}, set(range(NUM_SETS)) ^ {action}))) != 0:
            new_state = State(current_state.taken ^ {action}, current_state.not_taken ^ {action})
            frontier.put((a_star(new_state, current_state), new_state))
    _, current_state = frontier.get()

print(f"Solved in {counter} steps ({len(current_state.taken)} tiles) in {time.time() - start_time}s")

Solved in 23 steps (3 tiles) in 0.008954763412475586s


In [29]:
print(current_state)
print(goal_check(current_state), distance(current_state))

State(taken={0, 4, 15}, not_taken={1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19})
True 0.0


## Breadth-first

In [30]:
# Breadth-first to check if I found the optimal solution
frontier = SimpleQueue()
state = State(set(), set(range(NUM_SETS)))
frontier.put(state)

start_time = time.time()
counter = 0
current_state = frontier.get()
while not goal_check(current_state):
    counter += 1
    for action in current_state.not_taken:
        new_state = State(current_state.taken ^ {action}, current_state.not_taken ^ {action})
        frontier.put(new_state)
    current_state = frontier.get()

print(f"Solved in {counter} steps ({len(current_state.taken)} tiles) in {time.time() - start_time}s")

Solved in 468 steps (3 tiles) in 0.01621413230895996s


In [31]:
print(goal_check(current_state), distance(current_state))

True 0.0


In [32]:
current_state

State(taken={0, 4, 15}, not_taken={1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19})

In [33]:
[SETS[i] for i in current_state.taken]

[array([ True, False, False,  True,  True,  True,  True, False, False,
         True]),
 array([False, False,  True, False, False, False,  True, False,  True,
        False]),
 array([False,  True, False, False, False,  True, False,  True, False,
        False])]