<a href="https://colab.research.google.com/github/kjacone/swahili_cake_boss/blob/main/Gemma/spoken-language-tasks/k-gemma-it/spoken_language_tasks_with_gemma.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to Fine-tuning Gemma for Spoken Language Tasks

This notebook demonstrate how to fine tune Gemma for the specific task on replying to email requests that a Korean bakery business might get.

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/kjacone/swahili_cake_boss/blob/main/Gemma/spoken-language-tasks/k-gemma-it/spoken_language_tasks_with_gemma.ipynb#scrollTo=YNDq8NbCY7oh"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

## Setup

### Select the Colab runtime
To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model:

1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **L4** or **A100 GPU**.


### Gemma setup on Kaggle
To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

* Get access to Gemma on kaggle.com.
* Select a Colab runtime with sufficient resources to run the Gemma 2B model.
* Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

### Set environemnt variables

Set environement variables for ```KAGGLE_USERNAME``` and ```KAGGLE_KEY```.

In [7]:
import os
from google.colab import userdata, drive

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")

# Mounting gDrive for to store artifacts
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Install dependencies

Install Keras and KerasNLP

In [None]:
!pip install -q -U keras-nlp datasets
!pip install -q -U keras

# Set the backbend before importing Keras
os.environ["KERAS_BACKEND"] = "jax"
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

import keras_nlp
import keras

# Run at half precision.
#keras.config.set_floatx("bfloat16")

# Training Configurations
token_limit = 512
num_data_limit = 100
lora_name = "bankingassistant"
lora_rank = 4
lr_value = 1e-4
train_epoch = 20
model_id = "gemma2_instruct_2b_en"

##Load And Translate Data

In [None]:
!pip install transformers datasets


In [None]:
!pip install huggingface-hub

In [None]:
!huggingface-cli login

In [6]:
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json

# Load the original Korean Cake Boss dataset
ds = load_dataset("bebechien/korean_cake_boss")

# Load the English-to-Swahili translation model and tokenizer
model_name = "Helsinki-NLP/opus-mt-en-sw"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Translate function for individual text fields
def translate_text(text, tokenizer, model):
    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
    with torch.no_grad():
        outputs = model.generate(**inputs)
    translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return translated_text

# Translate and save function for datasets with "input" and "output" fields
def translate_and_save(dataset_split, filename, tokenizer, model):
    with open(filename, "w", encoding="utf-8") as f:
        for item in dataset_split:
            translated_input = translate_text(item["input"], tokenizer, model)
            translated_output = translate_text(item["output"], tokenizer, model)
            # Save each translated entry as a dictionary in JSON format
            json.dump({"input": translated_input, "output": translated_output}, f, ensure_ascii=False)
            f.write("\n")  # Newline to separate JSON objects

# Translate and save the train and test splits
translate_and_save(ds["train"], "swahili_cake_boss_train_sw.json", tokenizer, model)
translate_and_save(ds["test"], "swahili_cake_boss_test_sw.json", tokenizer, model)

print("Translation completed and saved to swahili_cake_boss_train_sw.json and swahili_cake_boss_test_sw.json.")

# Load the translated JSON files and create Dataset objects
def load_translated_dataset(filename):
    with open(filename, "r", encoding="utf-8") as f:
        data = [json.loads(line) for line in f]
    return Dataset.from_dict({"input": [entry["input"] for entry in data], "output": [entry["output"] for entry in data]})

# Create DatasetDict from translated files
translated_train = load_translated_dataset("swahili_cake_boss_train_sw.json")
translated_test = load_translated_dataset("swahili_cake_boss_test_sw.json")
translated_dataset = DatasetDict({"train": translated_train, "test": translated_test})

# Push the translated dataset to Hugging Face Hub
translated_dataset.push_to_hub("jacone/swahili_cake_boss", private=False)

print("Swahili-translated dataset successfully pushed to Hugging Face Hub.")


README.md:   0%|          | 0.00/61.0 [00:00<?, ?B/s]

train.csv:   0%|          | 0.00/14.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/20 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/821k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/813k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.33M [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/300M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

KeyError: 'test'

## Load Model

In [None]:
import keras
import keras_nlp

import time

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
gemma_lm.summary()

tick_start = 0

def tick():
    global tick_start
    tick_start = time.time()

def tock():
    print(f"TOTAL TIME ELAPSED: {time.time() - tick_start:.2f}s")

def text_gen(prompt):
    tick()
    input = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
    output = gemma_lm.generate(input, max_length=token_limit)
    print("\nGemma output:")
    print(output)
    tock()

# inference before fine-tuning
text_gen("Tafadhali andika jibu la barua pepe kwa:\n\"Hujambo, ningependa kuagiza keki moja nambari 3 kwa ajili ya maadhimisho ya harusi yetu. Je, hilo linawezekana?\n")

## Load Dataset

In [None]:
import keras
import keras_nlp
import datasets

tokenizer = keras_nlp.models.GemmaTokenizer.from_preset(model_id)

# prompt structure
# <start_of_turn>user
# 다음에 대한 이메일 답장을 작성해줘.
# "{EMAIL CONTENT FROM THE CUSTOMER}"
# <end_of_turn>
# <start_of_turn>model
# {MODEL ANSWER}<end_of_turn>

# input, output
from datasets import load_dataset
ds = load_dataset(
    "bebechien/korean_cake_boss",
    split="train",
)
print(ds)
data = ds.with_format("np", columns=["input", "output"], output_all_columns=False)
train = []

for x in data:
  item = f"<start_of_turn>user\n다음에 대한 이메일 답장을 작성해줘.\n\"{x['input']}\"<end_of_turn>\n<start_of_turn>model\n{x['output']}<end_of_turn>"
  length = len(tokenizer(item))
  # skip data if the token length is longer than our limit
  if length < token_limit:
    train.append(item)
    if(len(train)>=num_data_limit):
      break

print(len(train))
print(train[0])
print(train[1])
print(train[2])

## LoRA Fine-tuning

In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=lora_rank)
gemma_lm.summary()

# Limit the input sequence length (to control memory usage).
gemma_lm.preprocessor.sequence_length = token_limit
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=lr_value,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)


Note that enabling LoRA reduces the number of trainable parameters significantly.

In [None]:
class CustomCallback(keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    model_name = f"/content/drive/MyDrive/{lora_name}_{lora_rank}_epoch{epoch+1}.lora.h5"
    gemma_lm.backbone.save_lora_weights(model_name)

    # Evaluate
    text_gen("다음에 대한 이메일 답장을 작성해줘.\n\"안녕하세요, 결혼기념일을 위해 3호 케이크 1개를 주문하고 싶은데 가능할까요?\"")

history = gemma_lm.fit(train, epochs=train_epoch, batch_size=2, callbacks=[CustomCallback()])

import matplotlib.pyplot as plt
plt.plot(history.history['loss'])
plt.show()

In [None]:
# Example Code for Load LoRA
'''
train_epoch=17
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
# Use the same LoRA rank that you trained
gemma_lm.backbone.enable_lora(rank=4)

# Load pre-trained LoRA weights
gemma_lm.backbone.load_lora_weights(f"/content/drive/MyDrive/{lora_name}_{lora_rank}_epoch{train_epoch}.lora.h5")
'''

## Try a different sampler

The top-K algorithm randomly picks the next token from the tokens of top K probability.

In [None]:
gemma_lm.compile(sampler="top_k")
text_gen("다음에 대한 이메일 답장을 작성해줘.\n\"안녕하세요, 결혼기념일을 위해 3호 케이크 1개를 주문하고 싶은데 가능할까요?\"")
text_gen("다음에 대한 이메일 답장을 작성해줘.\n\"안녕하세요, 결혼기념일을 위해 3호 케이크 1개를 주문하고 싶은데 가능할까요?\"")
text_gen("다음에 대한 이메일 답장을 작성해줘.\n\"안녕하세요, 결혼기념일을 위해 3호 케이크 1개를 주문하고 싶은데 가능할까요?\"")

Try a slight different prompts

In [None]:
text_gen("다음에 대한 답장을 작성해줘.\n\"안녕하세요, 결혼기념일을 위해 3호 케이크 1개를 주문하고 싶은데 가능할까요?\"")
text_gen("아래에 적절한 답장을 써줘.\n\"안녕하세요, 결혼기념일을 위해 3호 케이크 1개를 주문하고 싶은데 가능할까요?\"")
text_gen("다음에 관한 답장을 써주세요.\n\"안녕하세요, 결혼기념일을 위해 3호 케이크 1개를 주문하고 싶은데 가능할까요?\"")

Try a different email inputs

In [None]:
text_gen("""다음에 대한 이메일 답장을 작성해줘.
"안녕하세요,

6월 15일에 있을 행사 답례품으로 쿠키 & 머핀 세트를 대량 주문하고 싶습니다.

수량: 50세트
구성: 쿠키 2개 + 머핀 1개 (개별 포장)
디자인: 심플하고 고급스러운 디자인 (리본 포장 등)
문구: "감사합니다" 스티커 부착
배송 날짜: 6월 14일
대량 주문 할인 혜택이 있는지, 있다면 견적과 함께 배송 가능 여부를 알려주시면 감사하겠습니다.

감사합니다.

박철수 드림" """)
