In [1]:
SAMPLE_TEXT = """
Player 1 starting position: 4
Player 2 starting position: 8
"""

In [2]:
def tokenize_line(line):
    return int(line.split(" ")[-1])

def parse_text(raw_text):
    return [tokenize_line(l) for l in raw_text.split("\n") if l]

def read_input():
    with open("input.txt", "rt") as f:
        return f.read()

def split_input(lines):
    return lines[0], lines[1:]

In [115]:
from functools import cache
from dataclasses import dataclass
from collections import defaultdict

In [99]:
class DeterministicDice:
    def __init__(self):
        self.value = 0
        self.roll_count = 0

    def roll(self):
        self.value += 1
        self.roll_count += 1
        return self.value

    def roll3(self):
        return sum([self.roll(), self.roll(), self.roll()])

class Player:
    def __init__(self, initial_position):
        self.real_position = initial_position - 1
        self.score = 0

    def display_position(self):
        return self.real_position % 10 + 1

    def advance(self, spaces):
        self.real_position += spaces
        ending_position = self.display_position()
        self.score += ending_position


In [91]:
def play_game(loc1, loc2):
    p1 = Player(loc1)
    p2 = Player(loc2)
    dice = DeterministicDice()
    while True:
        p1.advance(dice.roll3())
        if p1.score >= 1000:
            break
        p2.advance(dice.roll3())
        if p2.score >= 1000:
            break
    if p1.score > p2.score:
        return p1, p2, dice
    else:
        return p2, p1, dice

In [63]:
loc1, loc2 = parse_text(SAMPLE_TEXT)
p1 = Player(loc1)
p2 = Player(loc2)
dice = DeterministicDice()
while True:
    p1.advance(dice.roll3())
    if p1.score >= 1000:
        break
    p2.advance(dice.roll3())
    if p2.score >= 1000:
        break

print(p1.score, p1.display_position(), p2.score, p2.display_position(), dice.roll_count)

1000 10 745 3 993


In [22]:
winner, loser, dice = play_game(*parse_text(SAMPLE_TEXT))
loser.score * dice.roll_count

739785

In [24]:
winner, loser, dice = play_game(*parse_text(read_input()))
loser.score * dice.roll_count

671580

In [111]:
DIRAC_ROLL_3 = [
            (3, 1),
            (4, 3),
            (5, 6),
            (6, 7),
            (7, 6),
            (8, 3),
            (9, 1)
        ]

# Inspired by https://github.com/SwampThingTom/AoC2021/blob/main/Python/21-DiracDice/DiracDice.py
@cache
def count_quantum_wins(p1_location, p2_location, p1_score, p2_score, player):
    if p1_score >= 21:
        return 1, 0
    if p2_score >= 21:
        return 0, 1

    total_p1_wins = 0
    total_p2_wins = 0
    next_player = 'p2' if player == 'p1' else 'p1'

    for distance, count in DIRAC_ROLL_3:
        next_p1_location = p1_location
        next_p1_score = p1_score
        next_p2_location = p2_location
        next_p2_score = p2_score

        if player == 'p1':
            next_p1_location += distance
            if next_p1_location > 10:
                next_p1_location -= 10
            next_p1_score += next_p1_location
        else:
            next_p2_location += distance
            if next_p2_location > 10:
                next_p2_location -= 10
            next_p2_score += next_p2_location

        p1_wins, p2_wins = count_quantum_wins(next_p1_location, next_p2_location, next_p1_score, next_p2_score, next_player)
        total_p1_wins += p1_wins * count
        total_p2_wins += p2_wins * count

    return total_p1_wins, total_p2_wins

In [112]:
p1_loc, p2_loc = parse_text(SAMPLE_TEXT)
count_quantum_wins(p1_loc, p2_loc, 0, 0, 'p1')

(444356092776315, 341960390180808)

In [113]:
p1_loc, p2_loc = parse_text(read_input())
count_quantum_wins(p1_loc, p2_loc, 0, 0, 'p1')

(912857726749764, 598173233581909)

In [158]:
# This is an iterative approach that doesn't quite work. :sadpanda:
# The idea is to keep track of a count of unique game states and keep iterating through the game
# until there are no more game states.

@dataclass(frozen=True)
class GameState:
    p1_location: int
    p2_location: int
    p1_score: int
    p2_score: int


def play_quantum_game(p1_location, p2_location):

    wins = {
        '1': 0,
        '2': 0,
    }

    game_states = defaultdict(int)
    game_states[GameState(p1_location, p2_location, 0, 0)] = 1

    def expand(game_state: GameState, player):
        result = []
        for distance, count in DIRAC_ROLL_3:
            if player == '1':
                new_location = game_state.p1_location + distance
                if new_location > 10:
                    new_location -= 10
                new_score = game_state.p1_score + new_location
                result.append((GameState(new_location, game_state.p2_location, new_score, game_state.p2_score), count))
            else:
                new_location = game_state.p2_location + distance
                if new_location > 10:
                    new_location -= 10
                new_score = game_state.p2_score + new_location
                result.append((GameState(game_state.p1_location, new_location, game_state.p1_score, new_score), count))
        return result

    player = '1'
    it = 0
    while game_states:
        it += 1
        if it == 100:
            print(wins)
            print(len(game_states))
            it = 0
        for game_state in list(game_states.keys()):
            count = game_states[game_state]
            if count == 0:
                continue
            game_states[game_state] -= count
            del game_states[game_state]
            for new_state, new_count in expand(game_state, player):
                if new_state.p1_score >= 21:
                    wins['1'] += count * new_count
                elif new_state.p2_score >= 21:
                    wins['2'] += count * new_count
                else:
                    game_states[new_state] += count * new_count
        if player == '1':
            player = '2'
        else:
            player = '1'

    return wins

In [159]:
play_quantum_game(*parse_text(SAMPLE_TEXT))

{'1': 339084638879874, '2': 197600034241633}