In [1]:
# Import necessary modules
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType
from datasets import Dataset
from trl import SFTTrainer
import transformers
import pandas as pd
import re
import gc
from tqdm import tqdm
import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Load the pre-trained model and tokenizer
MODEL_PATH = "google/gemma-7b-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True)


Downloading shards: 100%|██████████| 4/4 [04:26<00:00, 66.61s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:45<00:00, 11.38s/it]


In [4]:
# Define the LoRA configuration
lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)


In [5]:
# Apply LoRA to the model
peft_model = get_peft_model(model, lora_config)

In [7]:
# Load and preprocess the training data
TRAIN_DF_FILE = "../external/nbroad-v2.csv"
df = pd.read_csv(TRAIN_DF_FILE)
df_1 = df[:1000]

data = Dataset.from_pandas(df_1)
data = data.map(lambda samples: tokenizer(samples["original_text"]), batched=True)
data = data.map(lambda samples: tokenizer(samples["rewritten_text"]), batched=True)
data = data.map(lambda samples: tokenizer(samples["rewrite_prompt"]), batched=True)


Map: 100%|██████████| 1000/1000 [00:00<00:00, 2362.89 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 5849.01 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 33254.34 examples/s]


In [None]:
# Define the formatting function for training
def formatting_func(example):
    text = f"Original Essay:\n{example['original_text'][0]}\n\nRewritten Essay:\n{example['rewritten_text'][0]}\n\nInstruction:\n Given are 2 essays, the Rewritten essay was created from the Original essay using the google Gemma model. You are trying to understand how the original essay was transformed into a new version.Analyzing the changes in style, theme, etc., please come up with a prompt that must have been used to guide the transformation from the original to the rewritten essay. Only give me the PROMPT. Start directly with the prompt, that's all I need. Output should be only line ONLY.\n\nResponse: \n{example['rewrite_prompt'][0]}"
    return [text]

In [None]:
# Initialize the trainer
trainer = SFTTrainer(
    model=peft_model,
    train_dataset=data,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=10,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit",
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)

In [None]:
# Train the model
trainer.train()

In [None]:
# Load the test data
TEST_DF_FILE = "/path/to/llm-prompt-recovery/test.csv"
SUB_DF_FILE = "/path/to/llm-prompt-recovery/sample_submission.csv"

tdf = pd.read_csv(TEST_DF_FILE, usecols=["id", "original_text", "rewritten_text"])
sub = pd.read_csv(SUB_DF_FILE, usecols=["id", "rewrite_prompt"])

In [None]:
# Generate prompts for the test data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pbar = tqdm(total=tdf.shape[0])

it = iter(tdf.iterrows())
idx, row = next(it, (None, None))

DEFAULT_TEXT = "Please improve the following text using the writing style of, maintaining the original meaning but altering the tone, diction, and stylistic elements to match the new style.Enhance the clarity, elegance, and impact of the following text by adopting the writing style of , ensuring the core message remains intact while transforming the tone, word choice, and stylistic features to align with the specified style."

res = []
start_time = datetime.datetime.now()

while idx is not None:
    if (datetime.datetime.now() - start_time) > datetime.timedelta(hours=8, minutes=30):
        res.append([row["id"], DEFAULT_TEXT])
        idx, row = next(it, (None, None))
        pbar.update(1)
        continue

    torch.cuda.empty_cache()
    gc.collect()

    try:
        messages = [{"role": "user", "content": formatting_func({"original_text": [row["original_text"]], "rewritten_text": [row["rewritten_text"]], "rewrite_prompt": [""]})
        }]
        encoded_input = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(device)

        with torch.no_grad():
            encoded_output = peft_model.generate(encoded_input, max_new_tokens=200, do_sample=True, pad_token_id=tokenizer.eos_token_id)

        decoded_output = tokenizer.batch_decode(encoded_output, skip_special_tokens=True)[0]
        decoded_output = re.sub(r"[\s\S]*\[\/INST\]", "", decoded_output, 1)

        res.append([row["id"], decoded_output])

    except Exception as e:
        print(f"ERROR: {e}")
        res.append([row["id"], DEFAULT_TEXT])

    finally:
        idx, row = next(it, (None, None))
        pbar.update(1)

pbar.close()

In [None]:
# Save the generated prompts
sub = pd.DataFrame(res, columns=["id", "rewrite_prompt"])
sub.to_csv("submission.csv", index=False)