# Medicine Testing - Gittins Index

In [10]:
# number of patients
L = 50

# discount factor alpha
alpha = 0.9

# domain of both x and y is [0, L]
X_minus = list(range(L))
X = X_minus + [L]
Y_minus = list(range(L))
Y = Y_minus + [L]

from typing import Tuple

# Value iteration step
def value_iteration_step(selected: Tuple[int, int], V_old):

    V_new = {}
    for x in X_minus:
        for y in Y_minus:
            V_regular = (
                ((x + 1) / (x + y + 2)) * (1 + alpha * V_old[(x + 1, y)]) +
                ((y + 1) / (x + y + 2)) * alpha * V_old[(x, y + 1)]
            )
            V_restart = (
                ((selected[0] + 1) / (selected[0] + selected[1] + 2)) * (1 + alpha * V_old[(selected[0] + 1, selected[1])]) +
                ((selected[1] + 1) / (selected[0] + selected[1] + 2)) * alpha * V_old[(selected[0], selected[1] + 1)]
            )
            V_new[(x,y)] = max(V_regular, V_restart)
        
    for x in X:
        V_new[(x,L)] = V_old[(x,L)]

    for y in Y:
        V_new[(L,y)] = V_old[(L,y)]

    return V_new


def diff(V_old, V_new):
    max_diff = 0
    for state in V_new.keys():
        diff = abs(V_new[state] - V_old[state])
        if diff > max_diff:
            max_diff = diff

    return max_diff


def value_iteration(selected: Tuple[int, int]):

    # Initial Value function
    V_old = {}
    for x in X:
        for y in Y:
            V_old[(x,y)] = 0

    max_n = 1000
    eps = 0.001

    for n in range(max_n):
        V_new = value_iteration_step(selected, V_old)
        diff_value = diff(V_new, V_old)
        if diff_value < eps:
            break

        V_old = V_new

    return V_new, n, diff_value


In [19]:
selected_states = [(0, 1), (4, 6), (3, 7), (8, 12), (7, 13), (12, 18), (11, 19), (16, 24), (15, 26), (20, 30), (19, 31)]
g_index_selected_states = []
success_prob_selected_states = []
for selected_state in selected_states:  
    V_selected, n, diff_value = value_iteration(selected_state)
    g_index_selected = (1 - alpha) * V_selected[selected_state]
    g_index_selected_states.append(g_index_selected)

    success_prob = (selected_state[0] + 1) / (selected_state[0] + selected_state[1] + 2)
    success_prob_selected_states.append(success_prob)

In [20]:
g_index_selected_states

[0.49914613287272824,
 0.47395540384703405,
 0.3892527519336255,
 0.44241447444918786,
 0.39648757971192344,
 0.42966503050433796,
 0.39817759491281435,
 0.42273734392358947,
 0.38943377922067524,
 0.41830401481252216,
 0.3989274600293537]

In [21]:
success_prob_selected_states

[0.3333333333333333,
 0.4166666666666667,
 0.3333333333333333,
 0.4090909090909091,
 0.36363636363636365,
 0.40625,
 0.375,
 0.40476190476190477,
 0.37209302325581395,
 0.40384615384615385,
 0.38461538461538464]