In [None]:
# 🔗 Link to Github - https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L64

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

## Rotary Embeddings (RopE) is one of the Fundamental Building Blocks of LlaMA-2 Implementation 🚀

Let's go through this nice and simple implementation of this finest of Algorithm in LlaMa's Source code - the `apply_rotary_emb()` 🔥

First, briefly understand the concept of Rotary Embedding.

👉 Rotary Embedding (Rotary Positional Encoding) is a kind of positional encoding introduced in the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding". Unlike traditional positional encodings, Rotary Embedding tries to encode the relative positions of tokens rather than their absolute positions. This is achieved by applying a rotation operation in the embedding space.

-------

**Key steps of RopE implementation in LlaMa-2:**

👉 The tensors `xq` and `xk` (representing queries and keys) are reshaped to facilitate the creation of complex numbers. The last dimension of each tensor (representing the embedding dimension) is reshaped into two dimensions, with each pair of values treated as a real and imaginary part of a complex number.

---------

👉 Workings of `view_as_complex`

👉 The rotation operation here relies on the application of complex numbers due to their intrinsic rotational properties. In the complex plane, multiplying two complex numbers results in adding their angles, which effectively rotates one complex number by the angle of the other.

👉 So the function transforms the last dimension of the input tensors `xq` (query) and `xk` (key) into a complex representation. These tensors initially have real values, with the last dimension corresponding to the dimensions of the vector embeddings.

In this line of code, `xq.float().reshape(*xq.shape[:-1], -1, 2)` reshapes the last dimension of `xq` into two dimensions. Each `(real, imaginary)` pair in the last two dimensions represents one complex number.

👉 Then, `torch.view_as_complex` takes these pairs of real numbers and treats them as complex numbers, where the first element of each pair is the real part and the second is the imaginary part.

-----------

👉 `reshape_for_broadcast()` is a function that reshapes the frequency embeddings `freqs_cis` to match the dimensions of the input tensors `xq` and `xk`. This ensures that the rotation operation (which is an element-wise multiplication) can be performed.

👉 Here, the rotation operation is applied by performing an element-wise multiplication of the input tensors (`xq` and `xk`) and the frequency embeddings (`freqs_cis`). This operation effectively encodes the relative positional information into the embeddings.

👉 Then, the result is reshaped back into real numbers (from complex numbers) and flattened to match the original shape of `xq` and `xk`.

--------

👉 Finally, the function returns the rotated embeddings in the original data type of the input tensors.

🔗 Link to SourceCode in Llama Github - https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L64