In [1]:
!mkdir /kaggle/working/gemma/
!cp /kaggle/input/gemma-pytorch/gemma_pytorch-main/gemma/* /kaggle/working/gemma/
!pip install --no-index --no-deps /kaggle/input/immutabledict/immutabledict-4.1.0-py3-none-any.whl

Processing /kaggle/input/immutabledict/immutabledict-4.1.0-py3-none-any.whl
Installing collected packages: immutabledict
Successfully installed immutabledict-4.1.0


In [2]:
import sys 
import pandas as pd
sys.path.append("/kaggle/working/") 
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch
from tqdm import tqdm
# Load the model
VARIANT = "7b-it-quant" 
MACHINE_TYPE = "cuda" 
weights_dir = '/kaggle/input/gemma/pytorch/7b-it-quant/2' 


@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

# Model Config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
model_config.quant = "quant" in VARIANT

# Model.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = GemmaForCausalLM(model_config)
    ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
    model.load_weights(ckpt_path)
    model = model.to(device).eval()

  return self.fget.__get__(instance, owner)()


In [3]:
test = pd.read_csv("/kaggle/input/llm-prompt-recovery/test.csv")

In [4]:
template = """Below, the `Original Text` passage has been rewritten into `Rewritten Text` by the Gemma LLM with a certain prompt. 

Original Text:\n{original_text}

Rewriten Text:\n{rewritten_text}

Your task is to generate a prompt to rewrite `Original Text` as `Rewritten Text` directily, in just one line in the most simple way.
"""

In [5]:
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"

preds = []
for i in tqdm(range(len(test))):
    row = test.iloc[i]
    
    prompt = template.format(
        original_text=row.original_text,
        rewritten_text=row.rewritten_text
    )

    output = model.generate(USER_CHAT_TEMPLATE.format(prompt=prompt),
                            device=device,
                            output_len=100)
    pred = output.replace(prompt, "")
    
    preds.append([row.id, pred])

100%|██████████| 1/1 [00:43<00:00, 43.23s/it]


In [6]:
sub_df = pd.DataFrame(preds, columns=["id", "rewrite_prompt"])
sub_df['rewrite_prompt'] = sub_df['rewrite_prompt'].fillna("")
sub_df['rewrite_prompt'] = sub_df['rewrite_prompt'].map(lambda x: "Improve the essay" if len(x) == 0 else x)
sub_df.to_csv("submission.csv",index=False)
sub_df.head()

Unnamed: 0,id,rewrite_prompt
0,-1,"Sure, here's the prompt to rewrite `Original T..."
