<a href="https://colab.research.google.com/github/hululuzhu/llama-lora-chinese-couplet/blob/main/LLaMA_LoRA_Chinese_Couplet_demo_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LLmMA + LoRA = Finetune on Consumer level GPU
- Inspired by [Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html), [LoRA](https://arxiv.org/abs/2106.09685), and [Alpaca LoRA](https://github.com/tloen/alpaca-lora)
- [A related intro slide deck](https://github.com/hululuzhu/llama-lora-chinese-couplet/blob/main/llama-lora-v1.0.pdf)
- Example used Chinese Couplet to avoid potential conflict of interest with my employer
  - Tesla T4 16G takes 35mins to show examples below
  - A100 40G takes 9mins to show some examples below
- Last update: 04/30/2023
- Contact: hululu.zhu@gmail.com


Zero-shot Examples
- after 3 epochs of 5k pairs, cap max tokens, greedy
- post-processing to match # of chinese chars
- ideally a well trained model will know end of sentence (eos) itself
- prompt: `对联：{上联}\n下联：`

|上联| Base LLaMA | LLaMa_LoRA_A100_9mins | LLaMa_LoRA_Tesla_T4_35mins |
| ----------- | ----------- | ----------- | ----------- |
|春风得意花铺路| 沉浸落泥\n上联 | 月光听声风吹梦 | 风雨吹梦浮浮� |
|美丽中国魅力北京| 美丽中国魅力北京\n上联： | 历史浓浅中华梦境 | 梦幻中国梦想宏碁|
|鱼书千里梦| 鱼肉烧肉\n | 鸟声万里声 | 鸟声万里声|
|日落晚霞临古寺| 晚霞临古寺\n上 | 月映晨雨满梦境 | 月映晨霜满梦境 |


## Prerequisites
- Nvidia GPU, check if 10G HBM (High Bandwidth Memory) ram available
- pip install software

In [None]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-ddf3f440-ca1f-dad7-f831-0092334cfc2f)


In [None]:
!pip install -q nvidia-ml-py3
import nvidia_smi
nvidia_smi.nvmlInit()
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
# card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate
info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
nvidia_smi.nvmlShutdown()

print("Total memory:", info.total)
print("Free memory:", info.free)
print("Used memory:", info.used)

assert info.free > 1e10, (
    "Looks like your GPU is busy or not having enough 10G memory to continue")

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for nvidia-ml-py3 (setup.py) ... [?25l[?25hdone
Total memory: 16106127360
Free memory: 15835267072
Used memory: 270860288


In [None]:
!pip install -q bitsandbytes
!pip install -q datasets loralib sentencepiece
!pip install -q peft
!pip install -q transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.3/104.3 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m50.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.2/212.2 kB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.9/132.9 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m56.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━

## All the Imports

In [None]:
# disable warnings unless needed
import warnings
warnings.filterwarnings('ignore')

In [None]:
from datasets import Dataset, load_dataset
import numpy as np
import os
import pandas as pd
import pathlib
from peft import PeftModel, get_peft_config, get_peft_model, LoraConfig, TaskType, prepare_model_for_int8_training
import pickle
import sys
import torch
import transformers
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig, AutoModelForSeq2SeqLM, DataCollatorForLanguageModeling


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /usr/local/lib/python3.9/dist-packages/bitsandbytes/libbitsandbytes_cuda118.so
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /usr/local/lib/python3.9/dist-packages/bitsandbytes/libbitsandbytes_cuda118.so...


## Define top-level configs

In [None]:
# Check out more details at https://huggingface.co/decapoda-research/llama-7b-hf
# Not for commercial use.
model_name_or_path = "decapoda-research/llama-7b-hf"
tokenizer_name_or_path = "decapoda-research/llama-7b-hf"

# Max num of tokens (including prompt and output), chinese encoding takes more
# than # of chars as observed
CUTOFF_LEN = 96
# Predict training prompt as well to increase quality as Alpaca Lora does.
# Turn off to speedup, but might affect quality.
TRAIN_ON_INPUT = True

## Load LLaMa 7B and tokenizer
- Takes about 5 mins, 13G+ model weights downloaded
- After loading, GPU usage is 7.6G+

In [None]:
original_8bit_llama_model = LlamaForCausalLM.from_pretrained(
    model_name_or_path,
    device_map="auto",
    load_in_8bit=True)

# set padding id and side based on https://github.com/tloen/alpaca-lora/blob/main/finetune.py#L121
tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name_or_path)
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"  # Allow batched inference



Downloading (…)lve/main/config.json:   0%|          | 0.00/427 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/25.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/33 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00002-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00003-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00004-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00005-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00006-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00007-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00008-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00009-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00010-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00011-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00012-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00013-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00014-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00015-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00016-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00017-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00018-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00019-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00020-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00021-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00022-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00023-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00024-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00025-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00026-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00027-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00028-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00029-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00030-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00031-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00032-of-00033.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (…)l-00033-of-00033.bin:   0%|          | 0.00/524M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.


## Load Training data

In [None]:
# Reuse my T5 couplet data code https://github.com/hululuzhu/chinese-ai-writing-share/blob/main/training/t5_finetune/Mengzi_T5_Finetune_Chinese_Couplet_V1.ipynb
working_dir = "/tmp/working_dir"
!mkdir -p {working_dir}
!wget https://github.com/wb14123/couplet-dataset/releases/download/1.0/couplet.tar.gz -P {working_dir}
!ls -l {working_dir}
!mkdir -p {working_dir}/couplet_files
!tar -xf {working_dir}/couplet.tar.gz -C {working_dir}/couplet_files
!head -1 {working_dir}/couplet_files/couplet/train/in.txt {working_dir}/couplet_files/couplet/train/out.txt

COUPLET_PATH = f'{working_dir}/couplet_files/couplet'
MAX_SEQ_LEN = 32  # Max 32 chinese char including punctuation marks

train_df, test_df = None, None
for t in ['train', 'test']:
  ins, outs = [], []
  for i in ['in', 'out']:
    with open(f"{COUPLET_PATH}/{t}/{i}.txt", "r") as f:
      for line in f:
        clean_line = line.strip().replace(' ', '').replace('\n', '').replace('\r', '')[:MAX_SEQ_LEN]
        if i=='in':
          ins.append(clean_line)
        else:
          outs.append(clean_line)
  # The column names to match simpleT5
  data_dict = {
      'source_text': ins,
      'target_text': outs,
  }
  if t == 'train':
    train_df = pd.DataFrame(data_dict)
  else:
    test_df = pd.DataFrame(data_dict)

COUPLET_PROMPOT = '对联：'
COUPLET_SUFFIX = '\n下联：'
train_df['source_text'] = COUPLET_PROMPOT + train_df['source_text'] + COUPLET_SUFFIX
test_df['source_text'] = COUPLET_PROMPOT + test_df['source_text'] + COUPLET_SUFFIX

--2023-04-26 03:16:24--  https://github.com/wb14123/couplet-dataset/releases/download/1.0/couplet.tar.gz
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/122695108/9643dda6-194e-11e8-9642-44c7d57d40ac?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20230426%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20230426T031625Z&X-Amz-Expires=300&X-Amz-Signature=a4125116217e24fb68c7d5fd74df77d7319f46e1ee19a0f928b44d490752a708&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=122695108&response-content-disposition=attachment%3B%20filename%3Dcouplet.tar.gz&response-content-type=application%2Foctet-stream [following]
--2023-04-26 03:16:25--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/122695108/9643dda6-194e-11e8-9642-44c7d57d40ac?X-Amz-

In [None]:
# Sample 5k
train_df_sample = train_df[['source_text', 'target_text']].sample(5000)
train_df_sample

Unnamed: 0,source_text,target_text
230554,对联：颠三倒四无头绪\n下联：,去伪存真有窍门
194706,对联：受福寿无疆，周张仲簋\n下联：,敬风夜勿废，殷师酉敦
457600,对联：看峰头情物，此地宏开讲舍\n下联：,读壁上诗词，诸君景仰先型
88568,对联：我欲乘风逐日远\n下联：,谁将化雨润花开
444863,对联：大肚里头装的啥呀？打赌\n下联：,干喉开口想饮嘛呢？猜谜
...,...,...
160859,对联：惟圣惟勤惟断\n下联：,以德以功以言
479018,对联：玉树临风，挥舞几枝雪臂\n下联：,娇花照水，梳妆一缕云鬟
678699,对联：凶残一世，仍端居庙社，神乎？鬼也\n下联：,祸害四邻，竟雅号靖国，奇矣！怪哉
551707,对联：搭东台，唱西府，南腔北调\n下联：,作春田，栽夏禾，秋收冬藏


## Convert Data to Training-friendly DataSet

In [None]:
# Copied from Alpaca-LoRA, notice input_ids, attention_mask, and labels are
# default expected columns in huggingface dataset lib
def tokenize(tokenizer, prompt, cutoff_len, add_eos_token=True):
  # there's probably a way to do this with the tokenizer settings
  # but again, gotta move fast
  result = tokenizer(
      prompt,
      truncation=True,
      max_length=cutoff_len,
      padding=False,
      return_tensors=None,
  )
  if (
      result["input_ids"][-1] != tokenizer.eos_token_id
      and len(result["input_ids"]) < cutoff_len
      and add_eos_token
  ):
    result["input_ids"].append(tokenizer.eos_token_id)
    result["attention_mask"].append(1)

  # result["labels"] = copy.deepcopy(result["input_ids"])
  result["labels"] = result["input_ids"].copy()
  return result


# Branched from Alpaca-LoRA
def tokenize_fn(data_point):
  prompt_in, prompt_out = data_point['source_text'], data_point['target_text']
  full_prompt = prompt_in + prompt_out
  tokenized_full_prompt = tokenize(tokenizer, full_prompt, CUTOFF_LEN)
  if not TRAIN_ON_INPUT:
    user_prompt = prompt_in
    tokenized_user_prompt = tokenize(tokenizer, user_prompt, CUTOFF_LEN, add_eos_token=False)
    user_prompt_len = len(tokenized_user_prompt["input_ids"])
    tokenized_full_prompt["labels"] = [
        -100 # special id for skipping
    ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
  return tokenized_full_prompt


train_ds = Dataset.from_pandas(train_df_sample)
train_ds = train_ds.flatten()
tokenized_train_ds = train_ds.map(
    tokenize_fn,
    remove_columns=['source_text', 'target_text', '__index_level_0__'],
)

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [None]:
# Optionally check a few examples by decoding the inputs
for i in range(100, 103):
  print("token length", len(tokenized_train_ds['input_ids'][i]))
  print(tokenizer.decode(tokenized_train_ds['input_ids'][i]))
  print("Label ids", tokenized_train_ds['labels'][i])
  print()

token length 34
<unk> 对联：玉人微笑能倾国
下联：刚汉精忠可保邦<unk>
Label ids [0, 29871, 30783, 31986, 30383, 31395, 30313, 31935, 234, 175, 148, 30815, 232, 131, 193, 30356, 13, 30557, 31986, 30383, 232, 139, 157, 31593, 234, 181, 193, 31744, 30682, 30982, 236, 133, 169, 0]

token length 46
<unk> 对联：恩感风雷皆变化
下联：诗裁绵绣借光辉<unk>
Label ids [0, 29871, 30783, 31986, 30383, 233, 132, 172, 233, 135, 162, 236, 166, 145, 236, 158, 186, 234, 157, 137, 31462, 30705, 13, 30557, 31986, 30383, 235, 178, 154, 235, 166, 132, 234, 190, 184, 234, 190, 166, 232, 131, 162, 30867, 235, 193, 140, 0]

token length 96
<unk>对联：涪王兄弟、蕲王夫妇、鄂王父子，聚河岳精灵，仅留半壁
下联：两字君恩、四字母训、五字兵法，洒英雄涕泪，�
Label ids [0, 29871, 30783, 31986, 30383, 233, 185, 173, 30462, 232, 136, 135, 232, 191, 162, 30330, 235, 152, 181, 30462, 31176, 232, 169, 138, 30330, 236, 135, 133, 30462, 234, 139, 185, 30319, 30214, 235, 132, 157, 30828, 232, 181, 182, 234, 181, 193, 234, 132, 184, 30214, 231, 190, 136, 234, 152, 156, 232, 144, 141, 232, 166, 132, 13, 30557, 31986, 30383

## LoRA setup
- Check out LoRA paper
- Why q_proj and v_proj?
  - ![](https://lh3.googleusercontent.com/pG1o98-ZuTdvGaIuf3r0_GQ2wqZv1eAjaM13ki_AoipSm4Vo0v3JCynmU26PjE_6qKvyLdiDlZfQP8mGpvy0hG6TMDM-ROpup35WYmH3lBiGVC67tQL2kZnNIVzz0gviU88lq7yP126N1DCaKxkvXd1vE5TBwasBTH2waI_QbcyT324snp5iOCJXDrMa9bPokbM4w8PwJL61lqfGXmOvbP2Yqo5gOC7kA73aZPMOG3CnWzFujZpbQ5so7ZnHNOvmBWSjKUHvqI8UbvJvAy43SXL2UePcFg-KcWAA9gCUscNxKOOtai0_6ShZgZLCXCMLvLYpbqK6IqtYTS7-dwQMYQrRJ80IyPxMYwfSLaYn2UVd0I04ETFCqOH-pDtsToZ3eCGqQi-zxLdcDUpqhcXxj60PjxpMOyFK_wCK1tKEx7hu7nUDR4GYIIRFtNZS_jMhFaKhqcZf6d3Vora-2v0Sv_CVhDTy5cabVpSDEqBWpGMiCcj5IvnBIRAkPY5D_Mr5elWSCuanOXMp9riwK2-WobJoNvW7qATFAr3aiTA5MCQPqwvkOXhpj9YF7QudshxaplDzpBiLxJbdvzE-froAlxAup2yDEhEOb_xuvRBLetvL366GOEivlq577Y8MTusVcz_b9ex6TP77_XjRHAp4lQ7Bs7tR2tjY-n29bC1MhGB_t7Ta82MdLivR-T5lG4hvhGJ-rTsqMkUm0KY-Vqup-04eZHBMkY1RHjj7oNc8vDXHbTiFskLfne5Trr0_3MCZamyRZuwPeZXzFlzbif1lSBwXpSk0ckzPMGRFhiDZ0sa3QUrLeyvGA5UzHIhqHL0Ve-f03V0z48o_YoHSdWrhN8xZJb6ga-eGu0MM9f5VxE7Y9znQ4qE9_5neS6GBHvA0-YXjzZ7INP9KVgKpX_FTuAuegL7ARB1gG4lbXKWVKQS38g=w1577-h337-s-no?authuser=0)
- why r=16 (>=8)
  - ![](https://lh3.googleusercontent.com/sxLGQpoBbmjnZwK853wFOcgEgzvJIa7wOpaH72v1eNw9gI9VaMvhpWGhzPCPowSuG44wzO53ENrXGMrdoXXhTjPmy1jRVvAMqbFYiwcCU4sZ0jqOe2vP1I9hEw-syKqpPW1-Nr5TM10Qm8MYXuigatFPNl76FSxYXBRHNcZRjeluGPxMjz78SXzBa07j6YomCGQJCyx5QTRVfhWw7iy4dbb1rybldeodUvY1xI8XzTzQeclYhE8kLI6yN7J02LKpkhHmzMgFY0Qr73gvoINGmZyguJItJ0ZcaR3zBJNfIIaBSaYe3amB4qL7zWu6sYOxfdBk8v_lWLCCTys1_ThFoiUhLDrHRK1LX5QELQTf_MAlFVk5qisF2dZo3GFEt5bKga2CNwSH-I5FjL9Z6jnopRDHCs1JGVmYn6sgdLhrG7fbj83hAb5NLwSfebi5pjYRASAVVC18hNo1ZkG_TCTPJv2__KveWNjalDkSWWEVzMZO6ZlnoLMtwEA_KfqaDNEdVTs2wa_-dsbXijPDkF0bSdiDqtiAe6Nk_sL1iEoNMswMvCGOD5orD4oojigkteh-xdyaQ3W0mEqC3HXEaPAAKHQOp2V5XBIMi-wIo_M-bjK304SA68jvmizpyYTI-yTiAb_B8lUR0PeAMp1avZux4NdXVZZu1wWlzgpr3HC7pU7ZmqgO_xJrU-ICMrkeqy8eYOy-jy1NrK03Y9NsT7-JTxg5HHGBptMKkJORT7IIQRl_eCgw0WMu9Bc9ueSIbSCLQgZ_WdaMe3wSjLkj8NmRgQ83HW686Ww54xfwscMq8l97MgaJobKqRvagOzx2KG_cMthGWfkIqVFCdfTgBv3c8Mf7lBDwduxsyfNbLuPSxW_NI4UmxN7Tkx5xN_qBBI5prltAYX_jYhFtq6JNd2bnWgqDBfDDkHp96Wj3kMRF43A8sw=w1804-h507-s-no?authuser=0)

In [None]:
model = prepare_model_for_int8_training(original_8bit_llama_model)

config = LoraConfig(
    r=16,
    lora_alpha=32, # scaling param related to r, reuse alpaca-lora
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

lora_model = get_peft_model(original_8bit_llama_model, config)

## Training
- Show before and after!

In [None]:
# Check out meaning of the Chinese char using ChatGPT
def eval_model(my_model, examples=["上联：春风得意花铺路\n下联：",
                                   "上联：美丽中国魅力北京\n下联：",
                                   "上联：鱼书千里梦\n下联：",
                                   "对联：日落晚霞临古寺\n下联：",]):
  for p_in in examples:
    batch = tokenizer(
        p_in,
        return_tensors='pt',
    )
    with torch.cuda.amp.autocast(): # required for mixed precisions
      output_tokens = my_model.generate(
          **batch, max_new_tokens=batch['input_ids'].shape[-1])
    # print(output_tokens[0])
    out = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    # My own post-processing logic to "cheat" to align chars
    if len(out) > len(p_in) * 2 - 7:
      out = out[:len(p_in) * 2 - 7 - len(out)] # perfectly match chars
    # replace the last N for visibility
    if out.count('\n') > 1:
      out = out[::-1].replace("\n", "n\\", 1)[::-1]
    print(out)
    print()

In [None]:
# Different GPUs may give out slightly different answers below due to very small precision difference
print("Before training")
eval_model(lora_model)

Before training
上联：春风得意花铺路
下联：沉浸落泥\n上联

上联：美丽中国魅力北京
下联：美丽中国魅力北京

上联：鱼书千里梦
下联：鱼书千里梦

对联：日落晚霞临古寺
下联：晚霞临古寺\n上



In [None]:
# As you can tell, I even omitted eval_dataset for this demo :(
trainer = transformers.Trainer(
    model=lora_model, 
    train_dataset=tokenized_train_ds,
    args=transformers.TrainingArguments(
        # increased batch size will significantly increase GPU requirement here
        # Decrease to 4 if you have less than 16G vram
        # Batch = 4, probably 8.3-8.8G vram
        # Batch = 16, 9.5G+
        # Batch = 32, 11G+
        # Batch = 64, 14G+
        per_device_train_batch_size=32,
        gradient_accumulation_steps=2,
        warmup_steps=8,
        num_train_epochs=2,
        learning_rate=2e-4, 
        fp16=True,
        logging_steps=20,
        output_dir='outputs',
        remove_unused_columns=False,
    ),
    data_collator=transformers.DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True,
    ),
)
lora_model.config.use_cache = False # Alpaca Lora sets this for training
trainer.train()

Step,Training Loss
20,4.2558
40,3.1916
60,2.9281
80,2.7971
100,2.7352
120,2.7313
140,2.712


TrainOutput(global_step=156, training_loss=3.016379796541654, metrics={'train_runtime': 2041.2573, 'train_samples_per_second': 4.899, 'train_steps_per_second': 0.076, 'total_flos': 3.5728764671361024e+16, 'train_loss': 3.016379796541654, 'epoch': 1.99})

In [None]:
# Empirical quick tests showed "somehow ok" results if loss < 3.0
print("After training")
eval_model(lora_model)

After training
上联：春风得意花铺路
下联：风雨吹梦浮浮�

上联：美丽中国魅力北京
下联：梦幻中国梦想宏碁

上联：鱼书千里梦
下联：鸟声万里声

对联：日落晚霞临古寺
下联：月映晨霜满梦境



## Suggested additional reading
- [Decoding algorithm by HF](https://huggingface.co/blog/how-to-generate)
- So far, I only demoed greedy search (output token with highest prob at each position without looking ahead)

## Optional: Upload to HuggingFace and share with the world!
- And you should!

In [None]:
# from huggingface_hub import notebook_login
# notebook_login()
# YOUR_HF_ID = "YOUR_ID_PLZ"
# lora_model.push_to_hub(f"{YOUR_HF_ID}/chinese-couplet-llama-lora-test-v0.1",
#                        use_auth_token=True,
#                        create_pr=True)
# # Go to huggingface and merge the PR to share with the world!