# LLMs for everyone

<img src="https://www.marktechpost.com/wp-content/uploads/2023/05/Blog-Banner-3.jpg" width="60%" />

<a href="https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2023/blob/main/practicals/large_language_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

© Deep Learning Indaba 2023. Apache License 2.0.

**Authors: Ruan van der Merwe, Marianne Monteiro, Everlyn Asiko Chimoto**

**Reviewers:Natasha Latysheva, Amrit Purshotam, Tom Makkink**

**Introduction:**

Welcome to "LLMs for Everyone," a practical exploration into the captivating world of Large Language Models (LLMs)! This entire introductory block of text was crafted solely by ChatGPT, showcasing the remarkable capabilities of these models. Throughout this tutorial, we will delve into the underlying fundamentals of transformers, the powerful technology that drives models like GPT, and learn how to fine-tune and train our very own Language Models. Let's embark on this exciting journey of understanding and creating LLMs, and discover how such impressive AI text generation is made possible! 🚀📚

**Topics:**

Content: [<font color='green'>Attention Mechanism</font>, <font color='green'>Transformer Architecture</font>, <font color='blue'>LoRA</font>]

Level: <font color='orange'>Beginner</font>, <font color='green'>Intermediate</font>, <font color='blue'>Advanced</font>

**Aims/Learning Objectives:**

* Understand the idea behind [Attention](https://arxiv.org/abs/1706.03762) and why it is used.
* Present and describe the fundamental building blocks of the [Transformer Architecture](https://arxiv.org/abs/1706.03762) along with an intuition on such an architecture design.
* Present and explain intuition behind [LoRA](https://arxiv.org/abs/2106.09685) along with a simplified demo of how to finetune a LLM using LoRA on a single GPU using [Hugging Face](https://huggingface.co)

**Prerequisites:**

* Introductory knowledge of Deep Learning.
* Introductory knowledge of NLP.
* Introductory knowledge of sequence to sequence models.
* Linear algebra basic understanding.

**Outline:**

>[LLMs for everyone](#scrollTo=m2s4kN_QPQVe)

>>[Installations, Imports and Helper Functions](#scrollTo=6EqhIg1odqg0)

>>[Let's kick things off with a Hugging Face Demo! Beginner](#scrollTo=4zu5cg-YG4XU)

>>>[Hugging Face](#scrollTo=AwjIIipOG4fz)

>>>[Time for a demo! ⏰⚡ Load Hugging Face model and run sample](#scrollTo=eq46TV_0G4f0)

>>[1. Attention](#scrollTo=-ZUp8i37dFbU)

>>>[Intuition - Beginner](#scrollTo=ygdi884ugGcu)

>>>[Sequence to sequence attenion mechanisms - Intermediate](#scrollTo=aQfqM1EJyDXI)

>>>[Self-attention to Multihead Attention - Intermediate](#scrollTo=J-MU6rrny8Nj)

>>>>[Self-attention](#scrollTo=0AFUEFZGzCTv)

>>>>>[Queries, keys and values](#scrollTo=pwOIMtdZzdTf)

>>>>>[Scaled dot product attention](#scrollTo=OhGZHFsHz_Qp)

>>>>>[Masked attention](#scrollTo=D7B-AgO80gIt)

>>>>[Multihead Attention - Advanced](#scrollTo=hNHklaSV1Tej)

>>[2. Building your own LLM](#scrollTo=e9NW58_3hAg2)

>>>[2.1 High-level overvierw Beginner](#scrollTo=bA_2coZvhAg3)

>>>[2.2 Tokenization + Positional encoding Beginner](#scrollTo=fbTsk0MdhAhC)

>>>>[2.2.1 Tokenization](#scrollTo=DehUpfym_RF8)

>>>>[2.2.2 Positional encodings](#scrollTo=639s7Zuk_RF9)

>>>>>[Sine and cosine functions](#scrollTo=rklY-aL-_RF9)

>>>[2.3 Transformer block   Intermediate](#scrollTo=SdNPg0pnhAhG)

>>>>[2.3.1 Feed Forward Network (FFN) / Multilayer perceptron (MLP) Beginner](#scrollTo=kTURbfr__RF-)

>>>>[2.3.2 Add and Norm block Beginner](#scrollTo=Sts5Vr4i_RF-)

>>>[2.4 Building the Transformer Decoder / LLM Intermediate](#scrollTo=91dXd29b_RF_)

>>>[2.5 Training your LLM](#scrollTo=wmt3tp38G90A)

>>>>[2.5.1 Training objective Intermediate](#scrollTo=agLIpsoh_RGA)

>>>>[2.5.2 Training models Advanced](#scrollTo=4CSfvGj__RGA)

>>>>[2.5.3 Inspecting the trained LLM Beginner](#scrollTo=pGv9c2AFmF4V)

>>[Efficiently Finetuning LLMs with Hugging Face](#scrollTo=C4hKnTFbHtdM)

>>>[3.1 Adapter and Fine-Tuning methods  Intermediate](#scrollTo=KoTvhvap_RGC)

>>>>[3.1.1 Prefix tuning](#scrollTo=znctvjrE_RGC)

>>>>[3.2.1 Adapter Methods](#scrollTo=U4sLxSol_RGD)

>>>[3.2 LoRA Beginner, Intermediate, Advanced](#scrollTo=MoBc08xY_RGD)

>>>>[3.2.1 LoRA implementation Advanced](#scrollTo=ri1FGEh6_RGE)

>>>>[3.2.3 🤗 Deep dive into LoRA with Hugging Face! 🤗 Beginner](#scrollTo=mpCz5otl_RGE)

>>>>>[Gathering and processing data (optional)](#scrollTo=NwZbrcFY_RGF)

>>>>>[Finetune a model with LoRA](#scrollTo=qUAjpRx3_RGG)

>>>[⏰⚡ Demo Time with our trained model🚀😰](#scrollTo=JkWj5bxd_RGG)

>>[Conclusion](#scrollTo=fV3YG7QOZD-B)

>[Feedback](#scrollTo=o1ndpYE50BpG)




**Before you start:**

For this practical, you will need to use a GPU to speed up training. To do this, go to the "Runtime" menu in Colab, select "Change runtime type" and then in the popup menu, choose "GPU" in the "Hardware accelerator" box.

[Any other tasks just before starting.]

**Suggested experience level in this topic:**

| Level         | Experience                            |
| --- | --- |
`Beginner`      | It is my first time being introduced to this work. |
`Intermediate`  | I have done some basic courses/intros on this topic. |
`Advanced`      | I work in this area/topic daily. |

In [None]:
# @title **Paths to follow:** What is your level of experience in the topics presented in this notebook? (Run Cell)
experience = "advanced" #@param ["beginner", "intermediate", "advanced"]
sections_to_follow=""


if experience == "beginner": sections_to_follow = """we recommend you to not attempt to do every coding task but instead, skip through to every section and ensure you interact with the LoRA finetuned LLM presented in the last section as well as with the pretrained LLM to get a practical understanding of how these models behave"""

elif experience == "intermediate": sections_to_follow = """we recommend you go through every section in this notebook and try the coding tasks tagged as beginner or intermediate. If you get stuck on the code ask a tutor for help or move on to better use the time of the practical"""

elif experience == "advanced": sections_to_follow = """we recommend you go through every section and try every coding task until you get it to work"""


print(f"Based on your experience, {sections_to_follow}.\nNote: this is just a guideline, feel free to explore the colab as you'd like if you feel comfort able!")

## Installations, Imports and Helper Functions

In [None]:
## Install and import anything required. Capture hides the output from the cell.
# @title Install and import required packages. (Run Cell)

!pip install transformers datasets
!pip install seaborn umap-learn
!pip install livelossplot
!pip install -q datasets
!pip install -q transformers[torch]
!pip install accelerate -U
!pip install -q peft

# Python utils
!pip install -q ipdb      # debugging.
!pip install -q colorama  # print colors :).

import os
import math
import urllib.request

# https://stackoverflow.com/questions/68340858/in-google-colab-is-there-a-programing-way-to-check-which-runtime-like-gpu-or-tpu
if os.environ["COLAB_GPU"] and int(os.environ["COLAB_GPU"]) > 0:
    print("a GPU is connected.")
elif "COLAB_TPU_ADDR" in os.environ and os.environ["COLAB_TPU_ADDR"]:
    print("A TPU is connected.")
    import jax.tools.colab_tpu

    jax.tools.colab_tpu.setup_tpu()
else:
    print("Only CPU accelerator is connected.")

# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html#gpu-memory-allocation
# Avoid GPU memory allocation to be done by JAX.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"

import chex
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import optax

import transformers
from transformers import pipeline, AutoTokenizer, AutoModel
import datasets
import peft

from PIL import Image
from livelossplot import PlotLosses

# Utils.
import colorama

import torch
import torchvision

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import itertools
import random

# download images used in notebook
urllib.request.urlretrieve(
    "https://images.unsplash.com/photo-1529778873920-4da4926a72c2?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8Y3V0ZSUyMGNhdHxlbnwwfHwwfHw%3D&w=1000&q=80",
    "cat.png",
)

import copy

import gensim
from nltk.data import find
import nltk

nltk.download("word2vec_sample")

import huggingface_hub
import ipywidgets as widgets
from IPython.display import display

In [None]:
# @title Helper Plotting Functions. (Run Cell)
def plot_position_encodings(P, max_tokens, d_model):
    """Function that takes in a position encoding matrix and plots it."""

    plt.figure(figsize=(20, np.min([8, max_tokens])))
    im = plt.imshow(P, aspect="auto", cmap="Blues_r")
    plt.colorbar(im, cmap="blue")

    if d_model <= 64:
        plt.xticks(range(d_model))
    if max_tokens <= 32:
        plt.yticks(range(max_tokens))
    plt.xlabel("Embedding index")
    plt.ylabel("Position index")
    plt.show()


def plot_image_patches(patches):
    """Function that takes in a list of patches and plots them."""
    axes = []
    fig = plt.figure(figsize=(25, 25))
    for a in range(patches.shape[1]):
        axes.append(fig.add_subplot(1, patches.shape[1], a + 1))
        plt.imshow(patches[0][a])
    fig.tight_layout()
    plt.show()


def plot_projected_embeddings(embeddings, labels):
    """Function that takes in a list of embeddings projects them onto a 2D space and plots them using UMAP."""
    import umap
    import seaborn as sns

    projected_embeddings = umap.UMAP().fit_transform(embeddings)

    plt.figure(figsize=(15, 8))
    plt.title("Projected text embeddings")
    sns.scatterplot(
        x=projected_embeddings[:, 0], y=projected_embeddings[:, 1], hue=labels
    )
    plt.show()


def plot_attention_weight_matrix(weight_matrix, x_ticks, y_ticks):
    """Function that takes in a weight matrix and plots it with custom axis ticks"""
    plt.figure(figsize=(15, 7))
    ax = sns.heatmap(weight_matrix, cmap="Blues")
    plt.xticks(np.arange(weight_matrix.shape[1]) + 0.5, x_ticks)
    plt.yticks(np.arange(weight_matrix.shape[0]) + 0.5, y_ticks)
    plt.title("Attention matrix")
    plt.xlabel("Attention score")
    plt.show()

In [None]:
# @title Helper Text Processing Functions. (Run Cell)

def get_word2vec_embedding(words):
    """
    Function that takes in a list of words and returns a list of their embeddings,
    based on a pretrained word2vec encoder.
    """
    word2vec_sample = str(find("models/word2vec_sample/pruned.word2vec.txt"))
    model = gensim.models.KeyedVectors.load_word2vec_format(
        word2vec_sample, binary=False
    )

    output = []
    words_pass = []
    for word in words:
        try:
            output.append(jnp.array(model.word_vec(word)))
            words_pass.append(word)
        except:
            pass

    embeddings = jnp.array(output)
    del model  # free up space again
    return embeddings, words_pass


def remove_punctuation(text):
    """Function that takes in a string and removes all punctuation."""
    import re

    text = re.sub(r"[^\w\s]", "", text)
    return text

def print_sample(prompt: str, sample: str):
  print(colorama.Fore.MAGENTA + prompt, end="")
  print(colorama.Fore.BLUE + sample)
  print(colorama.Fore.RESET)

## Let's kick things off with a Hugging Face Demo! <font color='orange'>Beginner</font>

### Hugging Face


<img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png" width="10%">


[Hugging Face](https://huggingface.co/) is a startup founded in 2016 and, in their own words: "are on a mission to democratize good machine learning, one commit at a time." Currently they are a treasure trove for tools to work on and with Large Language Model (LLMs).

They have developed various open-source packages and allow users to easily interact with a large corpus of pretrained transformer models (across all modalities) and datasets to train or fine-tune pre-trained transformers. Their software is used widely in industry and research. For more details on them and usage, refer to [last years transformer practical](https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/attention_and_transformers.ipynb#scrollTo=qFBw8kRx-4Mk).


In this colab we print prompts in <font color='HotPink'><b>pink</b></font> and samples generated from a model in <font color='blue'><b>blue</b></font>  like in the example below:

In [None]:
print_sample(prompt='My fake prompt', sample=' is awesome!')

### Time for a demo! ⏰⚡ Load Hugging Face model and run sample

Here we show how easy is to sample from a model loaded from Hugging Face!

In this colab we pre-configured two options for models:
* `gpt-neo-125M`: 125M parameters (faster and uses less memory! We recommend trying this one out first! If you'd like to try `gpt2-medium` restart the colab kernel and change the model name in the cell below).
* `gpt2-medium`: 355M parameters

**Note**: What we do below can even work on models such the Llama 13 billion parameter model!

In [None]:
model_name = "EleutherAI/gpt-neo-125M" # @param ["gpt2-medium", "EleutherAI/gpt-neo-125M"]

test_prompt = 'What is love?' # @param {type: "string"}
generator = transformers.pipeline('text-generation', model=model_name)
generator(test_prompt, do_sample=True, min_length=20)

**Tip:** Try running the code above with different prompts of with the same prompt but more than once!

**Discussion:** Why do you think the generated text change every time even with the same prompt?

Let's implement our own `generator` like function to make it easier to load different weights for the model and configure how generation is being done. Just run the cells below 😀!

In [None]:
if 'gpt2' in model_name:
  tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
  model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
elif model_name == "EleutherAI/gpt-neo-125M":
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
  model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
else:
  raise NotImplementedError

if torch.cuda.is_available():
  model = model.to("cuda")

tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
def run_sample(
    model,
    tokenizer,
    prompt: str,
    seed: int | None = None,
    temperature: float = 0.6,
    top_p: float = 0.9,
    max_new_tokens: int = 64,
) -> str:
    inputs = tokenizer(prompt, return_tensors="pt")

    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    input_ids = input_ids.to(model.device)
    attention_mask = attention_mask.to(model.device)

    generation_config = transformers.GenerationConfig(
      do_sample=True,
      temperature=temperature,
      top_p=top_p,
      pad_token_id=tokenizer.pad_token_id,
      top_k=0,
    )

    if seed is not None:
      torch.manual_seed(seed)

    generation_output = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=max_new_tokens,
        generation_config=generation_config,
    )

    # We assume a single sample is returned to make things simpler.
    assert len(generation_output.sequences) == 1
    output_sequence = generation_output.sequences[0]
    output_string = tokenizer.decode(output_sequence)
    response = output_string.split(prompt)[1].rstrip()
    print_sample(prompt, response)
    return response

In [None]:
_ = run_sample(model, tokenizer, prompt="What is love?", seed=2)

By the end of this practical we'll see how we can **finetune** a transformer model in an **efficient way** with a state of the art method called LoRA so it can perform better at a specific task out of the box, a task example is generating lyrics to a song of artists like The Beatles, Michael Jackson or even Tyler the Creator!

But before we do so let's look and build an understanding of what are Large Language Models and what are the fundamental Machine Learning building blocks that make this amazing technology possible! At the core of SoTA (state of the art) Large Language Models is the Attention Mechanism and the Transformer Architecture, let's have a look at these concepts in the next sections of this tutorial.

## **1. Attention**


The attention mechanism is inspired by how humans would look at an image or read a sentence.

Let us take the image of the dog in human clothes below (image and example [source](https://lilianweng.github.io/posts/2018-06-24-attention/)). When paying *attention* to the red blocks of pixels, we will say that the yellow block of pointy ears is something we expected (correlated) but that the grey blocks of human clothes are unexpected for us (uncorrelated). This is *based on what we have seen in the past* when looking at pictures of dogs, specifically one of a Shiba Inu.

<img src="https://drive.google.com/uc?export=view&id=1iEU7Cph2D2PCXp3YEHj30-EndhHAeB5T" alt="drawing" width="450"/>

Assume we want to identify the dog breed in this image. When we look at the red pixels, we tend to pay more *attention* to relevant pixels that are more similar or relevant to them, which could be the ones in the yellow box. We almost completely remove the snow in the background and the human clothing for this task. However, when we begin looking at the background in an attempt to identify what is in it, we will fade out the dog pixels because they are irrelevant to the current task.

The same thing happens when we read. In order to understand the entire sentence, we will learn to correlate and *attend to* certain words based on the context of the entire sentence.

<img src="https://drive.google.com/uc?export=view&id=1j23kcfu_c3wINU6DUvxzMYNmp4alhHc9" alt="drawing" width="350"/>

 For instance, in the first sentence in the image above, when looking at the word "coding", we pay more attention to the word "Apple" and "computer" because we know that when we speak about coding, "Apple" is actually referring to the company. However, in the second sentence, we realise we should not consider " apple " when looking at "code" because given the context of the rest of the sentence, we know that this apple is referring to an actual apple and not a computer.

We can build better models by developing mechanisms that mimic attention. It will enable our models to learn better representations of our input data by contextualising what it knows about some parts of the input based on other parts. In the following sections, we will delve deeper into the mechanisms that enable us to train our deep learning models to attend to input data in the context of other input data.

### Intuition - <font color='orange'>Beginner</font>

Imagine attention as a mechanism that allows a neural network to focus more on certain parts of data. By doing this, the network can enhance its grasp of the problem it's working on, updating its understanding or representations accordingly.

One method for implementing attention involves representing each word (or even parts of a word) using different vectors [1]. These vectors are used to measure similarity, often through a process like calculating the dot product. This similarity becomes the "attention" measure, which then influences the update of our original vector. To put it simply, when two word representations are similar, they're likely relevant to each other. As a result, they impact each other's representations within our neural network.

To illustrate how the dot product can create meaningful attention weights, we'll use pre-trained [word2vec](https://jalammar.github.io/illustrated-word2vec/) embeddings. These word2vec embeddings are generated by a neural network that learned to create similar embeddings for words with similar meanings.

Even though we might not be sequentially processing contextual information, the attention matrix should still indicate which words are correlated and therefore should "attend" to each other.

[1] You can find more details about how this is done for LLMs in the "Building Your Own LLM" session.

**Code task** <font color='blue'>Intermediate</font>: Complete the dot product attention function below.

In [None]:
def dot_product_attention(hidden_states, previous_state):
  """
  Calculate the dot product between the hidden states and previous states.

  Args:
    hidden_states: A tensor with shape [T_hidden, dm]
    previous_state: A tensor with shape [T_previous, dm]
  """

  scores = # FINISH ME
  w_n = # FINISH ME
  c_t = jnp.matmul(w_n, hidden_states)

  return w_n, c_t

In [None]:
# @title Run me to test your code

key = jax.random.PRNGKey(42)
x = jax.random.normal(key, [2, 2])
w_n, c_t = dot_product_attention(x, x)

w_n_correct = jnp.array([[0.9567678, 0.04323225], [0.00121029, 0.99878967]])
c_t_correct = jnp.array([[0.11144122, 0.95290256], [-1.5571996, -1.5321486]])

assert jnp.allclose(w_n_correct, w_n), "w_n is not calculated correctly"
assert jnp.allclose(c_t_correct, c_t), "c_t is not calculated correctly"

print("It seems correct. Look at the answer below to compare methods.")

In [None]:
# when changing these words, note that if the word is not in the original
# training corpus it will not be shown in the weight matrix plot.
# @title Answer to code task (Try not to peek until you've given it a good try!')
def dot_product_attention(hidden_states, previous_state):
  # [T,d]*[d,N] -> [T,N]
  scores = jnp.matmul(previous_state, hidden_states.T)
  w_n = jax.nn.softmax(scores)
  # [T,N]*[N,d] -> [T,d]
  c_t = jnp.matmul(w_n, hidden_states)
  return w_n, c_t

In [None]:
words = ["king", "queen", "royalty", "food", "apple", "pear", "computers"]
word_embeddings, words = get_word2vec_embedding(words)
weights, _ = dot_product_attention(word_embeddings, word_embeddings)
plot_attention_weight_matrix(weights, words, words)

Looking at the matrix,  we can see which words have similar meanings. The "royal" group of words have higher attention scores with each other than the "food" words, which all attend to one another. We also see that "computers" have very low attention scores for all of them, which shows that they are neither very related to "royal" or "food" words.  

**Group task:**
  - Play with the word selections above. See if you can find word combinations whose attention values seem counter-intuitive. Think of possible explanations. Which sense of a word did the attention scores capture?
  - Ask your friend if they found examples.

**Note**: Dot product is only one of the ways to implement the scoring function for attention mechanisms, there is a more extensive list in this [blog](https://lilianweng.github.io/posts/2018-06-24-attention/#summary) post by Dr Lilian Weng.

More resources:

[A basic encoder-decoder model for machine translation](https://www.youtube.com/watch?v=gHk2IWivt_8&list=PLmZlBIcArwhPHmHzyM_cZJQ8_v5paQJTV&index=1)

[Training and loss for encoder-decoder models](https://www.youtube.com/watch?v=aBZUTuT1Izs&list=PLmZlBIcArwhPHmHzyM_cZJQ8_v5paQJTV&index=2)

[Basic attention](https://www.youtube.com/watch?v=BSSoEtv5jvQ&list=PLmZlBIcArwhPHmHzyM_cZJQ8_v5paQJTV&index=6)

### Sequence to sequence attenion mechanisms - <font color='green'>Intermediate</font>




The first attention mechanisms were used in sequence-to-sequence models. These models were usually RNN encoder and decoder structures. The input sequence was processed sequentially by an RNN, encoding the sequence in a single context vector, which is then fed into another RNN that generates a new sequence. Below is an example of this ([source](https://lilianweng.github.io/posts/2018-06-24-attention/)).


<img src="https://drive.google.com/uc?export=view&id=1FKfaArN1rsLjzVWaJGpMLEcxEshSLXd6" alt="drawing" width="600"/>

Due to there only being one context vector, it was often found that for longer input sequences, information gets lost due to the inability of the encoders to remember longer sequences. The attention mechanism introduced in [Bahdanau et al., 2015](https://arxiv.org/pdf/1409.0473.pdf) was proposed to solve this.

Here, instead of relying on one static context vector, which is also only used once in the decoding process, let us provide information on the entire input sequence at every decoding step using a dynamic context vector. By doing this, the decoder can access a larger "bank" of memory and attend to the input's required information based on the current decoder RNN output state, $s_t$. This is shown below.

<img src="https://drive.google.com/uc?export=view&id=1fB5KObXcKo5x35xlIDIcjHTq1q75ejIB" alt="drawing" width="600"/>

In deep learning, attention can be interpreted as a vector of "importance." To predict or infer one element, such as a pixel in an image or a word in a sentence, we estimate how strongly it is correlated with, or "attends to," other elements using the attention vector/weights. These attention weights are then used to generate a new weighted sum of the remaining elements, which represents the target [(source)](https://lilianweng.github.io/posts/2018-06-24-attention/).


This, usually, consists of two steps for each decoding step $t$:

1. Calculate the score (importance) for each $h_n$, given $s_{t-1}$ and generate an attention vector, $w_{n}$.
  - $\text{score} = a(s_{t−1}, h_{n})$, where $a$ can be any differentiable function
  - $w_{n} = \frac{\exp \left\{a\left(s_{t-1}, h_{n}\right)\right\}}{\sum_{j=1}^{N} \exp \left\{a\left(s_{t-1}, h_{j}\right)\right\}}$, where we use the softmax function to generate relative attention weights
2. Generate the final context vector, $c_t$
  - $c_t=\sum_{n=1}^{N} w_n h_{n}$

The final state fed into the RNN to generate $s_{t+1}$, is given below, where $f$ can again be any combination method.

$s_{t+1} = f\left ( c_t, s_t \right)$

In Bahdanau et al., 2015, $f$ was a learned feedforward layer taking in the concatenated vector $[c_t; s_t]$, with $a(s_{t−1}, h_{n})$ being the dot product. Next, let us build up this attention schema.

In dot product attention, the score is given by

$a(s_{t-1}, h_n)=s_{t-1} h_n^\top$

In order to show how the dot product can produce attention weights that make sense, let us use pretrained [word2vec](https://jalammar.github.io/illustrated-word2vec/) embeddings. These word2vec embeddings are generated by an encoder network that was trained to generate similar embeddings for words with similar meanings.

Even though we are not processing something sequentially that contains context, the attention matrix should indicate which words are correlated—and would thus attend to each other.


### Self-attention to Multihead Attention - <font color='blue'>Intermediate</font>



Self-attention and multi-head attention (MHA) are the core building blocks for the transformer architecture. We will build up the intuition and implementation here in detail. Then in the **Transformers** section, you will see how this mechanism is utilised to build an attention only sequence-to-sequence model.


Going forward in this section, we will represent a sentence by splitting it up into a list of words, then using the word2vec model from above to encode each word. In the transformers section, we will dive deeper into how we transform an input into a sequence of vectors.

In [None]:
def embed_sentence(sentence):
    # Embed a sentence using word2vec; for example use cases only.
    sentence = remove_punctuation(sentence)
    words = sentence.split()
    word_vector_sequence, words = get_word2vec_embedding(words)
    return jnp.expand_dims(word_vector_sequence, axis=0), words

#### Self-attention

Self-attention is an attention mechanism where each vector of a given input sequence attends to the entire sequence. To gain an intuition for why self-attention is important, let us think about the following sentence (example taken from [source](https://jalammar.github.io/illustrated-transformer/)):

`"The animal didn't cross the street because it was too tired."`

A simple question about this sentence is what the word "it" refers to? Even though it might look simple, it can be tough for an algorithm to learn this. This is where self-attention comes in, as it can learn an attention matrix for the word "it" where a large weight is assigned to the word "animal".

Self-attention also allows the model to learn how to interpret words with the same embeddings, such as apple, which can be a company or food, depending on the context. This is very similar to the hidden state found within an RNN, but this process, as you will see, allows the model to attend over the entire sequence in parallel, allowing longer sequences to be utilised.

Self-attention consists of three concepts:

- Queries, keys and values
- Scaled dot product attention
- Masks

##### **Queries, keys and values**

Typically all attention mechanisms can be written in terms of `key-value` pairs and `queries` to calculate the attention matrix and new context vector.

To gain intuition, one can interpret the `query` vector as containing the information we are interested in, which is used to determine the `values` we should attend to, based on the similarity between the `keys` (which are paired with the `values`) and the `query`. Thus the similarity between the `queries` and `keys` gives us our attention score, where that score then determines the attention put in conjunction with the `values`. Or as [Lena Voita](https://lena-voita.github.io/nlp_course/seq2seq_and_attention.html) puts it:

- Query: asking for information
- Key: saying that it has some information
- Value: giving the information

In transformer architectures, we use learnable weights matrices, represented as $W_Q,W_K,W_V$, to project each sequence vector to unique $q$, $k$, and $v$ vectors.

<img src="https://drive.google.com/uc?export=view&id=1-96YjPxhcqW6FczUYwErGXHp6YpoLltq" alt="drawing" width="600"/>

You will notice that the vectors $q,k,v$ are smaller in size than the input vectors. This will be covered at a later stage, but just know that it is a design choice for transformers and not required at all to work.

This process can also be parallelised, as the input sequence can be represented as a matrix $X$, which can be transformed into query, key, and value matrices $Q$, $K$, and $V$ respectively:

$Q=W_QX \\ K=W_KX \\ V=W_VX$

Below we show the code that creates three linear layers, which projects the input data to the $Q,K,V$ matrices, where the output size can be adjusted.

In [None]:
class SequenceToQKV(nn.Module):
  output_size: int

  @nn.compact
  def __call__(self, X):
    initializer = nn.initializers.variance_scaling(scale=0.5, mode="fan_in", distribution="truncated_normal")

    # this can also be one layer, how do you think you would do it?
    q_layer = nn.Dense(self.output_size, kernel_init=initializer)
    k_layer = nn.Dense(self.output_size, kernel_init=initializer)
    v_layer = nn.Dense(self.output_size, kernel_init=initializer)

    Q = q_layer(X)
    K = k_layer(X)
    V = v_layer(X)

    return Q, K, V

##### **Scaled dot product attention**


Now that we have our `query`, `key` and `value` matrices, it is time to calculate the attention matrix. Remember, in attention mechanisms; we must first find a score for each sequence vector and then use these scores to create a new context vector. We do this in self-attention using scaled dot product attention with the formula below.

$\operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V$

What happens here is similar to what we did in the dot product attention in the previous section, just applying the mechanism to the sequence itself. For each element in the sequence, we calculate the attention weight matrix between $q_i$ and $K$. We then multiply $V$ by each weight and finally sum all weighted vectors $v_{weighted}$ together to form a new representation for $q_i$. By doing this, we are essentially drowning out irrelevant vectors and bringing up important vectors in the sequence when our focus is on $q_1$.

$QK^\top$ is scaled by the square root of the dimension of the vectors, $\sqrt{d_k}$, to ensure more stable gradients during training.


In [None]:
def scaled_dot_product_attention(query, key, value):
    d_k = key.shape[-1]
    logits = jnp.matmul(query, jnp.swapaxes(key, -2, -1))
    scaled_logits = logits / jnp.sqrt(d_k)
    attention_weights = jax.nn.softmax(scaled_logits, axis=-1)
    value = jnp.matmul(attention_weights, value)
    return value, attention_weights

Let's now see scaled dot product attention in action. We will take a sentence, embed each word using word2vec, and see what the final self-attention weights look like.

We will not use the linear projection layers as they are not trained. Instead, we are going to make $X=Q=V=K$.

In [None]:
sentence = "I drink coke, but eat steak"
word_embeddings, words = embed_sentence(sentence)
Q = K = V = word_embeddings

# calculate weights and plot
values, attention_weights = scaled_dot_product_attention(Q, K, V)
words = remove_punctuation(sentence).split()
plot_attention_weight_matrix(attention_weights[0], words, words)

Keep in mind that we have not trained our attention matrix yet. However, we can see that by utilising the word2vec vectors as our sequence, we can see how scaled dot product attention already is capable of attending to "eat" when "steak" is our query and that the query "drink" attends more to "coke" and "eat".

More resources:

[Attention with Q,K,V](https://www.youtube.com/watch?v=k-5QMalS8bQ&list=PLmZlBIcArwhPHmHzyM_cZJQ8_v5paQJTV&index=7)

##### **Masked attention**

There are cases where applying self-attention over the entire sequence is not practical. These can include:

- Uneven length sequences batched together.
  - When sending a batch of sequences through a network, the self-attention expects each sequence to be the same length. One handles this by padding the sequence. When calculating attention, ideally, these padding tokens should not be taken into consideration.
- Training a decoding model.
  - When training decoder models, such as GPT-3, the decoder has access to the entire target sequence when training (as training is done in parallel). In order to prevent the method from cheating by looking at future tokens, we have to mask the future sequence data so that earlier data can not attend to it.

By applying a mask to the final score calculated between queries and keys, we mitigate the influence of the unwanted sequence vectors. The vectors are masked by making the score between the query and their respective keys a VERY large negative value. This results in the softmax function pushing the attention weight very close to zero, and the resulting value will be summed out and not influence the final representation.


Putting everything together, masked scaled dot product attention visually looks like this:

<img src="https://windmissing.github.io/NLP-important-papers/AIAYN/assets/5.png" alt="drawing" width="200"/>.


In [None]:
# example of building a mask for tokens of size 32
mask = jnp.tril(jnp.ones((32, 32)))
sns.heatmap(mask, cmap="Blues")
plt.title("Example of mask that can be applied");

Lets now adapt our scaled dot product attention function to implement masked attention.

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    d_k = key.shape[-1]
    T_k = key.shape[-2]
    T_q = query.shape[-2]
    logits = jnp.matmul(query, jnp.swapaxes(key, -2, -1))
    scaled_logits = logits / jnp.sqrt(d_k)

    if mask is not None:
        scaled_logits = jnp.where(mask[:T_q, :T_k], scaled_logits, -jnp.inf)

    attention_weights = jax.nn.softmax(scaled_logits, axis=-1)
    attention = jnp.matmul(attention_weights, value)
    return attention, attention_weights

#### Multihead Attention - <font color='blue'>Advanced</font>

Rather than only computing the attention once, the multi-head attention (MHA) mechanism runs through the scaled dot-product attention multiple times in parallel. According to the paper, Attention is all you need, "multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this".

Multi-head attention can be viewed as a similar strategy to stacking convolution kernels in a CNN layer. This allows the kernels to focus on and learn different features and rules, which is why multiple heads of attention also work. The process for MHA is given below.

<img src="https://drive.google.com/uc?export=view&id=1q0Oq6IVEkkMfVSpY4LkHBP866mcoIFsh" alt="drawing" width="1000"/>

As can be seen from the figure, the scaled dot product attention discussed earlier is just repeated $N$ times, with $3N$ learnable matrices for each head. The outputs from the different heads are then concatenated, whereafter it is fed through a linear projection, which produces the final representation.

Due to these large amount of computations and memory requirements, a common design choice is to have the $W_{Qn}, W_{Kn}, W_{Vn}$ matrices produce embeddings of length $d_m/N$, where $d_m$ is the input sequence embedding size and $N$ is the number heads. By doing this, the MHA function is similar computation-wise to using a single head of attention.

**Code Task:** Finish the implementation of MHA below. Hint, swapaxes and reshape is a good place to start.

In [None]:
class MultiHeadAttention(nn.Module):
  num_heads: int
  d_m: int

  def setup(self):
    self.sequence_to_qkv = SequenceToQKV(self.d_m)
    initializer = nn.initializers.variance_scaling(
        scale=0.5, mode="fan_in", distribution="truncated_normal")
    self.Wo = nn.Dense(self.d_m, kernel_init=initializer)

  def __call__(self, X=None, Q=None, K=None, V=None, mask=None, return_weights=False):
    if None in [Q, K, V]:
      assert not X is None, "X has to be provided if either Q,K,V not provided"

      # project all data to Q, K, V
      Q, K, V = self.sequence_to_qkv(X)

    # get the batch size, sequence length and embedding size
    B, T, d_m = K.shape

    # calculate heads embedding size (d_m/N)
    head_size = d_m // self.num_heads

    # B,T,d_m -> B, T, N, dm//N -> B, N, T, dm//N
    q_heads = Q.reshape(B, -1, self.num_heads, head_size).swapaxes(1, 2)
    k_heads = K.reshape(B, -1, self.num_heads, head_size).swapaxes(1, 2)
    v_heads = V.reshape(B, -1, self.num_heads, head_size).swapaxes(1, 2)

    attention, attention_weights = scaled_dot_product_attention(
        q_heads, k_heads, v_heads, mask
    )

    # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, d_m) - re-assemble all head outputs
    attention = # FINISH ME

    # apply Wo
    X_new = self.Wo(attention)

    if return_weights:
      return X_new, attention_weights
    else:
      return X_new

In [None]:
# @title Run me to test your code

mha = MultiHeadAttention(2, 8)
# initialise model
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, [1, 2, 8])
params = mha.init(key, x)
x_new = mha.apply(params, x)

x_correct = jnp.array(
    [
        [
            [
              -0.59349924, -0.79245573,  0.64649045, -0.52850205,
              -0.4793459 , -0.34167248, -0.45467672,  0.8619362
            ],
            [
              -0.7895622 , -0.9945788 ,  0.7638061 , -0.65239996,
              -0.56319916, -0.2351217 , -0.39363512,  0.9993293
            ],
        ]
    ]
)


assert jnp.allclose(x_correct, x_new), "Not returning the correct value"
print(
    "It seems correct. Look at the answer below to compare methods then move to the transformers section."
)

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!')
class MultiHeadAttention(nn.Module):
  num_heads: int
  d_m: int

  def setup(self):
    self.sequence_to_qkv = SequenceToQKV(self.d_m)
    initializer = nn.initializers.variance_scaling(
        scale=0.5, mode="fan_in", distribution="truncated_normal")
    self.Wo = nn.Dense(self.d_m, kernel_init=initializer)

  def __call__(self, X=None, Q=None, K=None, V=None, mask=None, return_weights=False):
    if None in [Q, K, V]:
      assert not X is None, "X has to be provided if either Q,K,V not provided"

      # project all data to Q, K, V
      Q, K, V = self.sequence_to_qkv(X)

    # get the batch size, sequence length and embedding size
    B, T, d_m = K.shape

    # calculate heads embedding size (d_m/N)
    head_size = d_m // self.num_heads

    # B,T,d_m -> B, T, N, dm//N -> B, N, T, dm//N
    q_heads = Q.reshape(B, -1, self.num_heads, head_size).swapaxes(1, 2)
    k_heads = K.reshape(B, -1, self.num_heads, head_size).swapaxes(1, 2)
    v_heads = V.reshape(B, -1, self.num_heads, head_size).swapaxes(1, 2)

    attention, attention_weights = scaled_dot_product_attention(
        q_heads, k_heads, v_heads, mask
    )

    # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, d_m) - re-assemble all head outputs
    attention = attention.swapaxes(1, 2).reshape(B, -1, d_m)

    # apply Wo
    X_new = self.Wo(attention)

    if return_weights:
      return X_new, attention_weights
    else:
      return X_new

Until now, everything covered is not typically used on its own when constructing LLMs. However, they constitute the underlying mechanisms that enable these models to perform at such a high level. By comprehending these mechanisms, you can gain a better understanding of why LLMs may occasionally exhibit peculiar behavior and pinpoint potential starting points for debugging them.


There has also been many optimisations to the MHA structure mentioned above, as the machine learning engineers find more and more ways to optimise this compute heavy step to scale the models. These include Multi-query attention (MQA) or grouped query attention (GQA).




## **2. Building your own LLM**  

### 2.1 High-level overvierw <font color='orange'>Beginner</font>

The Transformer Architecture was famously introduced in the paper entitled [Attention is all you need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) by Vaswani et al.

As the title of the paper suggests, such an architecture consists of basically only attention mechanisms along with feed-forward layers and linear layers, as shown in the diagram below.

<img src="https://machinelearningmastery.com/wp-content/uploads/2021/08/attention_research_1.png" width="350" />

Transformers and its variations are in the core of Large Language Models and it's not an exaggeration to say that almost all language models out there are Transformer based architectures.

As you can see in the diagram the original Transformer architecture consists of two parts, one that receives inputs usually called encoder and another that receives outputs (i.e. targets) called decoder. This is because the transformer was designed for machine translation.

The encoder will receive an input sentence in one language and process it through multiple stacked `encoder blocks`. This creates a final representation, which contains helpful information necessary for the decoding task. This output is then fed into stacked `decoder blocks` that produce new outputs in an autoregressive manner.

The encoder consists of $N$ identical blocks, which process a sequence of token vectors sequentially. These blocks consist of 3 parts:

1. A multi-head attention block. These are the transformer architecture's backbone. They process the data to generate representations for each token, ensuring that the necessary information for the task at hand is represented in the vectors. These are exactly the MHA we covered in the attention section previously.
2. An MLP is applied to each input token separately and identically.
3. Residual connection that adds the input tokens to the attended representations and a residual connection between the input to the MLP and its outputs. For both these connections, the result is normalized using layernorm. In certain implementations, these normalization steps are applied to the inputs rather than the outputs. Just like a Resnet, transformers are designed to be very deep models thus, these add and norm blocks are essential for a smooth gradient flow.  

Similarly, the decoder block consists of $N$ identical blocks, however there is some variation within these block. Concretely, the different parts are:

1. A masked multi-head attention block. This is an MHA block that performs _self-attention_ on the output sequence however this computation is restricted to the inputs that have already been seen. In other words, future tokens are blocked when making predictions.
2. A multi-head attention block. This block receives the output of the final encoder block, the transformed tokens, and uses that as the key-value pairs, while using the output of the first MHA block as the query. In doing this, the model attends over the input required to perform the sequence task. This MHA block thus performs _cross-attention_ by looking at the encoder inputs.
3. An MLP same as the encoder
4. Residual connection same as the encoder.

Given this original architecture, there have been several variation with others focusing on the encoder only and others the **decoder only**. Large language models(LLMs) such as GPT-2, GPT-3 and Turing-NLG were born out of decoder only architectures. These architecture look like:

<img src="https://drive.google.com/uc?export=view&id=1MubUcshJTHCqOPTRHixLhrYYLXX9vP_h" alt="drawing" width="260"/>

with the cross attention block missing as no encoder output is available. So to build a language model, we will focus on the decoder only architecture as seen above.


### 2.2 Tokenization + Positional encoding <font color='orange'>Beginner</font>



#### 2.2.1 Tokenization


Transformers cannot handle raw strings of text. So to process text, the text is first split up into tokens. The tokens are then indexed and each token is assigned an embedding of size $d_{model}$. These embeddings can be learned during training or can come from a pretrained vocabulary of embeddings. This new sequence of token embeddings is then fed into the transformer architecture. This idea is visualised below.

\\

<img src="https://drive.google.com/uc?export=view&id=16euh4LADP_mcXywFwKKY3QQQkVplepiI" alt="drawing" width="450"/>


These token IDs are typically predicted when a model generates text, fills in missing words, etc.

This process of splitting up text into tokens and assigning an ID to each token is called [tokenisation](https://huggingface.co/docs/transformers/tokenizer_summary). There are various ways to tokenise text, with some methods being trained directly from the data. When using pre-trained transformers, it is crucial to use the same tokeniser that was used to train the model. The previous link has in-depth descriptions of many widely known techniques.

Below we show how the [BERT](https://arxiv.org/abs/1810.04805) model's tokeniser tokenises a sentence. We use [Hugging Face](https://huggingface.co/) for this part.


In [None]:
import transformers
from transformers import pipeline, AutoTokenizer, AutoModel

bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
encoded_input = bert_tokenizer("The practical is so much fun")
print(f"Token IDs: {encoded_input['input_ids']}")

Here we can see that the tokeniser returns the IDs for each token, as shown in the figure. But counting the number of IDs, we see that it is larger than the number of words in the sentence. Let's print the tokens associated with each ID.


In [None]:
print(f"Tokens: {bert_tokenizer.decode(encoded_input['input_ids'])}")

We can see the tokeniser attaches new tokens, `[CLS]` and `[SEP]`, to the start and end of the sequence. This is a BERT-specific requirement for training and inference. Adding special tokens is a very common thing to do. Using special tokens, we can tell a model when a sentence starts or ends or when a new part of the input starts. This can be helpful when performing different tasks.

For instance, to pretrain specific transformers, they perform what is known as masked prediction. For this, random tokens in a sequence are replaced by the `[MASK]` token, and the model is trained to predict the correct token ID for the token replaced with that token.

**Drawback of using raw token**:

One drawback of using raw tokens is that they lack any indication of the word's position in the sequence. This is evident when considering sentences like "I am happy" and "Am I happy" - these two phrases have distinct meanings, and the model needs to grasp the word order to understand the intended message accurately.

To address this, when converting the inputs into vectors, position vectors are introduced and added to these vectors to indicate the **position** of each word.


#### 2.2.2 Positional encodings

In most domains where a transformer can be utilised, there is an underlying order to the tokens produced, be it the order of words in a sentence, the location from which patches are taken in an image or even the steps taken in an RL environment. This order is very important in all cases; just imagine you interpret the sentence "I have to read this book." as "I have this book to read.". Both sentences contain the exact same words, yet they have completely different meanings based on the order.

As both the encoder and the decoder blocks process all tokens in parallel, the order of tokens is lost in these calculations. To cope with this, the sequence order has to be injected into the tokens directly. This can be done by adding *positional encodings* to the tokens at the start of the encoder and decoder blocks (though some of the latest techniques add positional information in the attention blocks). An example of how positional encodings alter the tokens is shown below.


\\

<img src="https://drive.google.com/uc?export=view&id=1eSgnVN2hnEsrjdHygDGwk1kxEi8-dcFo" alt="drawing" width="650"/>

Ideally, these encodings should have these characteristics ([source](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/)):
* Each time-step should have a unique value
* The distance between time steps should stay constant.
* The encoding should be able to generalise to longer sequences than seen during training.
* The encoding must be deterministic.

##### **Sine and cosine functions**


In Attention is All you Need, the authors used a method that can satisfy all these requirements. This involves summing a combination of sine and cosine waves at different frequencies, with the formula for a position encoding at position $D$ shown below, where $i$ is the embedding index and $d_m$ is the token embedding size.

\\

$P_{D}= \begin{cases}\sin \left(\frac{D}{10000^{i/d_{m}}}\right), & \text { if } i \bmod 2=0 \\ \cos \left(\frac{D}{10000^{((i-1)/d_{m}}}\right), & \text { otherwise } \end{cases}$

\

Assuming our model as $d_m=8$, the position embedding will look like this:

\
$P_{D}=\left[\begin{array}{c}\sin \left(\frac{D}{10000^{0/8}}\right)\\ \cos \left(\frac{D}{10000^{0/8}}\right)\\ \sin \left(\frac{D}{10000^{2/8}}\right)\\ \cos \left(\frac{D}{10000^{2/8}}\right)\\ \sin \left(\frac{D}{10000^{4/8}}\right)\\ \cos \left(\frac{D}{10000^{4/8}}\right)\\ \sin \left(\frac{D}{10000^{8/8}}\right)\\ \cos \left(\frac{D}{10000^{8/8}}\right)\end{array}\right]$

\\

Let's first create a function that can return these encodings to understand why this will work.

In [None]:
def return_frequency_pe_matrix(token_sequence_length, token_embedding):

  assert token_embedding % 2 == 0, "token_embedding should be divisible by two"

  P = jnp.zeros((token_sequence_length, token_embedding))
  positions = jnp.arange(0, token_sequence_length)[:, jnp.newaxis]

  i = jnp.arange(0, token_embedding, 2)
  frequency_steps = jnp.exp(i * (-math.log(10000.0) / token_embedding))
  frequencies = positions * frequency_steps

  P = P.at[:, 0::2].set(jnp.sin(frequencies))
  P = P.at[:, 1::2].set(jnp.cos(frequencies))

  return P

In [None]:
token_sequence_length = 50  # Number of tokens the model will need to process
token_embedding = 10000  # token embedding (and positional encoding) dimensions, ensure it is divisible by two
P = return_frequency_pe_matrix(token_sequence_length, token_embedding)
plot_position_encodings(P, token_sequence_length, token_embedding)

Looking at the graph above, we can see that for each position index, there is a unique pattern forming, where each position index will always have the same encoding.

**Group task**:

- Discuss with your friend why we are seeing that specific pattern when `token_sequence_length` is 1000, and `token_embedding` is 768.
- You can try playing around with smaller values for `token_sequence_length` and  `token_embedding` to get a better intuition for the above discussion.
- Ask your friend why they think the 10000 constant is used in the functions above.
- Make `token_sequence_length` to be 50 and `token_embedding` something large, like 10000. What do you notice? Is a large token embedding always needed?


### 2.3 Transformer block   <font color='green'>Intermediate</font>

Just as a MLP or CNN is network is a stack of layers, transformers are also composed of a stack of transformer blocks. In this section we build out each one of these blocks that are required to form a transformer block.


#### 2.3.1 Feed Forward Network (FFN) / Multilayer perceptron (MLP) <font color='orange'>Beginner</font>


<img src="https://drive.google.com/uc?export=view&id=1H1pVFxJiSpM_Ozj1eKWNdcFQ5Hn5XsZz" alt="drawing" width="260"/>

These blocks are just a single 2-layer MLP that uses ReLU activation in the original model. GeLU has also become very popular, and we will be using it throughout the practical. The formula below represents the feedforward neural network (FFN) with ReLU activation, where input `x` is transformed through two linear layers with weights `W1` and `W2`, followed by bias terms `b1` and `b2`, and the `max` function represents the ReLU activation function.

$$
\operatorname{FFN}(x)=\max \left(0, x W_{1}+b_{1}\right) W_{2}+b_{2}
$$

One can interpret this block as processing what the MHA block has produced and then projecting these new token representations to a space that the next block can use more optimally. Usually, the first layer is very wide, in the range of 2-8 times the size of the token representations. They do this as it is easier to parallelize computations for a single wider layer during training than to parallelize a feedforward block with multiple layers. Thus they can add in more complexity but keep training and inference optimized.

**Code task:** Code up a Flax Module that implements the feed forward block.

In [None]:
class FeedForwardBlock(nn.Module):
  """
  A 2-layer MLP which widens then narrows the input.

  Args:
    widening_factor [optional, default=4]: The size of the hidden layer will be d_model * widening_factor.
  """

  widening_factor: int = 4
  init_scale: float = 0.25

  @nn.compact
  def __call__(self, x):
    '''
    Args:
      x: [B, T, d_m]

    Return:
      x: [B, T, d_m]
    '''
    d_m = x.shape[-1]
    layer1_size = self.widening_factor * d_m

    initializer = nn.initializers.variance_scaling(
        scale=self.init_scale, mode='fan_in', distribution='truncated_normal',
    )
    layer1 = # FINISH ME
    layer2 = # FINISH ME

    x = jax.nn.gelu(layer1(x))
    x = layer2(x)
    return x

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!')


class FeedForwardBlock(nn.Module):
  """A 2-layer MLP which widens then narrows the input."""
  widening_factor: int = 4
  init_scale: float = 0.25

  @nn.compact
  def __call__(self, x):
    d_m = x.shape[-1]
    layer1_size = self.widening_factor * d_m

    initializer = nn.initializers.variance_scaling(
        scale=self.init_scale, mode='fan_in', distribution='truncated_normal',
    )
    layer1 = nn.Dense(layer1_size, kernel_init=initializer)
    layer2 = nn.Dense(d_m, kernel_init=initializer)

    x = jax.nn.gelu(layer1(x))
    x = layer2(x)
    return x

#### 2.3.2 Add and Norm block <font color='orange'>Beginner</font>

In order to get transformers to go deeper, the residual connections are very important to allow an easier flow of gradients through the network. For normalisation, `layer norm` is used. This normalises each token vector independently in the batch. It is found that normalising the vectors improves the convergence and stability of transformers.

There are two learnable parameters in layernorm, `scale` and `bias`, which rescales the normalised value. Thus, for each input token in a batch, we calculate the mean, $\mu_{i}$ and variance $\sigma_i^2$. We then normalise the token with:

$\hat{x}_i = \frac{x_i-\mu_{i}}{\sigma_i^2 + ϵ}$.

Then $\hat{x}$ is rescaled using the learned `scale`, $γ$, and `bias` $β$, with:

$y_i = γ\hat{x}_i + β = LN_{γ,β}(x_i)$.

So our add norm block can be represented as $LN(x+f(x))$, where $f(x)$ is either a MLP or MHA block.

**Code task:** Code up a Flax Module that implements the add norm block. It should take as input the processed and unprocessed tokens. Hint: `hk.LayerNorm `

In [None]:
class AddNorm(nn.Module):
  """A block that impliments the add and norm block"""

  @nn.compact
  def __call__(self, x, processed_x):
    '''
    Args:
      x: Sequence of tokens before feeding into MHA or FF blocks, with shape [B, T, d_m]
      x: Sequence of after being processed by MHA or FF blocks, with shape [B, T, d_m]

    Return:
      add_norm_x: Transformed tokens with shape [B, T, d_m]
    '''

    added = # FINISH ME
    normalised = #FINISH ME
    return normalised(added)

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!')


class AddNorm(nn.Module):
  """A block that impliments the add and norm block"""

  @nn.compact
  def __call__(self, x, processed_x):

    added = x + processed_x
    normalised = nn.LayerNorm(reduction_axes=-1, use_scale=True, use_bias=True)
    return normalised(added)

### 2.4 Building the Transformer Decoder / LLM <font color='green'>Intermediate</font>

<img src="https://drive.google.com/uc?export=view&id=1MubUcshJTHCqOPTRHixLhrYYLXX9vP_h" alt="drawing" width="260"/>

Most of the groundwork has happened. We have built the positional encoding block, the MHA block, the feed-forward block and the add&norm block.

The only part needed is passing inputs to each decoder block and applying the masked MHA block found in the decoder blocks.

**Code task:** Code up a FLAX Module that implements the (FFN(norm(MHA(norm(X))))) for the decoder block

In [None]:
class DecoderBlock(nn.Module):
  """
  Transformer decoder block.

  Args:
    num_heads: The number of heads to be used in the MHA block.
    d_m: Token embedding size
    widening factor: The size of the hidden layer will be d_m * widening_factor.
  """

  num_heads: int
  d_m: int
  widening_factor: int = 4

  def setup(self):
    self.mha = MultiHeadAttention(self.num_heads, self.d_m)
    self.add_norm1 = AddNorm()
    self.add_norm2 = AddNorm()
    self.MLP = FeedForwardBlock(widening_factor=self.widening_factor)

  def __call__(self, X, mask=None, return_att_weight=True):
    """
    Args:
      X: Batch of tokens being fed into the decoder, with shape [B, T_decoder, d_m]
      encoder_output: Batch of tokens with was processed by the encoder, with shape [B, T_encoder, d_m]
      mask [optional, default=None]: Mask to be applied, with shape [T_decoder, T_decoder].
      return_att_weight [optional, default=True]: Whether to return the attention weights.
    """

    attention, attention_weights_1 = # FINISH ME

    X = # FINISH ME

    projection = # FINISH ME
    X = # FINISH ME

    return (X, attention_weights_1) if return_att_weight else X

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!')

class DecoderBlock(nn.Module):
  """
  Transformer decoder block.

  Args:
    num_heads: The number of heads to be used in the MHA block.
    d_m: Token embedding size
    widening factor: The size of the hidden layer will be d_m * widening_factor.
  """

  num_heads: int
  d_m: int
  widening_factor: int = 4

  def setup(self):
    self.mha = MultiHeadAttention(self.num_heads, self.d_m)
    self.add_norm1 = AddNorm()
    self.add_norm2 = AddNorm()
    self.MLP = FeedForwardBlock(widening_factor=self.widening_factor)

  def __call__(self, X, mask=None, return_att_weight=True):
    """
    Args:
      X: Batch of tokens being fed into the decoder, with shape [B, T_decoder, d_m]
      mask [optional, default=None]: Mask to be applied, with shape [T_decoder, T_decoder].
      return_att_weight [optional, default=True]: Whether to return the attention weights.
    """

    attention, attention_weights_1 = self.mha(X, mask=mask, return_weights=True)

    X = self.add_norm1(X, attention)

    projection = self.MLP(X)
    X = self.add_norm2(X, projection)

    return (X, attention_weights_1) if return_att_weight else X

Next, we just put everything together, adding in the positional encodings as well as stacking multiple transformer blocks and adding our prediction layer.

In [None]:
class LLM(nn.Module):
  """
  Transformer encoder consisting of several layers of decoder blocks.

  Args:
    num_heads: The number of heads to be used in the MHA block.
    num_layers: The number of decoder blocks to be used.
    d_m: Token embedding size
    vocab_size: The size of the vocabulary
    widening_factor: The size of the hidden layer will be d_m * widening_factor.
  """
  num_heads: int
  num_layers: int
  d_m: int
  vocab_size: int
  widening_factor: int = 4

  def setup(self):
    self.blocks = [
        DecoderBlock(self.num_heads, self.d_m, self.widening_factor)
        for _ in range(self.num_layers)
    ]
    self.embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.d_m) # convert tokens to embedding
    self.pred_layer = nn.Dense(self.vocab_size)

  def __call__(self, X, mask=None, return_att_weights=False):
    """
    Args:
      X: Batch of tokens being fed into the decoder, with shape [B, T_decoder, d_m]
      mask [optional, default=None]: Mask to be applied, with shape [T_decoder, T_decoder].
      return_att_weight [optional, default=True]: Whether to return the attention weights.
    """

    # convert a token id to a d_m dimensional vector
    X = self.embedding(X)
    sequence_len = X.shape[-2]
    positions = return_frequency_pe_matrix(sequence_len, self.d_m)
    X = X + positions

    if return_att_weights:
        att_weights = []
    block_n = 0
    for block in self.blocks:
        out = block(X, mask, return_att_weights)
        if return_att_weights:
            X = out[0]
            att_weights.append(out[1])
        else:
            X = out

    # apply a linear layer and softmax to calculate our logits over tokens
    logits = nn.log_softmax(self.pred_layer(X))

    return (
        logits if not return_att_weights else (logits, jnp.array(att_weights).swapaxes(0, 1))
    )


If everything is correct, then if we run the code below, everything should run without any issues.

In [None]:
B, T, d_m, N, vocab_size = 18, 32, 16, 8, 25670

llm = LLM(num_heads=1, num_layers=1, d_m=d_m, vocab_size=vocab_size, widening_factor=4)
mask = jnp.tril(np.ones((T, T)))

# initialise module and get dummy output
key = jax.random.PRNGKey(42)
X = jax.random.randint(key, [B, T], 0, vocab_size)
params = llm.init(key, X, mask=mask)

# extract output from decoder
logits, decoder_att_weights = llm.apply(
    params,
    X,
    mask=mask,
    return_att_weights=True,
)

As a final sanity check, we can see that our attention weights behave as expected for now. The encoder weights can attend to all input sequences, and our decoder only attends to previous tokens.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
plt.suptitle("LLM attention weights")
sns.heatmap(decoder_att_weights[0, 0, 0, ...], ax=ax, cmap="Blues")
fig.show()

### 2.5 Training your LLM

#### 2.5.1 Training objective <font color='green'>Intermediate</font>


A sentence is nothing but a string of words. A LLM aims to predict the next word by considering the current context, namely the words that have come before.

Here's the basic idea:

To calculate the probability of a full sentence "word1, word2, ..., last word" appearing in a given context $c$, the procedure is to break down the sentence into individual words and consider the probability of each word given the words that precede it. These individual probabilities are then multiplied together:

$$\text{Probability of sentence} = \text{Probability of word1} \times \text{Probability of word2} \times \ldots \times \text{Probability of last word}$$

This method is akin to building up a narrative one piece at a time based on the preceding storyline.

Mathematically, this is expressed as the likelihood (probability) of a sequence of words $y_1, y_2, ..., y_n$ in a given context $c$, which is achieved by multiplying the probabilities of each word $y_t$ calculated given the predecessors ($y_{<t}$) and the context $c$:

$$
P\left(y_{1}, y_{2}, \ldots, y_{n}, \mid c\right)=\prod_{t=1}^{n} P\left(y_{t} \mid y_{<t}, c\right)
$$

Here $y_{<t}$ stands for the sequence $y_1, y_2, ..., y_{t-1}$, while $c$ represents the context.

This is analogous to solving a jigsaw puzzle where the next piece is predictively placed based on what's already in place.

Remember just when training a transformer, we do not work in words, but in tokens. During the training process, the model's parameters are fine-tuned by computing the cross-entropy loss across the predicted token, and the correct token, and then performing backpropagation. The loss for time step "t" is computed as:

$$ \text{Loss}_t = - \sum_{w \in V} y_t\log (\hat{y}_t) $$

Here $y_t$ is the actual token at time step $t$, and $\hat{y}_t$ is the token predicted by the model at the same time step. The loss for the entire sentence is then computed as:

$$ \text{Sentence Loss} = \frac{1}{n} \sum^{n}_{t=1} \text{Loss}_t $$

where $n$ is the length of the sequence.

This iterative process ultimately hones the model's predictive capabilities over time.

**Code task**: Implement the cross-entropy loss function below.

In [None]:
def sequence_loss_fn(logits, targets):
  '''
  Compute the cross-entropy loss between predict token ID and true ID

  Args:
    logits: A array of shape [batch_size, sequence_length, vocab_size]
    targets: The targets we are trying to predict

  Returns:
    loss: A scalar value representing the mean batch loss
  '''

  target_labels = jax.nn.one_hot(targets, VOCAB_SIZE)
  assert logits.shape == target_labels.shape

  mask = jnp.greater(targets, 0)
  loss = #FINISH ME
  return loss

In [None]:
# @title Run me to test your code
VOCAB_SIZE = 25670
targets = jnp.array([[0, 2, 0]])
key = jax.random.PRNGKey(42)
X = jax.random.normal(key, [1, 3, VOCAB_SIZE])
loss = sequence_loss_fn(X, targets)
real_loss = jnp.array(10.966118)
assert jnp.allclose(real_loss, loss), "Not returning the correct value"
print("It seems correct. Look at the answer below to compare methods.")

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!')
def sequence_loss_fn(logits, targets):
  """Compute the loss on data wrt params."""
  target_labels = jax.nn.one_hot(targets, VOCAB_SIZE)
  assert logits.shape == target_labels.shape
  mask = jnp.greater(targets, 0)
  loss = -jnp.sum(target_labels * jax.nn.log_softmax(logits), axis=-1)
  loss = jnp.sum(loss * mask) / jnp.sum(mask)

  return loss

#### 2.5.2 Training models <font color='blue'>Advanced</font>

In the next section, we define all the processes required to train the model using the objective described above. A lot of this is now the work required to do training using FLAX.

Below we gather the dataset and we shall be training on, which is Karpathy's shakespeare dataset. Its not so important to understand this code, so either just run the cell to load the data, or view the code if you want to understand it.


In [None]:
# @title Create Shakespeare dataset and iterator (optional, but run the cell)

# Trick to avoid errors when downloading tinyshakespeare.
import locale
locale.getpreferredencoding = lambda: "UTF-8"

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O input.txt

class WordBasedAsciiDatasetForLLM:
    """In-memory dataset of a single-file ASCII dataset for language-like model."""

    def __init__(self, path: str, batch_size: int, sequence_length: int):
        """Load a single-file ASCII dataset in memory."""
        self._batch_size = batch_size

        with open(path, "r") as f:
            corpus = f.read()

        # Tokenize by splitting the text into words
        words = corpus.split()
        self.vocab_size = len(set(words))  # Number of unique words

        # Create a mapping from words to unique IDs
        self.word_to_id = {word: i for i, word in enumerate(set(words))}

        # Store the inverse mapping from IDs to words
        self.id_to_word = {i: word for word, i in self.word_to_id.items()}

        # Convert the words in the corpus to their corresponding IDs
        corpus = np.array([self.word_to_id[word] for word in words]).astype(np.int32)

        crop_len = sequence_length + 1
        num_batches, ragged = divmod(corpus.size, batch_size * crop_len)
        if ragged:
            corpus = corpus[:-ragged]
        corpus = corpus.reshape([-1, crop_len])

        if num_batches < 10:
            raise ValueError(
                f"Only {num_batches} batches; consider a shorter "
                "sequence or a smaller batch."
            )

        self._ds = WordBasedAsciiDatasetForLLM._infinite_shuffle(
            corpus, batch_size * 10
        )

    def __iter__(self):
        return self

    def __next__(self):
        """Yield next mini-batch."""
        batch = [next(self._ds) for _ in range(self._batch_size)]
        batch = np.stack(batch)
        # Create the language modeling observation/target pairs.
        return dict(
            input=batch[:, :-1], target=batch[:, 1:]
        )

    def ids_to_words(self, ids):
        """Convert a sequence of word IDs to words."""
        return [self.id_to_word[id] for id in ids]

    @staticmethod
    def _infinite_shuffle(iterable, buffer_size):
        """Infinitely repeat and shuffle data from iterable."""
        ds = itertools.cycle(iterable)
        buf = [next(ds) for _ in range(buffer_size)]
        random.shuffle(buf)
        while True:
            item = next(ds)
            idx = random.randint(0, buffer_size - 1)  # Inclusive.
            result, buf[idx] = buf[idx], item
            yield result


Lets now look how our data is structured for training

In [None]:
# sample and look at the data
batch_size = 2
seq_length = 32
train_dataset = WordBasedAsciiDatasetForLLM("input.txt", batch_size, seq_length)

batch = next(train_dataset)

for obs, target in zip(batch["input"], batch["target"]):
    print("-" * 10, "Input", "-" * 11)
    print("TEXT:", ' '.join(train_dataset.ids_to_words(obs)))
    print("ASCII:", obs)
    print("-" * 10, "Target", "-" * 10)
    print("TEXT:", ' '.join(train_dataset.ids_to_words(target)))
    print("ASCII:", target)

print(f"\n Total vocabulary size: {train_dataset.vocab_size}")

VOCAB_SIZE = train_dataset.vocab_size

Next, let us train our LLM and see how it performs in producing Shakespearian text. First, we will define what happens for every training step.

In [None]:
import functools

@functools.partial(jax.jit, static_argnums=(3, 4))
def train_step(params, optimizer_state, batch, apply_fn, update_fn):
  def loss_fn(params):
    T = batch['input'].shape[1]
    logits = apply_fn(params, batch['input'], jnp.tril(np.ones((T, T))))
    loss = sequence_loss_fn(logits, batch['target'])
    return loss

  loss, gradients = jax.value_and_grad(loss_fn)(params)
  updates, optimizer_state = update_fn(gradients, optimizer_state)
  params = optax.apply_updates(params, updates)
  return params, optimizer_state, loss

Next we initialise our optimizer and model. Feel free to play with the hyperparameters during the practical.

In [None]:
# all hyperparameters
d_model = 128
num_heads = 4
num_layers = 1
widening_factor = 2
LR = 2e-3
batch_size = 32
seq_length = 64

# set up the data
train_dataset = WordBasedAsciiDatasetForLLM("input.txt", batch_size, seq_length)
vocab_size = train_dataset.vocab_size
batch = next(train_dataset)

rng = jax.random.PRNGKey(42)

# initialise model
llm = LLM(num_heads=num_heads, num_layers=num_layers, d_m=d_model, vocab_size=vocab_size, widening_factor=widening_factor)
mask = jnp.tril(np.ones((batch['input'].shape[1], batch['input'].shape[1])))
params = llm.init(key, batch['input'], mask)

# set up the optimiser
optimizer = optax.adam(LR, b1=0.9, b2=0.99)
optimizer_state = optimizer.init(params)

Now we train! This will take a few minutes.. While it trains, have you greeted your neighbour yet?

In [None]:
plotlosses = PlotLosses()

MAX_STEPS = 3500
LOG_EVERY = 32
losses = []
VOCAB_SIZE = 25670

# Training loop
for step in range(MAX_STEPS):
    batch = next(train_dataset)
    params, optimizer_state, loss = train_step(
        params, optimizer_state, batch, llm.apply, optimizer.update)
    losses.append(loss)
    if step % LOG_EVERY == 0:
        loss_ = jnp.array(losses).mean()
        plotlosses.update(
            {
                "loss": loss_,
            }
        )
        plotlosses.send()
        losses = []

#### 2.5.3 Inspecting the trained LLM <font color='orange'>Beginner</font>


**Reminder:** remember to run all code presented so far in this section before runnning the cells below!

Lets generate some text now and see how our model did. DO NOT STOP THE CELL ONCE IT IS RUNNING, THIS WILL CHRASH THE SESSION.

In [None]:
import functools

@functools.partial(jax.jit, static_argnums=(2, ))
def generate_prediction(params, input, apply_fn):
  logits = apply_fn(params, input)
  argmax_out = jnp.argmax(logits, axis=-1)
  return argmax_out[0][-1].astype(int)

def generate_random_shakespeare(llm, params, id_2_word, word_2_id):
    '''
    Get the model output
    '''

    prompt = "Love"
    print(prompt, end="")
    tokens = prompt.split()

    # predict and append
    for i in range(15):
      input = jnp.array([[word_2_id[t] for t in tokens]]).astype(int)
      prediction = generate_prediction(params, input, llm.apply)
      prediction = id_2_word[int(prediction)]
      tokens.append(prediction)
      print(" "+prediction, end="")

    return " ".join(tokens)

id_2_word = train_dataset.id_to_word
word_2_id = train_dataset.word_to_id

generated_shakespeare = generate_random_shakespeare(llm, params, id_2_word, word_2_id)

Finally, we implemented everything above by taking the token ID with the maximum probability of being correct. This is greedy decoding, as we only took the most likely token. It worked well in this use case, but there are cases where we will see a degrading performance when taking this greedy approach, specifically when we are interested in generating realistic text.

Other methods exist for sampling from the decoder, with a famous algorithm being beam search. We provide resources below for anyone interested in learning more about this.

[Greedy Decoding](https://www.youtube.com/watch?v=DW5C3eqAFQM&list=PLmZlBIcArwhPHmHzyM_cZJQ8_v5paQJTV&index=4)

[Beam Search](https://www.youtube.com/watch?v=uG3xoYNo3HM&list=PLmZlBIcArwhPHmHzyM_cZJQ8_v5paQJTV&index=5)

## 3. **Efficiently Finetuning LLMs with Hugging Face**


The availability of open source pretrained language models (LLMs), such as [LLAMA](https://github.com/facebookresearch/llama) and [FALCON](https://falconllm.tii.ae/) has been a game-changer in the field of natural language processing. These models, often comprising orders of billions of parameters, offer unprecedented language understanding capabilities. However, as their sizes have grown significantly, fine-tuning them for specific tasks has become more challenging than before.


For our exploration in this tutorial, we will predominantly utilize open source code from Hugging Face, primarily drawing from the transformers, datasets, and PEFT libraries. The [transformers](https://github.com/huggingface/transformers) library grants us access to pretrained LLMs, the [datasets](https://github.com/huggingface/datasets) library provides convenient access to various datasets for training, and the [PEFT (Parameter-Efficient Fine-Tuning)](https://github.com/huggingface/peft) library encompasses implementations of the training and adaptation methods we'll be discussing below.

In this section, we will explore the intricacies of custom adaptation techniques, understanding how to effectively fine-tune these large LLMs and make the most of their extraordinary potential in our research and applications.



### 3.1 Adapter and Fine-Tuning methods  <font color='green'>Intermediate</font>


The world of open source LLMs brings forth an exciting range of possibilities, but their sheer size often poses a challenge for fine-tuning using consumer-grade hardware. Consequently, conventional adaptation methods fall short in such scenarios. To address this, innovative techniques have emerged to overcome these limitations.

A significant proportion of these techniques involve either keeping the model parameters fixed, as seen in prompt engineering, where the input text acts as an agent to adapt the LLM's behavior, or altering only a tiny subset of model parameters. In this tutorial, our focus will be on the latter approach, presenting methods that modify a small portion of the model parameters, or bring additional parameters to a LLM.

However, we encourage readers to explore [OpenAI's cookbook](https://github.com/openai/openai-cookbook), which hosts transferable recipes and links for prompting LLMs (as well as other models) for further insights and possibilities.

A lot of content here is inspired by [Lightning AI blogs](https://lightning.ai/pages/community/article/understanding-llama-adapters/).

**Discussion**: Before getting started into methods established in the literature let's first think why finetuning the whole model is so costly?

#### 3.1.1 Prefix tuning

Prefix tuning works by introducing a trainable token/tensor into each transformer block along with the input tokens, as opposed to solely modifying the input tokens (prompt engineering) or finetuning the entire transformer bloc. The contrast between a standard transformer block and a transformer block enhanced with a prefix is depicted in the following figure. This was first introduce in the ["Prefix-Tuning: Optimizing Continuous Prompts for Generation" paper](https://arxiv.org/abs/2101.00190) by Xiang Lisa Li and Percy Liang.

By only training the "Trainable tokens" and the new introduced MLP layer, we are able to adapt a model to our domain by training close to 0.1% of the parameters of a full model and achieve performance comparable to fine tuning the entire model.

<img src="https://drive.google.com/uc?export=view&id=1fSnk9MkoPN6KbmbP71iU9EViU9avvOHb" alt="drawing" width="230"/>

Below we show pseudo code for this method, as well as a normal block to showcase the differences. Note running the code will *not* work.

In [None]:
def normal_transformer_block(tokens):
  """
  Example of pseudo code for a normal transformer.
  """
  original_tokens = tokens
  x = MHA(tokens)
  x = LayerNorm(x + original_tokens)
  original_tokens = x
  x = FF(x)
  transformed_tokens = LayerNorm(x + original_tokens)
  return transformed_tokens

def transformer_block_with_prefix(tokens, trainable_tokens):
  """
  Example of pseudo code of transformer block with prefix tuning.
  """
  prefix = FF(trainable_tokens)  # Trainable FF and tokens.
  tokens = concat([prefix, tokens])
  original_tokens = tokens
  x = MHA(tokens)
  x = LayerNorm(x + original_tokens)
  original_tokens = x
  x = FF(x)
  transformed_tokens = LayerNorm(x+original_tokens)
  return transformed_tokens

#### 3.2.1 Adapter Methods

Very similar, and introduced in the ["Parameter-Efficient Transfer Learning for NLP" paper](https://arxiv.org/abs/1902) by Houlsby etc, it consists of adding a new block of weights between the transformer blocks called "Adapter".

<img src="https://drive.google.com/uc?export=view&id=1t521Q3_yAuUDsoakJmv7cgQyF5-VvjgX" alt="drawing" width="450"/>

During adapter tuning, the green layers are trained on the downstream data, this includes the adapter, the layer normalization parameters, and the final classification layer (not shown in the figure).

 It has been shown to achieve similar performance to updating an entire network while only training 3.6% of the total model parameters.

Below again is pseudo code highlighting where and how this work. Note running the code will not work.

**Code exercise**: In similar style implement a pseudo-code implementation of the `Adapter` block showed in the diagram above.

In [None]:
def transformer_block_with_adapters(tokens):
  """
  Example of psuedo code of transformer block with adapter layers.
  """

  # finish me

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!')

def transformer_block_with_adapters(tokens):
  """
  Example of psuedo code of transformer block with adapter layers.
  """

  original_tokens = tokens
  adapted_tokens = AdapterLayer(tokens) # trainable
  x = MHA(adapted_tokens)
  x = LayerNorm(x + original_tokens)
  original_tokens = x
  x = AdapeterLayer(x) # trainable
  x = FF(x)
  transformed_tokens = LayerNorm(x+original_tokens)
  return transformed_tokens

To see how both worlds from adapters and prefix tuning are together refer to the [LLAMA-Adapter paper](https://arxiv.org/abs/2303.16199).

### 3.2 LoRA <font color='orange'>Beginner</font>, <font color='green'>Intermediate</font>, <font color='blue'>Advanced</font>

> This section is a summarized copy of the [LoRA paper]((https://arxiv.org/abs/2106.09685)), read the paper for more details! We also recommend beginner focus more on the finetuning and interacting with your trained model, than on the details presented here.

Finally let's talk about one of the most widely used methods for Efficient FineTuning called LoRA introduced in the paper ["LoRA: Low-Rank Adaptation of Large Language Models"](https://arxiv.org/abs/2106.09685) from Edward J. Hu et al.

The idea behind Parameter Efficient Tunning (PEFT) is how to **efficiently** (memory wise and speedwise ) finetune these large models while significantly improving quality of the outputs produced, or to make the model behave in a different manner.

Also it's important not only to be efficient during the finetune process but also during **inference** (e.g. for text generation models we're usually most interested in sampling time).

LoRA is a method that's not only efficient during finetung but also **can be applied during inference without any additional cost when compared to the original model!**


**Task**: To build intuition for what is to come, prove that $y1$ equal, or not equal, to $y2$. Note, $X$ and $W$ are matrices.

> $W = W_1 + W_2$
>
> $y_1 = WX$
>
> $y_2 = W_1X + W_2X$

In [None]:
# @title Answer to math task (Try not to run until you've given it a good try!')
%%latex
\begin{aligned}
y_1 &= WX \\
y_1 &= (W_1+W_2)X \\
y_1 &= W_1X+W_2X \\
y_1 &= y_2
\end{aligned}

**LoRA details**


In the LoRA paper, the authors implement an innovative strategy when using the GPT-3 model, which has 175 billion parameters. Instead of fine-tuning all those parameters, they opt to freeze the pre-trained weights.

They rather decompose the delta weight matrix ($∇W$), which represents the changes applied to the original model weights ($W_0$) ,  into two smaller matrices, referred to as A and B, such that $BA=∇W$.

The shapes of A and B are chosen as $(d,r)$ and $(r,d)$ respectively, where $r$ is a lower rank projection used for input projection, and $r$ plays a crucial role in determining the degree of reduction in number of parameters.

This translates into a significant reduction in computational requirements, as illustrated in the figure blow when comparing the area of blue and orange areas in the paper.

Thus the output of a finetuned layer, $y$,  is reformulated as, where $x$ is the input data:

$y = W_0x + ∇Wx = W_0x + BAx$

During the fine-tuning process, we then only adapt the weights of $A$ and $B$, which are substantially smaller compared to the initial model weights $W_0$. $A$ and $B$ are called: *trainable rank decomposition matrices* and can be added in between each/any layer of the Transformer architecture.


<img src="https://miro.medium.com/v2/resize:fit:1046/format:webp/1*F7uWJePoMc6Qc1O2WxmQqQ.png">

Translated into simpler terms, the authors essentially create an "encoder" matrix $A$ that converts the original input data into a highly condensed hidden vector. They then use a "decoder" matrix $B$ to reconstruct this hidden vector back to its original dimensionality.  The final model output is then generated by adding together the original model output and the output created through this projection process. This way, the initially enormous model can be efficiently fine-tuned for specific tasks, saving computational resources and time, but only training these layer specific encoder-decoders.


**Question**: How many trainable parameters are introduced by these matrices?

**Question:** How could LoRA not add any extra costs during inference?

**Why this works?**

LoRa takes inspiration from [Li et al. (2018)](https://arxiv.org/abs/1804.08838) and [Aghajanyan et al. (2020)](https://arxiv.org/abs/2012.13255), both of which demonstrate that learned over-parametrized models actually reside on a low intrinsic dimension. In other words, it should be possible to extract the main content of a large attention weight with dimension `D` into a vector with a much smaller dimension `d`, where `d` <<< `D`.

The authors of LoRA then applied the following assumption: changes in weights during model adaptation/updates also exhibit this property of residing on lower instrinc dimensions. In theory then, it should be possible to learn this smaller dimension vector instead of directly updating the large matrix `D`.

**Does this come for free?**

Even though this method helps substantially to reduce the costs and speed during training, it does come with some additional costs during inference, if one keeps the learned matrices and original model separate. However, it is possible to merge the learned LoRA weights and original model together into one model again to actually have zero extra cost.

Why would one consider keeping it separate? Well if one has many different tasks that one want to fine tune towards, having one large model optimised, with many multiple "LoRA models" for each task becomes much easier to maintain and run in production than various large models.

#### 3.2.1 LoRA implementation <font color='blue'>Advanced</font>

Now let's implement a simple LoRA module together!

Before doing so, there're a couple of parameters / configs associated with LoRA:

1. **target_modules**: Which layers or transformer matrices (Q, K, V) we should apply LoRA to?  
2. **lora_rank**: What is the size of the rank to be used.  
3. **lora_alpha**: This is used for scaling. This scaling helps to reduce the need to retune hyperparameters when we vary `lora_rank`. When optimizing with Adam, tuning `lora_alpha` is roughly the same as tuning the learning
rate if we scale the initialization appropriately.
4. **initialization**: We use a random Gaussian initialization for `A` and
zero for `B`.   
5. **dropout**: as usually applied in Deep Learning models.

**Group task**: Can you think of why the initialization proposed in 4. is used?

**Code task**: Finis the below implementation. Hint: If you unfamiliar with einsum, refer to the following [documentation](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html)


In [None]:
class Lora(nn.Module):
  # Depend on Module we're applying LoRA to.
  input_dims: int
  output_dims: int

  lora_rank: int
  lora_alpha: float
  lodra_dropout: float

  a_init: nn.initializers.Initializer = nn.initializers.normal()
  b_init: nn.initializers.Initializer =  nn.initializers.zeros_init()

  def setup(self):
    self.a_weights = # FINISH ME
    self.b_weights = # FINISH ME

  def __call__(self, input_array: chex.Array, attn_h: chex.Array, training: bool):
    """Implements LoRA technique.

      Args:
        input_array: Shaped[..., input_dims]
        attn_h: Shaped[..., output_dims]
      Returns:
        output_array: Shaped[..., output_dims]
    """
    low_rank = jnp.einsum('...i,ij->...j', input_array, self.a_weights)
    output = # FINISH ME
    return output + attn_h

  def weights_for_inference(self, weights: chex.Array):
    """Return original weights + `LoRA` weights for no added costs during inference."""
    return jnp.einsum('ij,kj->ik', self.a_weights, self.b_weights) + # FINISH ME

In [None]:
# @title Answer to code task (Try not to peek until you've given it a good try!')
class Lora(nn.Module):
  # Depend on Module we're applying LoRA to.
  input_dims: int
  output_dims: int

  lora_rank: int
  lora_alpha: float
  lodra_dropout: float

  a_init: nn.initializers.Initializer = nn.initializers.normal()
  b_init: nn.initializers.Initializer = nn.initializers.zeros_init()

  def setup(self):
    self.a_weights = self.param('a_weights', self.a_init, (self.input_dims, self.lora_rank,))
    self.b_weights = self.param('b_weights', self.b_init, (self.output_dims, self.lora_rank,))

  def __call__(self, input_array: chex.Array, attn_h: chex.Array, training: bool):
    """
      Args:
        input_array: Shaped[..., input_dims]
        attn_h: Shaped[..., output_dims]
      Returns:
        output_array: Shaped[..., output_dims]
    """
    low_rank = jnp.einsum('...i,ij->...j', input_array, self.a_weights)
    output = jnp.einsum('...j,...kj->...k', low_rank, self.b_weights)


    scaling = self.lora_alpha / self.lora_rank
    output = output * scaling
    return output + attn_h

  def weights_for_inference(self, weights: chex.Array):
    return weights + jnp.einsum('ij,kj->ik', self.a_weights, self.b_weights)

#### 3.2.3 🤗 Deep dive into LoRA with Hugging Face! 🤗 <font color='orange'>Beginner</font>

While the implementation of LoRA is currently in place, there is still some additional work required to ensure it aligns perfectly with the outlined methods in the research paper. Fortunately, the open-source community is rapidly evolving. As a result, LoRA, including most of its variants, has now been fully developed and features a user-friendly interface for ease of use, rather than building it from scratch.

We'll now finetune `gpt2-medium` to generate song lyrics from the artist of your choice! We will do this by:

* Loading a pretrained model using Hugging Face transformers
* Gathering the dataset using datasets
* Fine tune using LoRA

Code in this section is based on: https://github.com/22-hours/cabrita/blob/main/notebooks/train_lora.ipynb from [piEsposito](https://github.com/piEsposito) and
[pedrogengo](https://github.com/pentrogengo).

##### Gathering and processing data (optional)

To train a model, we need to gather the dataset from Hugging Face to load some song lyrics.

In [None]:
dataset_author = "huggingartists"
artist_name = "the-beatles"

# List all avalable datasets.
all_datasets = {}
for ds in huggingface_hub.list_datasets(author=dataset_author):
  music_artist = ds.id.replace(f'{dataset_author}/', '')
  all_datasets[music_artist] = ds.id

dataset_name = all_datasets[artist_name]

print(f'Choose an artist available in {dataset_author}/ (careful! Some lyrics might contain offensive language)')
Dropdown_ = widgets.Dropdown(
    options=all_datasets.keys(),
    value=artist_name,
)
output = widgets.Output()


def on_change(change):
  global dataset_name
  global artist_name
  artist_name = change["new"]
  dataset_name = all_datasets[artist_name]
  print(f'`dataset_name` is now {dataset_name}.')

Dropdown_.observe(on_change, names='value')
display(Dropdown_)

In [None]:
formatted_artist_name = artist_name.replace('-', ' ').title()
prompt = f'This is a song by {formatted_artist_name}. It goes like this:\n\n' # @param
print(prompt)

In [None]:
# Let's use only the first 500 examples so it doesn't take too long to finetune.
train_dataset = datasets.load_dataset(dataset_name, split='train')
# Let's use up to 500 examples so it doesn't take too long to finetune.
train_dataset = train_dataset.filter(lambda example, idx: idx < 500, with_indices=True)
train_dataset

To train the model we usually set a fixed `sequence_length` for all inputs. A way to achieve this is by:
* Padding inputs shorter than `sequence_length`.
* Truncating inputs longer than `sequence_length`.

**Think about this**: How do we know which `sequence_length` to set?

In [None]:
max_num_tokens = 256 # @param

num_chars, num_tokens = [], []

for i, text in enumerate(train_dataset['text']):
  text = prompt + text
  num_tokens.append(len(tokenizer.tokenize(text)))
  num_chars.append(len(text))

num_chars = np.array(num_chars)
num_tokens = np.array(num_tokens)
print('Median #chars', np.median(num_chars))
print('Max #chars', np.max(num_chars))
print('Median #tokens', np.median(num_tokens))
print('Max #tokens', np.max(num_tokens))
num_truncations = np.sum(num_tokens > max_num_tokens)
num_truncated_tokens = num_tokens - max_num_tokens
median_num_truncated_tokens = np.median(
    np.where(num_truncated_tokens > 0, num_truncated_tokens, 0),
)
print(f'Number of examples that will be truncated: {num_truncations} ({num_truncations/len(num_tokens) * 100:.2f} %)')
plt.boxplot(num_truncated_tokens)
plt.title('#truncated tokens')
plt.show()

**Discussion:** By only truncating the data we're losing A LOT of relevant text. What we could do to avoid this?

Now we need to start preprocessing the data for training, and create a function tokenize our data for the model input.

In [None]:
# Drop empty lyrics.
train_dataset = train_dataset.filter(lambda x: len(x["text"]) > 0)
print('Here is a data example before tokenization')
print_sample(prompt, train_dataset[0]["text"])
# Add prompt
train_dataset = train_dataset.map(lambda x: {"text": prompt + x["text"]})

def tokenize(prompt):
  result = tokenizer(
      prompt,
      truncation=True,
      max_length=max_num_tokens,
      padding="max_length",
  )
  return {
      "input_ids": result["input_ids"],
      "attention_mask": result["attention_mask"],
  }

train_dataset = train_dataset.shuffle().map(lambda x: tokenize(x["text"]))

##### Finetune a model with LoRA

Below we fine tune our model with LoRA. The first thing to do now is to set our hyperparameters for training.

In [None]:
# @title Hyper-parameters
MICRO_BATCH_SIZE = 4 # @param
BATCH_SIZE = 32 # @param
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
EPOCHS = 40 # @param
LEARNING_RATE = 3e-4 # @param
CUTOFF_LEN = max_num_tokens
LORA_R = 12 # @param
LORA_ALPHA = 12 # @param
LORA_DROPOUT = 0.2 # @param
WARMUP_STEPS = 20 # @param
QUERY_USED_DURING_TRAINING = "Dreams of llamas" # @param

Now we load our model in using the PEFT library. Remember, the PEFT library is where all these optimisation and efficient training methods are created in.

**Discussion:** Do you know what each of the hyper-parameters do above? Do you understand the intution behind each of them? Try thinking about this and discussing it with your colleagues and tutors about this!

In [None]:
peft_config = peft.LoraConfig(
    task_type=peft.TaskType.CAUSAL_LM,
    inference_mode=False,
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    # Default is to apply lora to Q and V projection matrices.
    # These are based on results of Table 5 of the paper.
    target_modules=None,
)
peft_model = peft.get_peft_model(copy.deepcopy(model), peft_config)
peft_model.print_trainable_parameters()

With all of the groundwork finally laid out, we can train our model! We are using a bit of hacking to better showcase what is happening. It is not too important to understand everything below.

In [None]:
class PlotLossCalback(transformers.TrainerCallback):
  def on_epoch_end(self, args, state, control, model=None, tokenizer=None, logs=None, **kwargs):
    states_history = state.log_history
    losses, learning_rates, steps = [], [], []
    for curr_state in state.log_history:
      if 'loss' not in curr_state:  # Evaluation from `HackyTrainerThatRunsSampleInTheLoop`.
        continue
      losses.append(curr_state['loss'])
      learning_rates.append(curr_state['learning_rate'])
      steps.append(curr_state['step'])

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3))
    ax1.plot(steps, losses, '-ob')
    ax1.set_title('Steps vs Loss')

    ax2.plot(steps, learning_rates, '-or')
    ax2.set_title('Step vs Learning Rate')
    plt.show()


class HackyTrainerThatRunsSampleInTheLoop(transformers.Trainer):
  def prediction_step(
      self,
      model,
      inputs,
      prediction_loss_only: bool,
      ignore_keys: list[str] | None = None,
      # Return: loss, logits, labels
  ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
    del inputs, prediction_loss_only, ignore_keys  # unused.
    _ = run_sample(
        model,
        tokenizer,
        prompt=prompt + QUERY_USED_DURING_TRAINING,
        seed=1,
        temperature=0.8,
        top_p=0.9,
    )
    return (None, None, None)

training_arguments = transformers.TrainingArguments(
    per_device_train_batch_size=MICRO_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    num_train_epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    fp16=True,
    logging_steps=GRADIENT_ACCUMULATION_STEPS,
    warmup_steps=10,
    # lr_scheduler_type="cosine",
    output_dir="tmp",
    save_strategy="no",
    evaluation_strategy="epoch",
    logging_strategy="epoch",
)
trainer = HackyTrainerThatRunsSampleInTheLoop(
    model=peft_model,
    train_dataset=train_dataset,
    # Unused. But needed to run hacky inference.
    eval_dataset=[{'input_ids': [], 'attention_mask': []}],
    args=training_arguments,
    callbacks=[PlotLossCalback],
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    tokenizer=tokenizer,
)
peft_model.config.use_cache = False
trainer.train(resume_from_checkpoint=False)

### ⏰⚡ Demo Time with our trained model🚀😰

In [None]:
seed = 2
query = "You"
final_prompt = prompt + query
temperature = 1.0
top_p = 0.9
print('LoRA model')
_ = run_sample(
    peft_model,
    tokenizer,
    prompt=final_prompt,
    seed=seed,
    temperature=temperature,
    top_p=top_p,
)

print('Original model')
_ = run_sample(
    model,
    tokenizer,
    prompt=final_prompt,
    seed=seed,
    temperature=temperature,
    top_p=top_p,
)

That is pretty awesome is not it not?

As a challenge, play with all the hyperparameters above, as well as model choices, and see if you can beat your friend to create the best lyric generator using your own custom LLM trained with the LoRA technique.

## **Conclusion**
**Summary:**

You have now learned all the basics of how a LLM works, all the way from the pure fundamentals to finetuning a GPT architecture with LoRA. These are powerful tools and very applicable for many tasks, but just like any other deep learning model, they are just models and should be used for the correct problem and data.

**Next Steps:**

Follow all the links provided in this practical, as well as reading up on the llama2 and Falcon architectures to see how the latest techniques are utilised.


**References:** for further references check the links referenced throughout
specific sections of this colab.

* Attention is all you need paper: https://arxiv.org/abs/1706.03762
* LoRA paper: https://arxiv.org/abs/2106.09685
* RLHF (how ChatGPT was trained): https://huggingface.co/blog/rlhf
* Extending context lenght: https://kaiokendev.github.io/context


For other practicals from the Deep Learning Indaba, please visit [here](https://github.com/deep-learning-indaba/indaba-pracs-2023).

# Feedback

Please provide feedback that we can use to improve our practicals in the future.

In [None]:
# @title Generate Feedback Form. (Run Cell)
from IPython.display import HTML

HTML(
    """
<iframe
	src="https://forms.gle/Cg9aoa7czoZCYqxF7",
  width="80%"
	height="1200px" >
	Loading...
</iframe>
"""
)

<img src="https://baobab.deeplearningindaba.com/static/media/indaba-logo-dark.d5a6196d.png" width="50%" />