# Gemma 7b-it-quant on multi GPU

# Import libs

In [1]:
%%capture
%pip install -q -U transformers
%pip install -q -U accelerate
%pip install -q -U bitsandbytes
%pip install -q -U trl 
%pip install -q -U peft
%pip install -q datasets==2.16.0

In [2]:
import pandas as pd
import os
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline, logging
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os,torch, wandb
from datasets import load_dataset, Dataset
from trl import SFTTrainer

2024-08-28 16:55:53.678226: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-28 16:55:53.678339: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-28 16:55:53.846189: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_hf = user_secrets.get_secret("HUGGINGFACE_TOKEN")
secret_wandb = user_secrets.get_secret("WANDB_API_KEY")

In [4]:
!huggingface-cli login --token $secret_hf

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [5]:
wandb.login(key = secret_wandb)
run = wandb.init(
    project='Fine tuning mistral 7B', 
    job_type="training", 
    anonymous="allow"
)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdragoa389[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.17.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.16.3
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240828_165610-7g78pnhp[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33msage-frost-13[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/dragoa389/Fine%20tuning%20mistral%207B[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/dragoa389/Fine%20tuning%20mistral%207B/runs/7g78pnhp[0m


# Load Model 

In [6]:
model_name = "google/gemma-7b-it"

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=bnb_config, 
)

model.config.use_cache = False
model.config.pretraining_tp = 1

max_seq_length = 2048
tokenizer = AutoTokenizer.from_pretrained(model_name, max_seq_length=max_seq_length)
EOS_TOKEN = tokenizer.eos_token

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.11G [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

In [7]:
model

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 3072, padding_idx=0)
    (layers): ModuleList(
      (0-27): 28 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=3072, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear4bit(in_features=3072, out_features=24576, bias=False)
          (up_proj): Linear4bit(in_features=3072, out_features=24576, bias=False)
          (down_proj): Linear4bit(in_features=24576, out_features=3072, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((3072,), eps=1e-06)
        (post_attention_layernorm):

# Load our data

In [8]:
train = pd.read_csv('/kaggle/input/llm-prompt-recovery-ground-truth-1/llm-prompt-recovery-train.csv')
train = train.sample(1500).reset_index(drop=True) 
train

Unnamed: 0.1,Unnamed: 0,id,original_text,rewrite_prompt,rewritten_text
0,180,35891980,Gusts of 40-50mph (64-80km/h) are forecast for...,Write this as an audit report summary.,**Audit Report Summary**\n\nThe Met Office has...
1,1308,35888944,Joan Barnett suffered injuries at her home in ...,Rewrite the message as a vintage fashion magaz...,**Vintage Fashion Magazine Feature**\n\n**The ...
2,676,35581289,"Club director Charlton, 78, joined United in 1...",Make it a detailed report by a private investi...,**Private Investigation Report**\n\n**Date:** ...
3,7011,32304725,"The Mountain Goat, which is based on one of th...",Turn this into a board game instruction.,## The Mountain Goat Board Game Instructions\n...
4,4682,32977194,"The body of Jed Allen, 21, was found in woodla...",Adapt it as an intense courtroom drama scene.,The courtroom erupted into a tense atmosphere ...
...,...,...,...,...,...
1495,1836,37461782,The Lady Glovers are top of WSL 2 ahead of hos...,Transform this into a formal business proposal.,**[Your Name]**\n**[Your Title]**\n**[Your Com...
1496,4552,39659216,James Hall qualified third in the all around a...,Make it a declaration of love in a cyberpunk c...,Neon lights flickered against the chrome skyli...
1497,5977,36194476,Rhyl's Sun Centre shut in 2014 after the trust...,Rephrase this as an etiquette guide for a fanc...,**Etiquette Guide for Attending a Fancy Gala**...
1498,5240,33425414,Media playback is unsupported on your device\n...,Make the text into an elegant wedding invitation,"In the glow of a starlit sky, we invite you to..."


In [9]:
test = pd.read_csv('/kaggle/input/llm-prompt-recovery/test.csv')
test['original_text'] = test['original_text'].fillna(' ')
test

Unnamed: 0,id,original_text,rewritten_text
0,-1,The competition dataset comprises text passage...,Here is your shanty: (Verse 1) The text is rew...


In [10]:
USER_CHAT_TEMPLATE ="""<start_of_turn>user\nTask: Your task is to compare the texts below to identify key changes and then, generate a concise rewrite \
prompt that can direct to transform the text in same way. \n\nOriginal Text :'{ot}'\n\nRewritten Text :'{rt}' \n 
Expected output format: \nPrompt:'concise rewrite prompt '<end_of_turn>\n<start_of_turn>model\n"""

print(USER_CHAT_TEMPLATE.format(ot=test.original_text[0], rt = test.rewritten_text[0]))

<start_of_turn>user
Task: Your task is to compare the texts below to identify key changes and then, generate a concise rewrite prompt that can direct to transform the text in same way. 

Original Text :'The competition dataset comprises text passages that have been rewritten by the Gemma LLM according to some rewrite_prompt instruction. The goal of the competition is to determine what prompt was used to rewrite each original text.  Please note that this is a Code Competition. When your submission is scored, this example test data will be replaced with the full test set. Expect roughly 2,000 original texts in the test set.'

Rewritten Text :'Here is your shanty: (Verse 1) The text is rewritten, the LLM has spun, With prompts so clever, they've been outrun. The goal is to find, the prompt so bright, To crack the code, and shine the light. (Chorus) Oh, this is a code competition, my dear, With text and prompts, we'll compete. Two thousand texts, a challenge grand, To guess the prompts, ha

In [11]:
def generate(ot, rt, model, device) -> str:
    USER_TEMPLATE=USER_CHAT_TEMPLATE.format(ot=ot, rt =rt)
    return model.generate(USER_TEMPLATE,device=device, output_len=50)

In [12]:
def gen_df(test):
    test['prompt'] = test.progress_apply(lambda row: generate(ot=row['original_text'], rt=row['rewritten_text'], model=model, device="cuda:0"), axis=1)

In [13]:
train

Unnamed: 0.1,Unnamed: 0,id,original_text,rewrite_prompt,rewritten_text
0,180,35891980,Gusts of 40-50mph (64-80km/h) are forecast for...,Write this as an audit report summary.,**Audit Report Summary**\n\nThe Met Office has...
1,1308,35888944,Joan Barnett suffered injuries at her home in ...,Rewrite the message as a vintage fashion magaz...,**Vintage Fashion Magazine Feature**\n\n**The ...
2,676,35581289,"Club director Charlton, 78, joined United in 1...",Make it a detailed report by a private investi...,**Private Investigation Report**\n\n**Date:** ...
3,7011,32304725,"The Mountain Goat, which is based on one of th...",Turn this into a board game instruction.,## The Mountain Goat Board Game Instructions\n...
4,4682,32977194,"The body of Jed Allen, 21, was found in woodla...",Adapt it as an intense courtroom drama scene.,The courtroom erupted into a tense atmosphere ...
...,...,...,...,...,...
1495,1836,37461782,The Lady Glovers are top of WSL 2 ahead of hos...,Transform this into a formal business proposal.,**[Your Name]**\n**[Your Title]**\n**[Your Com...
1496,4552,39659216,James Hall qualified third in the all around a...,Make it a declaration of love in a cyberpunk c...,Neon lights flickered against the chrome skyli...
1497,5977,36194476,Rhyl's Sun Centre shut in 2014 after the trust...,Rephrase this as an etiquette guide for a fanc...,**Etiquette Guide for Attending a Fancy Gala**...
1498,5240,33425414,Media playback is unsupported on your device\n...,Make the text into an elegant wedding invitation,"In the glow of a starlit sky, we invite you to..."


# Training

In [14]:
output = train.copy()
output['prompt'] = output.apply(lambda row: USER_CHAT_TEMPLATE.format(ot=row['original_text'], rt=row['rewritten_text'],rp=row['rewrite_prompt']), axis=1)

In [15]:
train_dataset = Dataset.from_pandas(output[['prompt']])
train_dataset

Dataset({
    features: ['prompt'],
    num_rows: 1500
})

In [16]:
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
model = get_peft_model(model, peft_config)

In [17]:
training_arguments = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
)

In [18]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    peft_config=peft_config,
    max_seq_length= None,
    dataset_text_field="prompt",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [19]:
trainer.train()



Step,Training Loss
25,2.4777
50,1.3512
75,1.3719
100,1.3284
125,1.2899
150,1.2782
175,1.273
200,1.2325
225,1.293
250,1.2475




TrainOutput(global_step=375, training_loss=1.3444193929036459, metrics={'train_runtime': 10494.6485, 'train_samples_per_second': 0.143, 'train_steps_per_second': 0.036, 'total_flos': 3.0513989191458816e+16, 'train_loss': 1.3444193929036459, 'epoch': 1.0})

# Saving the Model

In [20]:
new_model_name = "gemma_prompt_recovery_finetuned"

trainer.model.save_pretrained(new_model_name)
wandb.finish()
model.config.use_cache = True

[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:         train/epoch ▁▂▂▃▃▄▄▅▅▆▆▇▇███
[34m[1mwandb[0m:   train/global_step ▁▁▂▃▃▃▄▅▅▅▆▇▇▇██
[34m[1mwandb[0m:     train/grad_norm ▃▆▂▅▂▆▂▄▁▃▁█▁▅▄
[34m[1mwandb[0m: train/learning_rate ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:          train/loss █▂▂▂▂▂▂▁▂▂▂▁▁▁▁
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:               total_flos 3.0513989191458816e+16
[34m[1mwandb[0m:              train/epoch 1.0
[34m[1mwandb[0m:        train/global_step 375
[34m[1mwandb[0m:          train/grad_norm 0.522
[34m[1mwandb[0m:      train/learning_rate 0.0002
[34m[1mwandb[0m:               train/loss 1.1627
[34m[1mwandb[0m:               train_loss 1.34442
[34m[1mwandb[0m:            train_runtime 10494.6485
[34m[1mwandb[0m: train_samples_per_second 0.143
[34m[1mwandb[0m:   train_s

In [21]:
# trainer.model.push_to_hub(new_model_name, use_temp_dir=False)