In [1]:
import copy 
import matplotlib.pyplot as plt
import mdptoolbox as mdpt
import numpy as np

In [2]:
class FirstPersonState:
    def __init__(self, state, length=2):
        self.state = np.array(state)
        self.terminal_length = length
    
    def __str__(self):
        res = ''
        for i in range(3):
            res += str(self.state[i])
        res += ','
        for i in range(3,6):
            res += str(self.state[i])
        res += ','
        res += str(self.state[6])
        return res
    
    def __setitem__(self, ind, val):
        self.state[ind] = val
        
    def __getitem__(self, ind):
        return self.state[ind]
    
    def isTerminal(self):
        return self.terminal_length in self.state[:3]

    def getHonestFork(self):
        return self.state[-1]

    def updateHonestFork(self):
        if any(self.state[:3] > self.state[self.getHonestFork()]):
            self.state[-1] = np.argmax(self.state[:3])
    
    def __eq__(self, other):
        return all(self.state == other.state)

In [3]:
def nextFirstPersonStates(state):
    new_states = []
    
    # anyone mining on their own fork
    for i in range(3):
        temp = copy.deepcopy(state)
        temp[i] += 1
        if i == 0:
            temp[3] += 1
        temp.updateHonestFork()
        new_states.append(temp)
        
    # anyone mining on another fork.
    for i in range(3):
        if state[i] != 0:
            for j in range(3):
                if j != i:
                    temp = copy.deepcopy(state)
                    temp[i] += 1
                    if j == 0:
                        temp[3+i] += 1
                    temp.updateHonestFork()
                    new_states.append(temp)
    return new_states

In [4]:
def checkIfIn(state, states):
    for s in states:
        if state == s:
            return True
    return False

In [5]:
def printTransitions(transitions, rewards, si):
    for i in range(len(total_states)):
        print(si.s(i))
        for j in range(len(total_states)):
            if transitions[0,i,j] != 0:
                print("    ", si.s(j), " selfish, {:0.2f}, {:0.2f}".format(
                    transitions[0,i,j], rewards[0,i,j]))
        for j in range(len(total_states)):
            if transitions[1,i,j] != 0:
                print("    ", si.s(j), " honest, {:0.2f}, {:0.2f}".format(
                    transitions[1,i,j], rewards[1,i,j]))

def prettyPrintPolicy(policy, si):
    for i in range(len(policy)):
        print(si.s(i), end=' ')
        if policy[i] == 0:
            print('selfish')
        else:
            print('honest')

In [6]:
def getTotalStates(length = 2):
    init_state = FirstPersonState([0,0,0,0,0,0,0], length)
    oneblock_state_a = FirstPersonState([1,0,0,1,0,0,0], length)
    oneblock_state_b = FirstPersonState([0,1,0,0,0,0,1], length)
    oneblock_state_c = FirstPersonState([0,0,1,0,0,0,2], length)
    total_states = [init_state, oneblock_state_a, oneblock_state_b, oneblock_state_c]
    states_to_process = [oneblock_state_a, oneblock_state_b, oneblock_state_c]
    
    while states_to_process:
        elem = states_to_process.pop()
        next_states = nextFirstPersonStates(elem)
        for s in next_states:
            if not checkIfIn(s, total_states):
                total_states.append(s)
            if not s.isTerminal() and not checkIfIn(s, states_to_process):
                states_to_process.append(s)
    return total_states

In [7]:
total_states = getTotalStates(length=3)
len(total_states)

253

In [8]:
class Ind:
    def __init__(self, total_states):
        self.stringToInd = {}
        self.indToString = {}
        for i in range(len(total_states)):
            self.stringToInd[str(total_states[i])] = i
            self.indToString[i] = str(total_states[i])
    
    def i(self, string):
        return self.stringToInd[string]

    def s(self, ind):
        return self.indToString[ind]

In [9]:
class State:
    def __init__(self, state_index, policy, length=2):
        self.state = np.array([
            [0,0,0],
            [0,0,0],
            [0,0,0]
        ])
        self.honest_fork = 0
        self.state_index = state_index
        self.policy = policy
        self.length = length
        self.visited = False
    
    def getState(self, miner_ind):
        if np.sum(self.state) == 0:
            return np.array([0,0,0,0,0,0,0])
        
        fork_lens = np.sum(self.state, axis=0)
        fork_lens[0], fork_lens[miner_ind] = fork_lens[miner_ind], fork_lens[0]
        
        miner_lens = copy.deepcopy(self.state[miner_ind])
        miner_lens[0], miner_lens[miner_ind] = miner_lens[miner_ind], miner_lens[0]
        
        return np.concatenate((fork_lens, miner_lens, [self.getHonestFork(miner_ind)]))
    
    def getHonestFork(self, miner_ind):
        if self.honest_fork == miner_ind:
            return 0
        elif self.honest_fork == 0:
            return miner_ind
        return self.honest_fork

    def updateHonestFork(self):
        fork_lens = np.sum(self.state, axis=0)
        if any(fork_lens > fork_lens[self.getHonestFork(0)]):
            self.honest_fork = np.argmax(fork_lens)
    
    def isTerminal(self):
        fork_lens = np.sum(self.state, axis=0)
        return np.max(fork_lens) == self.length
    
    def __eq__(self, other):
        if np.all(self.state == other.state) and self.honest_fork == other.honest_fork:
            return True
        return False
    
    def getStrRep(self, miner_ind):
        state = self.getState(miner_ind)
        return self.getStrRepState(state)
    
    def getStrRepState(self, state):
        res = ''
        for i in range(3):
            res += str(state[i])
        res += ','
        for i in range(3,6):
            res += str(state[i])
        res += ','
        res += str(state[6])
        return res
    
    def getOtherAgentActions(self):
        actions = []
        for i in [1,2]:
            other_agent_state = self.getStrRep(i)
            action = self.policy[self.state_index.i(other_agent_state)]
            if action == 0: # selfish
                actions.append((i, i))
            else: # honest
                actions.append((i, self.honest_fork))
        return actions
    
    def getReward(self):
        reward = -1/3
        if self.isTerminal():
            state = self.getState(0)
            if state[0] == self.length:
                reward += WHALE_REWARD + state[3]
            else:
                win_ind = np.argmax(state[:3])
                reward += state[3 + win_ind]
        return reward
        
    def getNextStatesSelfish(self):
        actions = self.getOtherAgentActions()
        actions.append((0,0)) # mining on my fork
        
        next_states = []
        for a in actions:
            temp = copy.deepcopy(self)
            temp.state[a] += 1
            temp.updateHonestFork()
            next_states.append(temp)
        return next_states
    
    def getNextStatesHonest(self):
        actions = self.getOtherAgentActions()
        actions.append((0,self.getHonestFork(0))) # mining on honest
        
        next_states = []
        for a in actions:
            temp = copy.deepcopy(self)
            temp.state[a] += 1
            temp.updateHonestFork()
            next_states.append(temp)
        return next_states

In [10]:
def getTR(si, total_states, policy, length=2):
    t = np.zeros((2, len(total_states), len(total_states)))
    r = np.zeros((2, len(total_states), len(total_states)))

    init_state = State(si, policy, length)
    states_to_process = [init_state]
    processed_states = [init_state]
    processed_strings = [init_state.getStrRep(0)]
    while states_to_process:
        elem = states_to_process.pop()
        current_string = elem.getStrRep(0)
        selfish_states = elem.getNextStatesSelfish()
            
        for s in selfish_states:
            next_string = s.getStrRep(0)
            t[0, si.i(current_string), si.i(next_string)] += 1/3
            r[0, si.i(current_string), si.i(next_string)] = s.getReward()

            if not s.isTerminal() and not checkIfIn(next_string, processed_strings):
                states_to_process.append(s)
                processed_states.append(copy.deepcopy(s))
                processed_strings.append(next_string)

        honest_states = elem.getNextStatesHonest()
        for s in honest_states:
            next_string = s.getStrRep(0)
            t[1, si.i(current_string), si.i(next_string)] += 1/3
            r[1, si.i(current_string), si.i(next_string)] = s.getReward()

            if not s.isTerminal() and not checkIfIn(next_string, processed_strings):
                states_to_process.append(s)
                processed_states.append(copy.deepcopy(s))
                processed_strings.append(next_string)

    for i in range(len(total_states)):
        if total_states[i].isTerminal():
            t[0,i,i] = 1
            t[1,i,i] = 1
    
    sums = np.sum(t, axis = 2)
    for i in range(len(total_states)):
        if sums[0,i] == 0:
            assert(sums[1,i] == 0)
            t[0,i,i] = 1
            t[1,i,i] = 1
            
    return t, r

In [11]:
si = Ind(total_states)

In [12]:
selfish_policy = np.zeros(len(total_states)) 
honest_policy = np.ones(len(total_states))
honest_policy[0] = 0
honest_policy[1] = 0
# honest_policy[2] = 0
# honest_policy[3] = 0

In [36]:
WHALE_REWARD = 1.82

In [37]:
# t, r = getTR(si, total_states, selfish_policy, length=3)
t, r = getTR(si, total_states, honest_policy, length=3)
val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
val_iter.run()
policy = val_iter.policy



In [38]:
prettyPrintPolicy(policy, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 selfish
001,000,2 selfish
101,100,2 selfish
011,000,2 selfish
002,000,2 honest
002,001,2 honest
102,101,2 honest
012,001,2 selfish
003,001,2 selfish
003,002,2 selfish
112,101,2 selfish
022,001,2 selfish
013,001,2 selfish
022,011,2 selfish
013,002,2 selfish
122,111,2 selfish
032,011,1 selfish
023,011,2 selfish
032,021,1 selfish
023,012,2 selfish
222,211,2 selfish
132,111,1 selfish
123,111,2 selfish
222,111,2 selfish
132,121,1 selfish
123,112,2 selfish
322,211,0 selfish
232,111,1 selfish
223,111,2 selfish
322,111,0 selfish
232,121,1 selfish
223,112,2 selfish
322,311,0 selfish
232,211,1 selfish
223,211,2 selfish
232,221,1 selfish
223,212,2 selfish
122,101,2 selfish
032,001,1 selfish
023,001,2 selfish
023,002,2 selfish
222,201,2 selfish
132,101,1 selfish
123,101,2 selfish
222,101,2 selfish
123,102,2 selfish
322,201,0 selfish
232,101,1 selfish
223,101,2 selfish
322,101,0 selfish
223,102,2 selfish
322,301,0 selfish
232,201,1 selfish
223,201,2 sel

In [175]:
t2, r2 = getTR(si, total_states, policy, length=3)
val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
val_iter.run()
policy2 = val_iter.policy



In [41]:
WHALE_REWARD = 1.82
construction_policy = copy.deepcopy(selfish_policy)

cont = True
count = 0
while cont:
    t, r = getTR(si, total_states, construction_policy, length=3)
    val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
    val_iter.run()
    solved_policy = np.array(val_iter.policy)
    
    # find reachable states
    reachable_states = []
    states_to_process = [0]
    while states_to_process:
        cur_state = states_to_process.pop()
        cur_opt_action = solved_policy[cur_state]
        for next_state in np.nonzero(t[cur_opt_action, cur_state])[0]:
            reachable_states.append(next_state)
            if next_state not in states_to_process and '3' not in si.s(next_state)[:3]:
                states_to_process.append(next_state)
    reachable_states = list(set(reachable_states))
    reachable_nonterminal_states = []
    for rs in reachable_states:
        if '3' not in si.s(rs)[:3]:
            reachable_nonterminal_states.append(rs)
    print(reachable_nonterminal_states)
    if np.all(construction_policy[reachable_nonterminal_states] == solved_policy[reachable_nonterminal_states]):
        print('converged! optimal policy is ...')
        for rs in reachable_nonterminal_states:
            if construction_policy[rs] == 0:
                print('  ', si.s(rs), ' selfish')
            else:
                print('  ', si.s(rs), ' honest')
        cont = False
    construction_policy = copy.deepcopy(solved_policy)

[1, 2, 3, 5, 6, 7, 9, 13, 79, 80, 82, 83, 104, 107, 118, 126, 135, 136, 139, 159, 161, 166, 167, 169, 177, 190, 191, 197, 198, 205, 208, 209, 210, 226, 228, 236, 243, 245, 246, 247, 251]
[1, 2, 3, 6, 7, 208, 210, 245, 248]
converged! optimal policy is ...
   100,100,0  selfish
   010,000,1  honest
   001,000,2  honest
   002,000,2  honest
   002,001,2  honest
   020,000,1  honest
   020,010,1  honest
   200,200,0  selfish
   200,100,0  selfish


In [45]:
np.all(solved_policy == construction_policy)

True

In [88]:
def getReachableStates(policy, si):
    t, r = getTR(si, total_states, policy, length=3)
    reachable_states = []
    states_to_process = [0]
    while states_to_process:
        cur_state = states_to_process.pop()
        cur_opt_action = int(policy[cur_state])
        for next_state in np.nonzero(t[cur_opt_action, cur_state])[0]:
            reachable_states.append(next_state)
            if next_state not in states_to_process and '3' not in si.s(next_state)[:3]:
                states_to_process.append(next_state)
    reachable_states = list(set(reachable_states))
    reachable_nonterminal_states = []
    for rs in reachable_states:
        if '3' not in si.s(rs)[:3]:
            reachable_nonterminal_states.append(rs)
    return reachable_nonterminal_states

def prettyPrintPolicyReachableStates(policy, si):
    reachable_states = getReachableStates(policy, si)
    for i in reachable_states:
        print(si.s(i), end=' ')
        if policy[i] == 0:
            print('selfish')
        else:
            print('honest')

In [117]:
prettyPrintPolicyReachableStates(construction_policy, si)

100,100,0 selfish
010,000,1 honest
001,000,2 honest
002,000,2 honest
002,001,2 honest
020,000,1 honest
020,010,1 honest
200,200,0 selfish
200,100,0 selfish


In [118]:
prettyPrintPolicyReachableStates(honest_policy, si)

100,100,0 selfish
010,000,1 honest
001,000,2 honest
002,000,2 honest
002,001,2 honest
020,000,1 honest
020,010,1 honest
200,200,0 selfish
200,100,0 selfish


In [96]:
si.i('200,100,0')

248

In [112]:
honest_policy[245] = 0
honest_policy[248] = 0

In [140]:
# test_t, test_r = getTR(si, total_states, honest_policy, length=3)
test_t, test_r = getTR(si, total_states, construction_policy, length=3)
val_iter = mdpt.mdp.ValueIteration(test_t, test_r, discount=1)
val_iter.run()



In [141]:
prettyPrintPolicy(val_iter.policy, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 honest
001,000,2 honest
101,100,2 selfish
011,000,2 selfish
002,000,2 honest
002,001,2 honest
102,101,2 selfish
012,001,2 selfish
003,001,2 selfish
003,002,2 selfish
112,101,2 selfish
022,001,2 selfish
013,001,2 selfish
022,011,2 selfish
013,002,2 selfish
122,111,2 selfish
032,011,1 selfish
023,011,2 selfish
032,021,1 selfish
023,012,2 selfish
222,211,2 selfish
132,111,1 selfish
123,111,2 selfish
222,111,2 selfish
132,121,1 selfish
123,112,2 selfish
322,211,0 selfish
232,111,1 selfish
223,111,2 selfish
322,111,0 selfish
232,121,1 selfish
223,112,2 selfish
322,311,0 selfish
232,211,1 selfish
223,211,2 selfish
232,221,1 selfish
223,212,2 selfish
122,101,2 honest
032,001,1 selfish
023,001,2 selfish
023,002,2 selfish
222,201,2 selfish
132,101,1 selfish
123,101,2 selfish
222,101,2 selfish
123,102,2 selfish
322,201,0 selfish
232,101,1 selfish
223,101,2 selfish
322,101,0 selfish
223,102,2 selfish
322,301,0 selfish
232,201,1 selfish
223,201,2 selfi

In [135]:
prettyPrintPolicyReachableStates(val_iter.policy, si)

100,100,0 selfish
010,000,1 selfish
001,000,2 selfish
101,100,2 selfish
011,000,2 selfish
002,000,2 honest
111,100,2 selfish
021,000,1 selfish
121,100,1 selfish
022,000,1 selfish
122,100,1 selfish
222,200,1 selfish
221,200,1 selfish
211,200,0 selfish
221,200,0 selfish
212,200,0 selfish
222,200,0 selfish
201,200,0 selfish
202,200,0 selfish
102,100,2 selfish
012,000,2 selfish
112,100,2 selfish
022,000,2 selfish
110,100,1 selfish
020,000,1 honest
011,000,1 selfish
101,100,0 selfish
111,100,1 selfish
120,100,1 selfish
220,200,1 selfish
122,100,2 selfish
222,200,2 selfish
210,200,0 selfish
212,200,2 selfish
220,200,0 selfish
200,200,0 selfish
110,100,0 selfish
111,100,0 selfish
202,200,2 selfish


In [50]:
np.where(val_iter.policy != construction_policy)

(array([  8,  39, 145, 211]),)

In [53]:
si.s(0)

'000,000,0'

In [297]:
WHALE_REWARD = 2.2
test_t, test_r = getTR(si, total_states, solved_policy, length=3)
val_iter = mdpt.mdp.ValueIteration(test_t, test_r, discount=1)
val_iter.run()
prettyPrintPolicy(val_iter.policy, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 honest
001,000,2 honest
101,100,2 selfish
011,000,2 honest
002,000,2 honest
002,001,2 honest
102,101,2 selfish
012,001,2 honest
003,001,2 selfish
003,002,2 selfish
112,101,2 selfish
022,001,2 honest
013,001,2 selfish
022,011,2 selfish
013,002,2 selfish
122,111,2 selfish
032,011,1 selfish
023,011,2 selfish
032,021,1 selfish
023,012,2 selfish
222,211,2 selfish
132,111,1 selfish
123,111,2 selfish
222,111,2 selfish
132,121,1 selfish
123,112,2 selfish
322,211,0 selfish
232,111,1 selfish
223,111,2 selfish
322,111,0 selfish
232,121,1 selfish
223,112,2 selfish
322,311,0 selfish
232,211,1 selfish
223,211,2 selfish
232,221,1 selfish
223,212,2 selfish
122,101,2 honest
032,001,1 selfish
023,001,2 selfish
023,002,2 selfish
222,201,2 selfish
132,101,1 selfish
123,101,2 selfish
222,101,2 selfish
123,102,2 selfish
322,201,0 selfish
232,101,1 selfish
223,101,2 selfish
322,101,0 selfish
223,102,2 selfish
322,301,0 selfish
232,201,1 selfish
223,201,2 selfish


In [158]:
construction_policy = copy.deepcopy(honest_policy)

cont = True
count = 0
while cont:
    t, r = getTR(si, total_states, construction_policy, length=3)
    val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
    val_iter.run()
    solved_policy = np.array(val_iter.policy)
    
    # find reachable states
    reachable_states = []
    states_to_process = [0]
    while states_to_process:
        cur_state = states_to_process.pop()
        cur_opt_action = solved_policy[cur_state]
        for next_state in np.nonzero(t[cur_opt_action, cur_state])[0]:
            reachable_states.append(next_state)
            if next_state not in states_to_process and '3' not in si.s(next_state)[:3]:
                states_to_process.append(next_state)
    reachable_states = list(set(reachable_states))
    reachable_nonterminal_states = []
    for rs in reachable_states:
        if '3' not in si.s(rs)[:3]:
            reachable_nonterminal_states.append(rs)
    print(reachable_nonterminal_states)
    if np.all(construction_policy[reachable_nonterminal_states] == solved_policy[reachable_nonterminal_states]):
        print('converged! optimal policy is ...')
        for rs in reachable_nonterminal_states:
            if construction_policy[rs] == 0:
                print('  ', si.s(rs), ' selfish')
            else:
                print('  ', si.s(rs), ' honest')
        cont = False
    construction_policy = copy.deepcopy(solved_policy)

[1, 2, 3, 4, 6, 198, 79, 207, 208, 226, 228, 236, 245, 248, 126]
[1, 2, 3, 4, 5, 6, 9, 13, 79, 80, 82, 83, 104, 107, 118, 126, 134, 135, 136, 139, 159, 161, 166, 167, 169, 177, 190, 191, 197, 198, 205, 207, 208, 209, 225, 226, 228, 236, 243, 245, 246, 247, 251]
[1, 2, 3, 4, 5, 6, 135, 9, 13, 136, 139, 161, 198, 199, 201, 205, 79, 80, 207, 83, 208, 209, 226, 228, 236, 237, 239, 243, 245, 246, 247, 126]
converged! optimal policy is ...
   100,100,0  selfish
   010,000,1  selfish
   001,000,2  selfish
   101,100,2  selfish
   011,000,2  honest
   002,000,2  honest
   021,000,1  honest
   012,001,2  honest
   022,001,2  honest
   021,010,1  honest
   022,010,1  honest
   022,000,1  honest
   201,200,0  selfish
   201,100,0  selfish
   202,100,0  selfish
   202,200,0  selfish
   102,100,2  selfish
   012,000,2  honest
   110,100,1  selfish
   022,000,2  honest
   020,000,1  honest
   011,000,1  honest
   120,100,1  selfish
   220,200,1  selfish
   210,200,0  selfish
   210,100,0  selfish
  