In [1]:
import copy 
import matplotlib.pyplot as plt
import mdptoolbox as mdpt
import numpy as np

In [5]:
class FirstPersonState:
    def __init__(self, state, length=2):
        self.state = np.array(state)
        self.terminal_length = length
    
    def __str__(self):
        res = ''
        for i in range(3):
            res += str(self.state[i])
        res += ','
        for i in range(3,6):
            res += str(self.state[i])
        res += ','
        res += str(self.state[6])
        return res
    
    def __setitem__(self, ind, val):
        self.state[ind] = val
        
    def __getitem__(self, ind):
        return self.state[ind]
    
    def isTerminal(self):
        return self.terminal_length in self.state[:3]

    def getHonestFork(self):
        return self.state[-1]

    def updateHonestFork(self):
        if any(self.state[:3] > self.state[self.getHonestFork()]):
            self.state[-1] = np.argmax(self.state[:3])
    
    def __eq__(self, other):
        return all(self.state == other.state)

In [12]:
def nextFirstPersonStates(state):
    new_states = []
    
    # anyone mining on their own fork
    for i in range(3):
        temp = copy.deepcopy(state)
        temp[i] += 1
        if i == 0:
            temp[3] += 1
        temp.updateHonestFork()
        new_states.append(temp)
        
    # anyone mining on another fork.
    for i in range(3):
        if state[i] != 0:
            for j in range(3):
                if j != i:
                    temp = copy.deepcopy(state)
                    temp[i] += 1
                    if j == 0:
                        temp[3+i] += 1
                    temp.updateHonestFork()
                    new_states.append(temp)
    return new_states

In [76]:
def checkIfIn(state, states):
    for s in states:
        if state == s:
            return True
    return False

In [118]:
def printTransitions(transitions, rewards, si):
    for i in range(len(total_states)):
        print(si.s(i))
        for j in range(len(total_states)):
            if transitions[0,i,j] != 0:
                print("    ", si.s(j), " selfish, {:0.2f}, {:0.2f}".format(
                    transitions[0,i,j], rewards[0,i,j]))
        for j in range(len(total_states)):
            if transitions[1,i,j] != 0:
                print("    ", si.s(j), " honest, {:0.2f}, {:0.2f}".format(
                    transitions[1,i,j], rewards[1,i,j]))

def prettyPrintPolicy(policy, si):
    for i in range(len(policy)):
        print(si.s(i), end=' ')
        if policy[i] == 0:
            print('selfish')
        else:
            print('honest')

In [14]:
def getTotalStates(length = 2):
    init_state = FirstPersonState([0,0,0,0,0,0,0], length)
    oneblock_state_a = FirstPersonState([1,0,0,1,0,0,0], length)
    oneblock_state_b = FirstPersonState([0,1,0,0,0,0,1], length)
    oneblock_state_c = FirstPersonState([0,0,1,0,0,0,2], length)
    total_states = [init_state, oneblock_state_a, oneblock_state_b, oneblock_state_c]
    states_to_process = [oneblock_state_a, oneblock_state_b, oneblock_state_c]
    
    while states_to_process:
        elem = states_to_process.pop()
        next_states = nextFirstPersonStates(elem)
        for s in next_states:
            if not checkIfIn(s, total_states):
                total_states.append(s)
            if not s.isTerminal() and not checkIfIn(s, states_to_process):
                states_to_process.append(s)
    return total_states

In [18]:
total_states = getTotalStates()
len(total_states)

37

In [20]:
class Ind:
    def __init__(self, total_states):
        self.stringToInd = {}
        self.indToString = {}
        for i in range(len(total_states)):
            self.stringToInd[str(total_states[i])] = i
            self.indToString[i] = str(total_states[i])
    
    def i(self, string):
        return self.stringToInd[string]

    def s(self, ind):
        return self.indToString[ind]

In [263]:
class State:
    def __init__(self, state_index, policy, length=2):
        self.state = np.array([
            [0,0,0],
            [0,0,0],
            [0,0,0]
        ])
        self.honest_fork = 0
        self.state_index = state_index
        self.policy = policy
        self.length = length
        self.visited = False
    
    def getState(self, miner_ind):
        if np.sum(self.state) == 0:
            return np.array([0,0,0,0,0,0,0])
        
        fork_lens = np.sum(self.state, axis=0)
        fork_lens[0], fork_lens[miner_ind] = fork_lens[miner_ind], fork_lens[0]
        
        miner_lens = copy.deepcopy(self.state[miner_ind])
        miner_lens[0], miner_lens[miner_ind] = miner_lens[miner_ind], miner_lens[0]
        
        return np.concatenate((fork_lens, miner_lens, [self.getHonestFork(miner_ind)]))
    
    def getHonestFork(self, miner_ind):
        if self.honest_fork == miner_ind:
            return 0
        elif self.honest_fork == 0:
            return miner_ind
        return self.honest_fork

    def updateHonestFork(self):
        fork_lens = np.sum(self.state, axis=0)
        if any(fork_lens > fork_lens[self.getHonestFork(0)]):
            self.honest_fork = np.argmax(fork_lens)
    
    def isTerminal(self):
        fork_lens = np.sum(self.state, axis=0)
        return np.max(fork_lens) == self.length
    
    def __eq__(self, other):
        if np.all(self.state == other.state) and self.honest_fork == other.honest_fork:
            return True
        return False
    
    def getStrRep(self, miner_ind):
        state = self.getState(miner_ind)
        return self.getStrRepState(state)
    
    def getStrRepState(self, state):
        res = ''
        for i in range(3):
            res += str(state[i])
        res += ','
        for i in range(3,6):
            res += str(state[i])
        res += ','
        res += str(state[6])
        return res
    
    def getOtherAgentActions(self):
        actions = []
        for i in [1,2]:
            other_agent_state = self.getStrRep(i)
            action = self.policy[self.state_index.i(other_agent_state)]
            if action == 0: # selfish
                actions.append((i, i))
            else: # honest
                actions.append((i, self.honest_fork))
        return actions
    
    def getReward(self):
        reward = -1/3
        if self.isTerminal():
            state = self.getState(0)
            if state[0] == 2:
                reward += WHALE_REWARD + state[3]
            else:
                win_ind = np.argmax(state[:3])
                reward += state[3 + win_ind]
        return reward
        
    def getNextStatesSelfish(self):
        actions = self.getOtherAgentActions()
        actions.append((0,0)) # mining on my fork
        
        next_states = []
        for a in actions:
            temp = copy.deepcopy(self)
            temp.state[a] += 1
            temp.updateHonestFork()
            next_states.append(temp)
        return next_states
    
    def getNextStatesHonest(self):
        actions = self.getOtherAgentActions()
        actions.append((0,self.getHonestFork(0))) # mining on honest
        
        next_states = []
        for a in actions:
            temp = copy.deepcopy(self)
            temp.state[a] += 1
            temp.updateHonestFork()
            next_states.append(temp)
        return next_states

In [339]:
def getTR(si, total_states, policy):
    t = np.zeros((2, len(total_states), len(total_states)))
    r = np.zeros((2, len(total_states), len(total_states)))

    init_state = State(si, policy)
    states_to_process = [init_state]
    processed_states = [init_state]
    while states_to_process:
        elem = states_to_process.pop()
        current_string = elem.getStrRep(0)
        selfish_states = elem.getNextStatesSelfish()
            
        for s in selfish_states:
            next_string = s.getStrRep(0)
            t[0, si.i(current_string), si.i(next_string)] += 1/3
            r[0, si.i(current_string), si.i(next_string)] = s.getReward()

            if not s.isTerminal() and not checkIfIn(s, processed_states):
                states_to_process.append(s)
                processed_states.append(copy.deepcopy(s))

        honest_states = elem.getNextStatesHonest()
        for s in honest_states:
            next_string = s.getStrRep(0)
            t[1, si.i(current_string), si.i(next_string)] += 1/3
            r[1, si.i(current_string), si.i(next_string)] = s.getReward()

            if not s.isTerminal() and not checkIfIn(s, processed_states):
                states_to_process.append(s)
                processed_states.append(copy.deepcopy(s))

    for i in range(len(total_states)):
        if total_states[i].isTerminal():
            t[0,i,i] = 1
            t[1,i,i] = 1
    
    sums = np.sum(t, axis = 2)
    for i in range(len(total_states)):
        if sums[0,i] == 0:
            assert(sums[1,i] == 0)
            t[0,i,i] = 1
            t[1,i,i] = 1
            
    return t, r

In [340]:
si = Ind(total_states)

In [384]:
selfish_policy = np.zeros(len(total_states)) 
honest_policy = np.ones(len(total_states))
honest_policy[0] = 0

In [392]:
WHALE_REWARD = 2.1

In [395]:
# t, r = getTR(si, total_states, selfish_policy)
t, r = getTR(si, total_states, honest_policy)
val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
val_iter.run()
policy = val_iter.policy



In [396]:
prettyPrintPolicy(policy, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 selfish
001,000,2 selfish
101,100,2 selfish
011,000,2 selfish
002,000,2 selfish
002,001,2 selfish
111,100,2 selfish
021,000,1 selfish
012,000,2 selfish
021,010,1 selfish
012,001,2 selfish
211,200,0 selfish
121,100,1 selfish
112,100,2 selfish
211,100,0 selfish
121,110,1 selfish
112,101,2 selfish
201,200,0 selfish
102,100,2 selfish
201,100,0 selfish
102,101,2 selfish
110,100,1 selfish
020,000,1 selfish
011,000,1 selfish
020,010,1 selfish
111,100,1 selfish
210,200,0 selfish
120,100,1 selfish
210,100,0 selfish
120,110,1 selfish
200,200,0 selfish
110,100,0 selfish
101,100,0 selfish
200,100,0 selfish
111,100,0 selfish


In [388]:
t, r = getTR(si, total_states, policy)
val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
val_iter.run()
policy2 = val_iter.policy



In [389]:
prettyPrintPolicy(policy2, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 selfish
001,000,2 honest
101,100,2 selfish
011,000,2 selfish
002,000,2 selfish
002,001,2 selfish
111,100,2 selfish
021,000,1 selfish
012,000,2 selfish
021,010,1 selfish
012,001,2 selfish
211,200,0 selfish
121,100,1 selfish
112,100,2 selfish
211,100,0 selfish
121,110,1 selfish
112,101,2 selfish
201,200,0 selfish
102,100,2 selfish
201,100,0 selfish
102,101,2 selfish
110,100,1 selfish
020,000,1 selfish
011,000,1 honest
020,010,1 selfish
111,100,1 selfish
210,200,0 selfish
120,100,1 selfish
210,100,0 selfish
120,110,1 selfish
200,200,0 selfish
110,100,0 selfish
101,100,0 selfish
200,100,0 selfish
111,100,0 selfish


In [374]:
t, r = getTR(si, total_states, policy2)
val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
val_iter.run()
policy3 = val_iter.policy



In [375]:
prettyPrintPolicy(policy3, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 honest
001,000,2 selfish
101,100,2 selfish
011,000,2 selfish
002,000,2 selfish
002,001,2 selfish
111,100,2 selfish
021,000,1 selfish
012,000,2 selfish
021,010,1 selfish
012,001,2 selfish
211,200,0 selfish
121,100,1 selfish
112,100,2 selfish
211,100,0 selfish
121,110,1 selfish
112,101,2 selfish
201,200,0 selfish
102,100,2 selfish
201,100,0 selfish
102,101,2 selfish
110,100,1 selfish
020,000,1 selfish
011,000,1 honest
020,010,1 selfish
111,100,1 selfish
210,200,0 selfish
120,100,1 selfish
210,100,0 selfish
120,110,1 selfish
200,200,0 selfish
110,100,0 selfish
101,100,0 selfish
200,100,0 selfish
111,100,0 selfish


In [376]:
t, r = getTR(si, total_states, policy3)
val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
val_iter.run()
policy4 = val_iter.policy



In [377]:
prettyPrintPolicy(policy4, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 honest
001,000,2 selfish
101,100,2 selfish
011,000,2 honest
002,000,2 selfish
002,001,2 selfish
111,100,2 selfish
021,000,1 selfish
012,000,2 selfish
021,010,1 selfish
012,001,2 selfish
211,200,0 selfish
121,100,1 selfish
112,100,2 selfish
211,100,0 selfish
121,110,1 selfish
112,101,2 selfish
201,200,0 selfish
102,100,2 selfish
201,100,0 selfish
102,101,2 selfish
110,100,1 selfish
020,000,1 selfish
011,000,1 selfish
020,010,1 selfish
111,100,1 selfish
210,200,0 selfish
120,100,1 selfish
210,100,0 selfish
120,110,1 selfish
200,200,0 selfish
110,100,0 selfish
101,100,0 selfish
200,100,0 selfish
111,100,0 selfish


In [378]:
t, r = getTR(si, total_states, policy4)
val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
val_iter.run()
policy5 = val_iter.policy



In [379]:
prettyPrintPolicy(policy5, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 selfish
001,000,2 honest
101,100,2 selfish
011,000,2 honest
002,000,2 selfish
002,001,2 selfish
111,100,2 selfish
021,000,1 selfish
012,000,2 selfish
021,010,1 selfish
012,001,2 selfish
211,200,0 selfish
121,100,1 selfish
112,100,2 selfish
211,100,0 selfish
121,110,1 selfish
112,101,2 selfish
201,200,0 selfish
102,100,2 selfish
201,100,0 selfish
102,101,2 selfish
110,100,1 selfish
020,000,1 selfish
011,000,1 selfish
020,010,1 selfish
111,100,1 selfish
210,200,0 selfish
120,100,1 selfish
210,100,0 selfish
120,110,1 selfish
200,200,0 selfish
110,100,0 selfish
101,100,0 selfish
200,100,0 selfish
111,100,0 selfish


In [380]:
t, r = getTR(si, total_states, policy5)
val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
val_iter.run()
policy6 = val_iter.policy



In [381]:
prettyPrintPolicy(policy6, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 selfish
001,000,2 honest
101,100,2 selfish
011,000,2 selfish
002,000,2 selfish
002,001,2 selfish
111,100,2 selfish
021,000,1 selfish
012,000,2 selfish
021,010,1 selfish
012,001,2 selfish
211,200,0 selfish
121,100,1 selfish
112,100,2 selfish
211,100,0 selfish
121,110,1 selfish
112,101,2 selfish
201,200,0 selfish
102,100,2 selfish
201,100,0 selfish
102,101,2 selfish
110,100,1 selfish
020,000,1 selfish
011,000,1 honest
020,010,1 selfish
111,100,1 selfish
210,200,0 selfish
120,100,1 selfish
210,100,0 selfish
120,110,1 selfish
200,200,0 selfish
110,100,0 selfish
101,100,0 selfish
200,100,0 selfish
111,100,0 selfish


In [382]:
t, r = getTR(si, total_states, policy6)
val_iter = mdpt.mdp.ValueIteration(t, r, discount=1)
val_iter.run()
policy7 = val_iter.policy



In [383]:
prettyPrintPolicy(policy7, si)

000,000,0 selfish
100,100,0 selfish
010,000,1 honest
001,000,2 selfish
101,100,2 selfish
011,000,2 selfish
002,000,2 selfish
002,001,2 selfish
111,100,2 selfish
021,000,1 selfish
012,000,2 selfish
021,010,1 selfish
012,001,2 selfish
211,200,0 selfish
121,100,1 selfish
112,100,2 selfish
211,100,0 selfish
121,110,1 selfish
112,101,2 selfish
201,200,0 selfish
102,100,2 selfish
201,100,0 selfish
102,101,2 selfish
110,100,1 selfish
020,000,1 selfish
011,000,1 honest
020,010,1 selfish
111,100,1 selfish
210,200,0 selfish
120,100,1 selfish
210,100,0 selfish
120,110,1 selfish
200,200,0 selfish
110,100,0 selfish
101,100,0 selfish
200,100,0 selfish
111,100,0 selfish
