In [5]:
import matplotlib.pyplot as plt

## batched_forward_pass(queries, responses, scores)

### input

- `queries`: list
    - len(queries) == 1024 (batch size)
    - `[q.shape for q in queries]`
- `responses`: list
    - len(responses) == len(queries)
    - `[r.shape for r in responses]`
- `scores`: tensor
    - `scores.shape == torch.Size([1024])`
- `model_inputs = self.prepare_model_inputs(queries, responses)`: 
    - `model_inputs.keys() = dict_keys(['input_ids', 'attention_mask'])`
    - `model_inputs['attention_mask'].sum(dim=-1)`
        - len(r) + len(q)

```
model_inputs = self.prepare_model_inputs(queries, responses)
all_logprobs, _, values, masks = self.batched_forward_pass(self.model, queries, responses, model_inputs, ...)
ref_logprobs, _, _, _ = self.batched_forward_pass(self.ref_model, queries, responses, model_inputs, ...)
```

- `all_logprobs.shape == torch.Size([1024, 21])`
    - logp，已做过 gather
- `ref_logprobs.shape == torch.Size([1024, 21])`
    - logp，已做过 gather

## compute_rewards(scores, all_logprobs, ref_logprobs, masks)

- kl penalty

    $$
    D_{KL}(P||Q)=\sum_{x}P(x)\log\frac{P(x)}{Q(x)}
    $$

    - `logprob - ref_logprob`（相对）

    $$
    \log p-\log q=\log \frac{p}{q}
    $$

### kl_ctl

```
# self.kl_ctl = AdaptiveKLController(0.02, 6, 10000)
self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
```

In [3]:
from trl.trainer import AdaptiveKLController

In [8]:
# https://arxiv.org/pdf/1909.08593.pdf, 2.2
kl_ctl = AdaptiveKLController(0.02, 6, 10000)

### $R(x,y)$

$$
R(x,y)=r(x,y)-\beta\log\frac{\pi(y|x)}{\rho(y|x)}
$$

```
reward = score - self.kl_ctl.value * kl
```

## compute_advantages(values, rewards, masks)