# RoFormer: Enhanced Transformer with Rotary Position Embedding.

Su, Jianlin, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, and Yunfeng Liu. "RoFormer: Enhanced Transformer with Rotary Position Embedding." arXiv:2104.09864. Preprint, arXiv, November 8, 2023. https://doi.org/10.48550/arXiv.2104.09864.


In [None]:
!pip install torchinfo

In [96]:
from typing import Optional

import math
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F

from torchinfo import summary

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
torch.manual_seed(42)

In [12]:
@dataclass
class ModelArgs:
    block_size: int = 1024              # seq_len
    vocab_size: int = 32_000
    hidden_size: int = 768              # embedding_size
    type_vocab_size: int = 2
    intermediate_size: int = 3072       # 4 x hidden size
    num_attention_heads: int = 8
    n_layers: int = 4
    n_heads: int = 4
    dim: int = 1024
    head_dim: int = 64
    rope_base: float = 100_0000.0
    norm_eps: float = 1e-5
    dropout: float = 0.1
    attn_dropout: float = 0.1
    layer_norm_eps: float = 1e-12

In [13]:
config = ModelArgs()

In [14]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.c_fc = nn.Linear(config.hidden_size, config.intermediate_size)
        self.c_proj = nn.Linear(config.intermediate_size, config.hidden_size)
        self.act = nn.GELU()

        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        hidden_states = x
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        x = self.layer_norm(x + hidden_states)
        return x

A (szög)frekvenciák listája:

$$
\omega_k =\frac{1}{\operatorname{base}^{2 k / d}}, \quad k=0,1, \ldots, d / 2-1
$$


A pozíciók $t$ és a frekvenciák $\omega_k$ külső szorzata adja a szögeket:

$$
\theta_{t, k} = t \otimes \omega_k  = t \cdot \omega_k^T =
\left[
    \begin{array}{cccc}
        t_1 \omega_{k1} & t_1 \omega_{k2} & \ldots &  t_1 \omega_{kn} \\
        t_2 \omega_{k1} & t_2 \omega_{k2} & \ldots & t_2 \omega_{kn} \\
        \vdots & \vdots & \ddots & \vdots \\
        t_m \omega_{k1} & t_m \omega_{k2} & \ldots & t_m \omega_{kn}
    \end{array}
\right] \quad \in \mathbb{R}^{m \times(d / 2)}
$$

ahol $t \in \mathbb{R}^m $ és $ \omega \in \mathbb{R}^{d / 2} $.

Ezek lesznek a szögek (radiánban), amivel a $ ( x_{2k}, x_{2k+1} ) $ dimenziópárt később elforgatjuk. Azaz a mátrix minden sora egy pozíció $t$, minden oszlopa egy dimenziópár $k$ szöge radiánban.

A komplex szám trigonometrikus alakja:
$$
z=|z|(\cos \theta + i \sin \theta)
$$

ahol $|z|$ a komplex szám abszolút értéke (hossza) és ha $z$ az egységsugarú körön van, akkor $|z|=1$. Tehát:

$$z=\cos \theta+i \sin \theta$$

Egy komplex számokból (trigonometrikus alak) álló tenzor kerül létrehozásra, amielynek valós része a $\cos \theta$ és imaginárius része a $\sin \theta$. A gyakorlatban a számitáshoz ezt két valós komponensre bontjuk, hogy valós műveletekkel épitsük fel a forgatást."

In [15]:
def precompute_freqs_cis(
        seq_len: int,
        head_dim: int,
        base: int = 10_000
) -> torch.Tensor:

    freqs = 1.0 / (base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim)) # (d // 2)

    t = torch.arange(seq_len, device=freqs.device) # [0, 1, 2, ..., seq_len-1], (seql_len)

    freqs = torch.outer(t, freqs) # (seq_len, d // 2)

    # x' = x * cos(angle) + x * sin(angle) * i, where i is imaginary number
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # (seq_len, d // 2)

    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (seq_len, d // 2, 2)

    return cache.to(dtype=torch.bfloat16)

A rotációs mátrix ($R$), olyan transzformációs mátrix, amely euklideszi térben forgatás végrehajtására szolgál.

$$
R=\left[\begin{array}{cc}
\cos \theta & -\sin \theta \\
\sin \theta & \cos \theta
\end{array}\right]
$$

In [16]:
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> torch.Tensor:
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
    x_out = torch.stack(
        [
            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
        ],
        -1,
    )
    x_out = x_out.flatten(3)
    return x_out.type_as(x)

In [81]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        assert config.hidden_size % config.n_heads == 0, "hidden size is indivisible by n_heads (number of attention heads)"

        self.n_heads = config.n_heads                               # h
        self.head_dim = config.hidden_size // config.n_heads         # d_h
        self.all_head_size = config.n_heads * self.head_dim

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attn_dropout)

        self.w_o = nn.Linear(config.hidden_size, config.hidden_size) # (d, d)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:

        batch_size, seq_len, hidden_size = x.shape # (b,t,d)

        q = self.query(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        k = self.key(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        v = self.value(x).view(batch_size, seq_len, self.n_heads, self.head_dim)

        q = apply_rotary_emb(q, freqs_cis)
        k = apply_rotary_emb(k, freqs_cis)

        # (b, h, t, d_h) <- (b, t, h, d_h)
        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

        # (b, h, t, t) = (b, h, t, d_h) @ (b, h, d_h, t)
        attn_scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask

        attention_probs = F.softmax(attn_scores, dim=-1)

        attention_probs = self.dropout(attention_probs)

        # (b, h, t, d_v) = (b, h, t, t) @ (b, h, t, d_v)
        # (b, t, h, d_v) <- (b, h, t, d_v)
        context = torch.matmul(attention_probs, v).transpose(1, 2).contiguous()

        context = context.view(batch_size, seq_len, hidden_size)

        return self.w_o(context)


In [46]:
class EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)
        self.ffn_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
        self.attention_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        h = x + self.attention(self.attention_norm(x), freqs_cis, attention_mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

In [122]:
class RoFormer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([EncoderLayer(config) for i in range(config.n_layers)])
        self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.freqs_cis = precompute_freqs_cis(config.block_size, config.hidden_size // config.n_heads, config.rope_base)

    def forward(
            self,
            x: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            input_pos: Optional[torch.Tensor] = None # (0, seq_len-1)
    ) -> torch.Tensor:

        batch_size, seq_length = x.size()
        freqs_cis = self.freqs_cis[input_pos]

        if attention_mask is None:
            attention_mask = torch.ones(((batch_size, seq_length)), device=x.device)
        attention_mask = self.get_extended_attention_mask(attention_mask, x.shape)

        x = self.token_embeddings(x)

        for i, layer in enumerate(self.layers):
            x = layer(x, freqs_cis, attention_mask)
        logits = self.output(x)
        return logits

    def get_extended_attention_mask(self, attention_mask: torch.Tensor, input_shape) -> torch.Tensor:
        # attention_mask is 1.0 for positions we want to attend and 0.0 for masked positions
        # create attention mask of shape (batch_size, 1, 1, seq_len), where 1 will be 0 and
        # 0 will the smallest value for a given dtype (-inf)
        extended_attention_mask = attention_mask[:, None, None, :] #
        extended_attention_mask = extended_attention_mask.to(dtype=attention_mask.dtype)
        extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(attention_mask.dtype).min
        return extended_attention_mask

In [124]:
roformer = RoFormer(config).to(device)

In [125]:
x = torch.randint(0, config.vocab_size, (4, config.block_size), device=device).to(torch.long)

In [126]:
summary(roformer, input_data=x)

Layer (type:depth-idx)                   Output Shape              Param #
RoFormer                                 [4, 1024, 32000]          --
├─Embedding: 1-1                         [4, 1024, 768]            24,576,000
├─ModuleList: 1-2                        --                        --
│    └─EncoderLayer: 2-1                 [4, 1024, 768]            --
│    │    └─LayerNorm: 3-1               [4, 1024, 768]            1,536
│    │    └─MultiHeadAttention: 3-2      [4, 1024, 768]            2,362,368
│    │    └─LayerNorm: 3-3               [4, 1024, 768]            1,536
│    │    └─FeedForward: 3-4             [4, 1024, 768]            4,723,968
│    └─EncoderLayer: 2-2                 [4, 1024, 768]            --
│    │    └─LayerNorm: 3-5               [4, 1024, 768]            1,536
│    │    └─MultiHeadAttention: 3-6      [4, 1024, 768]            2,362,368
│    │    └─LayerNorm: 3-7               [4, 1024, 768]            1,536
│    │    └─FeedForward: 3-8             [4,