In [16]:
import copy 
import matplotlib.pyplot as plt
import mdptoolbox as mdpt
import numpy as np

In [219]:
class State:
    def __init__(self, state, length=2):
        self.state = np.array(state)
        self.terminal_length = length
    
    def __str__(self):
        res = ''
        for i in range(3):
            res += str(self.state[i])
        res += ','
        for i in range(3,6):
            res += str(self.state[i])
        res += ','
        res += str(self.state[6])
        return res
    
    def __setitem__(self, ind, val):
        self.state[ind] = val
        
    def __getitem__(self, ind):
        return self.state[ind]
    
    def isTerminal(self):
        return self.terminal_length in self.state[:3]

    def getHonestFork(self):
        return self.state[-1]
    
    def __eq__(self, other):
        return all(self.state == other.state)

In [220]:
init_state = State([0,2,1,0,0,0,2])
str(init_state), init_state.isTerminal()

('021,000,2', True)

In [221]:
# returns possible next states, given an action
def nextStates(cur_state, action):
    new_states = []
    if action == 'selfish':
        for i in range(3):
            temp = copy.deepcopy(cur_state)
            temp[i] += 1
            if i == 0:
                temp[3] += 1
                
            # new honest fork.
            if any(temp[:3] > temp[temp.getHonestFork()]):
                temp[-1] = np.argmax(temp[:3])
                
            new_states.append(temp)
        
    else: # action == 'honest'
        for i in range(3):
            temp = copy.deepcopy(cur_state)
            if i == 0:
                temp[cur_state[-1]] += 1
                temp[3+cur_state[-1]] += 1
            else:
                temp[i] += 1
            new_states.append(temp)
            
            # new honest fork.
            if any(temp[:3] > temp[temp.getHonestFork()]):
                temp[-1] = np.argmax(temp[:3])
            
    return new_states

In [232]:
init_state = State([0,1,1,0,0,0,2])
res = nextStates(init_state, 'selfish')

In [233]:
str(res[0]), str(res[1]), str(res[2]) 

('111,100,2', '021,000,1', '012,000,2')

In [234]:
def checkIfInTotalStates(total_states, state):
    for s in total_states:
        if state == s:
            return True
    return False

In [292]:
# algorithm to enumerate possible states.
init_state = State([0,0,0,0,0,0,0], length=2)

total_states = [init_state]
states_to_process = [init_state]
while states_to_process:
    elem = states_to_process.pop()
    next_states = nextStates(elem, 'selfish')
    next_states.extend(nextStates(elem, 'honest'))
    for s in next_states:
        if not checkIfInTotalStates(total_states, s):
            total_states.append(s)
        if not s.isTerminal() and not checkIfInTotalStates(states_to_process, s):
            states_to_process.append(s)

In [293]:
for s in total_states:
    print(s)

000,000,0
100,100,0
010,000,1
001,000,2
101,100,2
011,000,2
002,000,2
002,001,2
111,100,2
021,000,1
012,000,2
012,001,2
211,200,0
121,100,1
112,100,2
112,101,2
201,200,0
102,100,2
102,101,2
110,100,1
020,000,1
011,000,1
020,010,1
111,100,1
021,010,1
121,110,1
210,200,0
120,100,1
120,110,1
200,200,0
110,100,0
101,100,0
111,100,0


In [295]:
class StateIndex:
    def __init__(self, total_states):
        self.stringToInd = {}
        self.indToString = {}
        for i in range(len(total_states)):
            self.stringToInd[str(total_states[i])] = i
            self.indToString[i] = str(total_states[i])
    
    def getIndex(self, string):
        return self.stringToInd[string]

    def getString(self, ind):
        return self.indToString[ind]

In [296]:
def getTerminalReward(state, whale_reward=1.5):
    assert(state.isTerminal())
    reward = state[3+np.argmax(state[:3])]
    if np.argmax(state[:3]) == 0:    
        reward += whale_reward
    return reward

In [309]:
def prettyPrintPolicy(policy, state_ind):
    for i in range(len(policy)):
        print(state_ind.getString(i), end=' ')
        if policy[i] == 0:
            print('selfish')
        else:
            print('honest')

In [297]:
state_ind = StateIndex(total_states)

In [298]:
transitions = np.zeros((2, len(total_states), len(total_states)))
rewards = np.zeros((2, len(total_states), len(total_states)))

In [310]:
# algorithm to construct transition and reward matrices.
init_state = State([0,0,0,0,0,0,0], length=2)

states_to_process = [init_state]
while states_to_process:
    elem = states_to_process.pop()
    selfish_states = nextStates(elem, 'selfish')
    start_ind = state_ind.getIndex(str(elem))
    for i in range(3):
        next_ind = state_ind.getIndex(str(selfish_states[i]))
        transitions[0, start_ind, next_ind] = 1/3
        rewards[0, start_ind, next_ind] = -1/3
        if selfish_states[i].isTerminal():
            
            rewards[0, start_ind, next_ind] += getTerminalReward(selfish_states[i], whale_reward=2)

    honest_states = nextStates(elem, 'honest')
    start_ind = state_ind.getIndex(str(elem))
    for i in range(3):
        next_ind = state_ind.getIndex(str(honest_states[i]))
        transitions[1, start_ind, next_ind] = 1/3
        rewards[1, start_ind, next_ind] = -1/3
        if honest_states[i].isTerminal():
            rewards[1, start_ind, next_ind] += getTerminalReward(honest_states[i], whale_reward=2)
        
    next_states = selfish_states + honest_states
    for s in next_states:
        if not s.isTerminal() and not checkIfInTotalStates(states_to_process, s):
            states_to_process.append(s)
        if s.isTerminal():
            index = state_ind.getIndex(str(s))
            transitions[0, index, index] = 1
            transitions[1, index, index] = 1

In [311]:
np.sum(transitions, axis=2)

array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1.]])

In [312]:
val_iter = mdpt.mdp.ValueIteration(transitions, rewards, discount=1)
val_iter.run()



In [313]:
val_iter.V

(0.3703703703703701,
 1.8888888888888884,
 0.111111111111111,
 0.111111111111111,
 1.333333333333333,
 2.7755575615628914e-17,
 0.0,
 0.0,
 0.9999999999999998,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.333333333333333,
 0.0,
 2.7755575615628914e-17,
 0.0,
 0.9999999999999998,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.333333333333333,
 1.333333333333333,
 0.9999999999999998)

In [314]:
prettyPrintPolicy(val_iter.policy, state_ind)

000,000,0 selfish
100,100,0 selfish
010,000,1 selfish
001,000,2 selfish
101,100,2 selfish
011,000,2 honest
002,000,2 selfish
002,001,2 selfish
111,100,2 selfish
021,000,1 selfish
012,000,2 selfish
012,001,2 selfish
211,200,0 selfish
121,100,1 selfish
112,100,2 selfish
112,101,2 selfish
201,200,0 selfish
102,100,2 selfish
102,101,2 selfish
110,100,1 selfish
020,000,1 selfish
011,000,1 honest
020,010,1 selfish
111,100,1 selfish
021,010,1 selfish
121,110,1 selfish
210,200,0 selfish
120,100,1 selfish
120,110,1 selfish
200,200,0 selfish
110,100,0 selfish
101,100,0 selfish
111,100,0 selfish


In [317]:
policy = val_iter.policy

In [427]:
class MinerBlocks:
    def __init__(self):
        self.miner_blocks = np.array([
            [0,0,0],
            [0,0,0],
            [0,0,0]
        ])
        self.honest_fork = 0
    
    def getState(self, miner_ind):
        if np.sum(self.miner_blocks) == 0:
            return np.array([0,0,0,0,0,0,0])
        
        fork_lens = np.sum(self.miner_blocks, axis=0)
        fork_lens[0], fork_lens[miner_ind] = fork_lens[miner_ind], fork_lens[0]
        
        miner_lens = copy.deepcopy(self.miner_blocks[miner_ind])
        miner_lens[0], miner_lens[miner_ind] = miner_lens[miner_ind], miner_lens[0]
        
        return np.concatenate((fork_lens, miner_lens, [self.getHonestFork(miner_ind)]))
    
    def getHonestFork(self, miner_ind):
        if self.honest_fork == miner_ind:
            return 0
        elif self.honest_fork == 0:
            return miner_ind
        return self.honest_fork

    def isTerminal(self):
        return np.max(self.miner_blocks) == 2
    
    def __eq__(self, other):
        if np.all(self.miner_blocks == other.miner_blocks) and self.honest_fork == other.honest_fork:
            return True
        return False
    
def checkForDuplicates(current_set, candidate):
    for s in current_set:
        if candidate == s:
            return True
    return False

In [408]:
mb = MinerBlocks()
print(State(mb.getState(1)))

000,000,0


In [457]:
# construct MDP from optimal policy.
init_state = MinerBlocks()

new_transitions = np.zeros((2, len(total_states), len(total_states)))
new_rewards = np.zeros((2, len(total_states), len(total_states)))

states_to_process = [init_state]
while states_to_process:
    elem = states_to_process.pop()
    
    actions = []
    # get actions of other 2 agents, from the policy.
    for i in range(1,3):
        temp_state = State(elem.getState(i))
        policy_action = policy[state_ind.getIndex(str(temp_state))]
        if policy_action == 0: # selfish
            actions.append((i,i))
        else: # honest
            actions.append((i,elem.honest_fork))
    
    
    # other agents actions, given the policy.
    for act in actions:
        agent, fork = act
        temp_state = copy.deepcopy(elem)
        temp_state.miner_blocks[(agent, fork)] += 1
        
        # new honest fork.
        fork_lens = np.sum(temp_state.miner_blocks, axis=0)
        if any(fork_lens > fork_lens[temp_state.honest_fork]):
            temp_state.honest_fork = np.argmax(fork_lens)
        cur_state_ind = state_ind.getIndex(str(State(elem.getState(0))))
        next_state_ind = state_ind.getIndex(str(State(temp_state.getState(0))))
        new_transitions[0, cur_state_ind, next_state_ind] = 1/3
        new_transitions[1, cur_state_ind, next_state_ind] = 1/3
        
        if not temp_state.isTerminal() and not checkForDuplicates(states_to_process, temp_state):
            states_to_process.append(temp_state)
    
    # We win block.
    for fork in [0, elem.honest_fork]:
        temp_state = copy.deepcopy(elem)
        temp_state.miner_blocks[(0, fork)] += 1
        
        # new honest fork.
        fork_lens = np.sum(temp_state.miner_blocks, axis=0)
        if any(fork_lens > fork_lens[temp_state.honest_fork]):
            temp_state.honest_fork = np.argmax(fork_lens)
        cur_state_ind = state_ind.getIndex(str(State(elem.getState(0))))
        next_state_ind = state_ind.getIndex(str(State(temp_state.getState(0))))
        new_transitions[0, cur_state_ind, next_state_ind] = 1/3
        new_transitions[1, cur_state_ind, next_state_ind] = 1/3
        
        if not temp_state.isTerminal() and not checkForDuplicates(states_to_process, temp_state):
            states_to_process.append(temp_state)

KeyError: '201,100,0'

In [453]:
def printTransitions(transitions):
    for i in range(len(total_states)):
        print(state_ind.getString(i))
        for j in range(len(total_states)):
            if transitions[0,i,j] != 0:
                print("    ", state_ind.getString(j), " selfish, {:0.2f}".format(transitions[0,i,j]))
        for j in range(len(total_states)):
            if transitions[1,i,j] != 0:
                print("    ", state_ind.getString(j), " honest, {:0.2f}".format(transitions[1,i,j]))

In [454]:
printTransitions(new_transitions)

000,000,0
     010,000,1  selfish, 0.33
     001,000,2  selfish, 0.33
     010,000,1  honest, 0.33
     001,000,2  honest, 0.33
100,100,0
010,000,1
     020,000,1  selfish, 0.33
     011,000,1  selfish, 0.33
     020,000,1  honest, 0.33
     011,000,1  honest, 0.33
001,000,2
     011,000,2  selfish, 0.33
     002,000,2  selfish, 0.33
     011,000,2  honest, 0.33
     002,000,2  honest, 0.33
101,100,2
011,000,2
     021,000,1  selfish, 0.33
     012,000,2  selfish, 0.33
     021,000,1  honest, 0.33
     012,000,2  honest, 0.33
002,000,2
002,001,2
111,100,2
021,000,1
012,000,2
012,001,2
211,200,0
121,100,1
112,100,2
112,101,2
201,200,0
102,100,2
102,101,2
110,100,1
020,000,1
011,000,1
     021,000,1  selfish, 0.33
     012,000,2  selfish, 0.33
     021,000,1  honest, 0.33
     012,000,2  honest, 0.33
020,010,1
111,100,1
021,010,1
121,110,1
210,200,0
120,100,1
120,110,1
200,200,0
110,100,0
101,100,0
111,100,0
