In [1]:
from numpy import *
from copy import deepcopy
random.seed(1)

In [2]:
######################################################################################################
## Set up an Environment to work on
######################################################################################################

class Environment:
    
    # MDP = Markov Chain + One-step Decision Theory

    nS = 3
    nA = 2

    def __init__(self):
        
        # Transition function: P(s'|s, a)
        self.P = random.rand(self.nS,self.nA,self.nS)
        self.P = self.P * (self.P > 0.9)

        for s in range(self.nS):
            for a in range(self.nA):
                self.P[s,a,:] = random.rand(self.nS) 
                self.P[s,a,:] = self.P[s,a,:] / sum(self.P[s,a,:])
        
        # reward function: R(s, a, s') 
        self.R = zeros((self.nS,self.nA,self.nS)) 
        self.R[:,0,self.nS-1] = 1
        self.R[:,1,self.nS-1] = 1

    def draw_graph(self, fname):
        ''' Draw the graph.
        '''
        with open(fname, "w") as text_file:
            text_file.write("digraph MDP {\n")
            for s in range(self.nS):
                text_file.write("\ts_%s [style=filled shape=circle fillcolor=lightblue] ;\n" % (s+1)) 
                for a in range(self.nA):
                    text_file.write("\ta_%s%s [label=\"a_%d\", style=filled, shape=diamond, fillcolor=indianred1, fontsize=10, fixedsize=true, width=0.5, height=0.5] ;\n" % (s+1,a+1,a+1)) 
                    text_file.write("\ts_%s -> a_%s%s ;\n" % (s+1,a+1,a+1)) 
                    for s_ in range(self.nS):
                        if self.R[s,a,s_] > 0:
                            text_file.write("\ta_%s%s -> s_%s [label=\"%2.1f (r=%2.1f)\" color=green] ;\n" % (s+1,a+1,s_+1,self.P[s,a,s_],self.R[s,a,s_])) 
                        else:
                            text_file.write("\ta_%s%s -> s_%s [label=\"%2.1f\"] ;\n" % (s+1,a+1,s_+1,self.P[s,a,s_])) 
            text_file.write("}")

In [3]:
env = Environment()
env.draw_graph("markov_decision_process.dot")

In [4]:
######################################################################################################
## Task 1 : Value Iteration
######################################################################################################

# Objective: obtain a policy pi*: S -> A, indicating which action to take in a given state
# Apply Bellman optimality equation Eq. iteratively, until convergence.

epsilon = 0.05
gamma   = 0.9

def value_iteration():

    V = zeros(env.nS)
    pi = zeros(env.nS, dtype='int64')
    W = zeros(env.nS)
    values = zeros(env.nA)

    while "No convergence":

        V = deepcopy(W)

        for s in range(env.nS):
            for a in range(env.nA):
                values[a] = sum([env.P[s,a,t]*(env.R[s,a,t] + gamma*V[t]) for t in range(env.nS)])

            pi[s] = argmax(values)
            W[s] = values[pi[s]]

        if linalg.norm(V-W) < epsilon:
            break

    return W,pi

print(value_iteration())

(array([6.67835378, 6.77772066, 6.63588305]), array([0, 1, 0]))


In [6]:
######################################################################################################
## Task 2 : Policty iteration
######################################################################################################

def policy_iteration():

    V = zeros(env.nS)
    pi = zeros(env.nS, dtype='int64')
    W = zeros(env.nS)
    i = 0
    
    updated = True
    
    while updated:
        updated = False
        i += 1
        V = deepcopy(W)
        
        # Policty evaluation
        for s in range(env.nS):
            W[s] = sum([env.P[s,pi[s],t]*(env.R[s,pi[s],t] + gamma*V[t]) for t in range(env.nS)])
        
        # Greedy Pi update
        for s in range(env.nS):
            q_best = W[s]

            for a in range(env.nA):
                q_sa = sum([env.P[s,a,t]*(env.R[s,a,t] + gamma*W[t]) for t in range(env.nS)])

                if q_sa > q_best:
                    pi[s] = a
                    q_best = q_sa
                    updated = True

    return W,pi

print(policy_iteration())

(array([6.91564801, 7.0150149 , 6.87317729]), array([0, 1, 0]))
