##### Copyright 2024 Google LLC.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Fine-Tuning Gemma for Retrieval-Augmented Generation with JORA

Scaling Large Language Models (LLMs) for retrieval-based tasks, particularly in Retrieval-Augmented Generation (RAG), poses significant memory challenges, especially when fine-tuning extensive prompt sequences.

[Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.

Existing open-source libraries support full-model inference and fine-tuning across multiple GPUs but often fall short in efficiently distributing parameters required for retrieved context. To address this limitation, [JORA](https://github.com/aniquetahir/JORA) introduced a novel framework for Parameter-Efficient Fine-Tuning (PEFT) of Llama/Gemma models using distributed training, leveraging [JAX](https://jax.readthedocs.io/en/latest/). This framework uniquely utilizes JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management, enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources.

The experiments demonstrate more than **12x improvement in runtime** compared to [Hugging Face](https://huggingface.co/docs/transformers/en/main_classes/trainer)/[DeepSpeed](https://github.com/microsoft/DeepSpeed) implementations with four GPUs while consuming less than half the VRAM per GPU. All in all, this library improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources.

In this tutorial, you will understand the end-to-end process of fine-tuning a [Gemma](https://github.com/google/gemma) model using JORA and converting the trained model back to the [Hugging Face](https://huggingface.co/) format for inference.

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/Finetune_with_JORA.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

## Setup Environment

Before you begin, ensure you have access to a Colab notebook with GPU runtime enabled. Go to **Runtime** > **Change runtime type** and select the right **GPU**.

## Setup

### Select the Colab runtime
To complete this tutorial, you'll need to have a [**Colab Pro/Pro+**](https://colab.research.google.com/signup) runtime with sufficient resources to run the Gemma model. In this case, you can use a **A100** GPU:

1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **A100 GPU**.

### **Kaggle** Gemma setup

To complete this tutorial and download and fine-tune using the necessary Kaggle Gemma Flax models, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

* Get access to Gemma on kaggle.com.
* Select a Colab runtime with sufficient resources to run
  the Gemma model.
* You'll generate and configure a Kaggle username and an API key as Colab secrets later in the guide.

### **Hugging Face** Gemma setup

You'll also be logging in to Hugging Face Hub to download the exact Gemma model used while fine-tuning. So, let's get you set up with Gemma:

1. **Hugging Face Account:**  If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).
2. **Gemma Model Access:** Head over to the [Gemma model page](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) and accept the usage conditions.
3. **Colab with Gemma Power:**  For this tutorial, you'll need a Colab runtime with enough resources to handle the Gemma model. Choose an appropriate runtime when starting your Colab session.
4. **Hugging Face Token:**  Generate a Hugging Face access (preferably `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). This token will come in handy later.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.


### Configure your credentials

Add your **Hugging Face** (`HF_TOKEN`) and **Kaggle** tokens (`KAGGLE_USERNAME` and `KAGGLE_KEY`) to the Colab Secrets manager to securely store them.

1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src="https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg" alt="The Secrets tab is found on the left panel." width=50%>
2. Create a new secret with the name `HF_TOKEN`.
3. Copy/paste your token key into the Value input box of `HF_TOKEN`.
4. Toggle the button on the left to allow notebook access to the secret.
5. Repeat it for the Kaggle secrets with names `KAGGLE_USERNAME` and `KAGGLE_KEY`.


In [None]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

# Disable progress bar to prevent verbose logging by kagglehub
os.environ["TQDM_DISABLE"] = "1"

### Clone **JORA** and install dependencies

In [None]:
# Clone the JORA repository and install the requirements
!git clone https://github.com/aniquetahir/JORA.git
%cd JORA
!pip install -q -e .

# Install google-deepmind/gemma as it's a required dependency for JORA
!pip install -q git+https://github.com/google-deepmind/gemma.git

Cloning into 'JORA'...
remote: Enumerating objects: 299, done.[K
remote: Counting objects: 100% (299/299), done.[K
remote: Compressing objects: 100% (216/216), done.[K
remote: Total 299 (delta 151), reused 203 (delta 71), pack-reused 0 (from 0)[K
Receiving objects: 100% (299/299), 6.99 MiB | 7.89 MiB/s, done.
Resolving deltas: 100% (151/151), done.
/content/JORA
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.1/57.1 MB[0m [31m41.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m320.1/320.1 kB[0m [31m22.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.9/94.9 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

### Import the dependencies

In [None]:
# Patch JORA's initialisation.py file to be compatible with the latest JAX version

!sed -i "s/jax\.config\.update('jax_default_matmul_precision', *jax\.lax\.Precision\.HIGHEST)/jax.config.update('jax_default_matmul_precision', 'bfloat16')/" jora/lib/proc_init_utils/initialisation.py

In [None]:
import kagglehub
import jax
import jora
import pathlib
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import snapshot_download

## Download the Gemma Model

Now, you can download the Gemma model using `kagglehub`:

In [None]:
VARIANT = "1.1-2b-it"
GEMMA_PATH = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')

print('GEMMA_PATH:', GEMMA_PATH)

Downloading 14 files:   0%|          | 0/14 [00:00<?, ?it/s]

Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/1.1-2b-it/1/download/2b-it/ocdbt.process_0/manifest.0000000000000002...
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/1.1-2b-it/1/download/2b-it/d/76bc26e743d2a486ffd9322898d08435...
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/1.1-2b-it/1/download/2b-it/ocdbt.process_0/manifest.ocdbt...
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/1.1-2b-it/1/download/2b-it/ocdbt.process_0/manifest.0000000000000003...
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/1.1-2b-it/1/download/2b-it/ocdbt.process_0/d/b09eb83b9b0c3b46e65c59c98c25f6bc...
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/1.1-2b-it/1/download/2b-it/ocdbt.process_0/d/82f86d35f73a7dd200794e77ea112419...
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/Flax/1.1-2b-it/1/download/2b-it/ocdbt.process_0/manifest.000000000000

**Note:** By default, `kagglehub` stores the model in the `~/.cache/kagglehub` directory.

Verify that JAX recognizes the GPU devices:

In [None]:
print(jax.devices())

[CudaDevice(id=0)]


## Configure JORA and Prepare the Dataset

Here, you'll configure the Gemma model and also the training process for **LoRA** fine-tuning.

In order to fine-tune Gemma, you will use the **Alpaca** dataset. Ensure you have the dataset file `alpaca_data_cleaned.json` in the appropriate directory. You can download it from [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data_cleaned.json) or use the one that's bundled in the repository. For demonstration purposes, let's use the bundled one.

**Credits:** [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json)

The `generate_alpaca_dataset` function is used to generate the dataset from an Alpaca format JSOB file. This helps with instruct format training since the dataset processing, tokenization, and batching is handled by the library. Alternatively, torch `Dataset` and `DataLoader` can be used for custom datasets.


In [None]:
# Configure the model and training parameters
config = jora.ParagemmaConfig(
    # Feel free to tweak these parameters
    N_EPOCHS=1,
    LORA_R=8,
    # Note: The `LORA_DROPOUT` parameter is currently not configurable.
    # https://github.com/aniquetahir/JORA?tab=readme-ov-file#contributing
    LORA_ALPHA=16,
    LR=1e-5,
    BATCH_SIZE=2,
    N_ACCUMULATION_STEPS=8,
    GEMMA_MODEL_PATH=GEMMA_PATH,
    MAX_SEQ_LEN=512,
    # Set `MODEL_VERSION` to '2b-it'
    # https://github.com/aniquetahir/JORA/blob/master/jora/lib/gemma/common.py#L282
    MODEL_VERSION='2b-it'
)

# Path to the Alpaca dataset
dataset_path = 'jora/alpaca_data_cleaned.json'

# Generate the dataset
dataset = jora.generate_alpaca_dataset_gemma(
    dataset_path, 'train', config,
    split_percentage=0.2,
    alpaca_mix=0.3
)

Processing data...


The `ParagemmaConfig` class is used to set up the configuration for training while `generate_alpaca_dataset_gemma` processes the dataset, handles tokenization, and prepares it for training.

In [None]:
config

ParagemmaConfig(GEMMA_MODEL_PATH='/root/.cache/kagglehub/models/google/gemma/Flax/1.1-2b-it/1', MODEL_VERSION='2b-it', NUM_SHARDS=None, LORA_R=8, LORA_ALPHA=16, LORA_DROPOUT=0.05, LR=1e-05, BATCH_SIZE=2, N_ACCUMULATION_STEPS=8, MAX_SEQ_LEN=512, N_EPOCHS=1, SEED=420, CACHE_SIZE=30)

## Fine-tune Gemma with **JORA**

Now, you can proceed to fine-tune the model using the `train_lora_gemma` function which initiates the fine-tuning process using LoRA (Low-Rank Adaptation).

In [None]:
# Path to the trained LoRA weights
checkpoint_path = 'checkpoints'
jora.train_lora_gemma(config, dataset, checkpoint_path)

Checkpoints will be saved in the folder specified by `checkpoint_path`.

## Convert the model to the **Hugging Face Format**

After fine-tuning, you need to convert the trained model to the Hugging Face format for compatibility with the Hugging Face ecosystem so that you can easily run inference later.

**Usage:**

```python
lorize_huggingface(HUGGINGFACE_PATH, JAX_PATH, SAVE_PATH, gemma=True)
```

- **HUGGINGFACE_PATH**: Path to the Hugging Face Gemma model (the base model before fine-tuning).
- **JAX_PATH**: Path to the LoRA merged parameters (the trained LoRA weights).
- **SAVE_PATH**: Path to save the updated Hugging Face Gemma fine-tuned model.
- **gemma**: Flag indicating you're working with a Gemma model.

First, specify the paths:

In [None]:
# Specify the repository
repo_id = "google/gemma-1.1-2b-it"
local_dir = 'pretrained'

snapshot_download(
    repo_id=repo_id,
    local_dir=local_dir,
    revision="main",
    ignore_patterns=['*.gguf']
)

HUGGINGFACE_PATH = local_dir  # Path to the base Hugging Face model
JAX_PATH = 'checkpoints/jax_lora_final.pickle'  # Path to the trained LoRA JAX weights
SAVE_PATH = 'gemma-ft'  # Path to save the converted Hugging Face model

Then, run the converter:

In [None]:
from jora.hf.__main__ import lorize_huggingface

lorize_huggingface(HUGGINGFACE_PATH, JAX_PATH, SAVE_PATH, gemma=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

model loaded
model saved to gemma-ft


- The `jora.hf` module converts the JAX-trained model back to the Hugging Face format.
- It merges the LoRA weights with the original model parameters.
- The converted model is saved in the specified `SAVE_PATH`.

## Load the Model and Generate Text

Finally, you can load the converted model using Hugging Face's Transformers library.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_PATH)
model = AutoModelForCausalLM.from_pretrained(SAVE_PATH, device_map="auto")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Here, the tokenizer and model are first loaded and then the model is moved to the GPU.

Generate text using the model. To do this, you'll use the Alpaca prompt format:

In [None]:
# Define the Alpaca prompt template
alpaca_prompt = """\
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
"""

# Function to generate response
def generate_response(instruction, input_text="", max_new_tokens=384):
    prompt = alpaca_prompt.format(instruction, input_text)
    device = "cuda"
    inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
    outputs = model.generate(inputs, max_new_tokens=max_new_tokens)
    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(text)

In [None]:
generate_response(
    instruction="Make a prediction about what will happen in the next paragraph.",
    input_text="Mary had been living in the small town for many years and had never seen anything like what was coming.",
)

Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Make a prediction about what will happen in the next paragraph.

### Input:
Mary had been living in the small town for many years and had never seen anything like what was coming.

### Response:
Mary will be surprised by what she will see.


In [None]:
generate_response(
    instruction="Identify a suitable <verb> in the following sentence.",
    input_text="Cat <verb> in the garden.",
)

Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Identify a suitable <verb> in the following sentence.

### Input:
Cat <verb> in the garden.

### Response:
Cat sat in the garden.


## Push the model to your Hugging Face Hub


Optionally, Hugging Face allows to you easily store trained models in their hub.

In [None]:
# Note: The token needs to have "write" permission
#       You can check it here:
#       https://huggingface.co/settings/tokens
# Uncomment and run this if you wish to publish the model to Hugging Face Hub
# model.push_to_hub("my-gemma-finetuned-model")

In this tutorial, you have learnt how to fine-tune a Gemma model using JORA and convert it to the Hugging Face model format for inference. By leveraging JAX's JIT compilation and tensor-sharding capabilities, you can achieve efficient resource management, enabling accelerated fine-tuning with reduced memory requirements.

This framework improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources.