In [1]:
import copy 
import matplotlib.pyplot as plt
import mdptoolbox as mdpt
import numpy as np

In [2]:
class FirstPersonState:
    def __init__(self, state, length=2):
        self.state = np.array(state)
        self.terminal_length = length
    
    def __str__(self):
        res = ''
        for i in range(3):
            res += str(self.state[i])
        res += ','
        for i in range(3,6):
            res += str(self.state[i])
        res += ','
        res += str(self.state[6])
        return res
    
    def __setitem__(self, ind, val):
        self.state[ind] = val
        
    def __getitem__(self, ind):
        return self.state[ind]
    
    def isTerminal(self):
        return self.terminal_length in self.state[:3]

    def getHonestFork(self):
        return self.state[-1]

    def updateHonestFork(self):
        if any(self.state[:3] > self.state[self.getHonestFork()]):
            self.state[-1] = np.argmax(self.state[:3])
    
    def __eq__(self, other):
        return all(self.state == other.state)

In [3]:
def nextFirstPersonStates(state):
    new_states = []
    
    # anyone mining on their own fork
    for i in range(3):
        temp = copy.deepcopy(state)
        temp[i] += 1
        if i == 0:
            temp[3] += 1
        temp.updateHonestFork()
        new_states.append(temp)
        
    # anyone mining on another fork.
    for i in range(3):
        if state[i] != 0:
            for j in range(3):
                if j != i:
                    temp = copy.deepcopy(state)
                    temp[i] += 1
                    if j == 0:
                        temp[3+i] += 1
                    temp.updateHonestFork()
                    new_states.append(temp)
    return new_states

In [4]:
def checkIfIn(state, states):
    for s in states:
        if state == s:
            return True
    return False

In [5]:
def printTransitions(transitions, rewards, si):
    for i in range(len(total_states)):
        print(si.s(i))
        for j in range(len(total_states)):
            if transitions[0,i,j] != 0:
                print("    ", si.s(j), " selfish, {:0.2f}, {:0.2f}".format(
                    transitions[0,i,j], rewards[0,i,j]))
        for j in range(len(total_states)):
            if transitions[1,i,j] != 0:
                print("    ", si.s(j), " honest, {:0.2f}, {:0.2f}".format(
                    transitions[1,i,j], rewards[1,i,j]))

def prettyPrintPolicy(policy, si):
    for i in range(len(policy)):
        print(si.s(i), end=' ')
        if policy[i] == 0:
            print('selfish')
        else:
            print('honest')

In [27]:
def getTotalStates(length = 2):
    init_state = FirstPersonState([0,0,0,0,0,0,0], length)
    oneblock_state_a = FirstPersonState([1,0,0,1,0,0,0], length)
    oneblock_state_b = FirstPersonState([0,1,0,0,0,0,1], length)
    oneblock_state_c = FirstPersonState([0,0,1,0,0,0,2], length)
    total_states = [init_state, oneblock_state_a, oneblock_state_b, oneblock_state_c]
    states_to_process = [oneblock_state_a, oneblock_state_b, oneblock_state_c]
    explored = [str(init_state)]
    
    count = 0
    while states_to_process:
        count += 1
        elem = states_to_process.pop()
        explored.append(str(elem))
        next_states = nextFirstPersonStates(elem)
        for s in next_states:
            if not checkIfIn(s, total_states):
                total_states.append(s)
            if not s.isTerminal() and str(s) not in explored:
                states_to_process.append(s)
    return total_states

In [28]:
total_states = getTotalStates(length=4)
len(total_states)

1138

In [30]:
class Ind:
    def __init__(self, total_states):
        self.stringToInd = {}
        self.indToString = {}
        for i in range(len(total_states)):
            self.stringToInd[str(total_states[i])] = i
            self.indToString[i] = str(total_states[i])
    
    def i(self, string):
        return self.stringToInd[string]

    def s(self, ind):
        return self.indToString[ind]

In [31]:
class State:
    def __init__(self, state_index, policy, length=2):
        self.state = np.array([
            [0,0,0],
            [0,0,0],
            [0,0,0]
        ])
        self.honest_fork = 0
        self.state_index = state_index
        self.policy = policy
        self.length = length
        self.visited = False
    
    def getState(self, miner_ind):
        if np.sum(self.state) == 0:
            return np.array([0,0,0,0,0,0,0])
        
        fork_lens = np.sum(self.state, axis=0)
        fork_lens[0], fork_lens[miner_ind] = fork_lens[miner_ind], fork_lens[0]
        
        miner_lens = copy.deepcopy(self.state[miner_ind])
        miner_lens[0], miner_lens[miner_ind] = miner_lens[miner_ind], miner_lens[0]
        
        return np.concatenate((fork_lens, miner_lens, [self.getHonestFork(miner_ind)]))
    
    def getHonestFork(self, miner_ind):
        if self.honest_fork == miner_ind:
            return 0
        elif self.honest_fork == 0:
            return miner_ind
        return self.honest_fork

    def updateHonestFork(self):
        fork_lens = np.sum(self.state, axis=0)
        if any(fork_lens > fork_lens[self.getHonestFork(0)]):
            self.honest_fork = np.argmax(fork_lens)
    
    def isTerminal(self):
        fork_lens = np.sum(self.state, axis=0)
        return np.max(fork_lens) == self.length
    
    def __eq__(self, other):
        if np.all(self.state == other.state) and self.honest_fork == other.honest_fork:
            return True
        return False
    
    def getStrRep(self, miner_ind):
        state = self.getState(miner_ind)
        return self.getStrRepState(state)
    
    def getStrRepState(self, state):
        res = ''
        for i in range(3):
            res += str(state[i])
        res += ','
        for i in range(3,6):
            res += str(state[i])
        res += ','
        res += str(state[6])
        return res
    
    def getOtherAgentActions(self):
        actions = []
        for i in [1,2]:
            other_agent_state = self.getStrRep(i)
            action = self.policy[self.state_index.i(other_agent_state)]
            if action == 0: # selfish
                actions.append((i, i))
            else: # honest
                actions.append((i, self.honest_fork))
        return actions
    
    def getReward(self):
        reward = -1/3
        if self.isTerminal():
            state = self.getState(0)
            if state[0] == self.length:
                reward += WHALE_REWARD + state[3]
            else:
                win_ind = np.argmax(state[:3])
                reward += state[3 + win_ind]
        return reward
        
    def getNextStatesSelfish(self):
        actions = self.getOtherAgentActions()
        actions.append((0,0)) # mining on my fork
        
        next_states = []
        for a in actions:
            temp = copy.deepcopy(self)
            temp.state[a] += 1
            temp.updateHonestFork()
            next_states.append(temp)
        return next_states
    
    def getNextStatesHonest(self):
        actions = self.getOtherAgentActions()
        actions.append((0,self.getHonestFork(0))) # mining on honest
        
        next_states = []
        for a in actions:
            temp = copy.deepcopy(self)
            temp.state[a] += 1
            temp.updateHonestFork()
            next_states.append(temp)
        return next_states

In [32]:
def getTR(si, total_states, policy, length=2):
    t = np.zeros((2, len(total_states), len(total_states)))
    r = np.zeros((2, len(total_states), len(total_states)))

    init_state = State(si, policy, length)
    states_to_process = [init_state]
    processed_states = [init_state]
    processed_strings = [init_state.getStrRep(0)]
    while states_to_process:
        elem = states_to_process.pop()
        current_string = elem.getStrRep(0)
        selfish_states = elem.getNextStatesSelfish()
            
        for s in selfish_states:
            next_string = s.getStrRep(0)
            t[0, si.i(current_string), si.i(next_string)] += 1/3
            r[0, si.i(current_string), si.i(next_string)] = s.getReward()

            if not s.isTerminal() and not checkIfIn(s, processed_states) and not checkIfIn(next_string, processed_strings):
                states_to_process.append(s)
                processed_states.append(copy.deepcopy(s))
                processed_strings.append(next_string)

        honest_states = elem.getNextStatesHonest()
        for s in honest_states:
            next_string = s.getStrRep(0)
            t[1, si.i(current_string), si.i(next_string)] += 1/3
            r[1, si.i(current_string), si.i(next_string)] = s.getReward()

            if not s.isTerminal() and not checkIfIn(s, processed_states) and not checkIfIn(next_string, processed_strings):
                states_to_process.append(s)
                processed_states.append(copy.deepcopy(s))
                processed_strings.append(next_string)

    for i in range(len(total_states)):
        if total_states[i].isTerminal():
            t[0,i,i] = 1
            t[1,i,i] = 1
    
    sums = np.sum(t, axis = 2)
    for i in range(len(total_states)):
        if sums[0,i] == 0:
            assert(sums[1,i] == 0)
            t[0,i,i] = 1
            t[1,i,i] = 1
            
    return t, r

In [33]:
si = Ind(total_states)

In [214]:
selfish_policy = np.zeros(len(total_states)) 
honest_policy = np.ones(len(total_states))
honest_policy[0] = 0
honest_policy[1] = 0
# honest_policy[2] = 0
# honest_policy[3] = 0

In [223]:
# WHALE_REWARD = 1.7476
WHALE_REWARD = 1.75

In [224]:
# t, r = getTR(si, total_states, selfish_policy, length=4)
t, r = getTR(si, total_states, honest_policy, length=4)
val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
val_iter.run()
policy = val_iter.policy



In [225]:
prettyPrintPolicy(policy, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 selfish
001,000,2 selfish
101,100,2 selfish
011,000,2 selfish
002,000,2 honest
002,001,2 honest
102,100,2 selfish
012,000,2 selfish
003,000,2 honest
003,001,2 honest
103,100,2 honest
013,000,2 selfish
004,000,2 selfish
004,001,2 selfish
113,100,2 selfish
023,000,2 selfish
014,000,2 selfish
023,010,2 selfish
014,001,2 selfish
123,100,2 selfish
033,000,2 selfish
024,000,2 selfish
033,010,2 selfish
024,001,2 selfish
133,100,2 selfish
043,000,1 selfish
034,000,2 selfish
043,010,1 selfish
034,001,2 selfish
233,200,2 selfish
143,100,1 selfish
134,100,2 selfish
233,100,2 selfish
143,110,1 selfish
134,101,2 selfish
333,200,2 selfish
243,100,1 selfish
234,100,2 selfish
333,100,2 selfish
243,110,1 selfish
234,101,2 selfish
433,200,0 selfish
343,100,1 selfish
334,100,2 selfish
433,100,0 selfish
343,110,1 selfish
334,101,2 selfish
433,300,0 selfish
343,200,1 selfish
334,200,2 selfish
343,210,1 selfish
334,201,2 selfish
333,300,2 selfish
243,200,1 selfi

303,202,0 selfish
302,301,0 selfish
402,401,0 selfish
303,302,0 selfish
111,100,2 selfish
021,000,1 selfish
021,010,1 selfish
121,100,1 selfish
031,000,1 selfish
022,000,1 selfish
031,010,1 selfish
022,001,1 selfish
122,100,1 selfish
222,200,1 selfish
222,100,1 selfish
122,101,1 selfish
222,201,1 selfish
222,101,1 selfish
131,100,1 selfish
041,000,1 selfish
041,010,1 selfish
231,200,1 selfish
141,100,1 selfish
231,100,1 selfish
141,110,1 selfish
331,200,1 selfish
241,100,1 selfish
331,100,1 selfish
241,110,1 selfish
431,200,0 selfish
341,100,1 selfish
431,100,0 selfish
341,110,1 selfish
431,300,0 selfish
341,200,1 selfish
341,210,1 selfish
331,300,1 selfish
241,200,1 selfish
241,210,1 selfish
431,400,0 selfish
341,300,1 selfish
341,310,1 selfish
131,110,1 selfish
041,020,1 selfish
231,210,1 selfish
231,110,1 selfish
141,120,1 selfish
331,210,1 selfish
331,110,1 selfish
241,120,1 selfish
431,210,0 selfish
431,110,0 selfish
341,120,1 selfish
431,310,0 selfish
341,220,1 selfish
331,310,1 

In [226]:
t2, r2 = getTR(si, total_states, policy, length=4)
val_iter = mdpt.mdp.ValueIteration(t2, r2, discount=1)
val_iter.run()
policy2 = val_iter.policy



In [227]:
prettyPrintPolicy(policy2, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 honest
001,000,2 honest
101,100,2 selfish
011,000,2 honest
002,000,2 honest
002,001,2 honest
102,100,2 selfish
012,000,2 honest
003,000,2 honest
003,001,2 honest
103,100,2 honest
013,000,2 honest
004,000,2 selfish
004,001,2 selfish
113,100,2 honest
023,000,2 honest
014,000,2 selfish
023,010,2 honest
014,001,2 selfish
123,100,2 honest
033,000,2 honest
024,000,2 selfish
033,010,2 honest
024,001,2 selfish
133,100,2 honest
043,000,1 selfish
034,000,2 selfish
043,010,1 selfish
034,001,2 selfish
233,200,2 selfish
143,100,1 selfish
134,100,2 selfish
233,100,2 selfish
143,110,1 selfish
134,101,2 selfish
333,200,2 selfish
243,100,1 selfish
234,100,2 selfish
333,100,2 selfish
243,110,1 selfish
234,101,2 selfish
433,200,0 selfish
343,100,1 selfish
334,100,2 selfish
433,100,0 selfish
343,110,1 selfish
334,101,2 selfish
433,300,0 selfish
343,200,1 selfish
334,200,2 selfish
343,210,1 selfish
334,201,2 selfish
333,300,2 selfish
243,200,1 selfish
234,200,2

032,001,1 honest
032,011,1 selfish
132,101,1 honest
042,001,1 selfish
042,011,1 selfish
033,002,1 selfish
133,102,1 selfish
233,202,1 selfish
233,102,1 selfish
333,202,1 selfish
333,102,1 selfish
333,302,1 selfish
232,201,1 selfish
142,101,1 selfish
232,101,1 selfish
142,111,1 selfish
332,201,1 selfish
242,101,1 selfish
332,101,1 selfish
242,111,1 selfish
432,201,0 selfish
342,101,1 selfish
432,101,0 selfish
342,111,1 selfish
432,301,0 selfish
342,201,1 selfish
342,211,1 selfish
332,301,1 selfish
242,201,1 selfish
242,211,1 selfish
432,401,0 selfish
342,301,1 selfish
342,311,1 selfish
132,111,1 selfish
042,021,1 selfish
033,012,1 selfish
133,112,1 selfish
233,212,1 selfish
233,112,1 selfish
333,212,1 selfish
333,112,1 selfish
333,312,1 selfish
232,211,1 selfish
232,111,1 selfish
142,121,1 selfish
332,211,1 selfish
332,111,1 selfish
242,121,1 selfish
432,211,0 selfish
432,111,0 selfish
342,121,1 selfish
432,311,0 selfish
342,221,1 selfish
332,311,1 selfish
242,221,1 selfish
432,411,0 se

In [228]:
t3, r3 = getTR(si, total_states, policy2, length=4)
val_iter = mdpt.mdp.ValueIteration(t3, r3, discount=1)
val_iter.run()
policy3 = val_iter.policy



In [230]:
prettyPrintPolicy(policy3, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 honest
001,000,2 honest
101,100,2 selfish
011,000,2 selfish
002,000,2 honest
002,001,2 honest
102,100,2 selfish
012,000,2 selfish
003,000,2 honest
003,001,2 honest
103,100,2 honest
013,000,2 selfish
004,000,2 selfish
004,001,2 selfish
113,100,2 selfish
023,000,2 selfish
014,000,2 selfish
023,010,2 selfish
014,001,2 selfish
123,100,2 selfish
033,000,2 selfish
024,000,2 selfish
033,010,2 selfish
024,001,2 selfish
133,100,2 selfish
043,000,1 selfish
034,000,2 selfish
043,010,1 selfish
034,001,2 selfish
233,200,2 selfish
143,100,1 selfish
134,100,2 selfish
233,100,2 selfish
143,110,1 selfish
134,101,2 selfish
333,200,2 selfish
243,100,1 selfish
234,100,2 selfish
333,100,2 selfish
243,110,1 selfish
234,101,2 selfish
433,200,0 selfish
343,100,1 selfish
334,100,2 selfish
433,100,0 selfish
343,110,1 selfish
334,101,2 selfish
433,300,0 selfish
343,200,1 selfish
334,200,2 selfish
343,210,1 selfish
334,201,2 selfish
333,300,2 selfish
243,200,1 selfish

331,210,0 selfish
321,300,0 selfish
421,400,0 selfish
331,300,0 selfish
331,310,0 selfish
121,110,1 selfish
022,010,1 selfish
031,020,1 selfish
022,011,1 selfish
122,110,1 selfish
222,210,1 selfish
222,110,1 selfish
122,111,1 selfish
222,211,1 selfish
222,111,1 selfish
131,120,1 selfish
041,030,1 selfish
231,220,1 selfish
231,120,1 selfish
141,130,1 selfish
331,220,1 selfish
331,120,1 selfish
241,130,1 selfish
431,220,0 selfish
431,120,0 selfish
341,130,1 selfish
431,320,0 selfish
341,230,1 selfish
331,320,1 selfish
241,230,1 selfish
431,420,0 selfish
341,330,1 selfish
221,210,1 selfish
221,110,1 selfish
321,210,0 selfish
321,110,0 selfish
421,210,0 selfish
421,110,0 selfish
331,120,0 selfish
421,310,0 selfish
331,220,0 selfish
321,310,0 selfish
421,410,0 selfish
331,320,0 selfish
211,200,0 selfish
211,100,0 selfish
311,200,0 selfish
221,100,0 selfish
212,100,0 selfish
311,100,0 selfish
221,110,0 selfish
212,101,0 selfish
222,100,0 selfish
222,110,0 selfish
222,101,0 selfish
222,111,0 

In [231]:
policy3 == policy2

False

In [212]:
WHALE_REWARD = 1.75
construction_policy = copy.deepcopy(selfish_policy)

cont = True
count = 0
while cont:
    t, r = getTR(si, total_states, construction_policy, length=4)
    val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
    val_iter.run()
    solved_policy = np.array(val_iter.policy)
    
    # find reachable states
    reachable_states = []
    states_to_process = [0]
    while states_to_process:
        cur_state = states_to_process.pop()
        cur_opt_action = solved_policy[cur_state]
        for next_state in np.nonzero(t[cur_opt_action, cur_state])[0]:
            reachable_states.append(next_state)
            if next_state not in states_to_process and '4' not in si.s(next_state)[:3]:
                states_to_process.append(next_state)
    reachable_states = list(set(reachable_states))
    reachable_nonterminal_states = []
    for rs in reachable_states:
        if '4' not in si.s(rs)[:3]:
            reachable_nonterminal_states.append(rs)
    print(reachable_nonterminal_states)
    if np.all(construction_policy[reachable_nonterminal_states] == solved_policy[reachable_nonterminal_states]):
        print('converged! optimal policy is ...')
        for rs in reachable_nonterminal_states:
            if construction_policy[rs] == 0:
                print('  ', si.s(rs), ' selfish')
            else:
                print('  ', si.s(rs), ' honest')
        cont = False
    construction_policy = copy.deepcopy(solved_policy)
#     prettyPrintPolicy(construction_policy, si)

[1024, 1, 2, 3, 5, 6, 7, 9, 10, 11, 8, 13, 16, 17, 19, 21, 22, 24, 543, 31, 545, 547, 1025, 1066, 558, 560, 563, 564, 1028, 566, 569, 1082, 573, 1113, 95, 1122, 1123, 1126, 1127, 1128, 110, 1135, 1136, 532, 1043, 26, 172, 187, 193, 719, 208, 722, 215, 218, 222, 54, 1077, 12, 1079, 370, 872, 873, 874, 875, 876, 877, 367, 368, 879, 880, 371, 372, 885, 375, 377, 379, 888, 373, 383, 378, 390, 903, 405, 412, 926, 937, 939, 942, 943, 442, 465, 467, 980, 468, 471, 476, 478, 995, 996, 997, 1000, 1005, 1014, 1015, 1019, 1020, 1021, 1022, 1023]
[1024, 1, 2, 3, 6, 7, 10, 11, 564, 1082, 1126, 1129, 1130, 1131, 1136, 1019, 1021, 1023]
converged! optimal policy is ...
   030,010,1  honest
   100,100,0  selfish
   010,000,1  honest
   001,000,2  honest
   002,000,2  honest
   002,001,2  honest
   003,000,2  honest
   003,001,2  honest
   003,002,2  honest
   030,020,1  honest
   200,200,0  selfish
   200,100,0  selfish
   300,200,0  selfish
   300,100,0  selfish
   300,300,0  selfish
   020,000,1  ho

In [213]:
construction_policy = copy.deepcopy(honest_policy)

cont = True
count = 0
while cont:
    t, r = getTR(si, total_states, construction_policy, length=4)
    val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
    val_iter.run()
    solved_policy = np.array(val_iter.policy)
    
    # find reachable states
    reachable_states = []
    states_to_process = [0]
    while states_to_process:
        cur_state = states_to_process.pop()
        cur_opt_action = solved_policy[cur_state]
        for next_state in np.nonzero(t[cur_opt_action, cur_state])[0]:
            reachable_states.append(next_state)
            if next_state not in states_to_process and '4' not in si.s(next_state)[:3]:
                states_to_process.append(next_state)
    reachable_states = list(set(reachable_states))
    reachable_nonterminal_states = []
    for rs in reachable_states:
        if '4' not in si.s(rs)[:3]:
            reachable_nonterminal_states.append(rs)
    print(reachable_nonterminal_states)
    if np.all(construction_policy[reachable_nonterminal_states] == solved_policy[reachable_nonterminal_states]):
        print('converged! optimal policy is ...')
        for rs in reachable_nonterminal_states:
            if construction_policy[rs] == 0:
                print('  ', si.s(rs), ' selfish')
            else:
                print('  ', si.s(rs), ' honest')
        cont = False
    construction_policy = copy.deepcopy(solved_policy)

[1024, 1, 2, 3, 4, 6, 8, 1025, 10, 11, 12, 1028, 1043, 547, 1066, 558, 1077, 193, 208, 1113, 1115, 1122, 1126, 1129, 1130, 1131, 1005, 1007, 1136, 1014, 1018, 1019, 1022, 1023]
[1024, 1, 2, 3, 5, 6, 7, 9, 10, 11, 8, 13, 16, 17, 19, 21, 22, 24, 543, 31, 545, 547, 1025, 1066, 558, 560, 563, 564, 1028, 566, 569, 1082, 573, 1113, 95, 1122, 1123, 1126, 1127, 1128, 1130, 110, 1135, 1136, 532, 1043, 26, 172, 187, 193, 719, 208, 722, 215, 218, 222, 54, 1077, 12, 1079, 370, 872, 873, 874, 875, 876, 877, 367, 368, 879, 880, 371, 372, 885, 375, 377, 379, 888, 373, 383, 378, 390, 903, 405, 412, 926, 937, 939, 942, 943, 442, 465, 467, 980, 468, 471, 476, 478, 995, 996, 997, 1000, 1005, 1014, 1015, 1019, 1020, 1021, 1022, 1023]
[1024, 1, 2, 3, 6, 7, 10, 11, 564, 1082, 1126, 1129, 1130, 1131, 1136, 1019, 1021, 1023]
converged! optimal policy is ...
   030,010,1  honest
   100,100,0  selfish
   010,000,1  honest
   001,000,2  honest
   002,000,2  honest
   002,001,2  honest
   003,000,2  honest
   003