In [1]:
import numpy as np
import pandas as pd 
import torch
from ihrl.taxicab import TaxiMDP, Root, taxi_state


In [2]:
layout_str = """
A--B
----
----
C--D 
"""
mdp = TaxiMDP(layout_str)
state_length = len(mdp.list_all_possible_states())
test_root = Root(mdp)
init_state = taxi_state(0,0,mdp.width-1,0)
init_state_index = mdp.list_all_possible_states().index(init_state)
action_length = mdp.action_length

In [3]:
#an object just for holding the true transition and reward functions
#used for value iteration since the cache cant take in matrices
tr = mdp.get_transition_reward_obj()

In [4]:
#get the state and action indices of the total trajectories 
state_indices, action_indices = mdp.get_state_action_indices(total_trajectories=1000,max_t_length=110,transition_reward_obj=tr,deterministic=False)

In [5]:
transition_matrix_true = torch.tensor(mdp.get_transition_matrix())
reward_weights_true = torch.tensor(mdp.get_reward_matrix(),dtype=torch.float32)
state_indices = torch.tensor(state_indices)
action_indices = torch.tensor(action_indices)

In [6]:
def soft_vi(
    transition_matrix : torch.Tensor, #Input to be learened
    reward_weights : torch.Tensor,
    discount_rate : float,
    entropy_bonus : float,
    precision : float = 1e-2,
) -> torch.Tensor:
    assert 0 <= discount_rate < 1
    state_values = torch.zeros_like(reward_weights[:,0])
    action_value = torch.zeros_like(transition_matrix[:, :, 0])
    for i in range(1): 
        next_state_value = torch.einsum("san,n->sa", transition_matrix, state_values) #T(s'| s,a)vb(s')
        #compute val of each [(p(s'| a1,s) * vb(s')) , ... , (p(s'| a6,s) * vb(s'))]

        action_value = reward_weights + discount_rate * next_state_value #R(s,a) + gamma * T(s'| s,a)vb(s')
        #[R(s,a1) + V(s'), ... , R(s,a6) * V(s')]
        new_state_values = entropy_bonus * torch.logsumexp((1/entropy_bonus)*action_value, dim=1) #log(sum(a)exp(1/b * Q(s,a)))

        state_values = new_state_values
        if torch.max(abs(new_state_values - state_values)) <  precision:
            break
    policy = torch.softmax(action_value*(1/entropy_bonus), dim=1)
    return policy, state_values

def maximum_likelihood_irl(
    mdp : TaxiMDP,
    discount_rate : float,
    state_indices : torch.Tensor, 
    action_indices :   torch.Tensor,
    iterations: int,
    entropy_bonus: float
):
    state_length = len(mdp.list_all_possible_states())
    action_length = len(mdp.actions(init_state))

    transition_weights = torch.rand(state_length, action_length, state_length)
    transition_weights.requires_grad_(True)

    #rewards should just be a reward for each state
    reward_tensor = -torch.rand(state_length,action_length)
    reward_tensor.requires_grad_(True)

    discount_rate = discount_rate
    optimizer = torch.optim.Adam([reward_tensor,transition_weights], lr=1e-3)#, weight_decay=1e-4)

    for i in range(iterations):
        #reward_weights = reward_weights.clamp(min = -10, max = 15) #same min,max as actual r to make comparing easier
        #transition is always bw 0 and 1
        #transition_matrix = transition_matrix_true# uncomment and change other parts to use true transition matrix
        transition_matrix = torch.softmax(transition_weights,dim=-1)
        reward_weights = reward_tensor
        policy_matrix, _ = soft_vi(
            transition_matrix,
            reward_weights=reward_weights,
            discount_rate=discount_rate,
            entropy_bonus=entropy_bonus
        )

        policy_loss = -torch.log(policy_matrix[state_indices, action_indices]).sum() / (len(state_indices))
        #print('policy loss', policy_loss)
        transition_loss = -torch.log((transition_matrix[state_indices[:-1],action_indices[:-1],state_indices[1:]])+1e-6).sum() / len(state_indices)
        #print('t loss',transition_loss)
        loss = policy_loss + transition_loss

        optimizer.zero_grad()
        loss.backward()
        if torch.isnan(loss):
            break
        optimizer.step()

        if i % 500 == 0:
            print(f"Loss: {loss.item()} at iteration {i}")
    print(f"Final Loss: {loss.item()}")
    return reward_weights.detach(), transition_matrix.detach()

In [7]:
reward_weights_est,transition_matrix_est = maximum_likelihood_irl(mdp,0.9,state_indices,action_indices,7000,0.1)
policy_matrix_est, state_values_est = soft_vi(transition_matrix_est,reward_weights_est,0.9,0.1)


Loss: 8.621004104614258 at iteration 0
Loss: 4.308089256286621 at iteration 500
Loss: 3.34889817237854 at iteration 1000
Loss: 2.5795114040374756 at iteration 1500
Loss: 1.9620102643966675 at iteration 2000
Loss: 1.525149941444397 at iteration 2500
Loss: 1.2516618967056274 at iteration 3000
Loss: 1.090242862701416 at iteration 3500
Loss: 0.994340181350708 at iteration 4000
Loss: 0.9351127743721008 at iteration 4500
Loss: 0.896856427192688 at iteration 5000
Loss: 0.8711298704147339 at iteration 5500
Loss: 0.8532452583312988 at iteration 6000
Loss: 0.8404781818389893 at iteration 6500
Final Loss: 0.8311876058578491


The policy matrix appears to overfit. It shouldn't be able to perfectly predict actions, or be over 90% accurate because of the randomness in trajectories when multiple best actions are available. 

In [8]:
policy_matrix_est[state_indices,action_indices].sum() / len(state_indices)

tensor(0.9143)

In [9]:
matches = 0
for i in range(len(mdp.list_all_possible_states())):
    true_max_action = mdp.value_iteration(i,110,tr)[1]
    if true_max_action == None:
        print(mdp.list_all_possible_states()[i]) #these are the terminal states that will never have an action follow them
        continue
    true_max_action_index = mdp.actions().index(true_max_action)
    if torch.argmax(policy_matrix_est[i]) == true_max_action_index:
        matches += 1
    else:
        print('mismatch')
        print(np.round(policy_matrix_est[i].detach().numpy(),2))
        print(mdp.list_all_possible_states()[i])
        print(f"True: {true_max_action_index} Est: {torch.argmax(policy_matrix_est[i])}")
print(f"Accuracy: {matches/len(mdp.list_all_possible_states())}")


mismatch
[0. 1. 0. 0. 0. 0.]
TaxiCabState(taxi=Taxi(location=Location(x=3, y=1), passenger=Passenger(location=None, destination=Location(x=0, y=0))), waiting_passengers=())
True: 3 Est: 1
mismatch
[0.   0.31 0.   0.69 0.   0.  ]
TaxiCabState(taxi=Taxi(location=Location(x=3, y=3), passenger=Passenger(location=None, destination=Location(x=0, y=0))), waiting_passengers=())
True: 1 Est: 3
mismatch
[1. 0. 0. 0. 0. 0.]
TaxiCabState(taxi=Taxi(location=Location(x=1, y=3), passenger=None), waiting_passengers=(Passenger(location=Location(x=3, y=0), destination=None),))
True: 3 Est: 0
mismatch
[0.   0.58 0.42 0.   0.   0.  ]
TaxiCabState(taxi=Taxi(location=Location(x=2, y=2), passenger=None), waiting_passengers=(Passenger(location=Location(x=0, y=3), destination=None),))
True: 2 Est: 1
mismatch
[0.   0.31 0.69 0.   0.   0.  ]
TaxiCabState(taxi=Taxi(location=Location(x=2, y=1), passenger=Passenger(location=None, destination=Location(x=0, y=3))), waiting_passengers=())
True: 1 Est: 2
mismatch
[0.61

In [10]:
action_val_est = (reward_weights_est + torch.einsum("san,n->sa", transition_matrix_true, state_values_est))
sv_est = torch.logsumexp((1/0.1)*action_val_est,dim=1)

In [11]:
true_state_vals = []
for i in range(len(mdp.list_all_possible_states())):
    true_state_vals.append(mdp.value_iteration(i,100,tr)[0])
true_state_vals = torch.tensor(true_state_vals)

In [None]:
#there doesn't seem to be a correspondence between the state values

In [12]:
pd.Series(true_state_vals).sort_values() 

110    -2.741294
102    -2.741294
44     -2.741294
128    -2.741294
81     -1.934771
         ...    
14     11.272727
26     13.636364
126    13.636364
60     13.636364
68     13.636364
Length: 132, dtype: float64

In [13]:
pd.Series(sv_est).sort_values()

1     -3.555149
76    -2.688258
10    -2.605184
118   -2.545925
21    -2.523063
         ...   
77     6.395243
54     6.564583
128    7.248069
117    7.346713
113    7.571374
Length: 132, dtype: float32

<b>
<ol>
<li>Reward argmax matches</li>
<li>[Transition matrix * Reward Weight] argmax matches</li>
<li>Transition matrix argmax matches</li>
</ol>
</b>

In [14]:
matches = 0
for i in range(len(reward_weights_est)):
    if reward_weights_est[i].argmax() == reward_weights_true[i].argmax():
        matches += 1
print(f"Matches: {matches}/{len(reward_weights_est)}")

Matches: 39/132


In [15]:
#transition * reward weights est 
#state 0, action 0
matches = 0
for i in range(state_length): 
    for a in range(action_length):
        if (transition_matrix_est[i][a] * reward_weights_est[i][a]).argmax() == (transition_matrix_true[i][a] * reward_weights_true[i][a]).argmax():
            matches += 1
print(f"Matches: {matches}/{state_length*action_length}")

Matches: 3/792


In [16]:
#transition_matrix
matches = 0
for i in range(state_length): 
    for a in range(action_length):
        if (transition_matrix_est[i][a]).argmax() == (transition_matrix_true[i][a]).argmax():
            matches += 1
print(f"Matches: {matches}/{state_length*action_length}")

Matches: 161/792
