# Advent of Code 2021

> I mean, if 10 years from now, when you are doing something quick and dirty, you suddenly visualize that I am looking over your shoulders and say to yourself "Dijkstra would not have liked this," well, that would be enough immortality for me.

-- Edsger W. Dijkstra

## Imports and definitions

In [232]:
#type: ignore
from math import *
from collections import Counter
from itertools import tee, repeat, count
from functools import reduce
from statistics import median, mean
from dataclasses import dataclass
from operator import mul
import heapq
import numpy as np

# Older Pythons don't have this


def prod(u): return reduce(mul, u)

# Read inputs


def inputfunc(day, kind='lines', testing=False):
    filename = 'test.txt' if testing else f"input/{day}.txt"

    def gen(func):
        if kind == 'lines':
            text = [x.strip() for x in open(filename)]
        elif kind == 'chunks':
            text = [
                x.strip()
                for x in open(filename).read().split('\n\n')
                if x.strip()
            ]
        elif kind == 'single':
            text = open(filename).read().strip()
        elif kind == 'raw':
            text = open(filename)

        def inner():
            return func(f=text)
        return inner
    return gen


## Day 1

In [233]:
@inputfunc(1)
def input_1(*, f):
    return [int(x) for x in f]


def count_increasing_slides(a, n):
    return sum(1 for x, y in zip(a, a[n:]) if x < y)


In [234]:
count_increasing_slides(input_1(), 1)


1624

In [235]:
count_increasing_slides(input_1(), 3)


1653

## Day 2

In [236]:
@inputfunc(2)
def input_2(*, f):
    return [(x, int(y)) for x, y in [x.split() for x in f]]


def pilot_1(l):
    x, y = 0, 0
    for c, n in l:
        if c == 'forward':
            x += n
        elif c == 'down':
            y += n
        elif c == 'up':
            y -= n
    return x * y


def pilot_2(l):
    x, y, aim = 0, 0, 0
    for c, n in l:
        if c == 'forward':
            x += n
            y += aim * n
        elif c == 'down':
            aim += n
        elif c == 'up':
            aim -= n
    return x * y


In [237]:
pilot_1(input_2())


1654760

In [238]:
pilot_2(input_2())


1956047400

## Day 3

In [239]:
@inputfunc(3)
def input_3(*, f):
    return f


def power_consumption(l):
    N = len(l)
    W = len(l[0])
    ones_count = [0] * W
    for s in l:
        for i, c in enumerate(reversed(s)):
            if c == '1':
                ones_count[i] += 1
    gamma = sum(2 ** e for e, b in enumerate(ones_count) if b > N/2)
    epsilon = (2 ** W - 1) - gamma
    return gamma * epsilon


def life_support_rating(l):
    def bisect_on(l, predicate):
        out_true, out_false = tee((u, predicate(u)) for u in l)
        return [u for u, p in out_true if p], [u for u, p in out_false if not p]

    # This uses the naive algorithm, it's a literal transcription of the
    # actual problem description. A far superior approach would be to
    # sort the array of numbers, which can be done in O(n) in a myriad of
    # ways (for example repeatedly applying a stable variant of counting-sort).
    # At that point, lookups for which segment to keep at each iteration are
    # trivial.
    # However, because of the problem conditions it just won't matter, both
    # run in milliseconds anyway.
    def oxygen_generator_rating(l):
        index = 0
        while len(l) > 1:
            zeroes, ones = bisect_on(l, lambda u: u[index] == '0')
            l = zeroes if len(zeroes) > len(ones) else ones
            index += 1
        return int(next(iter(l), None), 2)

    def co2_scrubber_rating(l):
        index = 0
        while len(l) > 1:
            zeroes, ones = bisect_on(l, lambda u: u[index] == '0')
            l = ones if len(ones) < len(zeroes) else zeroes
            index += 1
        return int(next(iter(l), None), 2)

    return oxygen_generator_rating(l) * co2_scrubber_rating(l)


In [240]:
power_consumption(input_3())


3320834

In [241]:
life_support_rating(input_3())


4481199

## Day 4

In [242]:
class BingoCard:
    def __init__(self, m):
        self.matrix = m
        self.num_rows = len(m)
        self.num_cols = len(m[0])
        self.hit_by_row = [0] * self.num_rows
        self.hit_by_col = [0] * self.num_cols
        self.reverse_map = {}
        for nr, r in enumerate(m):
            for nc, num in enumerate(r):
                self.reverse_map[num] = (nr, nc)

    def hit_num(self, n):
        if n not in self.reverse_map:
            return False
        nr, nc = self.reverse_map.pop(n)
        self.hit_by_row[nr] += 1
        if self.hit_by_row[nr] == self.num_cols:
            return True
        self.hit_by_col[nc] += 1
        if self.hit_by_col[nc] == self.num_rows:
            return True
        return False

    def sum_remaining(self):
        return sum(x for x in self.reverse_map)


@inputfunc(4, kind='chunks')
def input_4(*, f):
    seq = [int(x) for x in f[0].split(',')]
    cards = []
    for c in f[1:]:
        lines = c.split('\n')
        card = BingoCard([[int(u) for u in v.split()] for v in lines])
        cards.append(card)

    return cards, seq


def play_bingo_1(cards, seq):
    for num in seq:
        for card in cards:
            if card.hit_num(num):
                return num * card.sum_remaining()


def play_bingo_2(cards, seq):
    alive = set(cards)
    for num in seq:
        hits = set()
        for card in alive:
            if card.hit_num(num):
                if len(alive) == 1:
                    return num * card.sum_remaining()
                hits.add(card)
        alive -= hits


In [243]:
play_bingo_1(*input_4())


33348

In [244]:
play_bingo_2(*input_4())


8112

## Day 5

Disappoitingly, the naive algorithm. I have thought about it for a while and short of reimplementing from scratch a full-blown R Tree-backed database table, I am completely clueless as to how you would solve this.

In [245]:
class Board:
    def __init__(self):
        self._point_count = Counter()

    def put(self, x1, y1, x2, y2, with_diagonals=False):
        def get_range_for(c1, c2):
            if c1 == c2:
                return repeat(c1)
            elif c1 > c2:
                return range(c1, c2 - 1, -1)
            else:
                return range(c1, c2 + 1, 1)

        range_x = get_range_for(x1, x2)
        range_y = get_range_for(y1, y2)

        if (not with_diagonals and x1 != x2 and y1 != y2):
            return

        for x, y in zip(range_x, range_y):
            self._point_count.update({(x, y): 1})

    def overlaps(self):
        return set(p for p, n in self._point_count.items() if n >= 2)


@inputfunc(5)
def input_5(*, f):
    return [
        tuple(
            int(a)
            for a in line.replace('->', ',').split(',')
        )
        for line in f
    ]


def find_overlaps(l, with_diagonals=False):
    board = Board()
    for t in l:
        board.put(*t, with_diagonals=with_diagonals)
    return len(board.overlaps())


In [246]:
find_overlaps(input_5())


5092

In [247]:
find_overlaps(input_5(), with_diagonals=True)


20484

## Day 6

numpy == chet

Automorphisms of finite-dimensional real vector spaces are what plebeians call matrices.

In [248]:
@inputfunc(6, kind='single')
def input_6(*, f):
    return Counter(int(u) for u in f.split(','))


def get_lanternfish(state, itn):
    T = np.array([
        [0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1],
        [1, 0, 0, 0, 0, 0, 0, 0, 0]
    ])

    A = np.array([state.get(i, 0) for i in range(9)])
    return np.sum(np.linalg.matrix_power(T, itn) @ A)


In [249]:
get_lanternfish(input_6(), 80)


359344

In [250]:
get_lanternfish(input_6(), 256)


1629570219571

## Day 7

math == chet

Let $f_k(x)$ be the cost for the $k$-th crab to meet at $x$. Then the total cost $c(x) = \sum_k f_k(x)$ is minimized only if:

$$ 
    \frac{\partial}{\partial x} \sum_k f_k(x) = 0
$$

#### Case 1

Let the initial position of the $k$-th crab be $p_k$, and let $N$ be the number of crabs. Then $f_k(x) = \left| x - p_k\right|$. The derivative of the total cost function is then:

$$
    \frac{\partial}{\partial x} \sum_k \left| x - p_k \right| =
    \sum_k \sigma(x - p_k)
$$

Where $\sigma$ is the sign function, that is, $\sigma(0) = 0$ and $\sigma(x) = x / |x|$ for all $x \neq 0$. It trivially follows that $c^{\prime}(x) = 0$ if and only if $x$ is a median of $\{f_k\}_k$. Before mathematicians get a fucking aneurism from this: yes, of course $c$ isn't continuous, but it's continuous everywhere except on a finite subset of $\mathbb{R}$. It's trivial to check everything still works on those isolated points.

#### Case 2

Using the same notation as case 1, $f_k(x) = 1/2 \left( \left| x - p_k\right| + \left( x - p_k \right)^{2} \right)$. Therefore:

$$
    \frac{\partial c(x)}{\partial x} = 0 \Leftrightarrow 
    \sum_k { \sigma(x - p_k) + 2x - 2p_k} = 0 \Leftrightarrow
    \sum_k { \sigma(x - p_k) } + 2Nx - 2 \sum_k p_k = 0
$$

Shifting around the terms we get:

$$
    \frac{\partial c(x)}{\partial x} = 0 \Leftrightarrow
    x = \frac{1}{N} \sum_k p_k  - \frac{1}{2N} \sum_k \sigma(x - f_k)
$$

At this point we have two options. We either proceed to prove in detail that because we are limited to choosing integral values of $x$ then it is possible to bound the set of candidate solutions to the equation to a compact subset of $\mathbb{R}$ with inequalities and fancy stuff I'm supposed to know, or we 360 noscope yoloswag trust that the last term is _just too small to matter_.
Hold my beer.

In [251]:
@inputfunc(7, kind='single')
def input_7(*, f):
    return [int(u) for u in f.split(',')]


def cost_linear(i):
    m = int(median(i))
    return sum(abs(x - m) for x in i)


def cost_quadratic(i):
    def cost_for_crab(x, m):
        d = abs(x - m)
        return d * (d + 1) // 2

    # The mean is not guaranteed to be integral.
    # We simply "try" both the lower integer part and the
    # higher integer part and keep the best.
    al, ar = floor(mean(i)), ceil(mean(i))
    return min(
        sum(cost_for_crab(x, al) for x in i),
        sum(cost_for_crab(x, ar) for x in i)
    )


In [252]:
cost_linear(input_7())


344605

In [253]:
cost_quadratic(input_7())


93699985

## Day 8

In [254]:
@dataclass
class DisplayInstance:
    patterns: list[set[str]]
    outputs: list[set[str]]

    def get_output_value(self):
        patterns_bylen = {
            n: [set(x) for x in self.patterns if len(x) == n]
            for n in {2, 3, 4, 5, 6, 7}
        }

        x = [None] * 10

        # The pattern with 2 segments must be a one.
        x[1] = next(iter(patterns_bylen[2]))

        # The pattern with 3 segments must be a seven.
        x[7] = next(iter(patterns_bylen[3]))

        # The pattern with 4 segments must be a four.
        x[4] = next(iter(patterns_bylen[4]))

        # The pattern with 7 segments must be an eight.
        x[8] = next(iter(patterns_bylen[7]))

        # The patterns with 6 segments bust represent
        # zero, six and nine.
        for p in patterns_bylen[6]:
            # Nine is the only one that "contains" a four.
            if p > x[4]:
                x[9] = p
            # If it's not a nine, it's either a zero or a six.
            # Zero "contains" a one, and six doesn't.
            elif p > x[1]:
                x[0] = p
            else:
                x[6] = p

        # The remaining patterns with 5 segments must
        # represent two, three and five.
        for p in patterns_bylen[5]:
            # Three is the only one that "contains" a one.
            if p > x[1]:
                x[3] = p
            # If it's not a three, it's either a five or a two.
            # Six "contains" a five, but not a two.
            elif x[6] > p:
                x[5] = p
            else:
                x[2] = p

        return sum(
            n * 10**e
            for e, p in enumerate(reversed(self.outputs))
            for n, i in enumerate(x)
            if i == p
        )


@inputfunc(8)
def input_8(*, f):
    return [
        DisplayInstance(*[
            [set(u) for u in x.strip().split()]
            for x in l.split('|')
        ])
        for l in f
    ]


In [255]:
sum(sum(1 for u in x.outputs if len(u) in {2, 3, 4, 7}) for x in input_8())


521

In [256]:
sum(x.get_output_value() for x in input_8())


1016804

## Day 9

In [257]:
class HeightMap:
    def __init__(self, matrix):
        self._matrix = matrix
        self._width = len(matrix)
        self._height = len(matrix[0])

    def point(self, x, y):
        return self._matrix[x][y]

    def get_neighbors_of(self, x, y):
        return [
            (x+dx, y+dy, self.point(x+dx, y+dy))
            for dx, dy in {(1, 0), (-1, 0), (0, 1), (0, -1)}
            if 0 <= x+dx < self._width and 0 <= y+dy < self._height
        ]

    def get_basin_of(self, x, y):
        seen = {(x, y)}
        q = [(x, y, self.point(x, y))]
        while len(q) != 0:
            x, y, p = q.pop()
            for x1, y1, p1 in self.get_neighbors_of(x, y):
                if (x1, y1) not in seen and p1 > p and p1 != 9:
                    q.append((x1, y1, p1))
                    seen.add((x1, y1))
        return seen

    def get_lowpoints(self):
        return [
            (x, y, self.point(x, y))
            for x in range(self._width)
            for y in range(self._height)
            if all(self.point(x, y) < v for _, _, v in self.get_neighbors_of(x, y))
        ]


@inputfunc(9)
def input_9(*, f):
    return [[int(c) for c in x] for x in f]


In [258]:
def sum_lowpoints(i):
    hm = HeightMap(i)
    return sum(1 + p for _, _, p in hm.get_lowpoints())


sum_lowpoints(input_9())


456

In [259]:
def largest_basins(i, n):
    hm = HeightMap(i)
    basins_sizes = [len(hm.get_basin_of(x, y))
                    for x, y, _ in hm.get_lowpoints()]
    return prod(sorted(basins_sizes, reverse=True)[:n])


largest_basins(input_9(), 3)


1047744

## Day 10

In [260]:
@inputfunc(10)
def input_10(*, f):
    return f


@dataclass
class UnmatchedChar:
    char: str

    def score(self):
        scores = {
            ')': 3,
            ']': 57,
            '}': 1197,
            '>': 25137
        }
        return scores[self.char]


@dataclass
class IncompleteLine:
    char: list[str]

    def score(self):
        scr = str.maketrans('([{<', '1234')
        return int(''.join(self.char).translate(scr)[::-1], 5)


def check_line(i):
    chars = {
        '(': ')',
        '[': ']',
        '{': '}',
        '<': '>'
    }

    def is_opening(c):
        return c in chars.keys()

    s = []
    for c in i:
        if is_opening(c):
            s.append(c)
        else:
            if c != chars[s.pop()]:
                return UnmatchedChar(c)

    return IncompleteLine(s)


In [261]:
def syntax_all(i):
    res = [check_line(l) for l in i]
    return sum(u.score() for u in res if isinstance(u, UnmatchedChar))


syntax_all(input_10())


319329

In [262]:
def autocomplete_all(i):
    res = [check_line(l) for l in i]
    return median(u.score() for u in res if isinstance(u, IncompleteLine))


autocomplete_all(input_10())


3515583998

## Day 11

In [263]:
@inputfunc(11)
def input_11(*, f):
    return np.array([[int(u) for u in s.strip()] for s in f])


def convolution(M, K, T=np.int64):
    mx, my = M.shape
    kx, ky = K.shape
    px = mx + 2 * (kx - 1)
    py = my + 2 * (ky - 1)
    P = np.zeros((px, py), T)
    P[kx-1:mx+kx-1, ky-1:my+ky-1] = M
    O = np.zeros((mx + kx - 1, my + ky - 1), T)

    for y in range(my + ky - 1):
        for x in range(mx + kx - 1):
            O[x, y] = (K * P[x:x+kx, y:y+ky]).sum()

    return O


def time_step(i, T=np.int64):
    K = np.array([[1, 1, 1], [1, 0, 1], [1, 1, 1]], T)
    count_fired = 0
    elegible = np.ones(i.shape, bool)
    i += np.ones(i.shape, T)
    fired = np.zeros(i.shape, T) + np.ones(i.shape, T) * (i > 9) * elegible
    while fired.sum() > 0:
        count_fired += fired.sum()
        elegible &= (fired <= 0)
        i += convolution(fired, K)[1:-1, 1:-1]
        fired = np.zeros(i.shape, T) + np.ones(i.shape, T) * (i > 9) * elegible

    return i * np.ones(i.shape, T) * (i < 10), count_fired


In [264]:
def count_fired_over_ts(i, n):
    count_fired = 0
    for _ in range(n):
        i, n = time_step(i)
        count_fired += n
    return count_fired


count_fired_over_ts(input_11(), 100)


1661

In [265]:
def find_first_sync(i):
    for j in count(1):
        i, _ = time_step(i)
        if i.sum() == 0:
            return j


find_first_sync(input_11())


334

## Day 12

In [266]:
@dataclass
class Node:
    label: str
    kind: str
    adj: list


@inputfunc(12)
def input_12(*, f):
    nodes = {
        'start': Node('start', 'start', []),
        'end': Node('end', 'end', [])
    }

    def node_kind(s):
        if s == 'start' or s == 'end':
            return s
        elif s.isupper():
            return 'big'
        else:
            return 'small'

    for l in f:
        u, v = l.split('-')
        un = nodes.get(u, Node(u, node_kind(u), []))
        vn = nodes.get(v, Node(v, node_kind(v), []))
        un.adj.append(v)
        vn.adj.append(u)
        nodes[u] = un
        nodes[v] = vn

    return nodes


In [267]:
def breadth_search_1(graph):
    small = {}
    for n, k in (
        enumerate(k for k, v in graph.items() if v.kind == 'small')
    ):
        small[k] = 2 ** n

    def visited(state):
        return {k for k, v in small.items() if state & v}

    Q = [(0, 'start')]
    N = {(0, 'start'): 1}
    acc = 0

    while len(Q) != 0:
        state, nodeid = Q.pop()
        n = N[(state, nodeid)]
        del N[(state, nodeid)]
        for v in graph[nodeid].adj:
            if v == 'start':
                continue
            if v == 'end':
                acc += n
                continue
            if v in visited(state):
                continue
            nstate = state | small.get(v, 0)
            if (nstate, v) in N:
                N[(nstate, v)] += n
            else:
                Q.append((nstate, v))
                N[(nstate, v)] = n

    return acc


breadth_search_1(input_12())


3495

In [268]:
def breadth_search_2(graph):
    small = {}
    for n, k in (
        enumerate(k for k, v in graph.items() if v.kind == 'small')
    ):
        small[k] = 2 ** n

    def visited(state):
        return {k for k, v in small.items() if state & v}

    Q = [(0, False, 'start')]
    N = {(0, False, 'start'): 1}
    acc = 0

    while len(Q) != 0:
        state, double, nodeid = Q.pop()
        n = N[(state, double, nodeid)]
        del N[(state, double, nodeid)]
        for v in graph[nodeid].adj:
            if v == 'start':
                continue
            if v == 'end':
                acc += n
                continue
            if v in visited(state):
                if double:
                    continue
                else:
                    ndouble = True
            else:
                ndouble = double
            nstate = state | small.get(v, 0)
            if (nstate, ndouble, v) in N:
                N[(nstate, ndouble, v)] += n
            else:
                Q.append((nstate, ndouble, v))
                N[(nstate, ndouble, v)] = n

    return acc


breadth_search_2(input_12())


94849

## Day 13

In [269]:
@inputfunc(13)
def input_13(*, f):
    dots, folds = [], []
    for s in f:
        if s.startswith('fold'):
            _, _, d, n = s.replace('=', ' ').split(' ')
            folds.append((d, int(n)))
        elif s != '':
            y, x = s.split(',')
            dots.append((int(x), int(y)))
    sx, sy = 1 + max(x for x, _ in dots), 1 + max(y for _, y in dots)
    M = np.zeros((sx+1 - sx % 2, sy+1 - sy % 2), bool)
    for p in dots:
        M[p] = True

    return M, folds


def fold_along(M, s, n):
    if s == 'y':
        G, H = M[:n, :], np.flipud(M[n+1:, ])
    elif s == 'x':
        G, H = M[:, :n], np.fliplr(M[:, n+1:])
    return G | H


def fold_along_all(M, l):
    return reduce(lambda X, f: fold_along(X, *f), folds, M)


In [270]:
M, folds = input_13()
fold_along(M, *folds[0]).sum()


666

In [271]:
M, folds = input_13()
folded = fold_along_all(M, folds)
sx, sy = folded.shape
for x in range(sx):
    s = ""
    for y in range(sy):
        s += "O" if folded[x, y] else " "
    print(s)


 OO    OO O  O  OO  OOOO O  O O  O O  O 
O  O    O O  O O  O    O O  O O O  O  O 
O       O OOOO O  O   O  OOOO OO   O  O 
O       O O  O OOOO  O   O  O O O  O  O 
O  O O  O O  O O  O O    O  O O O  O  O 
 OO   OO  O  O O  O OOOO O  O O  O  OO  


This is absolutely metal, as always, advent of code doesn't disappoint!

## Day 14

Yet again, linear algebra. Pairs of consecutive letters are the dimesions of the vector space, the number of occurences of each pair at a given iteration is a vector, and the transformation map defines a linear map over vectors. Let $m$ be the size of the alphabet and let $n$ be the number of iterations. There are two solutions with an interesting tradeoff:

* Represent the initial vector $v$ as an array, and manually apply $n$ times the morphism. This has a time complexity of $\Theta(m^2n)$.

* Represent the linear map as a $m \times m$ matrix, compute $M^n$, then multiply it by the initial vector. This has a time complexity of $\Theta(m^3 \log{n})$.

Both solutions have a space complexity of $\Theta(n^2)$. I'll do the second because it's more interesting. I also believe it's much more efficient in practice as it can fully take advantage of vectorization intrinsics and is generally very tight: just plain math, no hashtable lookups.

The actual solution is one line, everything else is just shuffling around data.

In [272]:
@inputfunc(14)
def input_14(*, f):
    m = {}
    init, l_start, l_end = None, None, None
    for l in f:
        if not init:
            init = [*zip(l, l[1:])]
            l_start, l_end = l[0], l[-1]
            continue
        if l == '':
            continue
        (u1, u2), v = l.replace('->', ' ').split()
        m[(u1, u2)] = ((u1, v), (v, u2))

    num = {e: n for n, e in enumerate(m)}

    M = np.zeros((len(num), len(num)), np.int64)
    V = np.zeros(len(num), np.int64)

    for u, (v1, v2) in m.items():
        M[num[v1], num[u]] += 1
        M[num[v2], num[u]] += 1

    for u in init:
        V[num[u]] += 1

    return M, V, l_start, l_end, {n: e for e, n in num.items()}


def polymer_count_after(steps, M, V, l_start, l_end, d):
    F = np.linalg.matrix_power(M, steps) @ V
    C = Counter({l_start: 1, l_end: 1})
    for i, n in enumerate(F):
        v1, v2 = d[i]
        C.update({v1: n})
        C.update({v2: n})

    return C


In [273]:
C = polymer_count_after(10, *input_14())
max(u//2 for _, u in C.items()) - min(u//2 for _, u in C.items())


3906

In [274]:
C = polymer_count_after(40, *input_14())
max(u//2 for _, u in C.items()) - min(u//2 for _, u in C.items())


4441317262452

## Day 15

In [275]:
class WrappingHeightMap:
    def __init__(self, matrix, wrap_max = 1):
        self._matrix = np.array(matrix, np.uint32)
        mx, my = self._matrix.shape
        self._wrap_matrix = np.empty((mx * wrap_max, my * wrap_max), np.uint32)
        for i in range(wrap_max):
            for j in range(wrap_max):
                self._wrap_matrix[i*mx:(i+1)*mx, j*my:(j+1)*my] = (self._matrix + i + j - 1) % 9 + 1

    def shape(self):
        return self._matrix.shape

    def wrap_shape(self):
        return self._wrap_matrix.shape

    def get_neighbors_of(self, x, y):
        ww, wh = self._wrap_matrix.shape
        n = []
        if x != 0:
            n.append((x-1, y, self._wrap_matrix[x-1, y]))
        if x != ww - 1:
            n.append((x+1, y, self._wrap_matrix[x+1, y]))
        if y != 0:
            n.append((x, y-1, self._wrap_matrix[x, y-1]))
        if y != wh - 1:
            n.append((x, y+1, self._wrap_matrix[x, y+1]))
        return n


class PriorityQueue:
    def __init__(self):
        self._H = []
        self._map = {}
        self._len = 0

    def __len__(self):
        return self._len

    def _parent(self, n):
        if n == 0:
            return None
        else:
            w = n + 1
            return w // 2 - 1

    def _children(self, n):
        if self._len > 2 * n + 2:
            return 2 * n + 1, 2 * n + 2
        elif self._len == 2 * n + 2:
            return 2 * n + 1, None
        else:
            return None, None

    def _swap(self, a, b):
        self._map[self._H[a][1]], self._map[self._H[b][1]] = b, a
        self._H[a], self._H[b] = self._H[b], self._H[a]
        
    def _heap_decrease(self, n, r):
        self._H[n] = r
        p = self._parent(n)

        while p != None and self._H[n] < self._H[p]:
            self._swap(n, p)
            n, p = p, self._parent(p)

    def __contains__(self, val):
        return val in self._map

    def insert(self, prio, val):
        self._H.append(None)
        self._map[val] = self._len
        self._heap_decrease(self._len, (prio, val))
        self._len += 1

    def peek(self):
        return self._H[0]

    def extract(self):
        top_prio, top_val = self._H[0]
        self._swap(0, self._len - 1)
        del self._H[-1]
        del self._map[top_val]
        self._len -= 1
        c, (l, r) = 0, self._children(0)

        while l or r:
            cprio, _ = self._H[c]
            lprio, _ = self._H[l] if l else (inf, inf)
            rprio, _ = self._H[r] if r else (inf, inf)

            if cprio <= lprio and cprio <= rprio:
                break 
            elif lprio < rprio:
                self._swap(c, l)
                c = l
            else:
                self._swap(c, r)
                c = r

            l, r = self._children(c)

        return top_prio, top_val

    def decrease_prio(self, newprio, val):
        idx = self._map[val]
        prio, _ = self._H[idx]
        self._heap_decrease(idx, (newprio, val))
        

def dijkstra(f, wrap_max):
    m = WrappingHeightMap(f, wrap_max=wrap_max)
    start = (0, 0)
    target = tuple(d - 1 for d in m.wrap_shape())

    M = {(start): 0}
    Q = PriorityQueue()
    Q.insert(0, (0, 0))
    
    while len(Q) != 0:
        v, (x, y) = Q.extract()

        for xn, yn, vn in m.get_neighbors_of(x, y):
            if v + vn >= M.get((xn, yn), inf):
                continue

            M[(xn, yn)] = v + vn

            if (xn, yn) in Q:
                Q.decrease_prio(v + vn, (xn, yn))
            else:
                Q.insert(v + vn, (xn, yn))
            
        if (x, y) == target:
            return M[(x, y)]

    return inf


@inputfunc(15)
def input_15(*, f):
    return [[int(n) for n in x] for x in f]


In [276]:
dijkstra(input_15(), 1)


619

In [277]:
dijkstra(input_15(), 5)


2922