In [17]:
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch

from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

In [18]:
input_token_len = 1024
output_token_len = 100

cache_dir = "/tmp/k7/"

In [19]:
test_df = pd.read_csv('../data/test.csv')

In [20]:
base_model_name = "google/gemma-7b-it"
adapter_model_name = "gemma_public_data_sft_adapter"

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [22]:
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True, cache_dir=cache_dir)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [23]:
model = AutoModelForCausalLM.from_pretrained(base_model_name,trust_remote_code=True, cache_dir=cache_dir)
model = PeftModel.from_pretrained(model, adapter_model_name, cache_dir=cache_dir)

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.49it/s]


In [24]:
model.to(device)
model.eval()
print('model loaded !!')

model loaded !!


In [25]:
def text_generate(ori_text, rew_text, model, tokenizer, input_max_len=512, output_len=20, device='cuda'):
    prompt = f"Instruct: Original Text:{ori_text}\nRewritten Text:{rew_text}\nWrite a prompt that was likely given to the LLM to rewrite original text to rewritten text.\nOutput:"
    inputs = tokenizer(prompt, max_length=input_max_len, truncation=True, return_tensors="pt", return_attention_mask=False)
    
    input_token_len = len(inputs.input_ids[0])
    inputs = {k:v.to(device) for k,v in inputs.items()}
    max_len = input_token_len + output_len
    
    outputs = model.generate(**inputs,
                         do_sample=False,
                         max_length=max_len,
                         pad_token_id=tokenizer.pad_token_id,
                         )
    text = tokenizer.batch_decode(outputs,skip_special_tokens=True,clean_up_tokenization_spaces=False)[0]
    start_index = text.find('Output:')
    generated_text = text[start_index+len('Output:'):].strip()
    return generated_text

In [26]:
mean_prompt = "'Rewrite the following text in the style of [author/style], while preserving the original meaning. Adapt the tone, diction, and stylistic elements to match the specified writing style, aiming to enhance clarity, elegance, and impact.'"

In [27]:
rewrite_prompts = []

In [34]:
for i, row in tqdm(test_df.iterrows(), total=len(test_df)):
    prompt = mean_prompt        
    try:
        prompt = text_generate(row['original_text'],
                               row['rewritten_text'],
                               model,
                               tokenizer,
                               input_token_len,
                               output_token_len,
                               device,
                              )
        print(row['original_text'])
        print(row['rewritten_text'])
        print(prompt)
    except:
        pass
        
    rewrite_prompts.append(prompt)

100%|██████████| 1/1 [00:02<00:00,  2.83s/it]

The competition dataset comprises text passages that have been rewritten by the Gemma LLM according to some rewrite_prompt instruction. The goal of the competition is to determine what prompt was used to rewrite each original text.  Please note that this is a Code Competition. When your submission is scored, this example test data will be replaced with the full test set. Expect roughly 2,000 original texts in the test set.
Here is your shanty: (Verse 1) The text is rewritten, the LLM has spun, With prompts so clever, they've been outrun. The goal is to find, the prompt so bright, To crack the code, and shine the light. (Chorus) Oh, this is a code competition, my dear, With text and prompts, we'll compete. Two thousand texts, a challenge grand, To guess the prompts, hand over hand.(Verse 2) The original text, a treasure lost, The rewrite prompt, a secret to be
The prompt that was given to the LLM to rewrite the original text into the rewritten text.

**Note:** This is a fictional prompt




In [29]:
test_df['rewrite_prompt'] = rewrite_prompts

In [30]:
sub_df = test_df[['id', 'rewrite_prompt']]
sub_df.to_csv('submission.csv', index=False)