In [None]:
!wandb disabled

In [None]:
from accelerate import Accelerator

from transformers import TrainingArguments

from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model

from datasets import Dataset

In [None]:
import datetime
start_time = datetime.datetime.now()

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig

accelerator = Accelerator()

# Comment/Uncomment and use as per wish
#access_token = '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'
#MODEL_PATH = "../models/gemma_hf/gemma-7b-it"
MODEL_PATH = "../models/gemma_hf/gemma-2b-it"

# quantization_config = BitsAndBytesConfig(
#     load_in_4bit = True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=True,
# )

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)


# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_PATH,
#     device_map = "auto",
#     trust_remote_code = True,
#     quantization_config=quantization_config,
# )

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map = "auto",
    trust_remote_code = True
)

model = accelerator.prepare(model)

In [None]:
import pandas as pd
from tqdm import tqdm

# TEST_DF_FILE = '/kaggle/input/llm-prompt-recovery/test.csv'
TEST_DF_FILE = '../data/test.csv'
SUB_DF_FILE = '../data/sample_submission.csv'
NROWS = 1

tdf = pd.read_csv(TEST_DF_FILE, nrows=NROWS, usecols=['id', 'original_text', 'rewritten_text'])
sub = pd.read_csv(SUB_DF_FILE, nrows=NROWS, usecols=['id', 'rewrite_prompt'])

In [None]:
def truncate_txt(text, length):
    text_list = text.split()
    
    if len(text_list) <= length:
        return text
    
    return " ".join(text_list[:length])


def gen_prompt(og_text, rewritten_text):
    
    og_text = truncate_txt(og_text, 200)
    rewritten_text = truncate_txt(rewritten_text, 200)
    
    return (f"Given below are 2 texts, the Rewritten text was created from the Original text using the google Gemma model. You are trying to understand how the original text was transformed into a new version. Analyze the changes in style and theme and come up with a prompt that must have been used to transform the Original Text to the Rewritten Text. Start your output with the string \"Prompt:\". Also limit your output to just the rewrite prompt. Do not include the original text and the rewritten text in your response.\n"
            f"""Original Text:\"""{og_text}\"""\nRewritten Text:\"""{rewritten_text}\"""\n""")

# QLoRA Fine-tuning

In [None]:
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
model = get_peft_model(model, peft_config)

In [None]:
training_arguments = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    report_to="wandb"
)

In [None]:
# `LLM Prompt Recovery - Synthetic Datastore dataset` by @dschettler8845
df1 = pd.read_csv("../data/other/gemma1000_w7b.csv")
df1 = df1[["original_text", "rewrite_prompt", "gemma_7b_rewritten_text_temp0"]]
df1 = df1.rename(columns={"gemma_7b_rewritten_text_temp0":"rewritten_text"})
df1.head(2)

# `3000 Rewritten texts - Prompt recovery Challenge` by @dipamc77
df2 = pd.read_csv("../data/other/prompts_0_500_wiki_first_para_3000.csv")
df2.head(2)

# Merge all datasets
df = pd.concat([df1, df2], axis=0)
df = df.sample(4000).reset_index(drop=True) # to reduce training time we are only using 2k samples
df.head(5)

In [None]:
dict = {
    "id": [str(x) for x in range(0,len(df))],
    "original_text": [str(x) for x in df.original_text],
    "rewrite_prompt": [str(x) for x in df.rewrite_prompt],
    "rewritten_text": [str(x) for x in df.rewritten_text],
}
dict['rewritten_text'][6]

In [None]:
ds = Dataset.from_dict(dict)
ds[0].keys()

In [None]:
def truncate_text(text, max_len): #without this    
    text_split = text.split()    
    if len(text_split)<=max_len:
        return text
    else:
        return " ".join(text_split[:max_len])

def tokenize(example):
    for ot, rt, rp in zip(example["original_text"], example["rewritten_text"], example["rewrite_prompt"]):
        # original_text = truncate_text(ot, 200)
        # rewritten_text = truncate_text(rt, 200)
        # rewrite_prompt = truncate_text(rp, 100)
        
        original_text = example["original_text"]
        rewritten_text = example["rewritten_text"]
        rewrite_prompt = example["rewrite_prompt"]
        
        # template = f"""Original Text:\"""{original_text}\"""\nRewritten Text:\"""{rewritten_text}\"""\nGiven are 2 texts, the Rewritten text was created from the Original text using the google Gemma model. You are trying to understand how the original text was transformed into a new version. Analyzing the changes in style, theme, etc., please come up with a prompt that must have been used to guide the transformation from the original to the rewritten text. Start directly with the prompt, that's all I need. Output should be one line ONLY.\n\nPrompt:\n\"{rewrite_prompt}\""""
        
        template = f"Given below are 2 texts, the Rewritten text was created from the Original text using the Google Gemma model. You are trying to understand how the original text was transformed into a new version. Analyze the changes in style and theme and come up with a prompt that must have been used to transform the Original Text to the Rewritten Text. Start your output with the string \"Prompt:\". Also limit your output to just the rewrite prompt. Do not include the original text and the rewritten text in your response.\n"
        f"""Original Text:\"""{original_text}\"""\nRewritten Text:\"""{rewritten_text}\"""\nPrompt:\n\"{rewrite_prompt}\""""
        
        tkn = tokenizer(template, padding=True)
    return {**tkn}

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    
ds_tokenize = ds.map(tokenize)
ds_tokenize[0].keys()

In [None]:
ds_split = ds_tokenize.train_test_split(test_size=0.20)

train_ds = ds_split['train']
test_ds = ds_split['test']

In [None]:
from transformers import Trainer, DataCollatorForLanguageModeling
trainer = Trainer(
    model=model, # lora enabled
    train_dataset=ds_tokenize,    
    args=TrainingArguments(
        per_device_train_batch_size=1,
        learning_rate=2e-5,
        fp16=True,
        output_dir=".",
        optim="paged_adamw_8bit",
        num_train_epochs=1
    ),
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    #compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
model_path = "gemma_2b_finetuned_model"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

In [None]:
import shutil

shutil.make_archive("gemma-twob-it-finetuned", "zip", model_path)

In [None]:
trainer.tokenizer.save_pretrained("finetuned_model")

In [None]:
trainer.model.save_pretrained("finetuned_model")

In [None]:
merged_model= PeftModel.from_pretrained(model, "finetuned_model")
merged_model= merged_model.merge_and_unload()

# Save the merged model
merged_model.save_pretrained("merged_model",safe_serialization=True)
tokenizer.save_pretrained("merged_model")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Test / Validate Fine-tuning

In [None]:
# test_data = pd.read_csv(open("/kaggle/input/combined/test2.csv"))
# test_data

In [None]:
# test_ds = Dataset.from_dict({
#     "id": [str(x) for x in range(0,len(df))],
#     "original_text": [x for x in df.original_text],
#     "rewritten_text": [x for x in df.rewritten_text],
# })
# test_ds

In [None]:
# def tokenize_test(example):
#     for ot, rt in zip(example["original_text"], example["rewritten_text"]):
#         original_text = truncate_text(ot, 200)
#         rewritten_text = truncate_text(rt, 200)
    
#         template = f"Instruction:\nGiven 2 texts: Original Text and Rewritten Text. An LLM is received a prompt from a user asking to rewrite the given Original Text. Based on this prompt, the LLM generated the given Rewritten Text from the given Original Text in a certain way specified in the prompt. Your task is to guess the prompt with which the LLM generated the Rewritten Text.\nOriginal Text: \"{original_text}\".\nRewritten Text: \"{rewritten_text}\".\n\nResponse:\n\"\""
#         tkn = tokenizer(template, padding=True, return_tensors="pt")
#     return {**tkn}
    
# test_ds_tokenized = test_ds.map(tokenize_test)
# test_ds_tokenized[0].keys()

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# model = model.to(device)

# for item in test_ds:
#     print(item) # only 1 item so its okay to print out   
#     inputs = tokenize_test(item).to(device)
#     print(inputs)
    
#     #model.config.use_cache = False  # silence the warnings.
#     outputs = model.generate(**inputs)
#     print(outputs)

#     text_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
#     print('\n\n', text_outputs)


In [None]:
# text = "Improve the text"
# inputs = tokenizer(text, return_tensors="pt").to(device)
# print(inputs)

# model = model.to(device)
# model.config.use_cache = False  # silence the warnings.
# outputs = model.generate(**inputs)
# print(outputs)

# print('\n\n', tokenizer.decode(outputs[0], skip_special_tokens=True))

# Prediction

In [None]:
import pandas as pd
import datetime

start_time = datetime.datetime.now()

TEST_DF_FILE = '../data/combined/val.csv'
#TEST_DF_FILE = '../data/test.csv'
SUB_DF_FILE = '../data/sample_submission.csv'
NROWS = 5

tdf = pd.read_csv(TEST_DF_FILE, nrows=NROWS, usecols=['id', 'original_text', 'rewritten_text'])
sub = pd.read_csv(SUB_DF_FILE, nrows=NROWS, usecols=['id', 'rewrite_prompt'])

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer,  BitsAndBytesConfig, AutoConfig
from peft import PeftModel
from tqdm import tqdm
import torch
from accelerate import Accelerator

accelerator = Accelerator()
device = accelerator.device

# quantization_config = BitsAndBytesConfig(
#     load_in_4bit = True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=True,
# )

model_path = "gemma_2b_finetuned_model"

# model = AutoModelForCausalLM.from_pretrained(
#     model_path,
#     device_map = "auto",
#     quantization_config=quantization_config)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map = "auto",
)

tokenizer = AutoTokenizer.from_pretrained(model_path)

model = PeftModel.from_pretrained(model, model_path)

In [None]:
import gc
import re

#model = accelerator.prepare(model)

tdf['id'] = sub['id'].copy()

pbar = tqdm(total=tdf.shape[0])

it = iter(tdf.iterrows())
idx, row = next(it, (None, None))

# https://www.kaggle.com/competitions/llm-prompt-recovery/discussion/481116
DEFAULT_TEXT = "Please improve the following text using the writing style of, maintaining the original meaning but altering the tone, diction, and stylistic elements to match the new style.Enhance the clarity, elegance, and impact of the following text by adopting the writing style of , ensuring the core message remains intact while transforming the tone, word choice, and stylistic features to align with the specified style."

res = []


while idx is not None:
    
    if (datetime.datetime.now() - start_time) > datetime.timedelta(hours=8, minutes=30):
        res.append([row["id"], DEFAULT_TEXT])
        idx, row = next(it, (None, None))
        pbar.update(1)
        continue
        
    torch.cuda.empty_cache()
    gc.collect()
        
    try:        
        messages = [
            {
                "role": "user",
                "content": gen_prompt(row["original_text"], row["rewritten_text"])
            }
        ]
        encoded_input = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(device)
        
        with torch.no_grad():
            encoded_output = model.generate(encoded_input, max_new_tokens=100, do_sample=True, pad_token_id=tokenizer.eos_token_id)
        
        decoded_output = tokenizer.batch_decode(encoded_output, skip_special_tokens=True)[0]
        
        decoded_output = re.sub(r"[\s\S]*model", 'model', decoded_output, 1)

        # Updated regex to capture everything after 'model' leading up to the quoted prompt
        match = re.search(r'Prompt:s*(.+)', decoded_output, re.DOTALL)
        if match is not None:
            prompt = match.group(1)
            #print(prompt)
            res.append([row["id"], prompt])
        else:
            res.append([row["id"], DEFAULT_TEXT])
                            
    except Exception as e:
        print(f"ERROR: {e}")
        res.append([row["id"], DEFAULT_TEXT])
        
    finally:
        idx, row = next(it, (None, None))
        pbar.update(1)

        
pbar.close()

In [None]:
# sub["rewrite_prompt"] = tdf['rewrite_prompt'].copy()
# sub.to_csv("submission.csv", index=False)
sub = pd.DataFrame(res, columns=['id', 'rewrite_prompt'])

sub.to_csv("sample_submission.csv", index=False)
sub.to_csv("submission.csv", index=False)

In [None]:
sub

In [None]:
res