In [1]:
import jax

In [2]:
jax.devices()

[GpuDevice(id=0, process_index=0)]

In [56]:
import numpy as np
import matplotlib.pyplot as plt

## Imitation Learning

#### Formulation in Population

We are interested in offline imitation learning with an unknown return function $r(y) \in \{\theta^T\phi(y): \theta\in\Theta\}$ for $\Theta:=\{\theta: \geq 0, \sum_k\theta_k=1\}$ with discrete treatment $t\in\mathcal{T}$.
The expert policy depends on an unobserved confounder $U$ while the learned poicy does not, so that they are $\pi_\text{exp}(t|x, u)$ and $\pi(t|x)$. Under these assumptions, the imitation learning (i.e. multiobjective policy improvement) from offline dataset becomes:
$$\hat\pi 
= \arg\max_\pi\min_{\theta\in\Theta}[\min_{w\in \mathcal{W}}V_w^\pi - V^{\pi_\text{exp}}].$$
Here, the objective is defined as:
$$
\min_{w\in \mathcal{W}}V_w^\pi - V^{\pi_\text{exp}}
= \min_{w\in\mathcal{W}}
\mathbb{E}_X\mathbb{E}_U\left[
\sum_{t\in\mathcal{T}}
  \frac{
    \int\mathrm{d}p(y|t, X, U)\pi_\text{exp}(t|X, U)w(t, X, U)\pi(t|X) \theta^T\phi(y)
  }{
    \mathbb{E}_X\mathbb{E}_U[\pi_\text{exp}(t|X, U) w(t, X, U)]
  }
  - \sum_{t\in\mathcal{T}}\int\mathrm{d}p(y|t, X, U)\pi_\text{exp}(t|X, U) \theta^T\phi(y) 
\right],
$$
where $\mathcal{W}$ is an uncertain set of the inverse treatment probability $\frac{1}{\pi_\text{exp}(t|x, u)}$ satisfying $\frac{1}{\Gamma} \leq \frac{(1-\pi_\text{exp}(t|x))\pi_\text{exp}(t|x, u)}{\pi_\text{exp}(t|x)(1-\pi_\text{exp}(t|x, u))} \leq \Gamma$ for marginalized policy $\pi_\text{exp}(t|x) := \mathbb{E}_U\pi_\text{exp}(t|x, U)$ (but $w$ does not necessarily satisfy the normalization requirements $\sum_t \frac{1}{w(t, x, u)}=1$ for any $x, u$).
The first condition can be reformulated as 
$a_{t, x} \leq w(t, x, u) \leq b_{t, x}$ with $a_{t, x}:=1 + \Gamma^{-1} \cdot \left(\frac{1}{\pi_\text{exp}(t|x)}-1\right)$ and $b_{t, x}:=1 + \Gamma\cdot\left(\frac{1}{\pi_\text{exp}(t|x)}-1\right)$.

#### Re-formulation in Sample 
Given samples $\{(X_i, T_i, Y_i) \}_{i=1}^n$ taken from the expert policy $\pi_\text{exp}$, we can approximate the above objective as follows:
$$
\min_{w\in \mathcal{W}}V_w^\pi - V^{\pi_\text{exp}}
\approx \min_{w\in\mathcal{W}_n}
\sum_{t\in\mathcal{T}} \sum_{i=1}^n \left[
  \frac{
    w_i\pi(t|X_i) \theta^T\phi(Y_i)\mathbb{1}[T_i=t]
  }{
    \sum_{j=1}^n w_j  \mathbb{1}[T_j=t]
  }
  - \theta^T\phi(Y_i)\mathbb{1}[T_i=t]
\right],
$$
where 
$$\mathcal{W}_n:= \left\{(w_1, \ldots, w_n) :
a_{T_i, X_i} \leq w_i \leq b_{T_i, X_i} \text{ for any }i=1, \ldots, n
\right\}.
$$
Therefore, our imitation policy is given by
$$
\hat \pi = 
\arg\max_\pi\min_{\theta\in\Theta}\min_{w, \psi}
\frac{1}{n}\sum_{i=1}^n \left[
  w_i \pi(T_i|X_i) \theta^T\phi(Y_i) - \theta^T\phi(Y_i) 
\right]
$$
such that 
$$
\psi_t > 0 \text{ and }\sum_{i=1}^n w_i \mathbb{1}[T_i=t] = 1 \text{ for any }t\in\mathcal{T}
\text{ and }\psi_{T_i} a_{T_i, X_i} \leq w_i \leq \psi_{T_i} b_{T_i, X_i} \text{ for any }i=1, \ldots, n.
$$
Here, $\psi_t$ represents the value of $\frac{1}{\sum_i w_i\mathbb{1}[T_i=t]}$.

## Multi-Objective Policy Improvement

More details need to be added.

#### Re-formulation in Sample 
Given samples $\{(X_i, T_i, Y_i) \}_{i=1}^n$ taken from the expert policy $\pi_\text{exp}$, we can approximate the above objective as follows:
$$
\min_{w\in \mathcal{W}} [V_w^\pi - V_w^{\pi_0}]
\approx \min_{w\in\mathcal{W}_n}
\sum_{t\in\mathcal{T}} \sum_{i=1}^n \left[
  \frac{
    w_i (\pi(t|X_i) - \pi_0(t| X_i)) \theta^T\phi(Y_i) \mathbb{1}[T_j=t]
  }{
    \sum_{j=1}^n w_j  \mathbb{1}[T_j=t]
  }
\right],
$$
where 
$$\mathcal{W}_n:= \left\{(w_1, \ldots, w_n) :
a_{T_i, X_i} \leq w_i \leq b_{T_i, X_i} \text{ for any }i=1, \ldots, n
\right\}.
$$
Therefore, our imitation policy is given by
$$
\hat \pi = 
\arg\max_\pi\min_{\theta\in\Theta}\min_{w, \psi}
\frac{1}{n}\sum_{i=1}^n \left[
  w_i (\pi(T_i|X_i) - \pi_0(T_i|X_i)) \theta^T\phi(Y_i)
\right]
$$
such that 
$$
\psi_t > 0 \text{ and }\sum_{i=1}^n w_i \mathbb{1}[T_i=t] = 1 \text{ for any }t\in\mathcal{T}
\text{ and }\psi_{T_i} a_{T_i, X_i} \leq w_i \leq \psi_{T_i} b_{T_i, X_i} \text{ for any }i=1, \ldots, n.
$$
Here, $\psi_t$ represents the value of $\frac{1}{\sum_i w_i\mathbb{1}[T_i=t]}$.

## Toy Example
We consider the following 2-dimensional model:
\begin{align*}
X_i&\sim N(0, I_2)
\\
U_i&\sim N(0, I_2)
\\
T_i &\sim \pi_\text{exp}(t|X_i, U_i)
\\
Y_i&\sim N(X_i + U_i + T_i, I_2)
\end{align*}
where $t =(t_1, t_2)^T \in \mathcal{T}:=\{-1, 0, 1\}^2$. As for the class of return function, we define $\phi(y):= (- \|y_1\|, - \|y_2\|)^T$ and assume the true parameter is $\theta^*:=(0, 0)^T$. 

In [51]:
def return_func(Y):
    return -np.linalg.norm(Y, axis=1)

In [52]:
def generate_data(n=1, pi=lambda x, u: np.random.choice((-1, 0, 1), 2 * x.shape[0]).reshape(x.shape)):
    X = np.random.randn(n * 2).reshape(n, 2)
    U = np.random.randn(n * 2).reshape(n, 2)
    T = pi(X, U)
    Y = X + U + T + np.random.randn(n * 2).reshape(n, 2)
    return X, U, T, Y

In [53]:
X, U, T, Y = generate_data(1000)
R = return_func(Y)

def plot_sequence(*args):
    xy_prev = None
    for xy in args:
        plt.scatter(xy[:, 0], xy[:, 1], color='b')
        if 
        plt.xlim(-3, 3)
        plt.ylim(-3, 3)
        xy_prev.append( = xy
        
def plot(X, U, T, Y):
    plt.subplot(2, 2, 1)
    plt.subplot(2, 2, 2)
    plt.scatter(X[:, 0], X[:, 1], alpha=0.1)
    plt.scatter(U[:, 0], U[:, 1])
    plt.subplot(2, 2, 3)
    plt.scatter(T[:, 0], T[:, 1])
    plt.subplot(2, 2, 4)
    plt.scatter(Y[:, 0], Y[:, 1])
    plt.show()

plt.figure(figsize=(9, 8))
plot(X, U, T, Y)

In [55]:
# policy value of random policy
R.mean()

-2.380767529128187

### Multi-Objective Policy Improvement
$\pi_0(t|x) := - \text{sign}(x) \cdot \mathbb{1}[\|x\|>0.5]$

### Imitation Learning
We assume that the expert policy is given by
$$\pi_\text{exp}:= \arg\max_\pi \mathbb{E}_\pi[{\theta^*}^T\phi(Y) + \lambda \log \pi(T|X, U)]$$
so that it is soft-optimal.