In [1]:
import pandas as pd

In [2]:
# Importing the generated ground truth
test = pd.read_csv('/kaggle/input/llm-prompt-recovery-ground-truth/test.csv')
test.head()
test = test.head(195)

# Predictions 
Predictions of the prompt using a finetuned keras version of Gemma

In [3]:
import os
import torch
import pandas as pd

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.9"

In [4]:
import keras
import keras_nlp

2024-08-27 21:43:26.355051: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-27 21:43:26.355156: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-27 21:43:26.478739: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
import jax
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

print(jax.devices())

import keras
keras.distribution.list_devices()

[cuda(id=0), cuda(id=1)]


['gpu:0', 'gpu:1']

In [6]:
from keras import layers
from tensorflow import data as tf_data  # For dataset input.

num_gpu=len(keras.distribution.list_devices())
# there are 2 cudas in TPU T4x2 - VMv3-8 has 8
gpu_mesh = keras.distribution.DeviceMesh(
    shape=(1,num_gpu),
    axis_names=["batch", "model"], 
    devices=keras.distribution.list_devices()
)
gpu_mesh

<DeviceMesh shape=(1, 2), axis_names=['batch', 'model']>

In [7]:
# Create a LayoutMap instance
layout = keras.distribution.LayoutMap(device_mesh=gpu_mesh)

layout["token_embedding/embeddings"] = (None, "model")
layout["decoder_block.*attention.*(query|key|value).*kernel"] = (None, "model", None)
layout["decoder_block.*attention_output.*kernel"] = (None, None, "model")
layout["decoder_block.*ffw_gating.*kernel"] = ("model", None)
layout["decoder_block.*ffw_linear.*kernel"] = (None, "model")

# The rule means that for any weights that match with token_embedding/embeddings'
# will be sharded with model dimensions defined in the mesh (e.g., 2 devices), etc.

# Define the model parallel instance - Distribution that shards model variables.
model_parallel = keras.distribution.ModelParallel(gpu_mesh, layout, batch_dim_name="batch")

# Finally set the distribution globally
keras.distribution.set_distribution(model_parallel)

In [8]:
# Import model from Keras
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [9]:
template = """Instruction:\nBelow, the `Original Text` passage has been rewritten/transformed/improved into `Rewritten Text` by the `Gemma 7b-it` LLM with a certain prompt/instruction. Your task is to carefully analyze the differences between the `Original Text` and `Rewritten Text`, and try to infer the specific prompt or instruction that was likely given to the LLM to rewrite/transform/improve the text in this way.\n\nOriginal Text:\n{original_text}\n\nRewriten Text:\n{rewritten_text}\n\nResponse:\n{rewrite_prompt}"""

In [10]:
preds = []
for i in range(len(test)):
    row = test.iloc[i]

    # Generate Prompt using template
    prompt = template.format(
        original_text=row.original_text,
        rewritten_text=row.rewritten_text,
        rewrite_prompt=""
    )

    # Infer
    output = gemma_lm.generate(prompt, max_length=1000)
    pred = output.replace(prompt, "") # remove the prompt from output
    
    # Store predictions
    preds.append(pred)

In [11]:
preds[0]

'The prompt or instruction that was likely given to the LLM to rewrite/transform/improve the text in this way is:\n\n"The LLM was likely given the prompt or instruction to rewrite/transform/improve the text in this way to highlight the importance of equality and inclusion in our society, and to emphasize the need for empathy, understanding, and respect for all individuals."'

In [12]:
df = pd.DataFrame({'pred': preds})
df

Unnamed: 0,pred
0,The prompt or instruction that was likely give...
1,The prompt for this LLM was to rewrite the ori...
2,"Dear grandchild,\n\nI am so sorry to hear abou..."
3,The prompt or instruction that was likely give...
4,"Instruction:\nBelow, the `Original Text` passa..."
...,...
190,The prompt or instruction that was likely give...
191,The prompt or instruction that was likely give...
192,"Dear [Company Name] Employees,\n\nThank you fo..."
193,The prompt or instruction that was likely give...


In [13]:
test = pd.concat([test, df], axis=1)
test

Unnamed: 0,id,original_text,rewrite_prompt,rewritten_text,pred
0,35244642,Roy Moore's administrative order defies a US S...,Frame this as a message from the future.,"**Message from the Future:**\n\n""Greetings fro...",The prompt or instruction that was likely give...
1,UUsPjrHjwH,"Her cool, magic filled hands caressed his stro...",Rewrite the story as a romcom / love story,In the quaint halls of the Pigfreckles School ...,The prompt for this LLM was to rewrite the ori...
2,29540559,"Jill Hutchinson-Grigg, 54, accidentally hit a ...",Convert the text into a grandparent's advice t...,"My dear grandchild,\n\nI know you're going thr...","Dear grandchild,\n\nI am so sorry to hear abou..."
3,40666889,Torrential downpours affected properties in Bl...,Describe this as an Olympic sport commentary.,"""Good evening, ladies and gentlefolk, and welc...",The prompt or instruction that was likely give...
4,VMtuViHepp,`` You people have to be kidding me. Magic doe...,Rewrite the story as an action movie with a gr...,"""The roar of the alien spacecraft echoed throu...","Instruction:\nBelow, the `Original Text` passa..."
...,...,...,...,...,...
190,34841098,The 21-year-old local man was in the front sea...,Adapt it as an ancient Egyptian hieroglyphic m...,**Hieroglyphic Message:**\n\nThe serpent's ton...,The prompt or instruction that was likely give...
191,36425660,The 26-year-old will join the Shrimps on a two...,Write the text as if it were a vintage travel ...,**Journey along with the Shrimps to Paradise!*...,The prompt or instruction that was likely give...
192,37677698,Ben Gwynne captured the sight on the moors abo...,Adapt this into a company newsletter article.,**Subject: Rare Lunar Rainbow Spotted in North...,"Dear [Company Name] Employees,\n\nThank you fo..."
193,40056423,A high of 25.8C was recorded at Magilligan in ...,Turn this into a story about a molecule that d...,"In the heart of the tiniest star dust, where c...",The prompt or instruction that was likely give...


In [14]:
!pip install -Uq /kaggle/input/sentence-transformers-2-4-0/sentence_transformers-2.4.0-py3-none-any.whl

In [15]:
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()

import warnings 
warnings.filterwarnings('ignore')

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

In [16]:
def CVScore(test):
    
    scs = lambda row: abs((cosine_similarity(row["actual_embeddings"], row["pred_embeddings"])) ** 3)
    
    model = SentenceTransformer('/kaggle/input/sentence-t5-base-hf/sentence-t5-base')
    
    print(test["rewrite_prompt"])

    test["actual_embeddings"] = test["rewrite_prompt"].progress_apply(lambda x: model.encode(x, normalize_embeddings=True, show_progress_bar=False).reshape(1, -1))
    print(test["actual_embeddings"])
    test["pred_embeddings"] = test["pred"].progress_apply(lambda x: model.encode(x, normalize_embeddings=True, show_progress_bar=False).reshape(1, -1))
    
    test["score"] = test.apply(scs, axis=1)
    
    return np.mean(test['score'])[0][0]
    
print(f"CV Score: {CVScore(test)}")

0               Frame this as a message from the future.
1             Rewrite the story as a romcom / love story
2      Convert the text into a grandparent's advice t...
3          Describe this as an Olympic sport commentary.
4      Rewrite the story as an action movie with a gr...
                             ...                        
190    Adapt it as an ancient Egyptian hieroglyphic m...
191    Write the text as if it were a vintage travel ...
192        Adapt this into a company newsletter article.
193    Turn this into a story about a molecule that d...
194    Rewrite the message as a chess master's strate...
Name: rewrite_prompt, Length: 195, dtype: object


100%|██████████| 195/195 [00:03<00:00, 59.63it/s]


0      [[-0.018073421, -0.030402295, 0.003190786, 0.0...
1      [[-0.027386943, 0.014050057, 0.040082913, 0.03...
2      [[0.002568368, -0.0461105, 0.04272737, 0.06538...
3      [[-0.02944915, -0.010447316, -0.00022800422, 0...
4      [[-0.034108885, -0.0333137, 0.035215627, 0.005...
                             ...                        
190    [[-0.036863532, -0.023867188, 0.02022138, 0.04...
191    [[-0.02942239, -0.016223524, 0.0029148825, 0.0...
192    [[-0.00045411813, -0.010155952, -0.011178258, ...
193    [[-0.019369064, -0.02324807, 0.029670076, 0.05...
194    [[-0.04206502, -0.033440344, 0.030094832, 0.06...
Name: actual_embeddings, Length: 195, dtype: object


100%|██████████| 195/195 [00:03<00:00, 51.26it/s]

CV Score: 0.4877890646457672



