# Day 21: Dirac Dice

In [1]:
from typing import Iterable, Callable
from itertools import repeat, chain
from more_itertools import take
import re
from functools import cache

## Puzzle input data

In [2]:
def parse_line(line: str) -> int:
    return int(re.findall('[1-9]+', line)[1])

# Test data.
tdata = tuple(map(parse_line, [
    'Player 1 starting position: 4',
    'Player 2 starting position: 8',
]))

# Input data.
data = tuple(map(parse_line, [
    'Player 1 starting position: 6',
    'Player 2 starting position: 7',
]))
data

(6, 7)

## Puzzle answers
### Part 1

In [3]:
Input = tuple[int]


def determ_die(nsides: int = 100) -> Iterable[int]:
    return chain.from_iterable(repeat(range(1, nsides+1)))


def update_pos(pos: int, nsteps: int) -> int:
    return (pos + nsteps - 1) % 10 + 1


def dirac_game(starts: tuple[int], die: Callable=determ_die) -> Iterable[tuple[tuple[int],tuple[int]]]:
    nplayers = len(starts)
    positions = list(starts)
    scores = [0]*nplayers
    rolls = die()
    while True:
        for i in range(nplayers):
            rollsum = sum(take(3, rolls))
            positions[i] = update_pos(positions[i], rollsum)
            scores[i] += positions[i]
            yield tuple(scores), tuple(positions)
            if scores[i] >= 1000:
                return


def solution(data: Input) -> int:
    states = list(dirac_game(data))
    nrolls = len(states) * 3
    loser_score = min(states[-1][0])
    return nrolls * loser_score


assert take(7, determ_die(5)) == [1,2,3,4,5,1,2]
assert [update_pos(1, n) for n in [8,9,10,11]] == [9,10,1,2]
assert len(list(dirac_game(tdata))) == 331
assert solution(tdata) == 739785

In [4]:
n = solution(data)
print(f'Multiplying the score of the losing player by the number of times the die was rolled: {n}')

Multiplying the score of the losing player by the number of times the die was rolled: 921585


### Part 2

In [5]:
def quantum_die(nsides: int = 3) -> Iterable[tuple[int]]:
    return repeat(tuple(range(1, nsides+1)))


def next_player(current: int) -> int:
    return (current + 1) % 2


@cache
def dirac_game(init_ps: tuple[int], init_ss: tuple[int], player: int, rolls: tuple[int], die: Callable=quantum_die) -> tuple[int]:
    if len(rolls) == 3:
        ps = list(init_ps)
        ss = list(init_ss)
        ps[player] = update_pos(ps[player], sum(rolls))
        ss[player] += ps[player]
        if max(ss) >= 21:
            return (1,0) if ss[0] > ss[1] else (0,1)
        init_ps = tuple(ps)
        init_ss = tuple(ss)
        player = next_player(player)
        rolls = tuple()
    p1wins, p2wins = zip(*(dirac_game(init_ps, init_ss, player, tuple(sorted(rolls + (roll,)))) for roll in next(die())))
    return (sum(p1wins), sum(p2wins))


def solution(data: Input) -> int:
    nwins = dirac_game(init_ps=data, init_ss=(0,0), player=0, rolls=tuple())
    return max(nwins)


assert take(2, quantum_die(3)) == [(1,2,3),(1,2,3)]
assert list(map(next_player, [0,1])) == [1,0]
assert solution(tdata) == 444356092776315

In [6]:
n = solution(data)
print(f'The player that wins in more universes wins in {n} universes.')

The player that wins in more universes wins in 911090395997650 universes.
