##### Copyright 2024 Google LLC.

In [1]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Gemma - finetune with Axolotl

This notebook demonstrates how to finetune Gemma with Axolotl. [Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures. Axolotl wraps the Hugging Face finetuning functionality and provides a simple interface for finetuning.
It's very easy to finetune Gemma with Axolotl. This notebook follows the [official Colab notebook](https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb) closely.

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/Finetune_with_Axolotl.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

## Setup

### Select the Colab runtime
To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:

1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **T4 GPU**.

### Install Axolotl

### Install PyTorch

In [2]:
!pip install torch=="2.1.2"

Collecting torch==2.1.2
  Downloading torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.1.2)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.1.2)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.1.2)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch==2.1.2)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.1.2)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2

### Colab runtime

At this point, restart your Colab runtime for the newly installed PyTorch version to take effect.

### Install Axolotl


In [3]:
!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl
# T4 does not support flash attention
# !pip install flash-attn=="2.5.0"
!pip install deepspeed=="0.13.1"
!pip install mlflow=="2.13.0"

Obtaining axolotl from git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl
  Cloning https://github.com/OpenAccess-AI-Collective/axolotl to ./src/axolotl
  Running command git clone --filter=blob:none --quiet https://github.com/OpenAccess-AI-Collective/axolotl /content/src/axolotl
  Resolved https://github.com/OpenAccess-AI-Collective/axolotl to commit 1f151c0d52d2d4c78c5e1b1a4ff4fb64cba1f45d
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting fschat@ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe (from axolotl)
  Cloning https://github.com/lm-sys/FastChat.git (to revision 27a05b04a35510afb1d767ae7e5990cbd278f8fe) to /tmp/pip-install-j71xwbh1/fschat_d8a49c65bcf3440d862b4934c0307124
  Running command git clone --filter=blob:none --quiet https://github.com/lm-sys/FastChat.git /tmp/pip-install-j71xwbh1/fschat_d8a49c65bcf3440d862b4934c0307124
  Running command git rev-parse -q --verify 'sha^27a05b04a35510afb1d767ae7e5990cbd27

## Finetune Gemma

Axolotl uses YAML config files to specify finetuning parameters. The YAML file below is adapted from the official [Gemma QLoRA example](https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/examples/gemma/qlora.yml).


In [4]:
import yaml

# Your YAML string
yaml_string = """
base_model: google/gemma-2b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

# huggingface repo
datasets:
  - path: mhenrichsen/alpaca_2k_test
    type: alpaca
val_set_size: 0.1
output_dir: ./outputs/out

adapter: qlora
lora_r: 4
lora_alpha: 4
lora_dropout: 0.05
lora_target_linear: true

sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:


gradient_accumulation_steps: 3
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
# T4 does not support BF16
bf16: false
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
# T4 does not support flash attention
flash_attention: false

warmup_ratio: 0.1
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

"""

# Convert the YAML string to a Python dictionary
yaml_dict = yaml.safe_load(yaml_string)

# Specify your file path
file_path = "gemma_axolotl.yaml"

# Write the YAML file
with open(file_path, "w") as file:
    yaml.dump(yaml_dict, file)

### Kick off finetuning

### Gemma setup on Hugging Face
Axolotl uses Hugging Face under the hood. So you will need to:

* Get access to Gemma on [huggingface.co](huggingface.co) by accepting the Gemma license on the Hugging Face page of the specific model, i.e., [Gemma 2B](https://huggingface.co/google/gemma-2b).
* Generate a [Hugging Face access token](https://huggingface.co/docs/hub/en/security-tokens) and configure it as a Colab secret 'HF_TOKEN'.

In [5]:
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

Now kick off finetuning.

In [6]:
!python -m axolotl.cli.train /content/gemma_axolotl.yaml

[2024-06-04 04:29:36,530] [INFO] [numexpr.utils._init_num_threads:161] [PID:3985] NumExpr defaulting to 8 threads.
[2024-06-04 04:29:36,731] [INFO] [datasets.<module>:58] [PID:3985] PyTorch version 2.1.2 available.
[2024-06-04 04:29:36,733] [INFO] [datasets.<module>:70] [PID:3985] Polars version 0.20.2 available.
[2024-06-04 04:29:36,733] [INFO] [datasets.<module>:105] [PID:3985] TensorFlow version 2.15.0 available.
[2024-06-04 04:29:36,734] [INFO] [datasets.<module>:118] [PID:3985] JAX version 0.4.26 available.
2024-06-04 04:29:38.359720: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-04 04:29:38.359772: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-04 04:29:38.361162: E external/local_xla/xla/stream_e

## Upload finetuned model to Hugging Face
### Merge LoRA adapter
Mering the adapter takes quite a bit memory so you may need to use the high-RAM Colab instance to avoid crashing.

In [7]:
!python -m axolotl.cli.merge_lora /content/gemma_axolotl.yaml --lora_model_dir="./outputs/out"

[2024-06-04 05:40:00,263] [INFO] [numexpr.utils._init_num_threads:161] [PID:22017] NumExpr defaulting to 8 threads.
[2024-06-04 05:40:00,421] [INFO] [datasets.<module>:58] [PID:22017] PyTorch version 2.1.2 available.
[2024-06-04 05:40:00,422] [INFO] [datasets.<module>:70] [PID:22017] Polars version 0.20.2 available.
[2024-06-04 05:40:00,423] [INFO] [datasets.<module>:105] [PID:22017] TensorFlow version 2.15.0 available.
[2024-06-04 05:40:00,424] [INFO] [datasets.<module>:118] [PID:22017] JAX version 0.4.26 available.
2024-06-04 05:40:01.500291: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-04 05:40:01.500346: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-04 05:40:01.501636: E external/local_xla/xla/str

### Push model to Hugging Face Hub

In [8]:
from transformers import AutoModel

model = AutoModel.from_pretrained("./outputs/out/merged", local_files_only=True)
model.push_to_hub("gemma-2-finetuned-model-axolotl")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


[2024-06-04 05:40:52,151] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


model-00003-of-00003.safetensors:   0%|          | 0.00/134M [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.91G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/windmaple/gemma-2-finetuned-model-axolotl/commit/eafbcf827d2bc9d77f177d39009e65ae10c455ff', commit_message='Upload model', commit_description='', oid='eafbcf827d2bc9d77f177d39009e65ae10c455ff', pr_url=None, pr_revision=None, pr_num=None)

## Conclusion

This notebook demonstrates how to use Axolotl to do instruction tuning for the Gemma 2B model. If you want to finetune with another dataset, please check out the Axolotl documentation on [Dataset Formats](https://openaccess-ai-collective.github.io/axolotl/docs/dataset-formats/).