In [1]:
import copy 
import matplotlib.pyplot as plt
import mdptoolbox as mdpt
import numpy as np

In [2]:
class State:
    def __init__(self, state, length=2):
        self.state = np.array(state)
        self.terminal_length = length
    
    def __str__(self):
        res = ''
        for i in range(3):
            res += str(self.state[i])
        res += ','
        for i in range(3,6):
            res += str(self.state[i])
        res += ','
        res += str(self.state[6])
        return res
    
    def __setitem__(self, ind, val):
        self.state[ind] = val
        
    def __getitem__(self, ind):
        return self.state[ind]
    
    def isTerminal(self):
        return self.terminal_length in self.state[:3]

    def getHonestFork(self):
        return self.state[-1]
    
    def __eq__(self, other):
        return all(self.state == other.state)

In [3]:
# returns possible next states
def nextStates(cur_state):
    new_states = []
    
    # mining on my fork
    temp = copy.deepcopy(cur_state)
    temp[0] += 1
    temp[3] += 1

    # new honest fork.
    if any(temp[:3] > temp[temp.getHonestFork()]):
        temp[-1] = np.argmax(temp[:3])

    new_states.append(temp)
    
    # mining on the honest fork
    temp = copy.deepcopy(cur_state)
    temp[temp.getHonestFork()] += 1
    temp[3+temp.getHonestFork()] += 1

    # new honest fork.
    if any(temp[:3] > temp[temp.getHonestFork()]):
        temp[-1] = np.argmax(temp[:3])

    new_states.append(temp)
    
    # them mining on honest.
    temp = copy.deepcopy(cur_state)
    temp[temp.getHonestFork()] += 1

    # new honest fork.
    if any(temp[:3] > temp[temp.getHonestFork()]):
        temp[-1] = np.argmax(temp[:3])

    new_states.append(temp)
    
    return new_states

In [4]:
def checkIfInTotalStates(total_states, state):
    for s in total_states:
        if state == s:
            return True
    return False

In [27]:
def printTransitions(transitions, rewards, si):
    for i in range(len(total_states)):
        print(si.s(i))
        for j in range(len(total_states)):
            if transitions[0,i,j] != 0:
                print("    ", si.s(j), " selfish, {:0.2f}, {:0.2f}".format(
                    transitions[0,i,j], rewards[0,i,j]))
        for j in range(len(total_states)):
            if transitions[1,i,j] != 0:
                print("    ", si.s(j), " honest, {:0.2f}, {:0.2f}".format(
                    transitions[1,i,j], rewards[1,i,j]))
                
def prettyPrintPolicy(policy, si):
    for i in range(len(policy)):
        print(si.s(i), end=' ')
        if policy[i] == 0:
            print('selfish')
        else:
            print('honest')

In [17]:
# algorithm to enumerate possible states.
init_state = State([0,0,0,0,0,0,0])
oneblockstate0 = State([1,0,0,1,0,0,0])
oneblockstate1 = State([0,1,0,0,0,0,1])
oneblockstate2 = State([0,0,1,0,0,0,2])

total_states = [init_state, oneblockstate0, oneblockstate1, oneblockstate2]
states_to_process = [oneblockstate0, oneblockstate1, oneblockstate2]
while states_to_process:
    elem = states_to_process.pop()
    next_states = nextStates(elem)
    for s in next_states:
        if not checkIfInTotalStates(total_states, s):
            total_states.append(s)
        if not s.isTerminal() and not checkIfInTotalStates(states_to_process, s):
            states_to_process.append(s)

In [18]:
for s in total_states:
    if not s.isTerminal():
        print(s)
    else:
        print('  ', s)

000,000,0
100,100,0
010,000,1
001,000,2
101,100,2
   002,001,2
   002,000,2
   201,200,0
   102,101,2
   102,100,2
110,100,1
   020,010,1
   020,000,1
   210,200,0
   120,110,1
   120,100,1
   200,200,0
   200,100,0


In [19]:
class Ind:
    def __init__(self, total_states):
        self.stringToInd = {}
        self.indToString = {}
        for i in range(len(total_states)):
            self.stringToInd[str(total_states[i])] = i
            self.indToString[i] = str(total_states[i])
    
    def i(self, string):
        return self.stringToInd[string]

    def s(self, ind):
        return self.indToString[ind]

In [20]:
si = Ind(total_states)

In [79]:
WHALE_REWARD = 2.00

In [80]:
transitions = np.zeros((2, len(total_states), len(total_states)))
rewards = np.zeros((2, len(total_states), len(total_states)))

# selfish
transitions[0, si.i('000,000,0'), si.i('100,100,0')] = 1/3
transitions[0, si.i('000,000,0'), si.i('010,000,1')] = 1/3
transitions[0, si.i('000,000,0'), si.i('001,000,2')] = 1/3
rewards[0, si.i('000,000,0'), si.i('100,100,0')] = -1/3
rewards[0, si.i('000,000,0'), si.i('010,000,1')] = -1/3
rewards[0, si.i('000,000,0'), si.i('001,000,2')] = -1/3

transitions[0, si.i('100,100,0'), si.i('200,200,0')] = 1/3
transitions[0, si.i('100,100,0'), si.i('200,100,0')] = 2/3
rewards[0, si.i('100,100,0'), si.i('200,200,0')] = WHALE_REWARD + 2 - 1/3
rewards[0, si.i('100,100,0'), si.i('200,100,0')] = WHALE_REWARD + 1 - 1/3

transitions[0, si.i('010,000,1'), si.i('110,100,1')] = 1/3
transitions[0, si.i('010,000,1'), si.i('020,000,1')] = 2/3
rewards[0, si.i('010,000,1'), si.i('110,100,1')] = -1/3
rewards[0, si.i('010,000,1'), si.i('020,000,1')] = -1/3

transitions[0, si.i('001,000,2'), si.i('101,100,2')] = 1/3
transitions[0, si.i('001,000,2'), si.i('002,000,2')] = 2/3
rewards[0, si.i('001,000,2'), si.i('101,100,2')] = -1/3
rewards[0, si.i('001,000,2'), si.i('002,000,2')] = -1/3

transitions[0, si.i('101,100,2'), si.i('201,200,0')] = 1/3
transitions[0, si.i('101,100,2'), si.i('102,100,2')] = 2/3
rewards[0, si.i('101,100,2'), si.i('201,200,0')] = WHALE_REWARD + 2 - 1/3
rewards[0, si.i('101,100,2'), si.i('102,100,2')] = -1/3

transitions[0, si.i('110,100,1'), si.i('210,200,0')] = 1/3
transitions[0, si.i('110,100,1'), si.i('120,100,1')] = 2/3
rewards[0, si.i('110,100,1'), si.i('210,200,0')] = WHALE_REWARD + 2 - 1/3
rewards[0, si.i('110,100,1'), si.i('120,100,1')] = -1/3

# honest
transitions[1, si.i('000,000,0'), si.i('100,100,0')] = 1/3
transitions[1, si.i('000,000,0'), si.i('010,000,1')] = 1/3
transitions[1, si.i('000,000,0'), si.i('001,000,2')] = 1/3
rewards[1, si.i('000,000,0'), si.i('100,100,0')] = -1/3
rewards[1, si.i('000,000,0'), si.i('010,000,1')] = -1/3
rewards[1, si.i('000,000,0'), si.i('001,000,2')] = -1/3

transitions[1, si.i('100,100,0'), si.i('200,200,0')] = 1/3
transitions[1, si.i('100,100,0'), si.i('200,100,0')] = 2/3
rewards[1, si.i('100,100,0'), si.i('200,200,0')] = WHALE_REWARD + 2 - 1/3
rewards[1, si.i('100,100,0'), si.i('200,100,0')] = WHALE_REWARD + 1 - 1/3

transitions[1, si.i('010,000,1'), si.i('020,010,1')] = 1/3
transitions[1, si.i('010,000,1'), si.i('020,000,1')] = 2/3
rewards[1, si.i('010,000,1'), si.i('020,010,1')] = 2/3
rewards[1, si.i('010,000,1'), si.i('020,000,1')] = -1/3

transitions[1, si.i('001,000,2'), si.i('002,001,2')] = 1/3
transitions[1, si.i('001,000,2'), si.i('002,000,2')] = 2/3
rewards[1, si.i('001,000,2'), si.i('002,001,2')] = 2/3
rewards[1, si.i('001,000,2'), si.i('002,000,2')] = -1/3

transitions[1, si.i('101,100,2'), si.i('102,101,2')] = 1/3
transitions[1, si.i('101,100,2'), si.i('102,100,2')] = 2/3
rewards[1, si.i('101,100,2'), si.i('102,101,2')] = 2/3
rewards[1, si.i('101,100,2'), si.i('102,100,2')] = -1/3

transitions[1, si.i('110,100,1'), si.i('120,110,1')] = 1/3
transitions[1, si.i('110,100,1'), si.i('120,100,1')] = 2/3
rewards[1, si.i('110,100,1'), si.i('120,110,1')] = 2/3
rewards[1, si.i('110,100,1'), si.i('120,100,1')] = -1/3

for i in range(len(total_states)):
    if total_states[i].isTerminal():
        transitions[0,i,i] = 1
        transitions[1,i,i] = 1

In [81]:
np.sum(transitions, axis=2)

array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.]])

In [82]:
val_iter = mdpt.mdp.ValueIteration(transitions, rewards, discount=1)
val_iter.run()
policy = val_iter.policy



In [83]:
prettyPrintPolicy(policy, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 honest
001,000,2 honest
101,100,2 selfish
002,001,2 selfish
002,000,2 selfish
201,200,0 selfish
102,101,2 selfish
102,100,2 selfish
110,100,1 selfish
020,010,1 selfish
020,000,1 selfish
210,200,0 selfish
120,110,1 selfish
120,100,1 selfish
200,200,0 selfish
200,100,0 selfish
