In [5]:
import itertools
import numpy as np 
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import functools 
from itertools import product
import sys
sys.path.append('../Environments/')
from ColoredGridWorld import SetUpInferenceSpace as setUp
from ColoredGridWorld.MDP import MDP
from ColoredGridWorld import OBMDP
from ColoredGridWorld import visualizations
sys.path.append('../Algorithms/')
from ActionInterpretation import ActionInterpretation
from ValueIteration import ValueIteration
import pprint

In [None]:
def getJointTrajectory(jointState, policyTable, goal, jointTransitionFn):
    traj = [jointState[0]]
    while(jointState[0]!=goal):
        action = max(policyTable[jointState], key=policyTable[jointState].get)
        jointState = max(jointTransitionFn(jointState, action), key = jointTransitionFn(jointState, action).get)
        traj.append(jointState[0])
    return traj
        
       
def viewPolicyStructure(d, levels, indent=0):
    for key, value in d.items():
        if (indent==0):
            beliefDict = key[1]()
            strkey = "("
            for env, prob in beliefDict.items():
                strkey = strkey + str(env[0]()) + ", " + str(env[1]) + ") :" + str(prob) + ", "
            strkey = strkey[:len(strkey)-2] + strkey[len(strkey):]
            strkey = strkey + ")"
            print('\t' * indent + str(levels[indent]) + ": "+ str(key[0]) + strkey)
        else: 
            print('\t' * indent + str(levels[indent]) + ": "+ str(key))
        if isinstance(value, dict):
            viewPolicyStructure(value, levels, indent+1)
        else:
            print('\t' * (indent+1) + str(levels[indent+1])+ ": " + str(value))

def main():
    dimensions = (5,6)
    goals = [(5,2)]
    goalState = (5,2)
    actions = {(-1,0),(0,1),(0,-1),(1,0)}
    goalNameDictionary = {(5,2):'goal'}
    colourReward = {'white': 0, 'orange': -2, 'purple': 0, 'blue':0, 'yellow':10}
    stateSpace = {(0,0): 'white',(0,1): 'white',(0,2): 'white',(0,3): 'white',(0,4): 'white', (1,0): 'blue',(1,1): 'orange', (1,2):'orange',(1,3):'orange',(1,4):'orange', (2,0): 'blue',(2,1):'purple', (2,2):'purple', (2,3):'purple', (2,4):'orange', (3,0): 'blue',(3,1):'purple',(3,2): 'blue',(3,3):'purple',(3,4):'orange', (4,0): 'blue', (4,1): 'blue', (4,2): 'blue',(4,3):'purple', (4,4):'orange', (5,0):'white',(5,1):'white', (5,2):'yellow', (5,3):'white', (5,4):'white'}
    getMDP = MDP(dimensions, stateSpace, colourReward)
    objectRewardFn, objectTransitionFn = getMDP()
    
    #set up for value-iteration and inference
    convergenceTolerance = 10e-4
    gamma = 0.95
    alpha = 20
    eps = 0.05
    hyperparameters = (convergenceTolerance, gamma, alpha, eps)
    variableColours = ['orange', 'purple', 'blue']
    variableReward = [0, -2]
    constantRewardDict = {'white': 0, 'yellow': 10}
    
    
    utilitySpace = setUp.buildUtilitySpace(variableColours, variableReward, constantRewardDict)
    transitionSpace = [True]
    worlds = setUp.buildWorldSpace(utilitySpace, transitionSpace)
    envSpace = [(world, goal) for world, goal in product(worlds, goals)]
    envMDPsAndPolicies = setUp.buildEnvPolicySpace(dimensions, stateSpace, actions, envSpace, hyperparameters)
    actionInterpretation = ActionInterpretation(envMDPsAndPolicies)
    
    #set up for OBMDP and its value iteration 
    beta = 10
    beliefGamma = 0.95
    beliefAlpha = 20
    beliefEps = 0.05
    bins = [0,0.25,0.5,0.75,1]
    beliefSpacePossible = [{key:value for key, value in zip(envSpace, permutations)} for permutations in product(bins, repeat = len(envSpace))]
    discreteBeliefSpace = [beliefDict for beliefDict in beliefSpacePossible if (sum(value for value in beliefDict.values())==1)]
    hashableDiscreteBeliefSpace = [setUp.HashableBelief(beliefDict) for beliefDict in discreteBeliefSpace]
    beliefUtilityFn = OBMDP.getBeliefUtility()
    literalObserver = OBMDP.LiteralObserver(actionInterpretation)
    getNextBelief = literalObserver(discreteBeliefSpace, True)
    jointStateSpace = list(product(stateSpace.keys(), hashableDiscreteBeliefSpace))
    getOBMDP = OBMDP.OBMDP(jointStateSpace, (setUp.HashableWorld(colourReward, True), goals[0]), True, beta)
    jointRewardFn, jointTransitionFn = getOBMDP(objectTransitionFn, objectRewardFn, getNextBelief, beliefUtilityFn)
    
    #implementing object-level MDP
    valueTable = {key: 0 for key in stateSpace.keys()}
    performValueIteration = ValueIteration(actions, objectTransitionFn, objectRewardFn, valueTable, [goalState], convergenceTolerance, gamma, alpha, eps, True)
    optimalValues, policyTable = performValueIteration()
    trapStates = [s for s in stateSpace if(colourReward[stateSpace[s]]<0)]
    m,n = dimensions 
    visualizations.visualizeValueTable(m, n, goalState, trapStates, optimalValues)
    visualizations.visualizePolicy(stateSpace, policyTable, goalState, otherGoals=[], trapStates=[], arrowScale = .3)
    
    #implementing OBMDP
    valueTable = {key: 0 for key in jointStateSpace}
    jointGoalStates = [jointState for jointState in jointStateSpace if jointState[0] == goalState]
    performValueIteration = ValueIteration(actions, jointTransitionFn, jointRewardFn, valueTable, jointGoalStates, convergenceTolerance, beliefGamma, beliefAlpha, beliefEps, True)
    optimalValues, policyTable = performValueIteration()
    initialState1 = ( (0,2), setUp.HashableBelief({(setUp.HashableWorld({'orange':0, 'white':0, 'blue':0, 'purple':0, 'yellow':10}, True),(5,2)): 0.25, 
                                                   (setUp.HashableWorld({'orange':0, 'white':0, 'blue':0, 'purple':-2, 'yellow':10}, True),(5,2)): 0, 
                                                   (setUp.HashableWorld({'orange':0, 'white':0, 'blue':-2, 'purple':0, 'yellow':10}, True),(5,2)): 0.25, 
                                                   (setUp.HashableWorld({'orange':0, 'white':0, 'blue':-2, 'purple':-2, 'yellow':10}, True),(5,2)): 0,
                                                   (setUp.HashableWorld({'orange':-2, 'white':0, 'blue':0, 'purple':0, 'yellow':10}, True),(5,2)): 0.25,
                                                   (setUp.HashableWorld({'orange':-2, 'white':0, 'blue':0, 'purple':-2, 'yellow':10}, True),(5,2)): 0,
                                                   (setUp.HashableWorld({'orange':-2, 'white':0, 'blue':-2, 'purple':0, 'yellow':10}, True),(5,2)): 0.25,
                                                   (setUp.HashableWorld({'orange':-2, 'white':0, 'blue':-2, 'purple':-2, 'yellow':10}, True),(5,2)): 0,
                                                   }) )
    trajectory1 = getJointTrajectory(initialState1, policyTable, goalState, jointTransitionFn)
    #viewPolicyStructure(policyTable, ["state", "action", "probability"])
    visualizations.visualizeEnvironmentByState(stateSpace, [goalState], [], trajectory1, goalNameDictionary, 2.5)
    

if __name__ == "__main__":
    main()
    