In [2]:
import numpy as np
import random as rnd
from matplotlib import pyplot as plt

In [254]:
class Garnet:
    
    def __init__(self, Ns, Na, Nb, gamma):
        
        self.Ns = Ns # number of states
        self.Na = Na # number of actions in each state
        self.Nb = Nb # branching factor
        self.gamma = gamma # discount factor
        
        self.initGarnet()
    
    # initialize garnet
    def initGarnet(self): # initialize garnet problem
        
        self.initTransition()
        self.initReward()
        self.initPi()
        
    #------------------------
    
    # initializers
    def initTransition(self):
        
        P = np.zeros((self.Ns, self.Na, self.Ns))
        
        for i in range(self.Ns):
            for j in range(self.Na):
                
                sampleArr = rnd.sample(range(self.Ns), k = self.Nb)
                P_sum = np.random.rand(self.Nb - 1)
                P_sum = np.sort([*P_sum, 0, 1])
                for k in range(self.Nb):
                    P[i][j][sampleArr[k]] = P_sum[k + 1] - P_sum[k]
        
        self.P = P
    
    def initReward(self):
        self.R = np.random.uniform(size=(self.Ns, self.Na))
    
    def initPi(self):
        # uniform random
        self.pi = np.ones((self.Na, self.Ns)) / self.Na
    
    #------------------------
    
    # generate sequence
    def getNext(self, s):
        
        a = self.getAction(s)
        r = self.R[s,a]
        s_next = self.getState(s,a)
        
        return a, r, s_next
    
    def getAction(self, s):
        return np.random.choice(np.arange(self.Na),p=self.pi[:,s])
    
    def getState(self, s, a):
        return np.random.choice(np.arange(self.Ns), p=self.P[s,a,:])
    
    #------------------------
    
    # cal & store true state values
    def calValue(self): 
        
        # for uniform random policy only
        
        P_pi = np.sum(self.P, axis=1) / self.Na
        R_pi = np.sum(self.R, axis=1) / self.Na
        
        self.V_star = np.linalg.inv(np.identity(self.Ns) - self.gamma * P_pi).dot(R_pi)
        return self.V_star
    
    #------------------------
    
    # MSE
    def MSE(self, V):
        
        self.calValue()
        if len(V) != len(self.V_star): return -1
        
        return np.mean((V - self.V_star)**2)

In [259]:
g = Garnet(10,5,3,0.9)