In [1]:
import random
import heapq
import numpy as np
import time
from typing import Callable

In [2]:
class State:
    """A basic State"""
    def __init__(self, data: np.ndarray):
        self._data = data.copy()
        self._data.flags.writeable = False

    def __hash__(self):
        return hash(bytes(self._data))

    def __eq__(self, other):
        return bytes(self._data) == bytes(other._data)

    def __lt__(self, other):
        return bytes(self._data) < bytes(other._data)

    def __str__(self):
        return str(self._data.tolist())

    def __repr__(self):
        return repr(self._data)

    @property
    def data(self):
        return self._data

    def copy_data(self):
        return self._data.copy()

class PriorityQueue:
    """A basic Priority Queue with simple performance optimizations"""

    def __init__(self):
        self._data_heap = list()
        self._data_set = set()

    def __bool__(self):
        return bool(self._data_set)

    def __contains__(self, item):
        return item in self._data_set

    def push(self, item, p=None):
        assert item not in self, f"Duplicated element"
        if p is None:
            p = len(self._data_set)
        self._data_set.add(item)
        heapq.heappush(self._data_heap, (p, item))

    def pop(self):
        p, item = heapq.heappop(self._data_heap)
        self._data_set.remove(item)
        return item

In [3]:
def unit_cost(a):
    """Unit Cost calculate, in this case its the length"""
    return len(a)

def priority_function_astar_with_state_cost1(state, N, state_cost):
    """A* priority function, optimized to avoid calculation the cost, version 1"""
    state_list = state.data.tolist()
    new_list = set()
    for lst in state_list:
        for s in lst:
            # the add only takes effect if the element is missing from the set
            new_list.add(s)
    return state_cost[state] + (len(range(N)) - len(new_list))

def priority_function_astar_with_state_cost2(state, N, state_cost):
    """A* priority function, optimized to avoid calculation the cost, version 2"""
    state_list = state.data.tolist()
    new_list = set()
    for lst in state_list:
        # the update takes the list and looks which element is missing from the set
        # if there is a missing element it adds it
        new_list.update(lst)
    return state_cost[state] + (len(range(N)) - len(new_list))


def priority_function_astar(state, N):
    """A* priority function, no optimizations"""
    state_list = state.data.tolist()
    new_list = set()
    sum = 0
    for lst in state_list:
        sum += len(lst)
        for s in lst:
            # the add only takes effect if the element is missing from the set
            new_list.add(s)
    return sum + (len(range(N)) - len(new_list))

def possible_blocks(blocks, state):
    """Takes all blocks and removes the ones that are already present in state"""
    possible_blocks = blocks.copy()
    for lst in state.data.tolist():
        possible_blocks.remove(lst)
    return possible_blocks

def goal_test(state, N):
    """Tests whether the state is Goal"""
    # take all elements in state list (eg. [[0, 1], [3], [3, 4, 5]])
    # put them in a unique set (eg. [0, 1, 2, 3, 4, 5])
    # compare this set with the set expected with a problem of N size
    state_list = state.data.tolist()
    new_list = set()
    for lst in state_list:
        for s in lst:
            new_list.add(s)
    return new_list == set(range(N))

In [4]:
def problem(N, seed=None):
    """Generates the problem, also makes all blocks generated unique"""
    random.seed(seed)
    blocks_not_unique = [
        list(set(random.randint(0, N - 1) for n in range(random.randint(N // 5, N // 2))))
        for n in range(random.randint(N, N * 5))
    ]
    blocks_unique = np.unique(np.array(blocks_not_unique, dtype=object))
    return blocks_unique.tolist()


def define_new_state(state, a):
    """Defines a new state, from the current state plus a block"""
    state.append(a)
    return State(np.array(state, dtype=object))

In [5]:
def search(
    blocks,
    initial_state: State,
    goal_test: Callable,
    parent_state: dict,
    state_cost: dict,
    priority_function: Callable,
    unit_cost: Callable,
):
    frontier = PriorityQueue()
    parent_state.clear()
    state_cost.clear()
    state = initial_state
    parent_state[state] = None
    state_cost[state] = 0
    while state is not None and not goal_test(state, N):
        for a in possible_blocks(blocks, state):
            new_state = define_new_state(state.data.tolist(), a)
            cost = unit_cost(a)
            if new_state not in state_cost and new_state not in frontier:
                parent_state[new_state] = state
                state_cost[new_state] = state_cost[state] + cost
                frontier.push(new_state, p=priority_function(new_state))
            elif new_state in frontier and state_cost[new_state] > state_cost[state] + cost:
                old_cost = state_cost[new_state]
                parent_state[new_state] = state
                state_cost[new_state] = state_cost[state] + cost
        if frontier:
            state = frontier.pop()
        else:
            state = None

    path = list()
    s = state
    while s:
        path.append(s.copy_data().tolist())
        s = parent_state[s]
    # print(f"Initial blocks : {blocks}")
    print(f"Found a solution in {len(path):,} steps; visited {len(state_cost):,} states; w {sum(len(s) for s in state.data.tolist())}")
    # for k,v in enumerate(reversed(path)):
    #     print(f"Step {k} : {v}")
    # print(f"Step {len(path) - 1} is Solution")

    return list(reversed(path))

In [6]:
parent_state = dict()
state_cost = dict()
for i in range(1):
    for N in [5, 10, 20, 100, 500, 100]:
        start = time.time()
        blocks = problem(N , seed=42)
        print(f"For N:{N}, iteration:{i+1}")
        final = search(
            sorted(blocks, key=lambda l: len(l)),
            State(np.array([])),
            goal_test=goal_test,
            parent_state=parent_state,
            state_cost=state_cost,
            priority_function=lambda s: priority_function_astar_with_state_cost2(s, N, state_cost),
            unit_cost=unit_cost
        )
        end = time.time()
        print(f"Time for N:{N} = {end-start}")

For N:5, iteration:1
Found a solution in 4 steps; visited 35 states; w 5
Time for N:5 = 0.0013780593872070312
For N:10, iteration:1
Found a solution in 5 steps; visited 4,323 states; w 10
Time for N:10 = 0.06314587593078613
For N:20, iteration:1
Found a solution in 6 steps; visited 446,640 states; w 23
Time for N:20 = 7.582447052001953
