### Sketch of hierarchical VI

In [2]:
from typing import Any, Tuple, Union, TypeVar
from tasks import get, put, root
from rllib.taxicab import TaxiCab, TaxiCabRenderer, Location, Passenger, Taxi  
import numpy as np
layout_str = """
A--B
----
C--D 
"""
taxicab = TaxiCab(layout_str)

east,west, north, south, pickup, putdown = taxicab.actions()


Action = TypeVar("Action")
State = TypeVar("State")
StateDist = list[Tuple[State, float]]
StateRewardDist = list[Tuple[Tuple[State, float], float]]

class MDP:
    state_list : list[State]
    def actions(self, state: State) -> list[Action]: pass
    def next_state_reward_dist(self, state: State, action: Action) -> StateRewardDist: pass #1
    def is_terminal(self, state: State) -> bool: pass

class SubTask(MDP):
    mdp: MDP
    #def child_subtasks(self, state: State) -> list[Union["SubTask", Action]]: pass
    def continuation_prob(self, state: State) -> float: pass #2, just for whole mdp?
    def exit_reward(self, state: State) -> float: pass #3 eq 7
    def exit_distribution(self, state: State) -> StateDist: pass #eq 8 #4

#import taxicab state
from rllib.taxicab import TaxiCabState
def taxi_state(taxi_x,taxi_y,waiting_passenger_x,waiting_passenger_y,passenger_x = None, passenger_y = None):
    taxi = Taxi(Location(taxi_x,taxi_y),None)
    waiting_passenger = Passenger(Location(waiting_passenger_x,waiting_passenger_y),None)
    return TaxiCabState(taxi, (waiting_passenger,))
def taxi_put_state(taxi_x,taxi_y,dest_x,dest_y):
    passenger = Passenger(None,Location(dest_x,dest_y))
    taxi = Taxi(Location(taxi_x,taxi_y),passenger)
    return TaxiCabState(taxi, ())


#all the code below will be polished up later

def pickup_transition(state,next_state):
    if state.taxi.passenger is None and state.waiting_passengers[0].location == state.taxi.location:
        if next_state.taxi.location == state.taxi.location and next_state.taxi.passenger is not None:
            return True
    return False
def putdown_transition(state,next_state):
    if state.taxi.passenger is not None and state.taxi.passenger.destination == state.taxi.location:
        if next_state.taxi.location == state.taxi.passenger.destination and next_state.taxi.passenger is None and len(next_state.waiting_passengers) == 0:
            return True
    return False
class TaxiMDP(MDP):
    def __init__(self,taxicab):
        state_list = taxicab.list_all_possible_states()
        self.tcab = taxicab
    def actions(self, state):
        return self.tcab.actions
    def next_state_reward_dist(self,state,action): #state,action,probability
    #only for prim actions
        state_reward_dist = []
        possibilities = 0
        if action == pickup:
            #this method rewards failed pickups if there's already a passenger
            for next_state in taxicab.list_all_possible_states():
                reward = 15 if taxicab.next_state_sample(state,action).taxi.passenger is not None else -10
                #if it's setup so that waiting passengers already have destinations, then this is 1 if true | will setup later
                probability = 1 if pickup_transition(state,next_state) else 0
                if probability > 0:
                    possibilities += 1
                    state_reward_dist.append(((next_state,reward),probability))
        elif action == putdown:
            for next_state in taxicab.list_all_possible_states():
                reward = 15 if taxicab.next_state_sample(state,action).taxi.passenger is None else -10
                #p is 1 since if successful since dropoff results in no new waiting passengers
                probability = 1 if putdown_transition(state,next_state) else 0
                if probability > 0:
                    possibilities += 1
                    state_reward_dist.append(((next_state,reward),probability))
        else:
            for next_state in taxicab.list_all_possible_states():
                reward = -1 #always -1 for nav
                #this is for primitive actions, not nav itself
                probability = 1 if next_state == taxicab.next_state_sample(state,action) else 0
                if probability > 0:
                    possibilities += 1
                    #simpler computaion
                    state_reward_dist.append(((next_state,reward),probability))
        if possibilities > 0:
            for i in range(len(state_reward_dist)):
                state_reward_dist[i] = (state_reward_dist[i][0],state_reward_dist[i][1]/possibilities)
        return state_reward_dist
    def is_terminal(self, state):
        if state.taxi.passenger is None and len(state.waiting_passengers) == 0:
            return True
        return False
def get_terminal(s):
    if s.taxi.passenger is not None:
        return True
    return False
def put_terminal(s):
    if s.taxi.passenger is None:
        return True
    return False
root_terminal = put_terminal

class Root(SubTask,MDP):
    def __init__(self,MDP):
        self.mdp = MDP
        self.child_subtasks = [Get(self.mdp),Put(self.mdp)] #lowercase is the original
        self.continuation_prob = None
        self.exit_reward = None 
    def is_terminal(self,s):
        if s.taxi.passenger is None and len(s.waiting_passengers) == 0:
            return True
        return False 
    def exit_distribution(self, s: State):
        distribution = []
        possibilities = 0
        for state in taxicab.list_all_possible_states():
            if self.is_terminal(state) and len(s.waiting_passengers) == 0 and s.passenger is None and s.taxi.location == state.taxi.location: #
                distribution.append((state,1))
                possibilities += 1
            else:
                pass
        if possibilities > 0:
            for i in range(len(distribution)):
                distribution[i] = (distribution[i][0],distribution[i][1]/possibilities)
        return distribution
class Get(SubTask,MDP):
    def __init__(self,MDP):
        self.mdp = MDP
        self.child_subtasks = [Nav(self.mdp),pickup]
        self.continuation_prob = None
        self.exit_reward = None 
    def is_terminal(self,s):
        if s.taxi.passenger is not None:
            return True
        return False
    def exit_distribution(self, s: State):
        distribution = []
        possibilities = 0
        for state in taxicab.list_all_possible_states():
            if self.is_terminal(state)  and len(s.waiting_passengers) == 1:
                if s.waiting_passengers[0].location == state.taxi.location:
                    if len(state.waiting_passengers) == 0:
                        if state.taxi.passenger is not None: #
                            distribution.append((state,1))
                            possibilities += 1
            else:
                pass
        if possibilities > 0:
            for i in range(len(distribution)):
                distribution[i] = (distribution[i][0],distribution[i][1]/possibilities)
        return distribution
        
class Put(SubTask,MDP):
    def __init__(self,MDP):
        self.mdp = MDP
        self.child_subtasks = [Nav(self.mdp),putdown]
        self.continuation_prob = None
        self.exit_reward = None
    def is_terminal(self,s):
        if s.taxi.passenger is None and len(s.waiting_passengers) == 0:
            return True
        return False 
    def exit_distribution(self, s):
        distribution = []
        possibilities = 0
        for state in taxicab.list_all_possible_states():
            if self.is_terminal(state) and len(state.waiting_passengers) == 0 and state.taxi.passenger is None and s.taxi.location == state.taxi.location: #
                distribution.append((state,1))
                possibilities += 1
            else:
                pass
        if possibilities > 0:
            for i in range(len(distribution)):
                distribution[i] = (distribution[i][0],distribution[i][1]/possibilities)
        return distribution
class Nav(SubTask,MDP):
    def __init__(self,MDP):
        self.mdp = MDP
        self.child_subtasks = [east,west,north,south]

        self.continuation_prob = None
        self.exit_reward = None

    def is_terminal(self,s): #return true if at pass dest or waiting pass location 
        if s.taxi.passenger is not None:
            if s.taxi.location == s.taxi.passenger.destination:
                return True
            else:
                return False
        elif len(s.waiting_passengers) == 0:
            return False #hack for now
        elif s.taxi.location == s.waiting_passengers[0].location:
            return True
        return False 
    def exit_distribution(self, s: State) -> list[Tuple[Any, float]]:
        distribution = []
        possibilities = 0
        for state in taxicab.list_all_possible_states():
            if s.taxi.passenger is not None and state.taxi.passenger is not None:
                if s.taxi.passenger.destination != state.taxi.passenger.destination:
                    continue
            if self.is_terminal(state) and s.waiting_passengers == state.waiting_passengers and (s.taxi.passenger is None) == (state.taxi.passenger is None):
                #need to also not include if starting state dest != end state dest
                distribution.append((state,1))
                possibilities += 1
            else:
                pass
                #distribution.append((state,0))
        if possibilities > 0:
            for i in range(len(distribution)):
                distribution[i] = (distribution[i][0],distribution[i][1]/possibilities)
        return distribution
    def name(subtask: SubTask) -> str:
        if isinstance(subtask, Nav):
            return "Nav"
        elif isinstance(subtask, Get):
            return "Get"
        elif isinstance(subtask, Put):
            return "Put"
        elif isinstance(subtask, Root):
            return "Root"
        elif 'dropoff=True' in str(subtask):
            return "Dropoff"
        elif 'pickup=True' in str(subtask):
            return "Pickup"
        else:
            return "nav_action"

In [3]:
len(taxicab.list_all_possible_states())

100

In [52]:
value_cache = {}
seen_depths = [0]
seen_states = {}
#function scales at n^n time
#this will take too long for depths > 10 (as per testing)
#bc we know that it is optimal to never revisit a state, by baking in this bias the function is albe to run quickly with 
#correct results

#added depth input only for inspection/debugging
def value(subtask: SubTask, s: State,depth = 0):
    if subtask not in seen_states:
        seen_states[subtask] = []
    if s in seen_states[subtask]:
        #aformentioned bias
        return -5
    seen_states[subtask].append(s)
    if subtask.is_terminal(s) == True:
        return 0  # HACK: do we need to add pseudoreward???
    
 

    continue_prob = 1  # subtask.continuation_prob(s) #for now, same as just if terminal. bc first if will exit out this can equal 1 for now
    max_qval = float("-inf")
    max_action = None #just for debugging it
    for a in subtask.child_subtasks:
        # Get next-state/reward distribution for semi-MDP
        if a in taxicab.actions():  # if isinstance(a, Action):
            ns_r_prob = subtask.mdp.next_state_reward_dist(s, a)  # lookup for state
        else: #will polish later but this if/else is same functionality as designating each a as action class # elif isinstance(a, SubTask):
            ns_r_prob = []
            for ns, prob in a.exit_distribution(s):
                #to speed things up greatly, we don't enter recursion if prob is 0 since it wont affect result
                if prob == 0:
                    continue
                #dont want to continue seen states if new task
                r = value(a, s,depth) - 0  #0 is exit reward | a.exit_reward(ns) can implement this function later
                ns_r_prob.append(((ns, r),(prob)))
        # Calculate expected value of action Q(subtask, s, a)

        qval = 0

        c = 0 #count
        for ns_r, prob in ns_r_prob:
            if prob == 0: #need to add, or s == s
                continue
            
            ns, r = ns_r

            if (subtask,s) in value_cache:
                v = value_cache[(subtask,s)]
            else:
                v = value(subtask, ns,depth+1)
            #useful traceback to see how v_root(s) is calculated    
            if isinstance(subtask,Root):
                print('depth',depth)
                print(f'root v_{name(a)}(s_{c})' ,ns_r[1]*prob)
                #print('action', a)
                print(ns_r[0])
                print('v next', v)
                c+=1

            #pickup isn't being reached bc all its probs are 0
            qval += prob * (r + continue_prob * v)
            if qval > max_qval:
                max_action = a
            max_qval = max(max_qval, qval)


   # print(max_action)
    value_cache[(subtask,s)] = max_qval

    return max_qval


In [6]:
Get(TaxiMDP(taxicab)).exit_distribution(taxi_state(0,1,3,0))

[(TaxiCabState(taxi=Taxi(location=Location(x=3, y=0), passenger=Passenger(location=None, destination=Location(x=3, y=2))), waiting_passengers=()),
  0.25),
 (TaxiCabState(taxi=Taxi(location=Location(x=3, y=0), passenger=Passenger(location=None, destination=Location(x=0, y=2))), waiting_passengers=()),
  0.25),
 (TaxiCabState(taxi=Taxi(location=Location(x=3, y=0), passenger=Passenger(location=None, destination=Location(x=0, y=0))), waiting_passengers=()),
  0.25),
 (TaxiCabState(taxi=Taxi(location=Location(x=3, y=0), passenger=Passenger(location=None, destination=Location(x=3, y=0))), waiting_passengers=()),
  0.25)]

In [51]:
print(taxicab.width, taxicab.height)
#using a smaller taxi domain for testing
value(Root(TaxiMDP(taxicab)), taxi_state(0,0,0,2)) #taxi state of taxiloc at 0,0 and waiting pass at 0,2 (no pass in taxi -- that is taxi_put_state)
#for root should be 13+v_root(put,(3,0,0,0))
#For root, p(s') = 0.25 since 4 possible destinations picked up passenger can have
#r = v_Get(s_0) = 9
#0.25[9 +v_root(s')]
#v_root(s') = v_get(s') = 12
#v_root = 0.25[9+12] = 6.25

4 3
len ns_r_prob 4
[((TaxiCabState(taxi=Taxi(location=Location(x=0, y=2), passenger=Passenger(location=None, destination=Location(x=3, y=2))), waiting_passengers=()), 9.0), 0.25), ((TaxiCabState(taxi=Taxi(location=Location(x=0, y=2), passenger=Passenger(location=None, destination=Location(x=3, y=0))), waiting_passengers=()), -5), 0.25), ((TaxiCabState(taxi=Taxi(location=Location(x=0, y=2), passenger=Passenger(location=None, destination=Location(x=0, y=2))), waiting_passengers=()), -5), 0.25), ((TaxiCabState(taxi=Taxi(location=Location(x=0, y=2), passenger=Passenger(location=None, destination=Location(x=0, y=0))), waiting_passengers=()), -5), 0.25)]
0
len ns_r_prob 0
[]
1
len ns_r_prob 1
[((TaxiCabState(taxi=Taxi(location=Location(x=0, y=2), passenger=None), waiting_passengers=()), 12.0), 1.0)]
1
depth 1
root v_Put(s_0) 12.0
TaxiCabState(taxi=Taxi(location=Location(x=0, y=2), passenger=None), waiting_passengers=())
v next 0
depth 0
root v_Get(s_0) 2.25
TaxiCabState(taxi=Taxi(location=L

6.25