<a href="https://www.kaggle.com/code/dbtmddn41/prompt-recovery-keras-tpu-train?scriptVersionId=170113483" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

copied from https://www.kaggle.com/code/nilaychauhan/keras-gemma-distributed-finetuning-and-inference

# Config

In [None]:
import datetime

start_time = datetime.datetime.now()

In [None]:
class CFG:
    seed = 42
    dataset_path = "/kaggle/input/llm-prompt-recovery"
    preset = "gemma_instruct_2b_en" # name of pretrained Gemma
    sequence_length = 1024 # max size of input sequence for training
    train_batch = 4 # size of the input batch in training
    validation_batch = 8
    epochs = 4 # number of epochs to train
    test_size = 0.
    train_datas = ['kishanvavdara', 'newtonbaba12345_3', 'newtonbaba12345_1', 'aatiffraz', 'newtonbaba12345_2', 'host']
    validation_datas = ["winddude"]
    lora_rank=128
    save_freq_steps = 1379

# Setup

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q tensorflow-cpu
!pip install -q -U keras-nlp tensorflow-hub
!pip install -q -U keras>=3
!pip install -U tensorflow-text
!pip install parmap
# !pip install -U jax jaxlib
# !pip install -U sentence-transformers

In [None]:
import jax

jax.devices()

In [None]:
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

In [None]:
import keras
import keras_nlp
import pandas as pd

In [None]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

In [None]:
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (None, model_dim)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    None, model_dim, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    None, None, model_dim)
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)

In [None]:
keras.mixed_precision.set_global_policy("mixed_bfloat16")

# Model

In [None]:
model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(CFG.preset)
gemma_lm.summary()

# Datasets

In [None]:
import pandas as pd
import numpy as np
import gc

In [None]:
%%time
df = pd.read_csv("/kaggle/input/all-in-one-dataset-with-embedding/df_with_emb_20240402.csv")

float_cols = df.select_dtypes('float64').columns
df[float_cols] = df[float_cols].astype('float32')
df = df[df['rewritten_text'].str.len() >= 20]
train_df = df[df['dataset_id'].isin(CFG.train_datas)].copy(deep=True)
val_df = df[df['dataset_id'].isin(CFG.validation_datas)].sample(frac=0.02).copy(deep=True)

# df = df.drop(index=0)
gc.collect()
train_df.head()

In [None]:
train_df.info()
val_df.info()

In [None]:
# df[df.duplicated(subset=["rewrite_prompt"], keep=False)]
display(train_df[["original_text", "rewrite_prompt", "rewritten_text"]].nunique())
print('duplicated row:', train_df.duplicated(subset=["original_text", "rewrite_prompt", "rewritten_text"], keep=False).sum())

각각은 중복된 text가 있지만 세 개 다 중복된 것은 없다.

In [None]:
import random
def display_random_row(df):
    random_idx = random.randrange(0, len(df))
    print(random_idx)
    try:
        print("\033[38;2;255;0;0m",df.loc[random_idx, ["original_text"]].values[0])
        print("\033[35m	", df.loc[random_idx, ["rewrite_prompt"]].values[0])
        print("\033[38;2;0;0;255m", df.loc[random_idx, ["rewritten_text"]].values[0])
        print('\033[36m', df.loc[random_idx, ["rewritte_prompt_predicted"]].values[0])
    except:
        pass
display_random_row(train_df)

In [None]:
import matplotlib.pyplot as plt
display(pd.concat([train_df["original_text"].str.len(), train_df["rewrite_prompt"].str.len(), train_df["rewritten_text"].str.len()], axis=1).describe())
bins = np.linspace(20, 2000, 50)
plt.hist(train_df["original_text"].str.len(), bins, alpha=0.5, label='original_text')
plt.hist(train_df["rewrite_prompt"].str.len(), bins, alpha=0.5, label='rewrite_prompt')
plt.hist(train_df["rewritten_text"].str.len(), bins, alpha=0.5, label='rewritten_text')
plt.legend(loc="upper left")
plt.show()

In [None]:
def truncate_txt(text, length):
    text_list = text.split()
    
    if len(text_list) <= length:
        return text
    
    return " ".join(text_list[:length])


def gen_val_prompt(df):
    
    # Truncate the texts to first 200 words for now
    # As we are having memory issues on Mixtral8x7b
    og_text = truncate_txt(df["original_text"].strip(), CFG.sequence_length//3)
    rewritten_text = truncate_txt(df["rewritten_text"].strip(), CFG.sequence_length//3)
    template = """<bos>Instruct: Original Text:{}\nRewritten Text:{}\nWrite a prompt that was likely given to the LLM to rewrite original text into rewritten text. Output:
<start_of_turn>model
"""
    return template.format(og_text, rewritten_text).strip()

def gen_prompt(df):
    
    # Truncate the texts to first 200 words for now
    # As we are having memory issues on Mixtral8x7b
    og_text = truncate_txt(df["original_text"].strip(), CFG.sequence_length//3)
    rewritten_text = truncate_txt(df["rewritten_text"].strip(), CFG.sequence_length//3)
    rewrite_prompt = truncate_txt(df["rewrite_prompt"].strip(), CFG.sequence_length//3)
    template = """<bos>Instruct: Original Text:{}\nRewritten Text:{}\nWrite a prompt that was likely given to the LLM to rewrite original text into rewritten text. Output:
<start_of_turn>model
{}<end_of_turn><eos>"""
    return template.format(og_text, rewritten_text, rewrite_prompt).strip()

In [None]:
from multiprocessing import cpu_count, Pool
from tqdm import tqdm
import parmap
# tqdm.pandas()
def parallel_apply(df, main_func, func, n_cores=None):
    if not n_cores:
        n_cores = cpu_count()  # 사용 가능한 모든 CPU 코어를 사용

    # 데이터를 코어 수만큼 분할
    data_split = np.array_split(df, n_cores)
    
    # multiprocessing.Pool.map에 전달하기 위해 partial을 사용하여 함수 인자를 고정합니다.
    from functools import partial
    pool_func = partial(main_func, func=func)
    
#     pool = Pool(n_cores)
    
    # 각 코어에서 apply 함수를 실행
    data = pd.concat(parmap.map(pool_func, data_split, pm_pbar=True, pm_processes=n_cores))
#     pool.close()
#     pool.join()
    return data
# apply 적용 함수
def apply_function(data, func):
    return data.apply(func, axis=1)

In [None]:
from sklearn.model_selection import train_test_split
from tqdm import tqdm
tqdm.pandas()
train_df['prompt'] = train_df[["original_text", "rewritten_text", "rewrite_prompt"]].progress_apply(gen_prompt, axis=1)

val_df['val_prompt'] = val_df[["original_text", "rewritten_text"]].progress_apply(gen_val_prompt, axis=1)

# df['rewrite_prompt'] = df['rewrite_prompt'].progress_apply(lambda x: x.strip())
if CFG.test_size > 0.:
    train_df, val_df = train_test_split(df, test_size=CFG.test_size, random_state=42, stratify=df['dataset_id'])

In [None]:
import tensorflow as tf

train_ds = tf.data.Dataset.from_tensor_slices(train_df['prompt'])
train_ds = (train_ds
            .batch(CFG.train_batch)
            .shuffle(8192)
            .prefetch(tf.data.AUTOTUNE)
           )

if val_df is not None:
    val_ds = tf.data.Dataset.from_tensor_slices(val_df['val_prompt'])#{'val_prompt': val_df['val_prompt'], 'rewrite_prompt_emb': val_df.filter(like="rewrite_prompt_emb").to_numpy()})
    val_ds = (val_ds
                .batch(CFG.validation_batch)
                .prefetch(tf.data.AUTOTUNE)
               )


# Train

In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=CFG.lora_rank)
gemma_lm.summary()

In [None]:
# from sentence_transformers import SentenceTransformer
# from sklearn.metrics.pairwise import cosine_similarity

# class ValidationMetricCallback(keras.callbacks.Callback):
#     def on_epoch_end(self, epochs, logs):
#         prompt_preds = self.model.generate(self.validation_data, max_length=512)
        
class LoraCheckPointOnBatchs(keras.callbacks.Callback):
    def __init__(self, save_freq_batchs):
        super().__init__()
        self.save_freq_batchs = save_freq_batchs
        self.epochs = 0
    def on_train_batch_end(self, batch, logs=None):
        if batch % self.save_freq_batchs == 0 and batch != 0:
            self.model.backbone.save_lora_weights(f"{self.epochs}_{batch}-lora_weights.lora.h5")
    def on_epoch_end(self, epoch, logs=None):
        self.epochs += 1

class LoraCheckPointOnBatchs(keras.callbacks.Callback):
    def __init__(self, start_epochs=0):
        super().__init__()
        self.start_epochs = start_epochs
    def on_epoch_end(self, epoch, logs=None):
        if epoch >= self.start_epochs:
            self.model.backbone.save_lora_weights(f"{epoch}-lora_weights.lora.h5")


In [None]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = CFG.sequence_length 

# Compile the model with loss, optimizer, and metric
lr = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=8e-6,
    decay_steps=10000//CFG.train_batch*10,
    warmup_target=8e-5,
    warmup_steps=3000//CFG.train_batch,
)
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=lr),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
    sampler="greedy"
)
# callbacks = [
#     ValidationMetricCallback()
# ]
# Train model
gemma_lm.fit(train_ds, epochs=CFG.epochs, callbacks=[LoraCheckPointOnBatchs(start_epochs=0)])

# Test and save

In [None]:
test_df = pd.read_csv("/kaggle/input/llm-prompt-recovery/test.csv")
template = """<bos>Instruct: Original Text:{}\nRewritten Text:{}\nWrite a prompt that was likely given to the LLM to rewrite original text into rewritten text. Output:
<start_of_turn>model
"""
print(gemma_lm.generate(template.format(test_df.iloc[0,1], test_df.iloc[0,2]).strip(), max_length=512))

In [None]:
# final_norm_layer = gemma_lm.backbone.get_layer("final_normalization")
gemma_lm.backbone.save_lora_weights("final-lora_weights.lora.h5")
# gemma_lm.save_lora_weights("keras-gemma_instruct_7b_en-lora_weights.weights.h5")

In [None]:
import json
json_string = gemma_lm.to_json()
data = json.loads(json_string)
with open("config.json", "w") as f:
    json.dump(data, f)
# json_string = gemma_lm.preprocessor.tokenizer.to_json()
# data = json.loads(json_string)
# with open("tokenizer.json", "w") as f:
#     json.dump(data, f)


In [None]:
%%time
import glob
val_df.to_csv("validation.csv")

lora_weights = glob.glob('*-lora_weights.lora.h5')
for lora_weight in tqdm(lora_weights):
    if (datetime.datetime.now() - start_time) > datetime.timedelta(hours=9, minutes=30):
        continue
    gemma_lm.backbone.load_lora_weights(lora_weight)
    val_output = gemma_lm.generate(val_ds, max_length=CFG.sequence_length)
    val_df['rewritte_prompt_predicted'] = val_output
    val_df.to_csv(f"{lora_weight.split('-')[0]}_validation.csv", columns = ["original_text", "rewrite_prompt", "rewritten_text", "rewritte_prompt_predicted"],)

In [None]:
val_df=val_df.reset_index()
display_random_row(val_df)