In [2]:
#from https://www.reddit.com/r/adventofcode/comments/rl6p8y/2021_day_21_solutions/hpf0wy8/?context=3
from collections import Counter
from dataclasses import dataclass
from functools import lru_cache
from itertools import product
from typing import Iterable, List, Tuple

Pair = Tuple[int, int]


FREQS = Counter(a + b + c for a, b, c in product([1, 2, 3], repeat=3))


@dataclass(frozen=True)
class State:
    positions: Pair
    scores: Pair

    def __init__(self, positions: Iterable[int], scores: Iterable[int]):
        object.__setattr__(self, "positions", tuple(positions))
        object.__setattr__(self, "scores", tuple(scores))


def vadd(a: Pair, b: Pair) -> Pair:
    return (a[0] + b[0], a[1] + b[1])


def vmul(x: int, b: Pair) -> Pair:
    return (x * b[0], x * b[1])


def move(val: int, p: int, state: State) -> State:
    pos = list(state.positions)
    sco = list(state.scores)
    pos[p] = (pos[p] + val - 1) % 10 + 1
    sco[p] += pos[p]
    return State(pos, sco)


@lru_cache(maxsize=None)
def play(p: int, state: State) -> Pair:
    if state.scores[0] >= 21:
        return (1, 0)
    if state.scores[1] >= 21:
        return (0, 1)
    nextp = 1 if p == 0 else 0
    result = (0, 0)
    for val, freq in FREQS.items():
        played = play(nextp, move(val, p, state))
        result = vadd(result, vmul(freq, played))
    return result


def part1(positions: List[int]) -> int:
    MAX = 1000
    i = 0
    state = State(positions, [0, 0])

    while state.scores[0] < MAX and state.scores[1] < MAX:
        p = i % 2
        val = 3 * (3 * i + 1) + 3
        state = move(val, p, state)
        i += 1

    return next(s for s in state.scores if s < MAX) * i * 3


def part2(positions: List[int]) -> int:
    p1, p2 = play(0, State(positions, (0, 0)))
    return max(p1, p2)


if __name__ == "__main__":
    # problem = [4, 8]  # test
    problem = [1, 3]  # input

    print("Part 1:", part1(problem))
    print("Part 2:", part2(problem))

Part 1: 897798
Part 2: 48868319769358


In [3]:
FREQS

Counter({3: 1, 4: 3, 5: 6, 6: 7, 7: 6, 8: 3, 9: 1})

In [None]:
#part 2
#question: In how many universes does each player win?
#no cigar
from itertools import product, permutations
from functools import cache

@cache
def play(game, pos1, pos2):
    sc1 = 0
    sc2 = 0
    player = 2
    for die in list(map(int, game)):
        player = 1 if player == 2 else 2
        if player == 1:
            pts = (pos1+die)%10
            pos1 = 10 if pts == 0 else pts
            sc1 += pos1
        else:
            pts = (pos2+die)%10
            pos2 = 10 if pts == 0 else pts
            sc2 += pos2

        if sc1 >= 21 or sc2 >= 21:
            return player       
    
    return 0

def main():
    p1w = 0
    p2w = 0
    r = 0
    opengames = map("".join, product("123", repeat=1))
    while opengames != []:
        r+=1
        newgames = map("".join, product("123", repeat=r))
        opengames = list(newgames) #[n for n in newgames if any(n.startswith(o) for o in opengames)]
        games = opengames
        for i, game in enumerate(games):
            result = play(game, 4, 8)
            if result == 1:
                p1w +=1
                opengames.pop(i)
            elif result == 2:
                p2w += 1
                opengames.pop(i)
                
            if opengames == []:
                break
                
        print(f"Update: p1w == {p1w} and p2w == {p2w}")
   
    return f"\n\nFinal result\nWins for player one: {p1w}\nWins for player two: {p2w}."

if __name__ == "__main__":
    print(main())

In [None]:
#part 1
#question: after the game ends multiply score of loser times number of rounds

class Dirac():
    """simulate a game of Dirac dice, https://adventofcode.com/2021/day/21"""
    
    def __init__(self, pos1, pos2, show=False):
        self.player = 1
        self.sc1 = 0
        self.sc2 = 0
        self.die = 1
        self.pos1 = pos1
        self.pos2 = pos2
        self.show = show
    
    def play(self):
        add = 3* self.die + 3
        self.die += 3
        if self.player == 1:
            self.pos1 = 10 if (self.pos1+add)%10 == 0 else (self.pos1+add)%10
            self.sc1 += self.pos1
            if self.show:
                print(f"Player {self.player} rolls {add} and moves to space {self.pos1} for a total score of {self.sc1}")
            self.player = 2
        else:
            self.pos2 = 10 if (self.pos2+add)%10 == 0 else (self.pos2+add)%10
            self.sc2 += self.pos2
            if self.show:
                print(f"Player {self.player} rolls {add} and moves to space {self.pos2} for a total score of {self.sc2}")
            self.player = 1

def main(start_p1, start_p2, show=False):
    D = Dirac(start_p1, start_p2, show)
    print("starting to play")
    while D.sc1 < 1000 and D.sc2 < 1000:
        D.play()
    result = min([D.sc1, D.sc2])*(D.die-1)
    print(f"The game has ended.\nThe die has been rolled {D.die-1} times.\nThe losing player had {min([D.sc1, D.sc2])} points.\n\
Multiplying these gives {result}")
    return result

if __name__ == "__main__":
    main(4, 8)

In [None]:
D = Dirac(4,8)
while D.sc1 <= 1000 or D.sc2 <= 1000:
    D.play()
    
D.rnd