# Day 20 - Reassembling images

* https://adventofcode.com/2020/day/20

I've treated this as an A* Search, otherwise the search space is way too large. The search state only needs to track what the remaining top and left edge values of the unplaced tiles.

Of course, you need to be able to determine what the edge values are; there are 8 different ways an image can be oriented through rotation or flipping, and we can know, up front, what part of the input data we need to use for this.

In [1]:
import operator
from collections.abc import Iterator, MutableSet, Sequence, Set
from dataclasses import dataclass, field
from enum import Enum
from functools import cached_property, partial, reduce
from heapq import heapify, heappop, heappush
from itertools import count, product
from typing import Dict, Generic, List, Mapping, Optional, Tuple, TypeVar

import numpy as np


T = TypeVar("T")


class PriorityQueue(Generic[T]):
    def __init__(self, *initial: Tuple[float, T]) -> None:
        self._queue: List[Tuple[float, int, T]] = []
        self._count = count()
        for pri, item in initial:
            self.put(pri, item)
        heapify(self._queue)

    def __len__(self) -> int:
        return len(self._queue)

    def put(self, pri: float, item: T) -> None:
        heappush(self._queue, (pri, next(self._count), item))

    def get(self) -> T:
        if not self:
            raise ValueError("Queue is empty")
        return heappop(self._queue)[-1]


class Edge(Enum):
    top = 0
    right = 1
    bottom = 2
    left = 3


class EdgeDir(Enum):
    tf = (slice(None, None, 1), 0)
    tr = (slice(None, None, -1), 0)
    rf = (-1, slice(None, None, 1))
    rr = (-1, slice(None, None, -1))
    bf = (slice(None, None, 1), -1)
    br = (slice(None, None, -1), -1)
    lf = (0, slice(None, None, 1))
    lr = (0, slice(None, None, -1))


class Orientation(Enum):
    orig = (EdgeDir.tf, EdgeDir.rf, EdgeDir.bf, EdgeDir.lf)
    rot90 = (EdgeDir.lr, EdgeDir.tf, EdgeDir.rr, EdgeDir.bf)
    rot180 = (EdgeDir.br, EdgeDir.lr, EdgeDir.tr, EdgeDir.rr)
    rot270 = (EdgeDir.rf, EdgeDir.br, EdgeDir.lf, EdgeDir.tr)
    mirror = (EdgeDir.tr, EdgeDir.lf, EdgeDir.br, EdgeDir.rf)
    mrt90 = (EdgeDir.rr, EdgeDir.tr, EdgeDir.lr, EdgeDir.br)
    mrt180 = (EdgeDir.bf, EdgeDir.rr, EdgeDir.tf, EdgeDir.lr)
    mrt270 = (EdgeDir.lf, EdgeDir.bf, EdgeDir.rf, EdgeDir.tf)


@dataclass(frozen=True)
class Tile:
    id: int
    data: str

    @cached_property
    def matrix(self) -> "np.array[np.bool]":
        return np.array(
            [c == "#" for line in self.data.splitlines() for c in line]
        ).reshape((-1, self.data.index("\n")))

    @classmethod
    def from_data(cls, data: str) -> "Tile":
        tile_line, data = data.split("\n", 1)
        tile_id = int(tile_line.split()[1].strip(":"))
        return cls(tile_id, data)

    def __str__(self):
        return f"Tile {self.id}:\n{self.data}"

    def __getitem__(self, edgedir: EdgeDir) -> int:
        return (
            np.packbits(np.pad(self.matrix[edgedir.value], (6, 0))).view(">u2").item()
        )


@dataclass(frozen=True)
class OrientedTile:
    tile: Tile
    orientation: Orientation

    @property
    def id(self) -> int:
        return self.tile.id

    def __getitem__(self, edge: Edge) -> int:
        return self.tile[self.orientation.value[edge.value]]


@dataclass(frozen=True)
class SearchState:
    size: int = 0
    diag: int = 0
    x: int = 0
    y: int = 0
    state: Sequence[Optional[OrientedTile]] = ()
    tops: Mapping[int, Set[OrientedTile]] = field(hash=False, default_factory=dict)
    lefts: Mapping[int, Set[OrientedTile]] = field(hash=False, default_factory=dict)

    @classmethod
    def from_data(cls, data: str) -> "SearchState":
        tile_set = frozenset(Tile.from_data(chunk) for chunk in data.split("\n\n"))
        size = int(len(tile_set) ** 0.5)  # square root of the initial set
        tops: Dict[int, MutableSet[OrientedTile]] = {}
        lefts: Dict[int, MutableSet[OrientedTile]] = {}
        for tile in tile_set:
            for orientation in Orientation:
                orientated = OrientedTile(tile, orientation)
                tops.setdefault(orientated[Edge.top], set()).add(orientated)
                lefts.setdefault(orientated[Edge.left], set()).add(orientated)
        return cls(size, tops=tops, lefts=lefts)

    @property
    def complete(self) -> bool:
        return bool(self.size) and self.size - 1 == self.x == self.y

    @property
    def checksum(self) -> int:
        if not self.complete:
            raise ValueError("Incomplete state")
        state, size = self.state, self.size
        corners = (state[x + size * y] for x, y in product((0, size - 1), repeat=2))
        return reduce(operator.mul, (s.id for s in corners if s is not None))

    @property
    def to_place(self) -> float:
        total, not_placed = len(self.state), sum(1 for o in self.state if o is None)
        return (not_placed / total) if total else 1

    def __str__(self):
        state, size = self.state, self.size
        as_matrix = [[state[x + y * size] for x in range(size)] for y in range(size)]
        formatted = "\n".join(
            [
                "    ".join(["None" if s is None else format(s.id, ">4d") for s in r])
                for r in as_matrix
            ]
        )
        return f"Filled: {1 - self.to_place:.2%}\n{formatted}"

    def _starting_states(self) -> Iterator["SearchState"]:
        tops, lefts = self.tops, self.lefts
        state_tail = (None,) * (self.size ** 2 - 1)
        for ort in set.union(*tops.values()):  # type: ignore  # Yes, these are sets.
            yield SearchState(
                self.size,
                state=(ort, *state_tail),
                tops={v: {o for o in os if o.id != ort.id} for v, os in tops.items()},
                lefts={v: {o for o in os if o.id != ort.id} for v, os in lefts.items()},
            )

    def neighbors(self) -> Iterator["SearchState"]:
        if not (state := self.state):
            yield from self._starting_states()
            return

        size = self.size
        diag, x, y = self.diag, self.x, self.y
        if x == (0 if diag < size else diag - size + 1):
            diag += 1
            x, y = min(diag, size - 1), max(0, diag - size + 1)
        else:
            x -= 1
            y += 1

        matched = None
        tops, lefts = self.tops, self.lefts
        if y and (below := state[x + size * (y - 1)]):
            matched = tops.get(below[Edge.bottom], None)
        if x and (next_to := state[x - 1 + size * y]):
            matched_left = lefts.get(next_to[Edge.right], frozenset())
            if matched is None:
                matched = matched_left
            else:
                matched &= matched_left
        if not matched:
            return

        state_lead, state_tail = state[: x + size * y], state[x + size * y + 1 :]
        new_searchstate = partial(SearchState, size=size, diag=diag, x=x, y=y)
        for ort in matched:
            yield new_searchstate(
                state=(*state_lead, ort, *state_tail),
                tops={v: {o for o in os if o.id != ort.id} for v, os in tops.items()},
                lefts={v: {o for o in os if o.id != ort.id} for v, os in lefts.items()},
            )


def reconstruct_image(tiledata: str) -> SearchState:
    seen = set()
    start = SearchState.from_data(tiledata)
    queue = PriorityQueue((1, start))

    while queue:
        current = queue.get()
        if current.complete:
            return current

        for state in current.neighbors():
            if state in seen:
                continue
            seen.add(state)
            queue.put(state.to_place, state)

    raise AssertionError("No solution found")


testdata = """\
Tile 2311:
..##.#..#.
##..#.....
#...##..#.
####.#...#
##.##.###.
##...#.###
.#.#.#..##
..#....#..
###...#.#.
..###..###

Tile 1951:
#.##...##.
#.####...#
.....#..##
#...######
.##.#....#
.###.#####
###.##.##.
.###....#.
..#.#..#.#
#...##.#..

Tile 1171:
####...##.
#..##.#..#
##.#..#.#.
.###.####.
..###.####
.##....##.
.#...####.
#.##.####.
####..#...
.....##...

Tile 1427:
###.##.#..
.#..#.##..
.#.##.#..#
#.#.#.##.#
....#...##
...##..##.
...#.#####
.#.####.#.
..#..###.#
..##.#..#.

Tile 1489:
##.#.#....
..##...#..
.##..##...
..#...#...
#####...#.
#..#.#.#.#
...#.#.#..
##.#...##.
..##.##.##
###.##.#..

Tile 2473:
#....####.
#..#.##...
#.##..#...
######.#.#
.#...#.#.#
.#########
.###.#..#.
########.#
##...##.#.
..###.#.#.

Tile 2971:
..#.#....#
#...###...
#.#.###...
##.##..#..
.#####..##
.#..####.#
#..#.#..#.
..####.###
..#.#.###.
...#.#.#.#

Tile 2729:
...#.#.#.#
####.#....
..#.#.....
....#..#.#
.##..##.#.
.#.####...
####.#.#..
##.####...
##..#.##..
#.##...##.

Tile 3079:
#.#.#####.
.#..######
..#.......
######....
####.#..#.
.#...#.##.
#.#####.##
..#.###...
..#.......
..#.###...
"""

assert reconstruct_image(testdata).checksum == 20899048083289

In [2]:
import aocd
data = aocd.get_data(day=20, year=2020)

In [3]:
print("Part 1:", reconstruct_image(data).checksum)

Part 1: 11788777383197
