In [2]:
import copy 
import matplotlib.pyplot as plt
import mdptoolbox as mdpt
import numpy as np

In [3]:
class State:
    def __init__(self, state, length=2):
        self.state = np.array(state)
        self.terminal_length = length
    
    def __str__(self):
        res = ''
        for i in range(3):
            res += str(self.state[i])
        res += ','
        for i in range(3,6):
            res += str(self.state[i])
        res += ','
        res += str(self.state[6])
        return res
    
    def __setitem__(self, ind, val):
        self.state[ind] = val
        
    def __getitem__(self, ind):
        return self.state[ind]
    
    def isTerminal(self):
        return self.terminal_length in self.state[:3]

    def getHonestFork(self):
        return self.state[-1]
    
    def __eq__(self, other):
        return all(self.state == other.state)

In [6]:
# returns possible next states, given an action
def nextStates(cur_state):
    new_states = []
    
    # anyone mining on own fork.
    for i in range(3):
        temp = copy.deepcopy(cur_state)
        temp[i] += 1
        if i == 0:
            temp[3] += 1

        # new honest fork.
        if any(temp[:3] > temp[temp.getHonestFork()]):
            temp[-1] = np.argmax(temp[:3])

        new_states.append(temp)
    
    # anyone mining on another fork.
    for i in range(3):
        if cur_state[i] != 0:
            for j in range(3):
                if j != i:
                    temp = copy.deepcopy(cur_state)
                    temp[i] += 1
                    if j == 0:
                        temp[3+i] += 1

                    # new honest fork.
                    if any(temp[:3] > temp[temp.getHonestFork()]):
                        temp[-1] = np.argmax(temp[:3])

                    new_states.append(temp)
    return new_states

In [5]:
def checkIfInTotalStates(total_states, state):
    for s in total_states:
        if state == s:
            return True
    return False

In [7]:
# algorithm to enumerate possible states.
init_state = State([0,0,0,0,0,0,0], length=2)

total_states = [init_state]
states_to_process = [init_state]
while states_to_process:
    elem = states_to_process.pop()
    next_states = nextStates(elem)
    for s in next_states:
        if not checkIfInTotalStates(total_states, s):
            total_states.append(s)
        if not s.isTerminal() and not checkIfInTotalStates(states_to_process, s):
            states_to_process.append(s)

In [15]:
for s in total_states:
    if not s.isTerminal():
        print(s)

000,000,0
100,100,0
010,000,1
001,000,2
101,100,2
011,000,2
111,100,2
110,100,1
011,000,1
111,100,1
110,100,0
101,100,0
111,100,0


In [9]:
len(total_states)

37

In [10]:
class StateIndex:
    def __init__(self, total_states):
        self.stringToInd = {}
        self.indToString = {}
        for i in range(len(total_states)):
            self.stringToInd[str(total_states[i])] = i
            self.indToString[i] = str(total_states[i])
    
    def getIndex(self, string):
        return self.stringToInd[string]

    def getString(self, ind):
        return self.indToString[ind]

In [11]:
def getTerminalReward(state, whale_reward=1.5):
    assert(state.isTerminal())
    reward = state[3+np.argmax(state[:3])]
    if np.argmax(state[:3]) == 0:    
        reward += whale_reward
    return reward

In [12]:
state_ind = StateIndex(total_states)

In [13]:
transitions = np.zeros((2, len(total_states), len(total_states)))
rewards = np.zeros((2, len(total_states), len(total_states)))

In [24]:
def isHonestState(state):
    lens = state[:3]
    if np.count_nonzero(lens) == 2 and lens.sum() == 2:
        return True
    return False

In [33]:
# algorithm to construct transition and reward matrices.
init_state = State([0,0,0,0,0,0,0], length=2)

states_to_process = [init_state]
while states_to_process:
    elem = states_to_process.pop()
    next_states = nextStates(elem)
    start_ind = state_ind.getIndex(str(elem))
    if isHonestState(elem):
        honest_ind = np.argmin(elem[:3])
        next_ind_honest = 
        
    else:
        for i in range(3):
            next_ind = state_ind.getIndex(str(next_states[i]))
            transitions[0, start_ind, next_ind] = 1/3
            rewards[0, start_ind, next_ind] = -1/3
            if next_states[i].isTerminal():
                rewards[0, start_ind, next_ind] += getTerminalReward(next_states[i], whale_reward=2)
    for s in next_states:
        if not s.isTerminal() and not checkIfInTotalStates(states_to_process, s):
            states_to_process.append(s)
        if s.isTerminal():
            index = state_ind.getIndex(str(s))
            transitions[0, index, index] = 1
            transitions[1, index, index] = 1

011,000,2 0
101,100,2 1
011,000,1 0
110,100,1 2
101,100,0 1
110,100,0 2


In [None]:
def printTransitions(transitions):
    for i in range(len(total_states)):
        print(state_ind.getString(i))
        for j in range(len(total_states)):
            if transitions[0,i,j] != 0:
                print("    ", state_ind.getString(j), " selfish, {:0.2f}".format(transitions[0,i,j]))
        for j in range(len(total_states)):
            if transitions[1,i,j] != 0:
                print("    ", state_ind.getString(j), " honest, {:0.2f}".format(transitions[1,i,j]))