# Quantization Aware Policy Optimization

### Random Quantization (Unbiased but impractical)

Quantized parameters $\phi \sim \Phi_\theta$ are sampled from a distribution parameterized by the true parameters $\theta$. We want to maximize the following objective:

$$J(\theta) = \mathbb E_{\tau \sim \pi_\phi, \phi \sim \Phi_\theta}[r(\tau)] 
= \int \sum_\phi P(\phi|\theta)\pi_\phi(\tau)r(\tau)d\tau.$$

$\tau$ is a trajectory (LLM output), $\pi_\phi$ is the quantized policy (quantized LLM), and $r$ is the reward function (grade of LLM output).

Then,

$$\nabla_\theta J(\theta) = \nabla_\theta \int \sum_\phi P(\phi|\theta)\pi_\phi(\tau)r(\tau)d\tau 
= \mathbb E_{\tau \sim \pi_\phi, \phi \sim \Phi_\theta}\left[\frac{\nabla_\theta P(\phi|\theta)}{P(\phi|\theta)}r(\tau)\right].$$

This approach is unrealistic and impractical for the following reasons:
- In practice, quantized parameters are not random.
- Estimating $\nabla_\theta J(\theta)$ is challenging because it requires sampling both $\phi$ and $\tau$.

### Deterministic Quantization with Surrogate Gradients (More practical but biased)

If quantization is deterministic, $P(\phi|\theta) = 0$ or is undefined, so we cannot optimize using $\nabla_\theta J(\theta)$ described previously. 

Now, 
$$\phi = \varphi(\theta) = \frac{\operatorname{round}(s(\theta - z))}{s} + z,$$
and 
$$J(\theta) = \mathbb E_{\tau \sim \pi_{\varphi(\theta)}}[r(\tau)] 
= \int \pi_{\varphi(\theta)}(\tau)r(\tau)d\tau.$$

Since $\operatorname{round}$ is not differentiable, we will use a surrogate, straight-through estimate of its gradient, 
$$\tilde\nabla_x\operatorname{round}(x) = 1.$$

Then,
$$\tilde\nabla_\theta\varphi(\theta) = 1,$$
so
$$\tilde\nabla_\theta f(\varphi(\theta)) = \nabla_{\varphi(\theta)} f(\varphi(\theta))\tilde\nabla_\theta\varphi(\theta) = \nabla_{\varphi(\theta)} f(\varphi(\theta)).$$

Finally,
$$\tilde\nabla_\theta J(\theta) = \tilde\nabla_\theta \int \pi_{\varphi(\theta)}(\tau)r(\tau)d\tau
= \mathbb E_{\tau \sim \pi_{\varphi(\theta)}}\left[
    \frac{\tilde\nabla_\theta\pi_{\varphi(\theta)}(\tau)}{\pi_{\varphi(\theta)}(\tau)}r(\tau)
\right]
= \mathbb E_{\tau \sim \pi_{\varphi(\theta)}}\left[\frac{\nabla_{\varphi(\theta)}\pi_{\varphi(\theta)}(\tau)}{\pi_{\varphi(\theta)}(\tau)}r(\tau)\right].$$


### Augmented GRPO Objective

To reduce variance and add regularization, we can optimize a GRPO style objective:

$$J(\theta) = \mathbb E_{\tau \sim \pi_{\varphi(\theta)}}\left[
    \min\left(
        \frac{\pi_{\varphi(\theta)}(\tau)}{\pi_{\phi}(\tau)}\hat A(\tau),
        \operatorname{clip}\left(
            \frac{\pi_{\varphi(\theta)}(\tau)}{\pi_{\phi}(\tau)}, 
            1 - \epsilon, 
            1 + \epsilon
        \right)\hat A(\tau)
    \right) - \beta \mathbb D_{\text{KL}}[\pi_{\varphi(\theta)}\|\pi_\text{ref}]
\right].$$

$\phi = \varphi(\theta)$, but it is treated as a constant for gradient computation.

The normalized advantage $\hat A(\tau)$ is given by
$$\hat A(\tau) = \frac{
    r(\tau) - \mathbb E_{\tau \sim \pi_{\varphi(\theta)}}[r(\tau)]
}{
    \sqrt{\operatorname{Var}_{\tau \sim \pi_{\varphi(\theta)}}[r(\tau)]}
}.$$