In [1]:
import jax

In [2]:
jax.devices()

[GpuDevice(id=0, process_index=0)]

## Imitation Learning

#### Formulation in Population

We are interested in offline imitation learning with an unknown return function $r(y) \in \{\theta^T\phi(y): \theta\in\Theta\}$ for $\Theta:=\{\theta: \geq 0, \sum_k\theta_k=1\}$ with discrete treatment $t\in\mathcal{T}$.
The expert policy depends on an unobserved confounder $U$ while the learned poicy does not, so that they are $\pi_\text{exp}(t|x, u)$ and $\pi(t|x)$. Under these assumptions, the imitation learning (i.e. multiobjective policy improvement) from offline dataset becomes:
$$\hat\pi 
= \arg\max_\pi\min_{\theta\in\Theta}[\min_{w\in \mathcal{W}}V_w^\pi - V^{\pi_\text{exp}}].$$
Here, the objective is defined as:
$$
\min_{w\in \mathcal{W}}V_w^\pi - V^{\pi_\text{exp}}
= \min_{w\in\mathcal{W}}
\mathbb{E}_X\mathbb{E}_U\left[
  \sum_t\int\mathrm{d}p(y|t, X, U)\pi_\text{exp}(t|X, U)w(t, X, U)\pi(t|X) \theta^T\phi(y) - \sum_t\int\mathrm{d}p(y|t, X, U)\pi_\text{exp}(t|X, U) \theta^T\phi(y) 
\right],
$$
where $\mathcal{W}$ is an uncertain set of the inverse treatment probability $\frac{1}{p(t|X, U)}$ satisfying $\frac{1}{\Gamma} \leq \frac{(1-\pi_\text{exp}(t|x))\pi_\text{exp}(t|x, u)}{\pi_\text{exp}(t|x)(1-\pi_\text{exp}(t|x, u))} \leq \Gamma$ for marginalized policy $\pi_\text{exp}(t|x) := \mathbb{E}_U\pi_\text{exp}(t|x, U)$ and $\mathbb{E}\left[\mathbb{1}[T=t]w(T, X, U)\right]=1$ for any $t\in\mathcal{T}$. The first condition can be reformulated as 
$a_{t, x} \leq w(t, x, u) \leq b_{t, x}$ with $a_{t, x}:=1 + \Gamma^{-1} \cdot \left(\frac{1}{\pi_\text{exp}(t|x)}-1\right)$ and $b_{t, x}:=1 + \Gamma\cdot\left(\frac{1}{\pi_\text{exp}(t|x)}-1\right)$.

#### Re-formulation in Sample 
Given samples $\{(X_i, T_i, Y_i) \}_{i=1}^n$ taken from the expert policy $\pi_\text{exp}$, we can approximate the above objective as follows:
$$
\min_{w\in \mathcal{W}}V_w^\pi - V^{\pi_\text{exp}}
\approx \min_{w\in\mathcal{W}_n}
\frac{1}{n}\sum_{i=1}^n \left[
  w_i\pi(t|X) \theta^T\phi(y) - \theta^T\phi(y) 
\right],
$$
where 
$$\mathcal{W}_n:= \left\{(w_1, \ldots, w_n) :
a_{T_i, X_i} \leq w_i \leq b_{T_i, X_i} \text{ for any }i=1, \ldots, n \text{ and }\frac{1}{n}\sum_{i=1}^n\mathbb{1}[T_i=t]w_i=1 \text{ for any }t\in\mathcal{T}
\right\}.
$$
Therefore, our imitation policy is given by
$$
\hat \pi = 
\arg\max_\pi\min_{\theta\in\Theta}\min_{w\in\mathcal{W}_n}
\frac{1}{n}\sum_{i=1}^n \left[
  w_i\pi(t|X) \theta^T\phi(y) - \theta^T\phi(y) 
\right]
.$$

## Policy Improvement

## Toy Example
We consider the following 2-dimensional model:
\begin{align*}
X_i&\sim N(0, I_2)
\\
U_i&\sim N(0, I_2)
\\
T_i &\sim \pi(t|X_i, U_i)
\\
Y_i&\sim N(X_i + U_i + T_i, I_2)
\end{align*}
where $t =(t_1, t_2)^T \in \mathcal{T}:=\{-1, 0, 1\}^2$. As for the class of return function, we define $\phi(y):= (\|y_1\|, \|y_2\|)^T$ and assume the true parameter is $\theta^*:=(0, 0)^T$. We assume that the expert policy is given by
$$\pi_\text{exp}:= \arg\max_\pi \mathbb{E}_\pi[{\theta^*}^T\phi(Y) + \lambda \log \pi(T|X, U)]$$
so that it is soft-optimal.