# Imports

In [1]:
import torch
import scipy
import numpy as np

In [2]:
int(scipy.special.binom(14,6))

3003

# State space
We need to specify the size of the state space.
Suppose the value of a worm is $6$.

We define two kinds of state, two kinds of action, and two kinds of transition matrices : the state in which we need to choose whether or not to throw the remaining dice, and the state in which we need to decide which dice to pick.


A state can be described by:
1. The current sum of values of the dice drawn.
2. The values of dice already picked
3. Nb de dés restants
4. Le fait qu'on ait fait un tirage ou pas, et la valeur du tirage le cas échéant

For state $s_1$:

1. Current value: $49$ possible values + $1$ absorbing state
2. Current dice values picked: $2^6$ possible values
3. Number of remaining dice: $9$ possible values

total : $28224$

For state $s_2$:
1. Current value: $49$ possible values
2. Current dice values picked: $2^6$ possible values
3. Number of remaining dice: $9$ possible values
4. Dice drawn: $\binom{14}{6}$ possible values (sticks and stones)

total : $49\times 2^6 \times 9 \times \binom{14}{6}= 84,756,672$

Possible actions after $s_1$:

1. Throw the dice or not

total: $2$

Possible actions after $s_2$:

1. Pick any of the numbers. If we pick an unavailable value, we lose (negative reward + we get into the absorbing state with $0$ dice remaining)

total: $6$


On ne peut pas matérialiser la matrice de transition de l'état $s_1$ à l'état $s_2$: ça ferait
$(49\times 2^6 \times 9)^2 \times \binom{14}{6}*2 = 4.784344621056 × 10^{12}$

Mais on n'est pas obligés d'avoir $2$ matrices d'états différents : on peut obliger l'agent à prendre la décision sans connaître le résultat du lancer.

On suppose qu'un tour est composé de $6\times 2$ actions, car il y a seulement $6$ valeurs de dés possibles.

Il faut donc calculer pour chacun des $6$ tours, la valeur de chaque état récursivement.

On doit créer un état absorbant qui signifie que le joueur a perdu suite à une action interdite

In [34]:
current_values_size = 50 # 0-48
current_dice_values_picked_size = 2**6 # subsets of 6 dice
remaining_dice = 9 #0-9
draw_result_size = int(scipy.special.binom(14,6))

action_shape_1 = 2

action_shape_2 = 6

c = -3

In [4]:
interval_1 = int(scipy.special.binom(9,1))# Etats où on a de 0 à 8 fois le dé 1
interval_2 = int(scipy.special.binom(10,2))# Etats où on a de 0 à 8 fois le dé 1 et de 0 à 8 fois le dé 2
interval_3 = int(scipy.special.binom(11,3))# Etats où on a de 0 à 8 fois le dé 1 et de 0 à 8 fois le dé 2 et de 0 à 8 fois le dé 3
interval_4 = int(scipy.special.binom(12,4))
interval_5 = int(scipy.special.binom(13,5))
interval_6 = int(scipy.special.binom(14,6))

In [89]:
intervals_0 = [1,1,1,1,1,1,1,1]
intervals_1 = [int(scipy.special.binom(i,0)) for i in range(8,0,-1)]
intervals_2 = [int(scipy.special.binom(i,1)) for i in range(9,1,-1)]
intervals_3 = [int(scipy.special.binom(i,2)) for i in range(10,2,-1)]
intervals_4 = [int(scipy.special.binom(i,3)) for i in range(11,3,-1)]
intervals_5 = [int(scipy.special.binom(i,4)) for i in range(12,4,-1)]
intervals_6 = [int(scipy.special.binom(i,5)) for i in range(13,5,-1)]
c_intervals_0 = np.array([0,1,2,3,4,5,6,7,8])
c_intervals_1 = np.cumsum(intervals_1)
c_intervals_2 = np.cumsum(intervals_2)
c_intervals_3 = np.cumsum(intervals_3)
c_intervals_4 = np.cumsum(intervals_4)
c_intervals_5 = np.cumsum(intervals_5)
c_intervals_6 = np.cumsum(intervals_6)

In [90]:
print(intervals_0)
print(intervals_1)
print(intervals_2)
print(intervals_3)
print(intervals_4)
print(intervals_5)
print(intervals_6)

[1, 1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1]
[9, 8, 7, 6, 5, 4, 3, 2]
[45, 36, 28, 21, 15, 10, 6, 3]
[165, 120, 84, 56, 35, 20, 10, 4]
[495, 330, 210, 126, 70, 35, 15, 5]
[1287, 792, 462, 252, 126, 56, 21, 6]


In [91]:
scipy.special.binom(13,6) + scipy.special.binom(12,5) + scipy.special.binom(11,4) + scipy.special.binom(10,3) + scipy.special.binom(9,2) + scipy.special.binom(8,1) + scipy.special.binom(7,0)

3003.0

In [92]:
print(c_intervals_0)
print(c_intervals_1)
print(c_intervals_2)
print(c_intervals_3)
print(c_intervals_4)
print(c_intervals_5)
print(c_intervals_6)

[0 1 2 3 4 5 6 7 8]
[1 2 3 4 5 6 7 8]
[ 9 17 24 30 35 39 42 44]
[ 45  81 109 130 145 155 161 164]
[165 285 369 425 460 480 490 494]
[ 495  825 1035 1161 1231 1266 1281 1286]
[1287 2079 2541 2793 2919 2975 2996 3002]


In [83]:
c_intervals_0[-1] + c_intervals_1[-1] + c_intervals_2[-1] + c_intervals_3[-1] + c_intervals_4[-1] + c_intervals_5[-1]

3002

In [46]:
c_intervals_5

array([ 792, 1254, 1506, 1632, 1688, 1709, 1715, 1716])

In [98]:
c_intervals_6 = [0] + list(c_intervals_6)

In [99]:
c_intervals_6

[0, 1287, 2079, 2541, 2793, 2919, 2975, 2996, 3002]

In [9]:
def index_partition():
    """Via la hockey stick identity, on construit l'ensemble des jetés de dés."""
    dict_index = {}
    for i_0 in range(9):
        for i_1 in range(9-i_0):
            for i_2 in range(9-i_0-i_1):
                for i_3 in range(9-i_0-i_1-i_2):
                    for i_4 in range(9-i_0-i_1-i_2-i_3):
                        for i_5 in range(9-i_0-i_1-i_2-i_3-i_4):
                            i_6 = 8-i_0-i_1-i_2-i_3-i_4-i_5
                            dict_index[((i_0,i_1,i_2,i_3,i_4,i_5,i_6))] = len(dict_index)
    return dict_index

In [13]:
dict_index = index_partition()

In [19]:
dict_number = {v: k for k, v in dict_index.items()}

In [51]:
list_numbers = list(dict_number.values())

In [64]:
for ind, tup in enumerate(list_numbers):
    if tup[0] == 1:
        print(ind, tup)
        break

1287 (1, 0, 0, 0, 0, 0, 7)


In [29]:
reward_vector = torch.zeros(50)
reward_vector[21:25] = 1
reward_vector[25:29] = 2
reward_vector[29:33] = 3
reward_vector[33:37] = 4

In [35]:
state_shape = (current_values_size, current_dice_values_picked_size, remaining_dice, draw_result_size)

Il y a deux cas : soit on choisit un nombre inacceptable et on va dans l'état $50$, soit on choisit un nombre acceptable, et en fonction du nombre choisi, on doit calculer la nouvelle valeur totale que l'on a, le nouveau nombre de dés restants, et mettre à jour les dés déjà obtenus (en passant simplement de l'état $n$ à $n+2^{c-1}$ où $c$ est la valeur du dé choisi)

In [123]:
list_numbers

[(0, 0, 0, 0, 0, 0, 8),
 (0, 0, 0, 0, 0, 1, 7),
 (0, 0, 0, 0, 0, 2, 6),
 (0, 0, 0, 0, 0, 3, 5),
 (0, 0, 0, 0, 0, 4, 4),
 (0, 0, 0, 0, 0, 5, 3),
 (0, 0, 0, 0, 0, 6, 2),
 (0, 0, 0, 0, 0, 7, 1),
 (0, 0, 0, 0, 0, 8, 0),
 (0, 0, 0, 0, 1, 0, 7),
 (0, 0, 0, 0, 1, 1, 6),
 (0, 0, 0, 0, 1, 2, 5),
 (0, 0, 0, 0, 1, 3, 4),
 (0, 0, 0, 0, 1, 4, 3),
 (0, 0, 0, 0, 1, 5, 2),
 (0, 0, 0, 0, 1, 6, 1),
 (0, 0, 0, 0, 1, 7, 0),
 (0, 0, 0, 0, 2, 0, 6),
 (0, 0, 0, 0, 2, 1, 5),
 (0, 0, 0, 0, 2, 2, 4),
 (0, 0, 0, 0, 2, 3, 3),
 (0, 0, 0, 0, 2, 4, 2),
 (0, 0, 0, 0, 2, 5, 1),
 (0, 0, 0, 0, 2, 6, 0),
 (0, 0, 0, 0, 3, 0, 5),
 (0, 0, 0, 0, 3, 1, 4),
 (0, 0, 0, 0, 3, 2, 3),
 (0, 0, 0, 0, 3, 3, 2),
 (0, 0, 0, 0, 3, 4, 1),
 (0, 0, 0, 0, 3, 5, 0),
 (0, 0, 0, 0, 4, 0, 4),
 (0, 0, 0, 0, 4, 1, 3),
 (0, 0, 0, 0, 4, 2, 2),
 (0, 0, 0, 0, 4, 3, 1),
 (0, 0, 0, 0, 4, 4, 0),
 (0, 0, 0, 0, 5, 0, 3),
 (0, 0, 0, 0, 5, 1, 2),
 (0, 0, 0, 0, 5, 2, 1),
 (0, 0, 0, 0, 5, 3, 0),
 (0, 0, 0, 0, 6, 0, 2),
 (0, 0, 0, 0, 6, 1, 1),
 (0, 0, 0, 0, 6,

In [132]:
def optimal_choice_1(state, action):
    # state is a vector (value, dice_values_picked, remaining_dice)
    # action is a value 0-1, representing whether to continue or not.
    value, dice_values_picked, dice_owned, draw_result = state
    dice_left = 8 - dice_owned
    print(dice_left)
    if action == 0 or dice_left == 0:
        if dice_values_picked[-1] == 0:
            # We choose to stop but have no pickominos
            return c
        else:
            # We have found at least one pickomino
            return reward_vector[value]
    elif action == 1:
        # We choose to continue and have remaining dice
        list_draws = list_numbers[c_intervals_6[dice_owned]:c_intervals_6[dice_owned+1]]
        list_values = []
        for draw in list_draws:
            new_state = value, dice_values_picked, dice_owned, draw
            for action in range(1,7):
                list_values.append(optimal_choice_2(new_state, action))
        return max(list_values)
        

def optimal_choice_2(state, action):
    # state is a vector (value, dice_values_picked, remaining_dice, draw_result)
    # here dice_values_picked is a vector of 6 values 0-1, representing whether the dice has been picked or not.
    # action is a number 0-6 representing the choice of dice to keep, or whether to not pick any dice.
    value, dice_values_picked, dice_owned, draw_result = state
    draw_values = draw_result
    if draw_values[action] == 0:
        return c
    else:
        if dice_values_picked[action-1] == 1:
            return c
        else:
            dice_values_picked[action-1] = 1
            new_state = value + min(action,5)*draw_values[action], dice_values_picked, dice_owned + draw_values[action], draw_result
            choices = [optimal_choice_1(new_state, 0), optimal_choice_1(new_state, 1)]
            return max(choices)


In [107]:
c_intervals_6

[0, 1287, 2079, 2541, 2793, 2919, 2975, 2996, 3002]

In [147]:
c = -3

In [148]:
value = 0
dice_values_picked = [0,0,0,0,0,0]
dice_owned = 0
draw_result = None
state = value, dice_values_picked, dice_owned, draw_result

optimal_choice_1(state, 1)

8
0
0
7
7
6
6
5
5
4
4
3
3


tensor(0.)

# Matrice de l'état
On code une matrice représentant l'état du système

In [24]:
terminal_matrix = torch.zeros(state_shape) # shape (total, picked_dice, remaining_dice, draw_result)

In [30]:
for i in range(current_values_size):
    terminal_matrix[i, :, :, :] = reward_vector[i] # Si on est à la fin, on compte les points pour attribuer les récompenses
    terminal_matrix[:,:current_dice_values_picked_size//2,:,:] = c # Si on n'a pas trouvé le pickomino, on a perdu et on a donc une récompense de -3

In [32]:
terminal_matrix[30]

tensor([[3., 3., 3.,  ..., 3., 3., 3.],
        [3., 3., 3.,  ..., 3., 3., 3.],
        [3., 3., 3.,  ..., 3., 3., 3.],
        ...,
        [3., 3., 3.,  ..., 3., 3., 3.],
        [3., 3., 3.,  ..., 3., 3., 3.],
        [3., 3., 3.,  ..., 3., 3., 3.]])

In [25]:
terminal_matrix.numel()

86486400