# The Transformer
In this lab scenario, you will implement causal attention for a Transformer Decoder model.
Transformer architecture was introduced in the [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
and has dominated the field of language modeling.  
Here we will go through different parts of the transformer architecture and explain each of them briefly.


## Transformer Overview

### The input
Transformer decoder models (such as LLaMa 3.1 and Mistral) are popular text-processing models.   
One can distinguish two versions of such models: base and instruction-tuned.  
The base models are usually transformers trained on predicting the continuation of a given text (for each prefix they output a probability distribution over the next text fragments).  
In contrast, the instruction-tuned ones are base models that were additionally trained to follow instructions.  
The text is presented to the transformer as a sequence of tokens.   
Tokens are integers used to represent pieces of text.  
To be more precise to convert text to tokens we first prepare a dictionary of common text fragments.   
We usually want to have all possible letters in the dictionary so that all texts can be tokenized.   
We then assign to each text piece from the dictionary an integer and use the dictionary to convert text into a sequence of tokens (integers).  
The program that converts text into tokens is called a tokenizer.  

In this lab scenario, we will use OpenLLaMAv2 tokenizer and HuggingFace library to tokenize text.   
HuggingFace contains a vast collection of transformer model weights and implementations along with training and inference code.  

In [1]:
!pip3 install transformers==4.47.0
!pip3 install sentencepiece
!pip3 install accelerate

Collecting transformers==4.47.0
  Downloading transformers-4.47.0-py3-none-any.whl.metadata (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.5/43.5 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.22,>=0.21 (from transformers==4.47.0)
  Downloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.47.0-py3-none-any.whl (10.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m55.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m58.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
   

In [2]:
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from tqdm import tqdm

tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_3b_v2")


text = "This is an example text that we will tokenize"
tokens_mask = tokenizer(text)
print(tokens_mask)

detokenized = tokenizer.batch_decode(tokens_mask["input_ids"])
print(detokenized)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/593 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/512k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/330 [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message


{'input_ids': [1, 660, 325, 371, 1938, 1880, 347, 389, 477, 8206, 753], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
['<s>', 'This', 'is', 'an', 'example', 'text', 'that', 'we', 'will', 'token', 'ize']


After tokenization the the HuggingFace tokenizer returns a sequence of tokens (`input_ids`) and information on whether the model should look at the ith element of the input (`attention_mask`).  
The other part is useful when we want to tokenize several sequences into one batch of elements of the same length. Then the attention mask can be used to hide the padding from the model.  
Consider the example below. Note how the second text is padded to match the length of the first one.

In [3]:
text = ["This is an example text that we will tokenize", "Hello"]
# We set the padding token to be the same as the end-of-sequence token (eos)
# eos token (</s> in this case) is used to mark the end of the sequence in training and can also be used by a model to finish its response
# bos token (here <s>) is used to mark the beginning of the input
tokenizer.pad_token = tokenizer.eos_token

tokens_mask = tokenizer(text, return_tensors="pt", padding=True, truncation=False)
print(tokens_mask)

detokenized = tokenizer.batch_decode(tokens_mask["input_ids"])
print(detokenized)

{'input_ids': tensor([[   1,  660,  325,  371, 1938, 1880,  347,  389,  477, 8206,  753],
        [   1, 8479,    2,    2,    2,    2,    2,    2,    2,    2,    2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}
['<s>This is an example text that we will tokenize', '<s> Hello</s></s></s></s></s></s></s></s></s>']


### The encoding
The input to the model is a batch of token sequences of the following shape   
`(batch, seq_len)`
where
* `batch` is the size of the batch
* `seq_len` is the length of the longest input sequence inside the batch (attention mask is used to handle the cases when sequences have different lengths)

Initially, the model assigns to each element of each sequence an embedding vector.  
To be more precise inside the model there is a matrix of shape `(num_dictionary_elements, hidden_size)` that is used to assign to each token from the dictionary a vector of length `hidden_size`.  
After the encoding step, we pass a tensor of shape `(batch, seq_len, hidden_size)` through the remaining layers of the model.

### Transformer layer
The internal parts of the transformer are grouped into transformer layers.  
Usually, each layer consists of layer norm, attention, layer norm, and a feed-forward layer.  
To be more precise the computation progresses roughly as presented below:
```python3
def transformer_layer(input, layer_norm_attn, attention, layer_norm_ff, feed_forward):
 x = attention(layer_norm_attn(x)) + x
 x = feed_forward(layer_norm_ff(x)) + x
 return x
```
Where:  
* `feed_forward` - This can be a linear projection followed by activation and another linear projection. For an input of shape `(batch, seq_len, hidden_size)` it treats the first two dimensions as batch and operates on the `hidden_size` dimension.
* `layer_norm` - Replaced by [RMSNorm](https://pytorch.org/docs/stable/generated/torch.nn.modules.normalization.RMSNorm.html) in LLaMa models. Similarly as `feed_forward` it operates only on the `hidden_size` dimension, treating other dimensions as a batch.
* `attention` - Causal multi-head attention that you will implement in further parts of this notebook. Let `t` be an input tensor of shape `(batch, seq_len, hidden_size)`. Attention will output a tensor `d` of the same shape with the following property:
 Calculation of `d[b, s, h]` depends only on values from `t[b, s', h']` such that  `s' <= s`. In other words, calculation is done independently per batch entry and dependency is causal (the past can influence the future but the future cannot influence the past).

### LM head
In the end, a linear projection is used to create weights for each element of the input dictionary.
To be more precise we take a tensor of shape `(batch, seq_len, hidden_dim)` and use norm + a linear projection from `hidden_dim` to `vocab_size`, in order to change it into tensor of shape `(batch, seq_len, vocab_size)`.  
Then we apply softmax over the last dimension (`vocab_size`) to get probability distribution over the next token in the sequence given the previous tokens.  
We can do this as all operations in our model were either done independently for each element (`layer_norm`, `feed_forward`, ...) or were causal (`attention`).  
The training loss of our model will be cross entropy over the next token prediction.  
That is we input a batch of token sequences into our model, the model outputs for each input token the probability distribution over the next token, and as we know the next token we use it as a ground truth label for the calculation of the model loss.

### Example
Below we show the steps described above using OpenLLaMAv2 3B.

In [4]:
## Input tokenization

tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_3b_v2")
text = ["2 + 7 = "]

tokens_mask = tokenizer(text, return_tensors="pt")
tokens = tokens_mask["input_ids"]
attention_mask = tokens_mask["attention_mask"]
print(tokens_mask)

{'input_ids': tensor([[    1, 29500, 29536,   835, 29500, 29574,   419, 29500]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}


In [5]:
## load model from huggingface

device = (
    torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)

# takes around 6.85GB in bf16
model = LlamaForCausalLM.from_pretrained(
    "openlm-research/open_llama_3b_v2", torch_dtype=torch.bfloat16, device_map=device
)

# we disable gradient calculatoin as otherwise the memory usage can explode
for p in model.parameters():
    p.requires_grad = False

config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/6.85G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [6]:
## input encoding
embedded_tokens = model.model.embed_tokens(tokens.to(device))
print(f"tokens.shape {tokens.shape} embedded_tokens.shape {embedded_tokens.shape}")

tokens.shape torch.Size([1, 8]) embedded_tokens.shape torch.Size([1, 8, 3200])


In [7]:
## passing through the layers of the model

hidden_states = embedded_tokens
batch, seq_len, hidden_size = hidden_states.shape

# additional tensor to tell the model positions of each token
position_ids = torch.arange(seq_len, device=hidden_states.device)[None, ...]

# mask used to make attention causal
causal_mask = model.model._update_causal_mask(
    attention_mask, hidden_states, position_ids, None, False
)

# additional encoding of positions within the sequence, used by attention
position_embeddings = model.model.rotary_emb(hidden_states, position_ids)


for l in tqdm(model.model.layers):
    hidden_states = l(
        hidden_states,
        attention_mask=causal_mask,
        position_ids=position_ids,
        past_key_value=None,  # can be used to continue generation
        output_attentions=False,
        use_cache=False,
        cache_position=position_ids,
        position_embeddings=position_embeddings,
    )[0]

# apply norm before final linear
hidden_states = model.model.norm(hidden_states)
hidden_states = model.lm_head(hidden_states)
hidden_states = torch.nn.functional.softmax(hidden_states, dim=-1)
next_token = torch.argmax(hidden_states[0, -1])
print(next_token)
print(tokenizer.decode(next_token))

100%|██████████| 26/26 [00:00<00:00, 43.87it/s]


tensor(29567, device='cuda:0')
9


In [8]:
# Using HuggingFace Generate


text = "The largest animal on earth is"
tokens_mask = tokenizer(text, return_tensors="pt")
output = model.generate(
    inputs=tokens_mask["input_ids"].to(device),
    max_new_tokens=8,
    num_beams=1,
    do_sample=True, # sample from the distribution created by softmax
    temperature=0.7, # divide pre softmax score by this value
    top_p=0.9 # cut out improbable tokens from sampling
)

print(tokenizer.batch_decode(output))


['<s>The largest animal on earth is the elephant. The average elephant is ']


## Causal Attention Implementation

Your task is to finish the implementation of the attention mechanism below. In case of problems, you can refer to the original implementation that can be found [here](https://github.com/huggingface/transformers/blob/7f95372c6267d3163fd2aa74aeff9d84ddb6cc35/src/transformers/models/llama/modeling_llama.py#L258).
To be more precise. You are given query and key tensors with positional encoding already applied. You also get the value tensors.
Each of those tensors is of shape `(batch, seq_len, num_heads, head_size)`.  
Your task is to compute for each head a scaled dot product between each query and each key that is either at the same position as the query or precedes the query in the sequence.
To be more precise you want to calculate a tensor `a` of shape `(batch, num_heads, seq_len, seq_len)` where  
$$
    a[b, h, q, k]=
\begin{cases}
    \sum_{d}{\mathrm{query}[b, q, h, d] * \mathrm{key}[b, k, h, d]} / \sqrt{\mathrm{head\_size}}, & if k \leq q\\
     -\mathrm{large\_number},              & \text{otherwise}
\end{cases}
$$

Then you should calculate the softmax over the last dimension of `a` creating `p`.  

$$p = \mathrm{SoftMax}(a)$$
Then you should calculate
$$v[b, q, h, d] = \sum_{k}{a[b, h, q, k] * \mathrm{value}[b, k, h, d]}$$  
That is for each query you should gather the `value`s using the probability distribution defined by `p`.  
In the end, you should reshape `v` to `(batch, seq_len, num_heads * head_size)` and apply a linear projection `output_projection`.  

As you do not get the attention mask you can assume that it consists of ones only and that the attention is causal.
For simplicity, you can also assume that the number of queries is equal to the number of keys.  
This is not always true, for example when we run the generate from HuggingFace transformers library, then instead of computing the whole attention each time, the keys for previous tokens are cached and we create queries only for new tokens.



In [9]:
def attention_forward(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output_projection: torch.nn.Linear,
) -> torch.Tensor:
    batch, q_seq_len, num_heads, head_dim = query.shape
    batch, k_seq_len, num_heads, head_dim = key.shape

    assert value.shape == key.shape

    assert q_seq_len <= k_seq_len
    assert query.shape[0] == key.shape[0]
    assert query.shape[2:] == key.shape[2:]

    # TODO {
    # Einsum
    score = torch.einsum('blhd,bkhd->bhlk', query, key) / (head_dim ** 0.5)
    mask = torch.triu(torch.ones(q_seq_len, k_seq_len, device = device), diagonal=k_seq_len - q_seq_len).bool()
    A_mask = score.masked_fill_(mask, float("-inf"))
    A = torch.nn.functional.softmax(A_mask, dim=-1)
    output = torch.einsum('bhlk,bkhd -> blhd', A, value)
    real_output = output.reshape(batch, q_seq_len, num_heads * head_dim)
    v = output_projection(real_output)


    #a_scores = torch.einsum('blhd,bkhd->bhlk', query, key) / (head_dim ** 0.5)


    # TODO }
    assert v.shape == (batch, q_seq_len, num_heads * head_dim)
    return v

### Integration with OpenLLaMA
The code below integrades your solution from above with OpenLLaMA.

In [10]:
from typing import Optional, Tuple


# Copied from  https://github.com/huggingface/transformers/blob/7f95372c6267d3163fd2aa74aeff9d84ddb6cc35/src/transformers/models/llama/modeling_llama.py
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from  https://github.com/huggingface/transformers/blob/7f95372c6267d3163fd2aa74aeff9d84ddb6cc35/src/transformers/models/llama/modeling_llama.py
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


# modified version of https://github.com/huggingface/transformers/blob/7f95372c6267d3163fd2aa74aeff9d84ddb6cc35/src/transformers/models/llama/modeling_llama.py
def custom_attention_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value=None,
    output_attentions: bool = False,
    use_cache: bool = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[
        Tuple[torch.Tensor, torch.Tensor]
    ] = None,  # will become mandatory in v4.46
    **kwargs,
):
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
    '''
    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
    '''

    if position_embeddings is None:
        cos, sin = self.rotary_emb(value_states, position_ids)
    else:
        cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_value is not None:
        # sin and cos are specific to RoPE models; cache_position needed for the static cache
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        key_states, value_states = past_key_value.update(
            key_states, value_states, self.layer_idx, cache_kwargs
        )

    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    # this is not memory optimal, can you tell why
    query_states = query_states.transpose(1, 2)
    key_states = key_states.transpose(1, 2)
    value_states = value_states.transpose(1, 2)

    attn_output = attention_forward(
        query=query_states,
        key=key_states,
        value=value_states,
        output_projection=self.o_proj,
    )

    return attn_output, None, past_key_value

### Testing
You can briefly test your solution below.

In [11]:
from functools import partial

model = LlamaForCausalLM.from_pretrained(
    "openlm-research/open_llama_3b_v2", torch_dtype=torch.bfloat16, device_map=device
)


for p in model.parameters():
    p.requires_grad = False

for l in model.model.layers:
    l.self_attn.forward = partial(custom_attention_forward, self=l.self_attn)


text = ["2 + 7 = "]

tokens_mask = tokenizer(text, return_tensors="pt")
tokens = tokens_mask["input_ids"]
attention_mask = tokens_mask["attention_mask"]


output = model(input_ids=tokens.to(device))
next_token = torch.argmax(output.logits[0, -1])
print(next_token)
decoded = tokenizer.decode(next_token)
print(f"Model answer: {decoded}")
assert decoded == "9"

tensor(0, device='cuda:0')
Model answer: <unk>


AssertionError: 

In [None]:
## If you have implemented the attention that can handle token by token generaion you can check your solution using the code below

text = "Solve x + 3 = 7"
tokens_mask = tokenizer(text, return_tensors="pt")
output = model.generate(
    inputs=tokens_mask["input_ids"].to(device),
    max_new_tokens=8,
    num_beams=1,
    do_sample=False,
)

print(tokenizer.batch_decode(output))