In [5]:
import re
from collections import Counter
from functools import cache
from itertools import product

from aoc import submit

DAY = 21

In [6]:
def parse_input(raw):
    return [int(p) - 1 for p in re.findall(": (\d+)", raw)]


@submit(day=DAY)
def part_one(raw):
    pos = parse_input(raw)
    score, d = [0, 0], 2
    for r in range(420):
        i = r % 2
        pos[i] = (pos[i] + d * 3) % 10
        score[i] += pos[i] + 1
        d = (d + 3) % 100
        if max(score) >= 1000:
            return min(score) * ((r + 1) * 3)
    return

part_one:
✅ example: 739785         (0.19 ms)
✅ input:   506466         (0.31 ms)


In [11]:
ROLL_FREQUENCY = Counter(map(sum, product([1, 2, 3], repeat=3)))

@cache
def simulate(p1, p2, s1=0, s2=0, p1_turn=True):
    w1, w2 = 0, 0
    for roll, freq in ROLL_FREQUENCY.items():
        if p1_turn:
            next_p1 = (p1 + roll) % 10
            next_s1 = s1 + next_p1 + 1
            if next_s1 >= 21:
                w1 += freq
            else:
                nw1, nw2 = simulate(next_p1, p2, next_s1, s2, not p1_turn)
                w1 += freq*  nw1
                w2 += freq * nw2
        else:
            next_p2 = (p2 + roll) % 10
            next_s2 = s2 + next_p2 + 1
            if next_s2 >= 21:
                w2 += freq
            else:
                nw1, nw2 = simulate(p1, next_p2, s1, next_s2, not p1_turn)
                w1 += freq * nw1
                w2 += freq * nw2
    return w1, w2

@submit(day=DAY)
def part_two(raw):
    pos = parse_input(raw)
    simulate.cache_clear()
    wins = simulate(*pos)
    return max(wins)

part_two:
✅ example: 444356092776315 (83.76 ms)
✅ input:   632979211251440 (75.19 ms)
