In [31]:
import logging
import random
import heapq
import numpy as np
from typing import Callable
logging.basicConfig(format="%(message)s", level=logging.INFO)

In [32]:
class State:
    def __init__(self, data: np.ndarray):
        self._data = data.copy()
        self._data.flags.writeable = False

    def __hash__(self):
        return hash(bytes(self._data))

    def __eq__(self, other):
        return bytes(self._data) == bytes(other._data)

    def __lt__(self, other):
        return bytes(self._data) < bytes(other._data)

    def __str__(self):
        return str(self._data)

    def __repr__(self):
        return repr(self._data)

    @property
    def data(self):
        return self._data

    def copy_data(self):
        return self._data.copy()

class PriorityQueue:
    """A basic Priority Queue with simple performance optimizations"""

    def __init__(self):
        self._data_heap = list()
        self._data_set = set()

    def __bool__(self):
        return bool(self._data_set)

    def __contains__(self, item):
        return item in self._data_set

    def push(self, item, p=None):
        assert item not in self, f"Duplicated element"
        if p is None:
            p = len(self._data_set)
        self._data_set.add(item)
        heapq.heappush(self._data_heap, (p, item))

    def pop(self):
        p, item = heapq.heappop(self._data_heap)
        self._data_set.remove(item)
        return item

In [33]:
def unit_cost(a):
    return len(a)

def priority_function(state):
    state_list = state.data.tolist()
    return sum(len(s) for s in state_list)

def possible_blocks(block, state):
    pb = block.copy()
    for s in state:
        pb.remove(s)

    return pb

def goal_test(state, N):
    state_list = state.data.tolist()
    print(state_list)
    new_list = set()
    for lst in state_list:
        for s in lst:
            new_list.add(s)
    return new_list == set(range(N))

In [34]:
def problem(N, seed=None):
    random.seed(seed)
    return [
        list(set(random.randint(0, N - 1) for n in range(random.randint(N // 5, N // 2))))
        for n in range(random.randint(N, N * 5))
    ]

def define_new_state(state, a):
    state.append(a)
    return State(np.array(state, dtype=object))

N = 7

In [35]:
def search(
    blocks,
    initial_state: State,
    goal_test: Callable,
    parent_state: dict,
    state_cost: dict,
    priority_function: Callable,
    unit_cost: Callable,
):
    frontier = PriorityQueue()
    parent_state.clear()
    state_cost.clear()
    state = initial_state
    parent_state[state] = None
    state_cost[state] = 0
    while state is not None and not goal_test(state, N):
        for a in possible_blocks(blocks, state.data.tolist()):
            # print("state:", state)
            # print("possible blocks:", possible_blocks(blocks, state.data.tolist()))
            new_state = define_new_state(state.data.tolist(), a)
            # print("new state", new_state)
            cost = unit_cost(a)
            if new_state not in state_cost and new_state not in frontier:
                parent_state[new_state] = state
                state_cost[new_state] = state_cost[state] + cost
                frontier.push(new_state, p=priority_function(new_state))
                # logging.debug(f"Added new node to frontier (cost={state_cost[new_state]})")
                # print(f"Added new node to frontier (cost={state_cost[new_state]})")
            elif new_state in frontier and state_cost[new_state] > state_cost[state] + cost:
                old_cost = state_cost[new_state]
                parent_state[new_state] = state
                state_cost[new_state] = state_cost[state] + cost
                # logging.debug(f"Updated node cost in frontier: {old_cost} -> {state_cost[new_state]}")
                # print(f"Updated node cost in frontier: {old_cost} -> {state_cost[new_state]}")
        if frontier:
            state = frontier.pop()
        else:
            state = None

    path = list()
    s = state
    while s:
        path.append(s.copy_data())
        s = parent_state[s]
    # logging.info(f"Found a solution in {len(path):,} steps; visited {len(state_cost):,} states")
    print(f"Found a solution in {len(path):,} steps; visited {len(state_cost):,} states")
    print("initial blocks", blocks)
    return list(reversed(path))

In [36]:
parent_state = dict()
state_cost = dict()
blocks = problem(N)
print("blocks", blocks)
final = search(
    sorted(blocks, key=lambda l: len(l)),
    State(np.array([])),
    goal_test=goal_test,
    parent_state=parent_state,
    state_cost=state_cost,
    priority_function=priority_function,
    unit_cost=unit_cost
)

blocks [[6], [2, 4, 5], [0, 1], [2], [3, 5], [1, 3], [4], [0, 6], [1, 4], [2], [2], [6], [5], [0, 2], [0, 1], [3], [2, 3], [3, 6], [3, 6], [5], [5], [1, 6], [6], [2, 6], [6], [6]]
[]
[[2]]
[[3]]
[[4]]
[[5]]
[[6]]
[[0, 1]]
[[0, 2]]
[[0, 6]]
[[1, 3]]
[[1, 4]]
[[1, 6]]
[[2], [2]]
[[2, 3]]
[[2], [4]]
[[2], [5]]
[[2, 6]]
[[3], [2]]
[[3], [4]]
[[3, 5]]
[[3, 6]]
[[4], [2]]
[[4], [3]]
[[4], [5]]
[[4], [6]]
[[5], [2]]
[[5], [3]]
[[5], [4]]
[[5], [5]]
[[5], [6]]
[[6], [2]]
[[6], [3]]
[[6], [4]]
[[6], [5]]
[[6], [6]]
[[1, 4], [5]]
[[6], [0, 1]]
[[1, 4], [4]]
[[4], [2, 6]]
[[1, 6], [6]]
[[6], [3, 6]]
[[1, 3], [6]]
[[3], [1, 6]]
[[4], [3, 5]]
[[3], [3, 6]]
[[2, 3], [6]]
[[2, 3], [6]]
[[2, 3], [5]]
[[2, 3], [6]]
[[4], [3, 6]]
[[3, 6], [6]]
[[0, 6], [5]]
[[4], [0, 1]]
[[0, 1], [6]]
[[4], [1, 6]]
[[3, 6], [6]]
[[3, 5], [6]]
[[3, 5], [5]]
[[5], [3, 6]]
[[6], [0, 1]]
[[5], [2, 3]]
[[5], [2, 6]]
[[3, 5], [4]]
[[4], [1, 4]]
[[0, 2], [6]]
[[0, 2], [6]]
[[0, 2], [5]]
[[0, 2], [5]]
[[0, 2], [5]]
[[0, 2], [6]

KeyboardInterrupt: 