#### Policy learning and Policy Gradients

This is a recap of policy learning, and how it differs when one assumes a stochastic versus a deterministic model. If these concepts are unfamiliar to you, there are many great resources [online](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html). This recap contextualizes how we can use MJX's differentiability for policy learning.

The goal of policy learning is to learn a control policy $\pi$ which outputs actions $a_t \sim \pi(\cdot| x_t, \theta)$ maximizing the total rewards $\sum r_t$, where $r_t$ is shorthand for a reward function evaluated at the state and action of time t: $r_t = r(x_t, a_t)$. $\theta$ are the parameters of the policy. In the common case that the policy is a neural network, $\theta$ would be the weights. **Policy gradient methods** involve estimating the gradient of the policy with respect to the weights, and using this value in a first-order optimization algorithm such as Gradient Descent or [Adam](https://arxiv.org/abs/1412.6980).

How you estimate the policy gradient depends on what state transition model you assume. 

#### Zeroth-Order Policy Gradients (ZoPG)

Referring to mjx.step as the simulation function f, we borrowing some [terminology](https://arxiv.org/abs/2202.00817) to differentiate between zeroth-order gradients, which only depend on values of f, and first-order gradients, which depend on its jacobian.

Reinforcement learning algorithms such as the standard [PPO](https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py) assume a stochastic state transition model $x_{t+1} \sim P(\cdot | x_t, a_t)$. This leads to a ZoPG of the form:

$$
\nabla_\theta J(\pi_\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[ \sum \nabla_\theta \log\pi_\theta (a_t | s_t) R(\tau) \right]
$$

Despite this method's popularity and extensive research into its refinement, a fundamental shortcoming is that the gradient has high variance. This allows the optimizer to thoroughly explore the space of policies, leading to the robust and often surprisingly good policies that have been achieved. However, the variance comes at the cost of requiring many samples $(x_t, a_t)$ to converge.

#### First-Order Policy Gradients (FoPG)
On the other hand, if you assume a deterministic state transition model $x_{t+1} = f(x_t, a_t)$, you end up with the first-order policy gradient. Unlike ZoPG methods, which models the state evolution as a probabilistic black box, the FoPG explicitly contains the jacobians of the simulation function f. For example, let's look at the gradient of $r_t$, in the case that it only depends on state.
$$
\frac{\partial r_t}{\partial \theta} = \frac{\partial r_t}{\partial x_t}\frac{\partial x_t}{\partial \theta} 
$$

$$
\frac{\partial x_t}{\partial \theta} = \textcolor{Navy}{\frac{\partial f(x_t, a_t)}{\partial x_{t-1}}}\frac{\partial x_{t-1}}{\partial \theta} + \textcolor{Navy}{\frac{\partial f(x_t, a_t)}{\partial a_{t-1}}} \frac{\partial a_{t-1}}{\partial \theta}
$$

The navy-colored terms in the above expression are enabled by MJX's differentiability and are the key difference between FoPG's and ZoPG's. An important consideration is what these jacobians look like near contact points. To see why certain gradients within the jacobian can be pathological, imagine a hard sphere falling toward a block of marble. How does its velocity change with respect to distance ($\frac{\partial \dot{z}_t}{\partial z_t}$, for $x_t$ = [$z_t, \dot{z}_t$]), the instant before it touches the ground? This is the case of an **uninformative gradient**, due to **hard contact**. In practice however, the default contact settings in Mujoco are sufficiently soft for learning via FoPG's. Soft contacts would resolve the above scenario by modelling the ground as applying an increasing force on the ball as it penetrates it.

A helpful way to think about FoPG's is via the chain rule, as illustrated below for how $r_2$ influences the parameter update, again for the case that the reward does not depend on action:
<img src="../doc/images/mjx/apg_diagram.png" alt="drawing" width="300"/>

Note that there three distinct gradient chains in this example. The red pathway does not use the simulator's differentiability. The blue path is the most intuitive usage of this feature, and captures how actions affect downstream rewards. The least intuitive may be the green chain, which shows how the reward depends on how actions depend on previous actions - experience shows that blocking this pathway via jax.lax.stop_grad can badly hinder policy learning. As the length of $x_t$ backbone increases, [gradient explosion](https://arxiv.org/abs/2111.05803) becomes a crucial consideration. In practice, this can be resolved via decaying downstream gradients or periodically truncating the gradient.

**The Sharp bits of FoPG's**

While FoPG's have been shown to be very sample efficient, especially as the [dimension of the state space increases](https://arxiv.org/abs/2204.07137), they can still struggle with wall-clock time. Because the gradients have low variance, they do not benefit significantly from massive parallelization of data collection - unlike [RL](https://arxiv.org/abs/2109.11978). Additionally, the policy gradient is typically calculated via autodifferentiation. This can be 3-5x slower than unrolling the simulation forward, and memory intensive, with memory requirements scaling with $O((m+n) \cdot m \cdot T)$, where m and n are the state and control dimensions and T is the number of steps propogated through.

Due to the lower gradient variance, FoPG's also have less exploration power than ZoPG's and benefit from the practioner being more explicit in the problem formulation. 

In this tutorial, we hope to convey *when* to use FoPG's, and *how* to use them through three case studies:


|     |     | Case Study                        |
| --- | ----| --------------------------------- |
| 1   |     | *Imitating Kinematics*            |
|     | a   | 1 Hz Trot                         |
|     | b   | 3 Hz Trot                         |
|     | c   | 3 Hz Trot; PPO                    |
| 2   |     | *Quadruped Locomotion Design*     |
|     | a   | 0.75 m/s trot with long strides   |
|     | b   | 1.5 m/s trot with short strides   |
|     | c   | 0.75 m/s trot without reference   |

(TODO)
Study 1a demonstrates the sample efficiency and degree of refinement possible from a FoPG algorithm. Studies 1a, 2a and 2b show that FoPG algorithms excel when the reward is specified clearly. Study 3 shows that like for classical methods such as MPC and trajectory optimization, the learned policy benefits greatly from a good "initial guess".