In [2]:
import sys
stdout = sys.stdout
from functools import reduce
sys.stdout = stdout

Define the MDP (environment)

In [3]:
def getAllPossibleActions(currentState):
  if currentState == "forget" or currentState == "done":
    return []
  
  else:
    allValidTaskActions = [index+1 for index, value in enumerate(currentState) if value == 0]
    otherActions = [0, 6]
    allActions = allValidTaskActions+otherActions
    return allActions

def transitionFunction(currentState, action):
  if action is 0:
    return {"forget": 0.025, currentState: 0.975}

  elif action is 6:
    if currentState == (1, 1, 1, 1, 1):
      return {"done": 1}
    else:
      return {currentState: 1}

  else:
    newState = list(currentState)
    newState[action-1] = 1
    return {tuple(newState): 1}

def rewardFunction(currentState, action, nextState):
  if action is 6 and currentState != (1, 1, 1, 1, 1):
    return 0

  if action is 0 and nextState == "forget":
      forgottenBonusEachTask = (0.25, 0.5, 0.75, 1, 1.75)
      totalBonus = 0
      for task in range(0, 5):
        if currentState[task] == 1:
          totalBonus += forgottenBonusEachTask[task]
      return totalBonus + 0.5

  else:
    taskFairPrices = (0.67, 1.33, 2, 2.67, 3.33)
    # taskFairPrices = (3, 3.25, 2.25, 3, 1)
    taskRewards = [-fairPrice for fairPrice in taskFairPrices]
    allRewards = [0.5] + taskRewards + [20]
    return allRewards[action]

Code to generate all possible states 

In [4]:
def generateAllPermutations(startingArray, numToDoTasks):
    if numToDoTasks == 0:
      return [tuple(startingArray)]

    allStates = []
    for bit in range(2):
      startingArrayWithNewBit = startingArray.copy()
      startingArrayWithNewBit.append(bit)
      allStates += generateAllPermutations(startingArrayWithNewBit, numToDoTasks-1)

    return allStates

generateAllStates = lambda numToDoTasks: generateAllPermutations([], numToDoTasks) + ["forget", "done"]

In [5]:
allStates = generateAllStates(5)
print(allStates)
print(len(allStates))

[(0, 0, 0, 0, 0), (0, 0, 0, 0, 1), (0, 0, 0, 1, 0), (0, 0, 0, 1, 1), (0, 0, 1, 0, 0), (0, 0, 1, 0, 1), (0, 0, 1, 1, 0), (0, 0, 1, 1, 1), (0, 1, 0, 0, 0), (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 1), (0, 1, 1, 0, 0), (0, 1, 1, 0, 1), (0, 1, 1, 1, 0), (0, 1, 1, 1, 1), (1, 0, 0, 0, 0), (1, 0, 0, 0, 1), (1, 0, 0, 1, 0), (1, 0, 0, 1, 1), (1, 0, 1, 0, 0), (1, 0, 1, 0, 1), (1, 0, 1, 1, 0), (1, 0, 1, 1, 1), (1, 1, 0, 0, 0), (1, 1, 0, 0, 1), (1, 1, 0, 1, 0), (1, 1, 0, 1, 1), (1, 1, 1, 0, 0), (1, 1, 1, 0, 1), (1, 1, 1, 1, 0), (1, 1, 1, 1, 1), 'forget', 'done']
34


Value iteration algorithm to solve the MDP

In [6]:
class ValueIteration(object):
    def __init__(self, allStates, rewardFunction, transitionFunction, getAllPossibleActions, valueTable, convergenceTolerance, gamma):
        self.allStates = allStates.copy()
        self.rewardFunction = rewardFunction
        self.transitionFunction = transitionFunction
        self.getAllPossibleActions = getAllPossibleActions
        self.valueTable = valueTable.copy()
        self.convergenceTolerance = convergenceTolerance
        self.gamma = gamma

    def getValueOfState(self, state):
        return self.valueTable[(state)]

    def computeQValue(self, currentState, actionTaken):
        allNextStatesWithProbabilities = self.transitionFunction(currentState, actionTaken)
        possibleNextStates = list(allNextStatesWithProbabilities.keys())
        probabilitiesOfAllNextStates = list(allNextStatesWithProbabilities.values())
        return reduce(lambda x, y: x + y, [probabilityOfNextState * (self.rewardFunction(currentState, actionTaken, nextState) + self.gamma * self.getValueOfState(nextState)) for probabilityOfNextState, nextState in zip(probabilitiesOfAllNextStates, possibleNextStates)])

    def updateValueOfState(self, state):
        possibleActions = self.getAllPossibleActions(state)
        if len(possibleActions) == 0:
          return 0
        QValuesForAllActions = [self.computeQValue(state, actionTaken) for actionTaken in possibleActions]
        return max(list(QValuesForAllActions))

    def __call__(self):
        counter = 0
        while True:
            maxChangeInValue = 0

            for state in self.allStates:
                currentState = state
                currentValueOfState = self.valueTable[currentState]
                updatedValueOfState = self.updateValueOfState(currentState)
                self.valueTable[state] = updatedValueOfState
                if abs(currentValueOfState - updatedValueOfState) > maxChangeInValue:
                    maxChangeInValue = abs(currentValueOfState - updatedValueOfState)

            if maxChangeInValue < self.convergenceTolerance:
                break

        stateValues = self.valueTable
        return stateValues

Run the value iteration algorithm 

In [7]:
initialValueTable = {state: 0 for state in allStates}
convergenceTolerance = 1e-10
gamma = 0.95
valueIteration = ValueIteration(allStates, rewardFunction, transitionFunction, getAllPossibleActions, initialValueTable, convergenceTolerance, gamma)

In [8]:
finalValueTable = valueIteration()

# **Function to compute optimal rewards**

In [9]:
def computeOptimalIncentive(task):
  nextState = [0]*5
  nextState[task-1] = 1
  immediateReward = rewardFunction((0, 0, 0, 0, 0), task, nextState)
  currentStateValue = finalValueTable[(0, 0, 0, 0, 0)]
  nextStateValue = finalValueTable[tuple(nextState)]

  return gamma*nextStateValue - currentStateValue

In [10]:
print("task: \t\toptimal incentive")
[print("task {}: \t\t{}".format(task, round(computeOptimalIncentive(task), 3))) for task in range(1, 6)]

task: 		optimal incentive
task 1: 		0.626
task 2: 		1.253
task 3: 		1.858
task 4: 		2.432
task 5: 		2.97


[None, None, None, None, None]

# **TESTS**

In [11]:
!pip install ddt
import unittest
from ddt import ddt, data, unpack

Collecting ddt
  Downloading https://files.pythonhosted.org/packages/31/3b/a38bb1606c0b912cd53976369ac10334f6b5e96fa260eebc46fabe6a43bf/ddt-1.4.1-py2.py3-none-any.whl
Installing collected packages: ddt
Successfully installed ddt-1.4.1


In [12]:
@ddt
class TestEnvironment(unittest.TestCase):
  @data(((0, 0, 0, 0, 0), [1, 2, 3, 4, 5, 0, 6]),
        ((1, 0, 1, 0, 1), [2, 4, 0, 6]),
        ((1, 1, 1, 1, 1), [0, 6]))
  @unpack
  def testGetAllPossibleActions(self, currentState, trueAllPossibleActions):
    allPossibleActions = getAllPossibleActions(currentState)
    assert len(trueAllPossibleActions) == len(allPossibleActions)
    assert ([a == b for a, b in zip(trueAllPossibleActions, allPossibleActions)])

  @data(((0, 0, 0, 0, 0), 0, {(0, 0, 0, 0, 0): 0.975, "forget": 0.025}),
        ((0, 0, 0, 0, 0), 1, {(1, 0, 0, 0, 0): 1}),
        ((0, 0, 0, 0, 0), 6, {(0, 0, 0, 0, 0): 1}),
        ((1, 0, 1, 0, 1), 2, {(1, 1, 1, 0, 1): 1}),
        ((1, 1, 1, 1, 1), 0, {(1, 1, 1, 1, 1): 0.975, "forget": 0.025}),
        ((1, 1, 1, 1, 1), 6, {"done": 1}))
  @unpack
  def testTransitionFunction(self, currentState, action, trueNextStateDict):
    nextStateDict = transitionFunction(currentState, action)
    self.assertDictEqual(trueNextStateDict, nextStateDict)

  @data(((0, 0, 0, 0, 0), 2, (0, 1, 0, 0, 0), -1.33),
        ((0, 0, 0, 0, 0), 0, (0, 0, 0, 0, 0), 0.5),
        ((0, 0, 0, 0, 0), 0, "forget", 0.5),
        ((1, 0, 1, 0, 1), 0, (1, 0, 1, 0, 1), 0.5),
        ((1, 0, 1, 0, 1), 0, "forget", 3.25),
        ((1, 0, 1, 0, 1), 6, (1, 0, 1, 0, 1), 0),
        ((1, 1, 1, 1, 1), 6, "done", 20))
  @unpack
  def testRewardFunction(self, currentState, action, nextState, trueReward):
    reward = rewardFunction(currentState, action, nextState)
    self.assertAlmostEqual(reward, trueReward)

In [13]:
@ddt
class TestValueIterationFunctions(unittest.TestCase):
  @data(((0, 0, 0, 0, 0), 2, {state: 0 for state in allStates}, -1.33),
        ((0, 1, 0, 0, 0), 0, {state: 0 for state in allStates}, 0.5125),
        ((1, 1, 1, 1, 1), 6, {state: 0 for state in allStates}, 20),
        ((0, 0, 0, 0, 0), 2, {state: 3 for state in allStates}, 1.52),
        ((1, 0, 1, 0, 1), 0, {state: 3 for state in allStates}, 3.41875),
        ((1, 0, 1, 0, 1), 6, {state: 3 for state in allStates}, 2.85))
  @unpack
  def testComputeQValue(self, currentState, actionTaken, valueTable, trueQValue):
    valueIteration = ValueIteration(allStates, rewardFunction, transitionFunction, getAllPossibleActions, valueTable, convergenceTolerance, gamma)
    QValue = valueIteration.computeQValue(currentState, actionTaken)
    self.assertAlmostEqual(trueQValue, QValue)

  @data(((0, 0, 0, 0, 0), {state: 0 for state in allStates}, 0.5),
        ((1, 0, 1, 0, 1), {state: 3 for state in allStates}, 3.41875),
        ((1, 1, 1, 1, 1), {state: 3 for state in allStates}, 22.85),
        ("forget", {state: 3 for state in allStates}, 0),
        ("done", {state: 3 for state in allStates}, 0))
  @unpack
  def testUpdateValueOfState(self, state, valueTable, trueUpdatedValue):
    valueIteration = ValueIteration(allStates, rewardFunction, transitionFunction, getAllPossibleActions, valueTable, convergenceTolerance, gamma)
    updatedValue = valueIteration.updateValueOfState(state)
    self.assertAlmostEqual(trueUpdatedValue, updatedValue)

unittest.main(argv=[''], verbosity=2, exit=False)

testGetAllPossibleActions_1___0__0__0__0__0____1__2__3__4__5__0__6__ (__main__.TestEnvironment) ... ok
testGetAllPossibleActions_2___1__0__1__0__1____2__4__0__6__ (__main__.TestEnvironment) ... ok
testGetAllPossibleActions_3___1__1__1__1__1____0__6__ (__main__.TestEnvironment) ... ok
testRewardFunction_1___0__0__0__0__0___2___0__1__0__0__0____1_33_ (__main__.TestEnvironment) ... ok
testRewardFunction_2___0__0__0__0__0___0___0__0__0__0__0___0_5_ (__main__.TestEnvironment) ... ok
testRewardFunction_3___0__0__0__0__0___0___forget___0_5_ (__main__.TestEnvironment) ... ok
testRewardFunction_4___1__0__1__0__1___0___1__0__1__0__1___0_5_ (__main__.TestEnvironment) ... ok
testRewardFunction_5___1__0__1__0__1___0___forget___3_25_ (__main__.TestEnvironment) ... ok
testRewardFunction_6___1__0__1__0__1___6___1__0__1__0__1___0_ (__main__.TestEnvironment) ... ok
testRewardFunction_7___1__1__1__1__1___6___done___20_ (__main__.TestEnvironment) ... ok
testTransitionFunction_1 (__main__.TestEnvironment) 

<unittest.main.TestProgram at 0x7f808927f748>