## LLM Prompt Recovery

##### Copyright 2024 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tune Gemma models in Keras using LoRA

## Setup

### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [1]:
import tensorflow
from tensorflow.python.client import device_lib
print(tensorflow.__version__)
# print(device_lib.list_local_devices())
print("Num GPUs Available: ", len(tensorflow.config.list_physical_devices('GPU')))

2024-04-02 02:37:39.406428: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-02 02:37:39.406452: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-02 02:37:39.407614: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-02 02:37:39.413823: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


2.15.0
Num GPUs Available:  2


2024-04-02 02:37:40.907592: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-02 02:37:40.907959: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-02 02:37:40.948266: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-

In [2]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
%pip install -q -U keras-nlp
%pip install -q -U keras>=3

[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
!pip install jax
!pip install -U jax jaxliab
!pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmls

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [2]:
import os

os.environ["KERAS_BACKEND"] = "torch"  # Or "jax" or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras and KerasNLP.

In [3]:
import keras
import keras_nlp
import glob
import pandas as pd

## Load Dataset

In [None]:
total_df = pd.DataFrame({'original_text':['text'], 'rewrite_prompt':['text'], 'rewritten_text':['text']})
file_types = ['gemma-rewrite-nbroad', 'wikipedia-first-paragraph', 'gemma1000_w7b']
df_paths = glob.glob("workspace/llm-prompt-recovery-data/"+ "/**/*.csv")

for df_path in df_paths:
    if any(file_type in df_path for file_type in file_types): 
        data_df = pd.read_csv(df_path)

        try:
            data_df = data_df[['original_text', 'rewrite_prompt', 'rewritten_text']]
            total_df = pd.concat([total_df, data_df], axis=0)
        except:
            pass

total_df = total_df.iloc[1:, :].reset_index(drop=True)
total_df["id"] = total_df.index
total_df = total_df[['id', 'original_text', 'rewrite_prompt', 'rewritten_text']]
total_df.to_csv("workspace/llm-prompt-recovery-data/finetuning.csv", index=False)

In [None]:
total_df

Unnamed: 0,id,original_text,rewrite_prompt,rewritten_text
0,0,"`` Well, there are healthier ways to tell me y...",Rewrite the story where the writer asks the re...,"Well, there are healthier ways to tell me you ..."
1,1,Rory ran his shaky fingers through his wife's ...,Rewrite the essay as a dramatic play,## The Final Curtain\n\n[FADE IN]\n\n**Setting...
2,2,As I made my way on foot across town to the Po...,Rewrite the story with all the themes and sett...,As I made my way through the Tatooine desert o...
3,3,`` Hello. We come in peace.'' \n \n The first ...,Rewrite the essay if the advanced aliens didn'...,`` Hello. We come in peace.''\n\nThe first enc...
4,4,"`` Karen, what the helllllll izzz...'' says my...",Rewrite the story as a court room drama starri...,The courtroom erupted in an uproar as District...
...,...,...,...,...
7561,7561,Výškovce is a village and municipality in Stro...,Write it as the last chapter of a book that ch...,## The Whispering Walls of Výškovce\n\nThe win...
7562,7562,The World I Want to Leave Behind is the fourth...,Adapt it as a solemn monastic chant.,"Sure, here's the adapted chant:\n\n""The world ..."
7563,7563,The Akademie Rudolph-Antoniana was an early mo...,Convert it into a narrative of the first sunri...,A pall of darkness cast the Akademie Rudolph-A...
7564,7564,"A prisoner transport vehicle, informally known...",Style it as a proclamation by a newly crowned ...,"**Hear ye, hear ye, gathered mortals,**\n\nI, ..."


In [4]:
import pandas as pd

train_df = pd.read_csv('workspace/kaggle/input/llm-prompt-recovery/train.csv')
prompt_for_llm = (
    "<start_of_turn>user\nGenerate a rewrite_prompt that effectively transforms the given original_text into the provided rewritten_text."
    "Capture the essence and context of the content while improving the language, coherence, and expressiveness."
    "Pay attention to detail, clarity, and overall quality in your generated rewrite_prompt."
    "Here is an example sample: original text-" + train_df.loc[0, 'original_text'] +
    "rewritten_text-" + train_df.loc[0, 'rewritten_text'] +
    "and this is the right rewrite_prompt-" + train_df.loc[0, 'rewrite_prompt'] +
    "Now, You will output in text the most suitable rewrite_prompt. For the given original_text- {ot}" +
    "and rewritten_text- {rt}" + 
    "<end_of_turn>\n<start_of_turn>model\n"
)

In [5]:
data = []
total_df = pd.read_csv("workspace/llm-prompt-recovery-data/finetuning.csv")
rewritten_texts = total_df['rewritten_text']
original_texts = total_df['original_text']
for idx, (rewritten_text, original_text) in enumerate(zip(rewritten_texts, original_texts)):
    template = prompt_for_llm.format(ot=original_text, rt=rewritten_text)
    data.append(template)

## Load Model

In [7]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

2024-04-02 02:37:53.791221: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-02 02:37:53.791475: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-02 02:37:53.791703: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-

## LoRA Fine-tuning

In [8]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [9]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=3, batch_size=1)

Epoch 1/3


2024-04-02 02:38:18.036623: I external/local_xla/xla/service/service.cc:168] XLA service 0x4473b2f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-04-02 02:38:18.036658: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 4090, Compute Capability 8.9
2024-04-02 02:38:18.036666: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (1): NVIDIA GeForce RTX 4090, Compute Capability 8.9
2024-04-02 02:38:18.667221: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-04-02 02:38:19.894224: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert
2024-04-02 02:38:23.173421: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDN

[1m   1/7566[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m46:55:13[0m 22s/step - loss: 2.6712 - sparse_categorical_accuracy: 0.5078

I0000 00:00:1712025509.288688  956397 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
W0000 00:00:1712025509.333967  956397 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m7566/7566[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m663s[0m 85ms/step - loss: 0.9074 - sparse_categorical_accuracy: 0.8296
Epoch 2/3
[1m7566/7566[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m641s[0m 85ms/step - loss: 0.6578 - sparse_categorical_accuracy: 0.8743
Epoch 3/3
[1m7566/7566[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m641s[0m 85ms/step - loss: 0.6480 - sparse_categorical_accuracy: 0.8750


<keras.src.callbacks.history.History at 0x79978441e190>

each epoch running Time = 11 minute
<br>batch size = 1
<br>lora_rank = 4
<br>model = gemma-2b-en
<br>the number of data = 7566

## Inference after fine-tuning

In [15]:
test = pd.read_csv('workspace/kaggle/input/llm-prompt-recovery/test.csv')
sample = pd.read_csv('workspace/kaggle/input/llm-prompt-recovery/sample_submission.csv')

In [17]:
test

Unnamed: 0,id,original_text,rewritten_text
0,-1,The competition dataset comprises text passage...,Here is your shanty: (Verse 1) The text is rew...


In [16]:
predictions = []

for i in range(len(test)):
    original_text = test.loc[i, 'original_text']
    rewritten_text = test.loc[i, 'rewritten_text']
    
    rewrite_prompt = gemma_lm.generate(prompt_for_llm.format(ot=original_text, rt=rewritten_text),
                      max_length=256)

    predictions.append(rewrite_prompt)

sample['rewrite_prompt'] = predictions   
sample.to_csv('workspace/submission.csv',index=False)

### Reference

Google
<br>[Original Code](https://ai.google.dev/gemma/docs/lora_tuning)
<br>[Distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) 
<br>[Generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).
<br>[Use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).
<br>[Fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)
<br><br>Keras
<br>[Model Architectures](https://keras.io/api/keras_nlp/models/)
<br><br>Paper
<br>[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685)
<br><br>Data
<br>[LLM Prompt Recovery - Synthetic Datastore](https://www.kaggle.com/datasets/dschettler8845/llm-prompt-recovery-synthetic-datastore)
<br>[3000 Rewritten texts - Prompt recovery Challenge](https://www.kaggle.com/datasets/dipamc77/3000-rewritten-texts-prompt-recovery-challenge)
<br>[gemma-rewrite-nbroad](https://www.kaggle.com/datasets/nbroad/gemma-rewrite-nbroad)
