In [191]:
import numpy as np
import random as rnd
from matplotlib import pyplot as plt

In [192]:
class Garnet:
    
    def __init__(self, Ns, Na, Nb, gamma):
        
        self.Ns = Ns # number of states
        self.Na = Na # number of actions in each state
        self.Nb = Nb # branching factor
        self.gamma = gamma # discount factor
        
        self.initGarnet()
    
    # initialize garnet
    def initGarnet(self): # initialize garnet problem
        
        self.initTransition()
        self.initReward()
        self.initPi()
        
    #------------------------
    
    # initializers
    def initTransition(self):
        
        P = np.zeros((self.Ns, self.Na, self.Ns))
        
        for i in range(self.Ns):
            for j in range(self.Na):
                
                sampleArr = rnd.sample(range(self.Ns), k = self.Nb)
                P_sum = np.random.rand(self.Nb - 1)
                P_sum = np.sort([*P_sum, 0, 1])
                for k in range(self.Nb):
                    P[i][j][sampleArr[k]] = P_sum[k + 1] - P_sum[k]
        
        self.P = P                                            # P: s x a x s'
    
    def initReward(self):
        self.R = np.random.uniform(size=(self.Ns, self.Na))   # R: s x a
    
    def initPi(self):
        # uniform random
        self.pi = np.ones((self.Ns, self.Na)) / self.Na    # pi: s x a
    
    #------------------------
    
    # generate sequence
    def getNext(self, s):
        
        a = self.getAction(s)
        r = self.R[s,a]
        s_next = self.getState(s,a)
        
        return a, r, s_next
    
    def getAction(self, s):
        return np.random.choice(np.arange(self.Na),p=self.pi[s,:])
    
    def getState(self, s, a):
        return np.random.choice(np.arange(self.Ns), p=self.P[s,a,:])
    
    #------------------------
    
    # cal & store true state values given a policy
    def calValue(self): 
        
        # works for different pi
        
        P_pi = np.einsum('ijk,ij->ik', self.P, self.pi)  # P_pi: s x s'
        R_pi = np.einsum('ij,ij->i', self.pi, self.R)    # R_pi: s x 1
        
        self.P_pi = P_pi
        self.R_pi = R_pi
        
        self.V_pi = np.linalg.inv(np.identity(self.Ns) - self.gamma * P_pi).dot(R_pi)
        return self.V_pi
    
    # calculate the avergae value given a policy
    def calAvgReward(self):
        
        # works for different pi
        
        self.calValue()
        e, v = np.linalg.eig(self.P_pi.T)
        mu = v[:,0] #eigenvector
        
        return self.R_pi.dot(mu).real
    
    #------------------------
    
    # MSE
    def MSE(self, V):
        
        self.calValue()
        if len(V) != len(self.V_pi):
            raise Exception("The dimensions of V {} and V_pi {} do not match" \
                            .format(V.shape, self.V_pi.shape) )
        
        return np.mean((V - self.V_pi)**2)

In [193]:
g = Garnet(10,10,10,0.9)