## Gemma2 2Billion
Fine Tuning Gemma 2 2B model for Roman Nepali Ecommerce Customer Support Question Answer.

In [1]:
# Installing required libraries

**transformers** : Library for working with models
**datasets** : For loading dataset of huffingface
**bitsandbytes** : Memory efficient quantization of large models
**accelerate** : For distributed training and faster inference
**peft** : For Parameter Efficient Fine Tuning using LORA
**trl** : Transformer Reinforcement Learning
**py7zr** : For handling specific file type

In [2]:
!pip install transformers datasets bitsandbytes accelerate peft trl py7zr

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl.metadata (2.9 kB)
Collecting trl
  Downloading trl-0.13.0-py3-none-any.whl.metadata (11 kB)
Collecting py7zr
  Downloading py7zr-0.22.0-py3-none-any.whl.metadata (16 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting texttable (from py7zr)
  Downloading texttable-1.7.0-py2.py3-none-any.whl.metadata (9.8 kB)
Collecting pycryptodomex>=3.16.0 (f

## Import Required Libraries

In [3]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from transformers import TrainingArguments, Trainer

### Loading dataset

In [4]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
dataset = load_dataset("manojbaniya/ecommerce_qna")

README.md:   0%|          | 0.00/180 [00:00<?, ?B/s]

question_answer.csv:   0%|          | 0.00/3.78M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/33753 [00:00<?, ? examples/s]

In [6]:
dataset["train"][:10]

{'query': ['Kasari account banaune?',
  'Account kholna ke ke lagcha?',
  'Account kholne process kasto huncha?',
  'Kati samaya lagcha account kholna?',
  'Account khole pachi k k garna sakincha?',
  'Account kholna lai kun kun info chahinchha?',
  'Account banaune sanga kun benefits huncha?',
  'Kahile pani account kholeko chaina, kasari suru garne?',
  'Account kholna lai email address chahincha?',
  'Account kholeko bhane kasari confirm garne?'],
 'response': ['Account banauna yaha click garnus.',
  'Email ra password chahincha account kholna.',
  'Form fill gare pachi, confirmation email auncha.',
  'Account kholna ko lagi matra 5 minutes lagcha.',
  'Account khole pachi, shopping garna sakincha.',
  'Account kholna, name, email ra password chahincha.',
  'Account banauda special discounts pani paincha.',
  'Yadi pehli account kholeko chaina bhane, new user bhanera sign up garnus.',
  'Yes, account kholna lai email address chahincha.',
  'Account kholeko bhaye, confirmation email 

### Loading the Gemma From Huggingface

In [7]:
model_checkpoint = "google/gemma-2-2b"

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

In [9]:
tokenizer.pad_token = tokenizer.eos_token

In [10]:
# Configure Quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bits=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

Unused kwargs: ['load_in_4bits']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


In [11]:
model = AutoModelForCausalLM.from_pretrained(
    model_checkpoint,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    use_auth_token=True
)



config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

### Prepare model for kbit training

In [12]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

In [13]:
peft_model = get_peft_model(model, peft_config)

In [14]:
peft_model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma2ForCausalLM(
      (model): Gemma2Model(
        (embed_tokens): Embedding(256000, 2304, padding_idx=0)
        (layers): ModuleList(
          (0-25): 26 x Gemma2DecoderLayer(
            (self_attn): Gemma2Attention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2304, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2304, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
        

In [15]:
dataset["train"]
# filter any None type
dataset = dataset.filter(lambda x: x["query"] is not None and x["response"] is not None)

Filter:   0%|          | 0/33753 [00:00<?, ? examples/s]

### Prepare the dataset

In [16]:
def prepare_dataset(examples):
  """Prepare the dataset for the training with prompt"""
  prompt = "Answer the following user query of Ecommerce customer support in Roman Nepali Language:\n\n"
  query = [prompt + query for query in examples["query"]]
  response = [answer for answer  in examples["response"]]

  model_inputs = tokenizer(query, max_length=200, padding="max_length", truncation=True, return_tensors="pt")
  labels = tokenizer(response, max_length=200, padding="max_length", truncation=True, return_tensors="pt")

  # replace padding token with -100 to ignore them
  labels["input_ids"] = labels["input_ids"].masked_fill(labels["input_ids"] == tokenizer.pad_token_id, -100)

  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

In [32]:
dataset_train = dataset["train"].map(prepare_dataset, batched=True)

Map:   0%|          | 0/33715 [00:00<?, ? examples/s]

In [33]:
dataset_train

Dataset({
    features: ['query', 'response', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 33715
})

In [34]:
len(dataset_train[10]["input_ids"])

200

In [35]:
len(dataset_train[10]["labels"])

200

In [36]:
# only get 10000 of data
dataset_train = dataset_train.select(range(10000))

In [37]:
# split the train and eval dataset
dataset_train = dataset_train.train_test_split(test_size=0.2, shuffle=True, seed=42)


In [38]:
dataset_train

DatasetDict({
    train: Dataset({
        features: ['query', 'response', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['query', 'response', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 2000
    })
})

### Training Arguments

In [43]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    gradient_accumulation_steps=8
)



In [44]:
trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=dataset_train["train"],
    eval_dataset=dataset_train["test"],
)

In [45]:
trainer.train()

Step,Training Loss,Validation Loss


TrainOutput(global_step=250, training_loss=8.876617324829102, metrics={'train_runtime': 6632.8913, 'train_samples_per_second': 1.206, 'train_steps_per_second': 0.038, 'total_flos': 1.94660425728e+16, 'train_loss': 8.876617324829102, 'epoch': 1.0})

### Evaluate

In [46]:
evaluation = trainer.evaluate()

print(evaluation)

The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.


{'eval_loss': 6.052026271820068, 'eval_runtime': 796.0309, 'eval_samples_per_second': 2.512, 'eval_steps_per_second': 0.628, 'epoch': 1.0}


### Test the model

In [57]:
test_input = "Answer the question of Ecommerce Customer Query in Roman Nepali.\n\n kasari account banaune"

In [58]:
inputs = tokenizer(test_input, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100)

tokenizer.decode(outputs[0], skip_special_tokens=True)

'Answer the question of Ecommerce Customer Query in Roman Nepali.\n\n kasari account banaune, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna, account banna'

### Save the model checkpoint to huggingface

In [52]:
# save the model to huggingface hub
# peft_model.push_to_hub("manojbaniya/gemma-2-2b-ecommerce-qna", use_auth_token=True)