- https://philliphaeusler.com/posts/aligning_tinystories/
    - https://github.com/pHaeusler/tinycatstories/tree/main

In [3]:
from IPython.display import Image

In [2]:
import torch
from torch.distributions import Categorical
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW
import matplotlib.pyplot as plt

### TinyStories

- https://arxiv.org/abs/2305.07759
    - https://arxiv.org/abs/2305.07759
    - Ok, so tinystories is a fantastic paper that shows how **a small transformer model** can be trained to produce **coherent stories**.
    - Their trick was to carefully **curate training data by synthetically generating it (using GPT)**. It worked!
- the perspective of RL
    - Rather than training on data (labeled or unlabeled), we can train with **another system that gives feedback**. This could be a simple function that evaluates the model state or action, perhaps from a simulator, or it could be a deep-learning model specifically trained to give feedback - Reinforcement Learning (RL).

### Embedding loss/rewards

In [2]:
#cannot import name 'DEFAULT_CIPHERS' from 'urllib3.util.ssl_' 
# !pip install --upgrade 'urllib3==1.26.7' 

In [3]:
# !pip install sentence_transformers

In [4]:
from sentence_transformers import SentenceTransformer, util

embedding_model = SentenceTransformer("all-MiniLM-L6-v2").to("cuda")
reference_embedding = embedding_model.encode("cat", convert_to_tensor=True)

def compute_rewards(sequences):
    sequence_embeddings = embedding_model.encode(sequences, convert_to_tensor=True)
    cosine_similarities = util.pytorch_cos_sim(
        reference_embedding.unsqueeze(0), sequence_embeddings
    ).squeeze()
    return cosine_similarities

In [5]:
reference_embedding.shape

torch.Size([384])

### REINFORCE

- reward function (sentences 级别）
    $$
    R(s)=\text{cos}(\text{Emb(s),Emb('cat')})
    $$
- loss function （negative objective function)

    $$
    L(\theta)=L_{pg}(\theta)+\lambda D_{kl}(\theta)
    $$
    - $L_{\text{pg}}(\theta) = -\mathbb{E}_{s \sim \pi_\theta}[R(s) \cdot \log \pi_\theta(s)]$
    - $D_{\text{KL}}(\theta) = \mathbb{E}_{s \sim \pi_\theta}[D_{\text{KL}}(\pi_\theta(\cdot|s) || \pi_{\text{ref}}(\cdot|s))]$
- 采样及优化过程
    - 采样 batch_size 个序列（token by token，autoregressive）
        - $s_i \sim \pi_\theta, i \in \{1,...,N\}$
    - 计算每个序列的对数概率：
        - $\log \pi_\theta(s_i) = \frac1T\sum_{t=1}^T \log \pi_\theta(s_{i,t}|s_{i,<t})=\frac1T\log\Pi_{t=1}^T\pi_\theta(s_{i,t}|s_{i,\lt t})$
            - 整个句子的联合概率（joint distribution）的 log prob
    - 计算 kl 散度（KL 散度约束保持生成文本的流畅性）
        - $D_{kl}=\frac1T\sum_{t=1}^TD_{kl}(\pi_\theta(\cdot|s_{\lt t})\|\pi_{ref}(\cdot|s_{\lt t}))$

In [7]:
Image(url='./figs/training_metrics_0.png', width=600)

In [8]:
Image(url='./figs/training_metrics_6000.png', width=600)