This week, we will go through how the ELBO concept is applied in the model-based RL.

# 0. Jensen's inequality

One of the important prerequisite in ELBO is the famous Jensen's inequality. This is simple but powerful concept, derived from the concept of convex function.

Given a convex function $f(x)$ and for $t \in [0,1]$,

$$
f(tx_1 + (1-t)x_2) \le tf(x_1) + (1-t)f(x_2)
$$

This quite simple notion can be generalized to finite form. For a convex function $\phi$ and positive weights $a_i: \sum a_i = 1$,

$$
\phi \left(\frac{\sum a_i x_i}{\sum a_i} \right) \le \frac{\sum a_i \phi(x_i)}{\sum a_i}
$$

The finite form can be easily achieved by induction. From here, we can go one step further: infinite, measure-theoretic form.

For a probability space $(\Omega, A, \mu)$, let $f: \Omega \to \mathbb{R}$. We also have $\phi: \mathbb{R} \to \mathbb{R}$ as a convex function. Then,

$$
\phi \left(\int_{\Omega}f d\mu \right) \le \int_{\Omega} \phi \circ f d\mu
$$

which can be specifically applied to a probability density function.

Let's say we have a pdf $f(x)$ in that $\int_{-\infty}^{\infty}f(x) dx = 1$ and $f(x) \ge 0$. Given a convex function $\phi$ and any real-valued measurable function $g(x)$,

$$
\phi \left(\int_{-\infty}^{\infty}g(x)f(x) dx \right) \le \int_{-\infty}^{\infty}\phi(g(x))f(x) dx
$$

Especially, this leads to the expectation as we set $g(x) = x$:

$$
\phi \left(\int_{-\infty}^{\infty}xf(x) dx \right) \le \int_{-\infty}^{\infty}\phi(x)f(x) dx
$$

## 1. One-step prediction ELBO derivation (from the Planet paper)

As we are now equipped with the Jensen's Inequality, let's start from a rabbit hole: let's think about what we have as an I/O in POMDP situation.

As RL 101, agent observes the world, and take an action. We may say, if we can simulate how the world changes given a series of actions, then we have a nice estimation of the world inside us - the world model.

Now, in RL environment, things happen sequentially: usually you observe first, than you choose action. This potentially gives you a sequence of $(o_1, a_1, o_2, a_2, ...)$.

This is good - but this POV can bother you a lot with PlaNet's time notation. Planet's time notation rather takes generative POV: Given $a_{1:T}$, predict $s_{1:T}$ and $o_{1:T}$. This is just an one-step time index shift.

Okay. So, at the beginning, all we have is $p(o_{1:T} | a_{1:T})$. However, how can we optimize this generation model?

Here comes the nice part of the model-based RL. We can think in this informally:

1. It would be nice to model an underlying generator / dynamics of the world. This can be simply a form of $s_T = f(s_{T-1}, a_{T-1})$.
2. How do latent states generate observation is also of interest. This can be done by a decoder: $p(o_T | s_T)$.
3. How can latent states be estimated is also very important. This can be done by an encoder: $q(s_T | o_{\le T}, a_{< T})$.
4. Encoder is an interesting one: in PlaNet, we have $q(s_{1:T} | o_{1:T}, a_{1:T}) = \prod^T_t q(s_t | o_{\le t}, a_{< t})$. Recursively:

$$
\begin{align}
q(s_{1:T} | o_{1:T}, a_{1:T}) &= q(s_T | s_{1:T-1}, o_{1:T}, a_{1:T}) q(s_{1:T-1} | o_{1:T}, a_{1:T}) \\

&= q(s_T | s_{1:T-1}, o_T, a_{T-1}) q(s_{1:T-1} | o_{1:T-1}, a_{1:T-1}) \quad (\textnormal{The first form is indeed our approx. posterior})\\

&= q(s_T | o_{1:T}, a_{1:T}) q(s_{1:T-1} | o_{1:T-1}, a_{1:T-1}) \\
\end{align}
$$

Then, given action and observation, we can derive our training target.

$$
\begin{align}
\ln p(o_{1:T} | a_{1:T}) &= \ln \int p(o_{1:T}, s_{1:T}| a_{1:T}) ds_{1:T} \quad (\textnormal{marginalization})\\

&= \ln \int p(o_{1:T} | a_{1:T}, s_{1:T}) p(s_{1:T} | a_{1:T}) ds_{1:T} \\

&= \ln \int p(o_{1:T} | s_{1:T}) p(s_{1:T} | a_{1:T}) ds_{1:T} \quad (\textnormal{assuming} \quad o \perp a \enspace | s) \\

&= \ln \mathbb{E}_{p(s_{1:T} | a_{1:T})}{\prod^T_{t=1} p(o_t | s_t)} \quad (\textnormal{Decoder}) \\

&= \ln \mathbb{E}_{q(s_{1:T} | o_{1:T}, a_{1:T})}{\prod^T_{t=1} p(o_t | s_t)p(s_t | s_{t-1}, a_{t-1}) / q(s_{1:T} | o_{1:T}, a_{1:T})} \quad (\textnormal{Dynamics}) \\

&= \ln \mathbb{E}_{q(s_{1:T} | o_{1:T}, a_{1:T})}{\prod^T_{t=1} p(o_t | s_t)p(s_t | s_{t-1}, a_{t-1}) / q(s_t | o_{\le t}, a_{<t})} \quad (\textnormal{Encoder}) \\

&\ge \mathbb{E}_{q(s_{1:T} | o_{1:T}, a_{1:T})} \left[ \sum^T_{t=1} \ln p(o_t | s_t)
- \ln \frac{ q(s_t | o_{\le t}, a_{<t})}{p(s_t | s_{t-1}, a_{t-1})}\right] \quad (\textnormal {Jensen's Inequality}) \\

&= \sum^T_{t=1} \left( \mathbb{E}_{q(s_t | o_{\le t}, a_{<t})} \ln p(o_t | s_t)
- \mathbb{E}_{q(s_{t-1} | o_{\le t-1}, a_{<t-1})} q(s_t | o_{\le t}, a_{<t}) \ln \frac{q(s_t | o_{\le t}, a_{<t})}{p(s_t | s_{t-1}, a_{t-1})}\right)  \quad (\textnormal{Another marginalization}) \\

&= \sum^T_{t=1} \left( \mathbb{E}_{q(s_t | o_{\le t}, a_{<t})} \ln p(o_t | s_t)
- \mathbb{E}_{q(s_{t-1} | o_{\le t-1}, a_{<t-1})} KL[q(s_t | o_{\le t}, a_{<t}) || p(s_t | s_{t-1}, a_{t-1})]\right)  \quad (\textnormal{KL divergence}) \\

&= \sum^T_{t=1} \left( \textnormal{Reconstruction loss} + \textnormal{One-step prediction loss} \right) \quad (\textnormal{ELBO})

\end{align}
$$

1. The assumption of $o \perp a \enspace | s$ could seem a bit arbitrary - but $o$ is by definition derived from $s$, so if we know $s$, then $a$ provides no further information about $o$.
2. As the sequence goes by ($o_1, a_1, o_2, a_2, ...$), why do we condition $o_1$ given actions? 

$$
p(o_{1:T} | a_{1:T}) = p(o_{2:T} | a_{1:T}, o_1) p(o_1 | a_{1:T}) = p(o_{2:T} | a_{1:T}, o_1) p(o_1)
$$