In [1]:
from typing import Tuple, Mapping, Dict, Sequence
from rl.markov_process import NonTerminal
from rl.markov_decision_process import FiniteMarkovDecisionProcess
from rl.dynamic_programming import value_iteration_result
from rl.distribution import Categorical
from scipy.stats import multinomial
import numpy as np
import itertools

DiceGameState = Tuple[Sequence[int], int, int]
DiceGameAction = Sequence[int]

DiceGameTransitionsMap = Mapping[DiceGameState, Mapping[
    DiceGameAction,
    Categorical[Tuple[DiceGameState, float]]
]]

In [2]:
class DiceGame(FiniteMarkovDecisionProcess[DiceGameState, DiceGameAction]):
    '''
    This implementation of DiceGame was heavily based on the implementation of the
    CareerOptimization from the 2021 midterm.
    '''
    def __init__(
        self,
        N: int,
        K: int,
        C: int
    ):
        self.N : int = N
        self.K : int = K
        self.C : int = C
        self.nonTerminalTables : np.ndarray = self.get_nonTerminalTables(N, K)
        self.TERMINAL : DiceGameState = ((0,)*K, 0, 0)
        
        super().__init__(self.get_transitions())
    
    def get_nonTerminalTables(self, N, K):
        ''' Returns a numpy array whose rows are the non-terminal table configurations '''
        Z_N_K = np.array(list(np.ndindex((N+1,)*K))) # All length-K lists whose elements are in (0,...,N)
        nDice = np.sum(Z_N_K,axis=1)
        return Z_N_K[(nDice <= N) & (nDice > 0)]
    
    def get_nDiceTables(self, n):
        ''' Returns a numpy array whose rows are the (non-terminal) table configurations with exactly n dice'''
        return self.nonTerminalTables[np.sum(self.nonTerminalTables,axis=1)==n]
    
    def get_transitions(self) -> DiceGameTransitionsMap:
        d: Dict[DiceGameState, Mapping[DiceGameAction, Categorical[Tuple[DiceGameState, float]]]] = {}
        
        for initTable in self.nonTerminalTables:
            n = sum(initTable)
            for c in range(self.C+1):
                for s in range(self.K*(self.N - n)+1):
                    d1: Dict[DiceGameAction, Categorical[Tuple[DiceGameState, float]]] = {}
                    
                    # An action is a non-empty table configuration whose values don't exceed the current values
                    actionList = self.nonTerminalTables[np.max(self.nonTerminalTables>initTable,axis=1)==0]
                    
                    for action in actionList:
                        c_new = min(c + action[0],self.C)
                        s_new = s + sum((i+1)*ai for i, ai in enumerate(action))
                        n_new = n - sum(action)
                        
                        if n_new == 0:
                            reward = (c_new==self.C) * s_new
                            d1[tuple(action)] = Categorical({(self.TERMINAL, reward) : 1.0})
                        else:
                            rv = multinomial(n_new, (1/self.K,)*self.K)
                            possible_new_tables = self.get_nDiceTables(n_new)
                            
                            reward = 0
                            sr_probs = { ((tuple(table),c_new,s_new),reward) : rv.pmf(table) for table in possible_new_tables }
                            d1[tuple(action)] = Categorical(sr_probs)
                    d[(tuple(initTable), c, s)] = d1
        return d

In [6]:
dg = DiceGame(N=6,K=4,C=1)

values, policy = value_iteration_result(dg,gamma=1)

In [4]:
# Our expected score is the sum of p(S)*v(S) where p(S) is the probability of S being 
# the initial state, and v(S) is the optimal value function evaluated on S
# In our case, the initial tables are distributed as a multinomial(n=N,p=(1/K,)*K), and 
# the initial values of s and c are both zero
initialTables = dg.get_nDiceTables(N)
initialProbabilities = [multinomial.pmf(table, n=N, p=(1/K,)*K) for table in initialTables]
initialValues = [values[NonTerminal((tuple(table),0,0))] for table in initialTables]
expectedScore = sum(p*v for p,v in zip(initialProbabilities, initialValues))
print(expectedScore)

18.39039025377679


In [5]:
# A starting roll of {1,2,2,3,3,4} is encoded in our state space as ((1,2,2,1), c=0, s=0)
policy.action_for[((1,2,2,1),0,0)]

(0, 0, 0, 1)