# Gridworld with Linearly-solvable Markov Decision Problems (LMDP)

In [194]:
%matplotlib widget
%load_ext autoreload
%autoreload 2
import sys
sys.path.insert(0, '..')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [195]:
import numpy as np
from lmdps import lmdps
from sklearn.preprocessing import MinMaxScaler
from lmdps.hierarchical_gridworld import create_gridworld_lmdp
from lmdps import plotting
import matplotlib.pyplot as plt
from scipy.linalg import eig
from scipy.special import kl_div

Standard formalism of MDPs.
- A **finite** set of states $\mathcal S$ 
- A set of admissible controls at state $i$ $U(i)$ (action set $\mathcal A$)
- A cost for being at state $i \in S$ and taking $u\in U(i)$, $c(i, u)$ (or reward $r(i, u)$)
- A stochastic matrix $P(u)$ whose element $p_{ij}(u)$ is a transition probability from state $i$ to $j$ under control $u$.

#### <span style='color: red;'> His setting </span>
- A **non-empty** subset $\mathcal A \subseteq S$ of absorbing states with  $p_{ij}(u) = \delta_i^j$ and $c(i,u)=0 \ \forall i \in A$ 
- Undiscounted, infinite-horizon optimal value function is $v(i) = \min_{u \in U(i)}\{c(i,u) + \sum_j p_{ij}(u) v(j)\}$

# <span style='color: #8d44ab;'>Original formulation (Todorov, 2006)</span>

### Linear MDPs
- We assume the existence of passive, uncontrolled dynamics expressed by the transition probability matrix $\overline{P}$.
- Now the control is a real value vector which modifies the transition probabilities $p_{ij}(\mathbf{u}) = \overline{p}_{ij} e^{u_j}$
- It follows that $P(0) = \overline{P}$ and $\overline{p}_{ij} \rightarrow p_{ij}(\mathbf u) = 0$
- $P(\mathbf u)$ must have row-sums equal to 1.

Thus, we derive the constraints:
\begin{align}
U(i) &= {\mathbf{u} \in \mathbb R^{|\mathcal{S}|}; \sum_j \overline{p}_{ij}e^{u_j} = 1; \overline{p}_ij = 0 \rightarrow u_j = 0} \\
\end{align}


Real-valued controls make it possible to derfine a natural control cost which measures how much proposed controls diverge from the uncontrolled dynamics. For such measurement, the KL divergence is used.
\begin{align}
r(i, \mathbf{u}) &= KL(\mathbf{p}_i(\mathbf{u})) \mathinner{\text{||}} \mathbf{p}_i (0)) = \sum_{j:\overline{p}_{ij}} \log{\frac{p_{ij}(\mathbf{u})}{p_{ij}(0)}} \\
\end{align}

Subtituting and replacing, the control costs become:
\begin{align}
r(i, \mathbf{u}) &= \sum_j p_{ij}(\mathbf{u}) u_j \\
\end{align}
This has an interesting interpretation <span style='color: red;'> although the MDP likes to behave according to $\overline P$ it can be paid to act according to $\overline{P}(\mathbf{u})$ </span>. We can also add an arbitrary state cost $q(i) \geq 0$ such that:
\begin{align}
c(i, \mathbf{u}) &= q(i) + r(i, \mathbf{u}) \\
\end{align}

# <span style='color: #0066ff;'> Experiments</span>

#### Setting
- Blue Square (0,0) 
- Red Triangle (0,5). 
- Red Square (5,0).
- Blue Triangle (5,5)

<img height="400" width="400" src='pictures/composed_setting.png'/>

# <span style='color: gray;'> Experiments</span>

In [141]:
sol = {}

N_DIM = 10

TS = {'RT': 0 , 'BS':9, 'BC': 23, 'RC':67, 'RS':90, 'BT': 99}

P_ = None

scl = MinMaxScaler()

for task in TS:
    
    goals = [TS[task]]
    non_goals = list(set(list(TS.values())) - set(goals))
    
    P, q, goal, non_goal = create_gridworld_lmdp(N_DIM, goals, non_goals, goal=0, non_goal=-1, self_move=True)
    
    G = np.diagflat(q)    
    
    z = lmdps.power_method(P, G)
    
    Z = z.reshape(N_DIM,N_DIM)
        
    sol[task+'_T'] = {'q': q, 'Z': Z}    

## BLUE

In [142]:
BLUE = (sol['BT_T']['Z'] + sol['BS_T']['Z'] + sol['BC_T']['Z'])/3
ss = 'Blue Triangle'
fig = plotting.plot_Z_function(ss, N_DIM, BLUE, BLUE, True)
fig.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## CIRCLE

In [143]:
CIRCLE = (sol['BC_T']['Z'] + sol['RC_T']['Z'])/2
ss = 'Triangle: BT (or) RT'
fig = plotting.plot_Z_function(ss, N_DIM, CIRCLE, CIRCLE, True)
fig.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## COMBINATION

In [144]:
C = np.array([BLUE, CIRCLE]).max(axis=0)
ss = 'min(BLUE, TRIANG)'
fig = plotting.plot_Z_function(ss, N_DIM, C, -np.log(C), True)
fig.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## TARGET

In [145]:
T = sol['BC_T']['Z'].copy()
ss = 'min(BLUE, TRIANG)'
fig = plotting.plot_Z_function(ss, N_DIM, T, np.log(T), True)
fig.show()

  This is separate from the ipykernel package so we can avoid doing imports until


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [146]:
T

matrix([[1.        , 0.72621709, 0.55312344, 0.4599038 , 0.48484315,
         0.56079965, 0.6524441 , 0.74935101, 0.8568017 , 1.        ],
        [0.7542509 , 0.62552782, 0.47324943, 0.34174481, 0.43382601,
         0.54511168, 0.64718166, 0.73880721, 0.82105411, 0.88714758],
        [0.63722489, 0.54839385, 0.37260167, 0.        , 0.36360439,
         0.53863941, 0.65236364, 0.73764207, 0.80145993, 0.84038864],
        [0.60902991, 0.55822102, 0.46876339, 0.37481816, 0.48195214,
         0.59347794, 0.68599142, 0.7579375 , 0.80675492, 0.8325584 ],
        [0.63164382, 0.60669693, 0.56941271, 0.54855712, 0.59590806,
         0.66732878, 0.7401866 , 0.80136157, 0.83506384, 0.85053165],
        [0.67920462, 0.66751017, 0.65363341, 0.65408955, 0.6857942 ,
         0.73974254, 0.80606462, 0.87225836, 0.88160724, 0.88397269],
        [0.73845987, 0.73050573, 0.7235212 , 0.72837346, 0.75343665,
         0.79978256, 0.87207098, 1.        , 0.93513405, 0.9197792 ],
        [0.80566926, 0.7925

In [253]:
p0 = lmdps.get_policy(P, T)
p1 = lmdps.get_policy(P, C)
div = kl_div(p0, p1)


a = 'SRC State'
d = 'KL-Div'

for i in range(div.shape[0]):
    print(i, div[i, :].sum())

0 0.0
1 0.012174803215771662
2 0.00935304946289972
3 0.009798306247917582
4 0.004837461012002581
5 0.006131920698319665
6 0.003951076321064229
7 0.0012383238436225463
8 0.0002715345615473541
9 0.0
10 0.008785380804763848
11 0.007890685232607803
12 0.01614794800764452
13 0.2002380611303357
14 0.013411846693670931
15 0.0074400118516830815
16 0.004422196821179031
17 0.0015259622250969052
18 0.0003117254693300875
19 0.0002495055489817599
20 0.0035309765092728096
21 0.010136692997809421
22 0.194786034509661
23 0.0
24 0.20533501474423418
25 0.013797779545038136
26 0.004466875169937157
27 0.002294951008871504
28 0.0007399294146221702
29 9.681601361166692e-05
30 0.000745998347516541
31 0.0026529641453362507
32 0.011256678697585953
33 0.19790679170274225
34 0.016246688756675398
35 0.006026237535474405
36 0.0027034015387787103
37 0.0013296471096543583
38 0.0006824297322308714
39 0.00022643293553542554
40 0.0006528730245726477
41 0.0013362507689293757
42 0.003841108067043908
43 0.0106718456749642