<a href="https://colab.research.google.com/github/hululuzhu/chinese-ai-writing-share/blob/main/further_finetune_example/gemma_lora_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gemma + LoRA = Finetune on Consumer level GPU
- branched from my own [llama finetune repo](https://github.com/hululuzhu/llama-lora-chinese-couplet)
- Last update: 02/21/2024
- Contact: hululu.zhu@gmail.com
- **TODO: make sure the tokens match Gemma settings**


Zero-shot Examples
- after 2 epochs of 5k pairs, cap max tokens, greedy
- post-processing to match # of chinese chars
- ideally a well trained model will know end of sentence (eos) itself
- prompt: `对联：{上联}\n下联：`



## Prerequisites
- Nvidia GPU, check if 10G HBM (High Bandwidth Memory) ram available
- pip install software

In [None]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-0ef16b0b-1c3e-c7e9-13d1-bc4d3c9ba230)


In [None]:
!pip install -q nvidia-ml-py3
import nvidia_smi
nvidia_smi.nvmlInit()
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
# card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate
info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
nvidia_smi.nvmlShutdown()

print("Total memory:", info.total)
print("Free memory:", info.free)
print("Used memory:", info.used)

assert info.free > 1e10, (
    "Looks like your GPU is busy or not having enough 10G memory to continue")

Total memory: 16106127360
Free memory: 15835529216
Used memory: 270598144


In [None]:
# As of 02/21/2024, we need latest transformer to pick up Gemma tokenizer
# Note we might need restart the instance to pick the changes
!pip install git+https://github.com/huggingface/transformers
!pip install -q bitsandbytes
!pip install -q datasets loralib sentencepiece
!pip install -q peft

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 MB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.7/536.7 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m16.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.4/183.4 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.0/280.0 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[?25h

## All the Imports

In [None]:
# disable warnings unless needed
import warnings
warnings.filterwarnings('ignore')

In [None]:
from datasets import Dataset, load_dataset
import numpy as np
import os
import pandas as pd
import pathlib
from peft import PeftModel, get_peft_config, get_peft_model, LoraConfig, TaskType, prepare_model_for_int8_training
import pickle
import sys
import torch
import transformers
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig, AutoModelForSeq2SeqLM, DataCollatorForLanguageModeling

## Define top-level configs

In [None]:
# Select your model
# MODEL = "llama-1-7b"  #@param ["llama-1-7b", "llama-2-7b", "llama-2-7b-chat"]
MODEL = "gg-hf/gemma-7b-it"
model_name_or_path = MODEL
tokenizer_name_or_path = MODEL

# Max num of tokens (including prompt and output), chinese encoding takes more
# than # of chars as observed
CUTOFF_LEN = 96
# Predict training prompt as well to increase quality as Alpaca Lora does.
# Turn off to speedup, but might affect quality.
TRAIN_ON_INPUT = True

In [None]:
# Required to login to get gemma checkpoints
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Token: 
Add token as git credential? (Y/n) Y
Token is valid (permission: read).
[1m[31mCannot authenticate through g

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Somehow Gemma takes way more memory than llama2, why why?
quantization_config = BitsAndBytesConfig(load_in_4bit=True)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",
                                            #  device_map="auto",
                                            #  load_in_8bit=True, # 8bit seems to take more than 9G memory...
                                            quantization_config=quantization_config,)


`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
# TODO check if this right, copied from llama
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"  # Allow batched inference

## Load Training data

In [None]:
# Reuse my T5 couplet data code https://github.com/hululuzhu/chinese-ai-writing-share/blob/main/training/t5_finetune/Mengzi_T5_Finetune_Chinese_Couplet_V1.ipynb
working_dir = "/tmp/working_dir"
!mkdir -p {working_dir}
!wget https://github.com/wb14123/couplet-dataset/releases/download/1.0/couplet.tar.gz -P {working_dir}
!ls -l {working_dir}
!mkdir -p {working_dir}/couplet_files
!tar -xf {working_dir}/couplet.tar.gz -C {working_dir}/couplet_files
!head -1 {working_dir}/couplet_files/couplet/train/in.txt {working_dir}/couplet_files/couplet/train/out.txt

COUPLET_PATH = f'{working_dir}/couplet_files/couplet'
MAX_SEQ_LEN = 32  # Max 32 chinese char including punctuation marks

train_df, test_df = None, None
for t in ['train', 'test']:
  ins, outs = [], []
  for i in ['in', 'out']:
    with open(f"{COUPLET_PATH}/{t}/{i}.txt", "r") as f:
      for line in f:
        clean_line = line.strip().replace(' ', '').replace('\n', '').replace('\r', '')[:MAX_SEQ_LEN]
        if i=='in':
          ins.append(clean_line)
        else:
          outs.append(clean_line)
  # The column names to match simpleT5
  data_dict = {
      'source_text': ins,
      'target_text': outs,
  }
  if t == 'train':
    train_df = pd.DataFrame(data_dict)
  else:
    test_df = pd.DataFrame(data_dict)

COUPLET_PROMPOT = '对联：'
COUPLET_SUFFIX = '\n下联：'
train_df['source_text'] = COUPLET_PROMPOT + train_df['source_text'] + COUPLET_SUFFIX
test_df['source_text'] = COUPLET_PROMPOT + test_df['source_text'] + COUPLET_SUFFIX

--2024-02-22 04:22:35--  https://github.com/wb14123/couplet-dataset/releases/download/1.0/couplet.tar.gz
Resolving github.com (github.com)... 20.27.177.113
Connecting to github.com (github.com)|20.27.177.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/122695108/9643dda6-194e-11e8-9642-44c7d57d40ac?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240222%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240222T042050Z&X-Amz-Expires=300&X-Amz-Signature=d33d16a9ba8242ef7c963f185e1d9cad2d5b85e109cca268794abb86e2beaf3a&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=122695108&response-content-disposition=attachment%3B%20filename%3Dcouplet.tar.gz&response-content-type=application%2Foctet-stream [following]
--2024-02-22 04:22:35--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/122695108/9643dda6-194e-11e8-9642-44c7d57d40ac?X-Am

In [None]:
# Sample 5k
train_df_sample = train_df[['source_text', 'target_text']].sample(5000)
train_df_sample

Unnamed: 0,source_text,target_text
138048,对联：柴扉半掩云深处\n下联：,僧客迟归日暮时
552529,对联：卅年梦寐惊回首\n下联：,一样亭台只断肠
597006,对联：竹报平安户\n下联：,花开富贵门
567235,对联：豪啸松间风壮意\n下联：,苦吟月下酒抒怀
734209,对联：竹林踏月独醉我\n下联：,梅海迎风自迷予
...,...,...
326710,对联：莫姿高情求逸思\n下联：,只缘幽恨在新诗
42032,对联：沉毅自从容，生死关头，结成屏障横山势\n下联：,舒张原有度，烽烟散后，响起驼铃大漠声
75045,对联：眸前旋转通天塔，撩人望日\n下联：,域外飞来作地标，煮海听涛
235626,对联：茶艺道行深，红炉煮出一壶诗，谁为高手\n下联：,茗乡文化厚，绿袖斟来千尺瀑，我上奖台


## Convert Data to Training-friendly DataSet

In [None]:
# Copied from Alpaca-LoRA, notice input_ids, attention_mask, and labels are
# default expected columns in huggingface dataset lib
def tokenize(tokenizer, prompt, cutoff_len, add_eos_token=True):
  # there's probably a way to do this with the tokenizer settings
  # but again, gotta move fast
  result = tokenizer(
      prompt,
      truncation=True,
      max_length=cutoff_len,
      padding=False,
      return_tensors=None,
  )
  if (
      result["input_ids"][-1] != tokenizer.eos_token_id
      and len(result["input_ids"]) < cutoff_len
      and add_eos_token
  ):
    result["input_ids"].append(tokenizer.eos_token_id)
    result["attention_mask"].append(1)

  # result["labels"] = copy.deepcopy(result["input_ids"])
  result["labels"] = result["input_ids"].copy()
  return result


# Branched from Alpaca-LoRA
def tokenize_fn(data_point):
  prompt_in, prompt_out = data_point['source_text'], data_point['target_text']
  full_prompt = prompt_in + prompt_out
  tokenized_full_prompt = tokenize(tokenizer, full_prompt, CUTOFF_LEN)
  if not TRAIN_ON_INPUT:
    user_prompt = prompt_in
    tokenized_user_prompt = tokenize(tokenizer, user_prompt, CUTOFF_LEN, add_eos_token=False)
    user_prompt_len = len(tokenized_user_prompt["input_ids"])
    tokenized_full_prompt["labels"] = [
        -100 # special id for skipping
    ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
  return tokenized_full_prompt


train_ds = Dataset.from_pandas(train_df_sample)
train_ds = train_ds.flatten()
tokenized_train_ds = train_ds.map(
    tokenize_fn,
    remove_columns=['source_text', 'target_text', '__index_level_0__'],
)

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [None]:
# Optionally check a few examples by decoding the inputs
for i in range(100, 103):
  print("token length", len(tokenized_train_ds['input_ids'][i]))
  print(tokenizer.decode(tokenized_train_ds['input_ids'][i]))
  print("Label ids", tokenized_train_ds['labels'][i])
  print()

token length 12
<bos>对联：闪电
下联：鸣雷<eos>
Label ids [2, 235735, 236612, 235465, 199171, 108, 235543, 236612, 235465, 239369, 237184, 1]

token length 22
<bos>对联：群蜂起舞争风采
下联：一马当先赛水平<eos>
Label ids [2, 235735, 236612, 235465, 236783, 239137, 235803, 236993, 237124, 236229, 237266, 108, 235543, 236612, 235465, 235411, 236502, 235804, 235938, 237013, 50499, 1]

token length 21
<bos>对联：二分明月怀中寄
下联：十里春风面上柔<eos>
Label ids [2, 235735, 236612, 235465, 235796, 186144, 235612, 237810, 235493, 237585, 108, 235543, 236612, 235465, 235905, 235792, 236544, 236229, 100079, 237640, 1]



## LoRA setup
- Check out LoRA paper
- Why q_proj and v_proj?
  - ![](https://lh3.googleusercontent.com/pG1o98-ZuTdvGaIuf3r0_GQ2wqZv1eAjaM13ki_AoipSm4Vo0v3JCynmU26PjE_6qKvyLdiDlZfQP8mGpvy0hG6TMDM-ROpup35WYmH3lBiGVC67tQL2kZnNIVzz0gviU88lq7yP126N1DCaKxkvXd1vE5TBwasBTH2waI_QbcyT324snp5iOCJXDrMa9bPokbM4w8PwJL61lqfGXmOvbP2Yqo5gOC7kA73aZPMOG3CnWzFujZpbQ5so7ZnHNOvmBWSjKUHvqI8UbvJvAy43SXL2UePcFg-KcWAA9gCUscNxKOOtai0_6ShZgZLCXCMLvLYpbqK6IqtYTS7-dwQMYQrRJ80IyPxMYwfSLaYn2UVd0I04ETFCqOH-pDtsToZ3eCGqQi-zxLdcDUpqhcXxj60PjxpMOyFK_wCK1tKEx7hu7nUDR4GYIIRFtNZS_jMhFaKhqcZf6d3Vora-2v0Sv_CVhDTy5cabVpSDEqBWpGMiCcj5IvnBIRAkPY5D_Mr5elWSCuanOXMp9riwK2-WobJoNvW7qATFAr3aiTA5MCQPqwvkOXhpj9YF7QudshxaplDzpBiLxJbdvzE-froAlxAup2yDEhEOb_xuvRBLetvL366GOEivlq577Y8MTusVcz_b9ex6TP77_XjRHAp4lQ7Bs7tR2tjY-n29bC1MhGB_t7Ta82MdLivR-T5lG4hvhGJ-rTsqMkUm0KY-Vqup-04eZHBMkY1RHjj7oNc8vDXHbTiFskLfne5Trr0_3MCZamyRZuwPeZXzFlzbif1lSBwXpSk0ckzPMGRFhiDZ0sa3QUrLeyvGA5UzHIhqHL0Ve-f03V0z48o_YoHSdWrhN8xZJb6ga-eGu0MM9f5VxE7Y9znQ4qE9_5neS6GBHvA0-YXjzZ7INP9KVgKpX_FTuAuegL7ARB1gG4lbXKWVKQS38g=w1577-h337-s-no?authuser=0)
- why r=16 (>=8)
  - ![](https://lh3.googleusercontent.com/sxLGQpoBbmjnZwK853wFOcgEgzvJIa7wOpaH72v1eNw9gI9VaMvhpWGhzPCPowSuG44wzO53ENrXGMrdoXXhTjPmy1jRVvAMqbFYiwcCU4sZ0jqOe2vP1I9hEw-syKqpPW1-Nr5TM10Qm8MYXuigatFPNl76FSxYXBRHNcZRjeluGPxMjz78SXzBa07j6YomCGQJCyx5QTRVfhWw7iy4dbb1rybldeodUvY1xI8XzTzQeclYhE8kLI6yN7J02LKpkhHmzMgFY0Qr73gvoINGmZyguJItJ0ZcaR3zBJNfIIaBSaYe3amB4qL7zWu6sYOxfdBk8v_lWLCCTys1_ThFoiUhLDrHRK1LX5QELQTf_MAlFVk5qisF2dZo3GFEt5bKga2CNwSH-I5FjL9Z6jnopRDHCs1JGVmYn6sgdLhrG7fbj83hAb5NLwSfebi5pjYRASAVVC18hNo1ZkG_TCTPJv2__KveWNjalDkSWWEVzMZO6ZlnoLMtwEA_KfqaDNEdVTs2wa_-dsbXijPDkF0bSdiDqtiAe6Nk_sL1iEoNMswMvCGOD5orD4oojigkteh-xdyaQ3W0mEqC3HXEaPAAKHQOp2V5XBIMi-wIo_M-bjK304SA68jvmizpyYTI-yTiAb_B8lUR0PeAMp1avZux4NdXVZZu1wWlzgpr3HC7pU7ZmqgO_xJrU-ICMrkeqy8eYOy-jy1NrK03Y9NsT7-JTxg5HHGBptMKkJORT7IIQRl_eCgw0WMu9Bc9ueSIbSCLQgZ_WdaMe3wSjLkj8NmRgQ83HW686Ww54xfwscMq8l97MgaJobKqRvagOzx2KG_cMthGWfkIqVFCdfTgBv3c8Mf7lBDwduxsyfNbLuPSxW_NI4UmxN7Tkx5xN_qBBI5prltAYX_jYhFtq6JNd2bnWgqDBfDDkHp96Wj3kMRF43A8sw=w1804-h507-s-no?authuser=0)

In [None]:
model = prepare_model_for_int8_training(model)

config = LoraConfig(
    r=16,
    lora_alpha=32, # scaling param related to r, reuse alpaca-lora
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

lora_model = get_peft_model(model, config)

## Training
- Show before and after!

In [None]:
# Check out meaning of the Chinese char using ChatGPT
def eval_model(my_model, examples=["上联：春风得意花铺路\n下联：",
                                   "上联：美丽中国魅力北京\n下联：",
                                   "上联：鱼书千里梦\n下联：",
                                   "对联：日落晚霞临古寺\n下联：",]):
  for p_in in examples:
    batch = tokenizer(
        p_in,
        return_tensors='pt',
    )
    with torch.cuda.amp.autocast(): # required for mixed precisions
      output_tokens = my_model.generate(
          **batch, max_new_tokens=batch['input_ids'].shape[-1])
    # print(output_tokens[0])
    out = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    # My own post-processing logic to "cheat" to align chars
    if len(out) > len(p_in) * 2 - 7:
      out = out[:len(p_in) * 2 - 7 - len(out)] # perfectly match chars
    # replace the last N for visibility
    if out.count('\n') > 1:
      out = out[::-1].replace("\n", "n\\", 1)[::-1]
    print(out)
    print()

In [None]:
# Different GPUs may give out slightly different answers below due to very small precision difference
print("Before training")
eval_model(lora_model)

Before training
上联：春风得意花铺路
下联：花香满鼻，春暖

上联：美丽中国魅力北京
下联：中国传统建筑风格

上联：鱼书千里梦
下联：笑口在笑。

对联：日落晚霞临古寺
下联：花香满院，迎风



In [None]:
torch.cuda.empty_cache()

In [None]:
# As you can tell, I even omitted eval_dataset for this demo :(
trainer = transformers.Trainer(
    model=lora_model,
    train_dataset=tokenized_train_ds,
    args=transformers.TrainingArguments(
        # increased batch size will significantly increase GPU requirement here
        # Decrease to 4 if you have less than 16G vram
        # Batch = 4, probably 8.3-8.8G vram
        # Batch = 16, 9.5G+
        # Batch = 32, 11G+
        # Batch = 64, 14G+
        per_device_train_batch_size=4,
        gradient_accumulation_steps=2,
        warmup_steps=8,
        num_train_epochs=2,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=20,
        output_dir='outputs',
        remove_unused_columns=False,
    ),
    data_collator=transformers.DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True,
    ),
)
lora_model.config.use_cache = False # Alpaca Lora sets this for training
trainer.train()

Step,Training Loss
20,30.7854
40,14.3419
60,5.215
80,4.8392
100,4.5699
120,4.4319
140,4.5515
160,4.4646
180,4.3566
200,4.3172


In [None]:
# Empirical quick tests showed "somehow ok" results if loss < THRESHOLD
print("After training")
eval_model(lora_model)

## Suggested additional reading
- [Decoding algorithm by HF](https://huggingface.co/blog/how-to-generate)
- So far, I only demoed greedy search (output token with highest prob at each position without looking ahead)

## Optional: Upload to HuggingFace and share with the world!
- And you should!

In [None]:
# from huggingface_hub import notebook_login
# notebook_login()
# YOUR_HF_ID = "YOUR_ID_PLZ"
# lora_model.push_to_hub(f"{YOUR_HF_ID}/chinese-couplet-gemma-lora-test-v0.1",
#                        use_auth_token=True,
#                        create_pr=True)
# # Go to huggingface and merge the PR to share with the world!