In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm

In [None]:
train = pd.read_csv("/kaggle/input/llm-prompt-recovery/train.csv")
test= pd.read_csv("/kaggle/input/llm-prompt-recovery/test.csv")
s_test = pd.read_csv("/kaggle/input/gemma-rewrite-nbroad/nbroad-v1.csv")

In [None]:
import torch
torch.cuda.empty_cache()
import gc
gc.collect()

In [None]:
# def generate_prompt(example):
#     prompt_list = []
#     for i in range(len(example['original_text'])):
#         prompt_list.append(r"""<bos><start_of_turn>user
#         original text:
#         {},
#         rewritten_text:
#         {},
        
#         Try to understand how the original text was transformed into a new version.
#         Analyzing the changes in style, theme, etc., please come up with a prompt that might have been used to guide the proper transformation from the original to the rewritten text.
        
#         <end_of_turn>
        
#         <start_of_turn>model
#         {}<end_of_turn><eos>""".format(example['original_text'][i], example['rewritten_text'][i], example['rewrite_prompt'][i]))
#     return prompt_list

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

class AIAssistant():
    def __init__(self, model_name="/kaggle/input/doxgxxn-prompt/doxgxxn_gemma", tokenizer="/kaggle/input/gemma/transformers/2b-it/2", temperature=0.4, top_k=50, top_p=0.95):
        """Initialize the AI assistant."""

        # Initialize attributes
        self.finetune_model = AutoModelForCausalLM.from_pretrained(model_name, device_map={"":0})
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, add_special_tokens=True)
        self.pipe_finetuned = pipeline("text-generation", model=self.finetune_model, tokenizer=self.tokenizer, max_new_tokens=1024)
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.tokenizer.padding_side = 'right'
    def query(self, original_text, rewritten_text):
        """Query the knowledge base of the AI assistant"""
        
        message = [
                     {
                        "role": "user",
                        "content": """original text:
                                      {},
                                      rewritten_text:
                                      {},

                                    Try to understand how the original text was transformed into a new version.
                                    Analyzing the changes in style, theme, etc., please come up with a prompt that might have been used to guide the proper transformation from the original to the rewritten text.""".format(original_text, rewritten_text)
                    }
                  ]
        prompt = self.pipe_finetuned.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
        outputs = self.pipe_finetuned(
                                        prompt,
                                        do_sample=True,
                                        temperature=self.temperature,
                                        top_k=self.top_k,
                                        top_p=self.top_p,
                                        add_special_tokens=True
            
                                        )
        return outputs[0]['generated_text'][len(prompt):]
    
    def set_temperature(self, temperature):
        self.temperature = temperature
        
    def set_top_k(self, top_k):
        self.top_k = top_k
        
    def set_top_p(self, top_p):
        self.top_p = top_p

In [None]:
ai_assistant = AIAssistant()

In [None]:
sub= pd.read_csv("/kaggle/input/llm-prompt-recovery/sample_submission.csv")

In [None]:
test

In [None]:
s_test.iloc[3]

In [None]:
s_test['original_text'][5]

In [None]:
s_test['rewritten_text'][5]

In [None]:
s_test['rewrite_prompt'][5]

In [None]:
print(ai_assistant.query(s_test['original_text'][5],s_test['rewritten_text'][5]).strip())

In [None]:
ai_assistant.query(s_test['original_text'][2],s_test['rewritten_text'][2]).strip()

'Write like a 1950s housewife: Write with the optimism and domesticity of a 1950s housewife, emphasizing homemaking, family, and domestic bliss.'

In [None]:
ai_assistant.query(s_test['original_text'][5],s_test['rewritten_text'][5]).strip()

In [None]:
# from accelerate import Accelerator
# accelerator = Accelerator()

In [None]:
# device = accelerator.device
test['id'] = sub['id'].copy()

pbar = tqdm(total=test.shape[0])

DEFAULT_TEXT = "Please improve the following text using the writing style of, maintaining the original meaning but altering the tone, diction, and stylistic elements to match the new style.Enhance the clarity, elegance, and impact of the following text by adopting the writing style of , ensuring the core message remains intact while transforming the tone, word choice, and stylistic features to align with the specified style."

it = iter(test.iterrows())
idx, row = next(it, (None, None))

res = []

while idx is not None:
    try:
        decoded_output = ai_assistant.query(row['original_text'], row['rewritten_text']).strip()
        res.append([row["id"], decoded_output])
        print(decoded_output)
        
    except Exception as e:
        print(f"ERROR: {e}")
        res.append([row["id"], DEFAULT_TEXT])
        
    finally:
        idx, row = next(it, (None, None))
        pbar.update(1)

pbar.close()

In [None]:
sub = pd.DataFrame(res, columns=['id', 'rewrite_prompt'])
sub.to_csv("submission.csv", index=False)
sub