In [1]:
%cd ..

/home/fblanke/Private/git/advent-of-code-2022


In [2]:
from __future__ import annotations

from collections import deque
from dataclasses import dataclass
from functools import lru_cache
from itertools import product
from math import lcm
import numpy as np
from typing import Iterable

from search import AStarState, shortest_path

### First attempt

In [3]:
class State:
    valid_dict: dict[int, set[complex]]
    storm_dict: dict[int, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]

    def __init__(
        self,
        width: int,
        height: int,
        entry_idx: int,
        exit_idx: int,
        storms: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
    ) -> None:
        self.width = width
        self.height = height
        self.entry_idx = entry_idx
        self.exit_idx = exit_idx

        self.start_point = self.entry_idx - 1j
        self.end_point = self.exit_idx + self.height * 1j

        self.valid_dict = {}
        self.storm_dict = {0: storms}

        self.all_points = {
            x + y * 1j for x, y in product(range(self.width), range(self.height))
        }
        self.movement_dirs = [1, 1j, -1, -1j, 0]
        self.dir_prefs = self.calculate_dir_prefs()

    def calculate_dir_prefs(self) -> dict[complex, tuple[complex, ...]]:
        dir_prefs = {}

        return {
            point: sorted(
                self.movement_dirs,
                key=lambda d: abs(point + d - self.end_point),
            )
            for point in self.all_points.union({self.start_point, self.end_point})
        }

    def find_valid_pos(self, time: int) -> None:
        if time not in self.storm_dict:
            self.run_simulation(time)

        storms = set(np.concatenate(self.storm_dict[time]))
        self.valid_dict[time] = self.all_points.difference(storms)

    def run_simulation(self, time: int) -> None:
        if (time - 1) not in self.storm_dict:
            self.run_simulation(time - 1)

    @lru_cache(maxsize=100000)
    def is_valid(self, time: int, pos: complex) -> bool:
        if pos == self.end_point:
            return True
        if time not in self.valid_dict:
            self.find_valid_pos(time)
        return pos in self.valid_dict[time]

    @lru_cache(maxsize=10000)
    def traverse_dfs(
        self, time: int = 0, pos: complex | None = None, curr_best=float("inf")
    ):
        if pos is None:
            pos = self.start_point
        elif pos == self.end_point:
            return time
        if time > curr_best:
            return curr_best
        for d in self.dir_prefs[pos]:
            if self.is_valid(time + 1, pos + d):
                val = self.traverse(time + 1, pos + d, curr_best=curr_best)
                if val < curr_best:
                    curr_best = val
        return curr_best

    def traverse_bfs(self):
        bfs_queue = deque()
        bfs_queue.append((0, self.start_point))

        while bfs_queue:
            time, pos = bfs_queue.popleft()
            for d in self.dir_prefs[pos]:
                new_pos = pos + d
                if new_pos == self.end_point:
                    return time + 1
                if self.is_valid(time + 1, new_pos):
                    bfs_queue.append((time + 1, pos + d))
        return None

    def board(self, time: int) -> str:
        ret_str = "    "
        for idx in range(-1, self.width + 1):
            ret_str += "." if idx == self.entry_idx else "#"
        ret_str += "\n"
        for row_idx in range(self.height):
            ret_str += f"{row_idx:2d}: #"
            for col_idx in range(self.width):
                if self.is_valid(time, row_idx, col_idx):
                    ret_str += "."
                else:
                    ret_str += "~"
            ret_str += "#\n"
        ret_str += "    "
        for idx in range(-1, self.width + 1):
            ret_str += "." if idx == self.exit_idx else "#"
        return ret_str

In [4]:
def parse_input(filename: str):
    with open(filename) as f:

        def _find_hole(row_str: str) -> int:
            row_str = row_str.rstrip()
            for char_idx, char in enumerate(row_str):
                if char != "#":
                    return char_idx - 1

        first_row = next(f).rstrip()
        width = len(first_row) - 2
        entry_idx = _find_hole(first_row)

        storms = [[], [], [], []]

        for row_idx, row in enumerate(f):
            row = row.rstrip()
            if row.count("##") > 0:
                exit_idx = _find_hole(row)
                break
            for char_idx, char in enumerate(row):
                if char == "#" or char == ".":
                    continue
                storms[">v<^".index(char)].append(char_idx - 1 + row_idx * 1j)
        return State(
            width,
            row_idx,
            entry_idx,
            exit_idx,
            tuple(np.array(storm) for storm in storms),
        )

## Second attempt

In [5]:
@dataclass
class Constants:
    start_pos: complex
    end_pos: complex
    ground: set[complex]
    blizzards: list[set[complex]]
    cycle_len: int


class ValleyState(AStarState[Constants]):
    def __init__(self, pos: complex | None = None, cycle: int = 1, **kwargs):
        if pos is None:
            self.pos = kwargs["constants"].start_pos
        else:
            self.pos = pos
        self.cycle = cycle
        super().__init__(**kwargs)

    def go_back(self) -> ValleyState:
        self.c.start_pos, self.c.end_pos = self.c.end_pos, self.c.start_pos
        return self

    @property
    def is_finished(self) -> bool:
        return self.pos == self.c.end_pos

    @property
    def next_states(self) -> Iterable[ValleyState]:
        blizzards = self.c.blizzards[self.cycle % self.c.cycle_len]
        movement_dirs = [1, 1j, -1, -1j, 0]
        return [
            self.move(pos=p, cycle=self.cycle + 1)
            for p in [self.pos + d for d in movement_dirs]
            if p in self.c.ground and p not in blizzards
        ]

    @property
    def heuristic(self) -> int:
        return abs(self.pos.real - self.c.end_pos.real) + abs(
            self.pos.imag - self.c.end_pos.imag
        )

    def __hash__(self) -> int:
        return hash((self.pos, self.cycle))

    def __repr__(self):
        return f"{self.pos} - {self.cycle}"


def calculate_blizzard_pos(
    storms: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
    num_timesteps: int,
    width: int,
    height: int,
) -> list[set[complex]]:
    storm_pos_lst = []

    def _add_storms(
        _storms: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
    ) -> set[complex]:
        storm_pos_lst.append(set(np.concatenate(_storms)))

    _add_storms(storms)
    right_storms, down_storms, left_storms, top_storms = storms

    for time in range(num_timesteps):
        right_storms = right_storms + 1
        left_storms = left_storms - 1
        top_storms = top_storms - 1j
        down_storms = down_storms + 1j

        right_storms.real[right_storms.real >= width] = 0
        left_storms.real[left_storms.real < 0] = width - 1
        top_storms.imag[top_storms.imag < 0] = height - 1
        down_storms.imag[down_storms.imag >= height] = 0

        _add_storms((right_storms, down_storms, left_storms, top_storms))
    return storm_pos_lst

In [6]:
def parse_input(filename: str):
    with open(filename) as f:

        def _find_hole(row_str: str) -> int:
            row_str = row_str.rstrip()
            for char_idx, char in enumerate(row_str):
                if char != "#":
                    return char_idx - 1

        first_row = next(f).rstrip()
        width = len(first_row) - 2
        entry_idx = _find_hole(first_row)

        storms = [[], [], [], []]

        for row_idx, row in enumerate(f):
            row = row.rstrip()
            if row.count("##") > 0:
                exit_idx = _find_hole(row)
                break
            for char_idx, char in enumerate(row):
                if char == "#" or char == ".":
                    continue
                storms[">v<^".index(char)].append(char_idx - 1 + row_idx * 1j)

        height = row_idx
        cycle_len = lcm(width, height)
        blizzards = calculate_blizzard_pos(
            tuple(np.array(storm) for storm in storms), cycle_len, width, height
        )
        start_pos = entry_idx - 1j
        end_pos = exit_idx + height * 1j
        const = Constants(
            start_pos,
            end_pos,
            ground={x + y * 1j for x, y in product(range(width), range(height))}.union(
                {start_pos, end_pos}
            ),
            blizzards=blizzards,
            cycle_len=cycle_len,
        )
        path = shortest_path(ValleyState(constants=const))
        print(f"Shortest path length: {path.length}")
        return_path = shortest_path(path.end_state.go_back())
        print(f"Trip back length: {return_path.length - path.length}")
        path_2 = shortest_path(return_path.end_state.go_back())
        print(f"Overall length: {path_2.length}")

In [7]:
parse_input("day-24/test-input.txt")

Shortest path length: 10
Trip back length: 10
Overall length: 30


In [8]:
parse_input("day-24/test-input2.txt")

Shortest path length: 18
Trip back length: 23
Overall length: 54


In [9]:
parse_input("day-24/input.txt")

Shortest path length: 334
Trip back length: 309
Overall length: 934
