In [18]:
from IPython.display import Image
import torch
import math
import numpy as np

## why position encoding in Transformer

- attention mechanism （Transformer 最特色的）
    - $X\in\mathbb R^{\ell\times d}$
    - $W_k\in\mathbb R^{d\times d_k},W_q\in\mathbb R^{d\times d_k},W_v\in\mathbb R^{d\times d_v}$
    - $Q=XW_q\in\mathbb R^{\ell\times d_k}, K=XW_k\in\mathbb R^{\ell\times d_k}, V=XW_v\in\mathbb R^{\ell\times d_v}$

$$
\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

$$
A_{ij}=\frac{\exp(\frac{q^T_ik_j}{\sqrt{d_k}})}{\sum_{j'}\exp(\frac{q^T_ik_{j'}}{\sqrt{d_k}})}
$$

- $A_{ij}$ （attention weights， $QK^T$: attention scores） 表示的是位置 $i$ 的词（token）与位置 $j$ 的词（token）的注意力权重，
    - 就是如果 $x_i$/$x_j$ 或者 $q_i$/$q_j$($k_i$/$k_j$) 没有编码位置信息，那么的话，这个weight就跟位置无关，显然在 seq modeling 中是有很大缺陷的
    - 也就是一句话的含义，肯定跟 token 的组织有序有关

## bert, gpt

In [2]:
Image(url='https://miro.medium.com/v2/resize:fit:786/format:webp/1*iJqlhZz-g6ZQJ53-rE9VvA.png', width=500)

- BERT: 加性 (absolue) position encoding （learnable position encoding）

    ```
    # modeling_bert.py
    embeddings = inputs_embeds + token_type_embeddings + self.position_embeddings(position_ids)
    ```

- GPT: 加性 （absolute）position encoding（learnable position encoding）

    ```
    # modeling_gpt.py
    if inputs_embeds is None:
        inputs_embeds = self.wte(input_ids)
    position_embeds = self.wpe(position_ids)
    hidden_states = inputs_embeds + position_embeds + token_type_embeds
    ```

### sin position encoding

- 无需训练; 依然是绝对位置编码
- transformers 原始论文

$$
\begin{split}
PE(t,2i)&=\sin(\frac{t}{10000^{\frac{2i}{d_{model}}}})\\
PE(t,2i+1)&=\cos(\frac{t}{10000^{\frac{2i}{d_{model}}}})\\
\Downarrow\\
PE(t,i)&=\sin(\frac{t}{10000^{\frac{i}{d_{model}}}}), \quad \text{i is even}\\
PE(t,i)&=\cos(\frac{t}{10000^{\frac{i-1}{d_{model}}}}), \quad \text{i is odd}\\
\end{split}
$$

d_model = 4

- pos = 0, $[\sin(0),\cos(0), \sin(0),\cos(0)]$
- pos = 1, $[\sin\left(\frac{1}{10000^{0/4}}\right),\cos\left(\frac{1}{10000^{0/4}}\right), \sin\left(\frac{1}{10000^{2/4}}\right), \cos\left(\frac{1}{10000^{2/4}}\right)]$
- pos = 2, $[\sin\left(\frac{2}{10000^{0/4}}\right),\cos\left(\frac{2}{10000^{0/4}}\right), \sin\left(\frac{2}{10000^{2/4}}\right), \cos\left(\frac{2}{10000^{2/4}}\right)]$
- pos = 3, $[\sin\left(\frac{3}{10000^{0/4}}\right),\cos\left(\frac{3}{10000^{0/4}}\right), \sin\left(\frac{3}{10000^{2/4}}\right), \cos\left(\frac{3}{10000^{2/4}}\right)]$




## llama RoPE

- 从绝对位置编码到相对位置编码
    - 绝对位置编码，位置 pos_i 的编码仅取决于 pos_i 的值；
    - 相对位置编码，（一般不需要对每个位置进行单独的编码），而是直接对位置之间的相对距离进行编码
        - pos=0 与 pos=1 的相对位置 $f(|0-1|)$
        - pos=1 与 pos=3 的相对位置 $f(|1-3|)$
        - 偏差构成的矩阵，称为 id 矩阵；
- RoPE
    - 旋转位置编码，为相对位置编码，非加性位置编码，直接嵌入到 attention mechanism 的计算中；
    - $R^d_{\Theta,m}$：位置 $m$ 对应的旋转矩阵 not learnable：非学习的，全局固定的；
        - $m\theta$：frequency


$$
\begin{split}
f(q,m)^Tf(k,n)&=(R_mq)^T(R_nk)\\
&=q^T(R^T_mR_n)k\\
&=q^TR_{n-m}k
\end{split}
$$


```
# freqs_cis 是一个全局的旋转矩阵
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
xq, xk, xv
```

In [33]:
# d: dim, 
# m: position
Image(url='../../imgs/rope_paper.png', width=600)

In [4]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

In [13]:
cis = precompute_freqs_cis(dim=4, end=3)
cis

tensor([[ 1.0000+0.0000j,  1.0000+0.0000j],
        [ 0.5403+0.8415j,  0.9999+0.0100j],
        [-0.4161+0.9093j,  0.9998+0.0200j]])

$$
\begin{split}
&\text{freqs}=[1,\frac{1}{\theta^{2/4}}]=[1., 0.01]\\
&t=[0,1,2]\\
&\text{freqs}=\begin{bmatrix}
0 & 0\\
1 & 0.01\\
2 & 0.02
\end{bmatrix}\\
&\text{freqs\_cis}=e^{j\cdot\text{freqs}}=\begin{bmatrix}
1 & 1\\
e^j & e^{j\cdot0.01}\\
e^{j\cdot2} & e^{j\cdot0.02}
\end{bmatrix}
\end{split}
$$

In [21]:
# torch.polar(torch.tensor([1.]), torch.tensor([0.]))
torch.polar(torch.tensor([1.]), torch.tensor([2.]))

tensor([-0.4161+0.9093j])

In [10]:
theta_matrix = torch.tensor([
    [0.0, 0.0],
    [1.0, 0.01],
    [2.0, 0.02]
])

# 幅度矩阵，全为 1
r_matrix = torch.ones_like(theta_matrix)

# 计算 e^{j*theta_matrix}
e_j_theta_matrix = torch.polar(r_matrix, theta_matrix)
e_j_theta_matrix

tensor([[ 1.0000+0.0000j,  1.0000+0.0000j],
        [ 0.5403+0.8415j,  0.9999+0.0100j],
        [-0.4161+0.9093j,  0.9998+0.0200j]])

### https://spaces.ac.cn/archives/8265

- 二维向量的旋转矩阵是 $2\times 2$的正交矩阵
- d维向量的旋转矩阵是 d*d，依然是正交矩阵
    - 注意 $x_1,x_2$ 对应一个2*2旋转矩阵 $R_{1,2}$
    - $x_3,x_4$ 对应一个2*2的旋转矩阵 $R_{3,4}$
    - ...
- $R=R_{1,2}R_{3,4}R_{5,6}\cdots R_{d-1,d}$

In [34]:
Image(url='../../imgs/4d_rotation.png', width=300)

In [3]:
Image(url='../../imgs/rope_1.png', width=450)

In [24]:
theta1 = torch.tensor(np.pi/6)
theta2 = torch.tensor(np.pi/3)

R12 = torch.tensor([
    [torch.cos(theta1), -torch.sin(theta1), 0, 0],
    [torch.sin(theta1),  torch.cos(theta1), 0, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 1]
])

R34 = torch.tensor([
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, torch.cos(theta2), -torch.sin(theta2)],
        [0, 0, torch.sin(theta2),  torch.cos(theta2)]
])

R = torch.mm(R12, R34)

R12, R34, R

(tensor([[ 0.8660, -0.5000,  0.0000,  0.0000],
         [ 0.5000,  0.8660,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  1.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]]),
 tensor([[ 1.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  1.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.5000, -0.8660],
         [ 0.0000,  0.0000,  0.8660,  0.5000]]),
 tensor([[ 0.8660, -0.5000,  0.0000,  0.0000],
         [ 0.5000,  0.8660,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.5000, -0.8660],
         [ 0.0000,  0.0000,  0.8660,  0.5000]]))

In [30]:
print(torch.allclose(torch.mm(R, R.T), torch.mm(R.T, R)))
print(torch.allclose(torch.mm(R, R.T), torch.eye(4)))

True
True


In [4]:
Image(url='../../imgs/rope_2.png', width=450)

In [10]:
# theta: m*theta
# m: position
def get_rope_matrix(d, theta):
    """Construct the ROPE rotation matrix."""
    cos_theta = torch.cos(theta)
    sin_theta = torch.sin(theta)
    mat = torch.zeros(d, d)
    mat[0:d//2, 0:d//2] = torch.diag(cos_theta)
    mat[d//2:, d//2:] = torch.diag(cos_theta)
    mat[0:d//2, d//2:] = -torch.diag(sin_theta)
    mat[d//2:, 0:d//2] = torch.diag(sin_theta)
    return mat

In [5]:
# 示例向量的维度
d = 4  # 必须是偶数
theta_m = torch.tensor([i * math.pi / 180 for i in range(d//2)])
theta_n = torch.tensor([(i + 1) * math.pi / 180 for i in range(d//2)])

In [7]:
theta_m, theta_n

(tensor([0.0000, 0.0175]), tensor([0.0175, 0.0349]))

In [8]:
theta_n - theta_m

tensor([0.0175, 0.0175])

In [9]:
torch.tensor([1 * math.pi / 180 for i in range(d//2)])

tensor([0.0175, 0.0175])

In [38]:
# 构造 R_m 和 R_n
R_m = get_rope_matrix(d, theta_m)
R_n = get_rope_matrix(d, theta_n)

In [42]:
R_m_T = R_m.T
R_n_minus_m = get_rope_matrix(d, theta_n - theta_m)
R_n_minus_m

tensor([[ 0.9998,  0.0000, -0.0175, -0.0000],
        [ 0.0000,  0.9998, -0.0000, -0.0175],
        [ 0.0175,  0.0000,  0.9998,  0.0000],
        [ 0.0000,  0.0175,  0.0000,  0.9998]])

In [41]:
product = torch.mm(R_m_T, R_n)
product

tensor([[ 0.9998,  0.0000, -0.0175,  0.0000],
        [ 0.0000,  0.9998,  0.0000, -0.0175],
        [ 0.0175,  0.0000,  0.9998,  0.0000],
        [ 0.0000,  0.0175,  0.0000,  0.9998]])

In [43]:
# 测试向量 q 和 k
q = torch.tensor([1.0, 2.0, 3.0, 4.0])
k = torch.tensor([4.0, 3.0, 2.0, 1.0])

# 计算 (R_m q)^T (R_n k)
R_m_q = torch.mv(R_m, q)
R_n_k = torch.mv(R_n, k)
# result_1 = torch.dot(R_m_q, R_n_k)

In [46]:
# 计算 q^T R_m^T R_n k
result_2 = torch.dot(q, torch.mv(product, k))
result_2

tensor(20.3460)

In [47]:
# 计算 q^T R_{n-m} k
result_3 = torch.dot(q, torch.mv(R_n_minus_m, k))
result_3

tensor(20.3460)

## CoPE：Contextual Position Encoding: Learning to Count What’s Important

- https://arxiv.org/pdf/2405.18719
- https://www.zhihu.com/question/657761483/answer/3517582623
- 相对位置编码
    - 之前的哪怕是 rope，都是基于 token positions 的，独立于上下文
    - cope （$p_{ij}$ 的计算已经考虑了 $q_i,k_j$）不只可以 attend 到 token，还可以到 sentence 到 paragraph

$$
a_{ij}=\text{Softmax}(\mathbf q^T_i(\mathbf k_j+\mathbf e[i-j]))
$$

In [15]:
# rel p_{ij} 显然是整数
# cope, p_{ij} 显然不是整数
# 这个例子想说的是我们要关注最后一句话， relative pe 并不能做到很好的 attend 到最后一句话，cope 可以
Image(url='../../imgs/cope.png', width=500)

In [12]:
Image(url='../../imgs/rel_position_matrix.png', width=400)

- $g_{ij}=\sigma(q_i^Tk_j)$（gate value）, $i$ 称为 target，$j\lt i$, 也就是 $i$ 左边的；
    - 可以想象自回归的过程，$i$表示当前位置，$j$则表示过去的已经遍历/生成的 tokens；
- $p_{ij}=\sum_{k=j}^ig_{ik}$
    - 如果 $g_{ij}==1$, $p_{ij}=i-j+1$
    - 显然 $p_{ij}$ 大概率不是整数，它是我们希望刻画的相对位置；
- a learnable embedding vector $\mathbf e[p]$（$p\in[0,T]$），做插值；

    $$
    \mathbf{e}[p_{ij}] = (p_{ij} - \lfloor p_{ij} \rfloor) \mathbf{e} \left[ \lceil p_{ij} \rceil \right] + (1 - p_{ij} + \lfloor p_{ij} \rfloor) \mathbf{e} \left[ \lfloor p_{ij} \rfloor \right].
    $$

- attention weights

    $$
    a_{ij}=\text{Softmax}(\mathbf q_i^T(\mathbf k_j+\mathbf e[p_{ij}]))
    $$
  - computing and storing vectors $\mathbf e[p_{ij}]$ uses extra compute and memory.

$$
\begin{align}
z_i[p] &= \mathbf{q}_i^\top \mathbf{e}[p] \quad \text{for } p \in \{0, 1, \ldots, T\} \\
z_i[p_{ij}] &= (p_{ij} - \lfloor p_{ij} \rfloor) z_i \left[ \lceil p_{ij} \rceil \right] + (1 - p_{ij} + \lfloor p_{ij} \rfloor) z_i \left[ \lfloor p_{ij} \rfloor \right] \\
a_{ij} &= \text{Softmax}(\mathbf{q}_i^\top \mathbf{k}_j + z_i[p_{ij}]).
\end{align}
$$

```
class CoPE(nn.Module):
    def __init__(self, npos_max, head_dim):
        super().__init__()
        self.npos_max = npos_max
        self.pos_emb = nn.parameter.Parameter(torch.zeros(1, head_dim, npos_max))

    def forward(self, query, attn_logits):
        # compute positions
        gates = torch.sigmoid(attn_logits)
        pos = gates.flip(-1).cumsum(dim=-1).flip(-1)
        pos = pos.clamp(max=self.npos_max - 1)
        # interpolate from integer positions
        pos_ceil = pos.ceil().long()
        pos_floor = pos.floor().long()
        logits_int = torch.matmul(query, self.pos_emb)
        logits_ceil = logits_int.gather(-1, pos_ceil)
        logits_floor = logits_int.gather(-1, pos_floor)
        w = pos - pos_floor
        return logits_ceil * w + logits_floor * (1 - w)

class SelfAttn(nn.Module):
    def __init__(self, npos_max, head_dim):
        super().__init__()
        self.cope = CoPE(npos_max, head_dim)
        self.head_dim = head_dim

    def forward(self, query, key, val, mask):
        # q, k, v have dimensions batch x seq_len x head_dim
        attn_logits = torch.bmm(query, key.transpose(-1, -2))
        attn_logits = attn_logits / math.sqrt(self.head_dim)
        attn_logits += mask.log()
        attn_logits += self.cope(query, attn_logits)
        attn = torch.softmax(attn_logits, dim=-1)
        out = torch.bmm(attn, val)
        return out
```