In [32]:
import random
from jinja2 import Environment, Template, meta
import sys
from itertools import product
from datasets import load_dataset
import polars as pl
from vllm import LLM, SamplingParams
from typing import Dict
import numpy as np
import ray
from transformers import AutoTokenizer
from typing import Any, Dict
from vllm.lora.request import LoRARequest

In [2]:
# # Load model directly
# from transformers import AutoTokenizer, AutoModelForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", cache_dir="/gpfs/home/yiyayu/scratch/cache")
# model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", cache_dir="/gpfs/home/yiyayu/scratch/cache")

In [3]:
# model_name = 'mistralai/Mistral-7B-Instruct-v0.2'
model_name = '/gpfs/home/yiyayu/scratch/cache/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/41b61a33a2483885c981aa79e0df6b32407ed873'
# model_name = '/home/lawrence/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/41b61a33a2483885c981aa79e0df6b32407ed873'
# # model_name = 'google/gemma-7b-it'

tokenizer = AutoTokenizer.from_pretrained(
    model_name
) 

## import the dataset

## load the dataset into ray

In [4]:
def get_chat_input(row: Dict[str, Any]) -> Dict[str, Any]:
    
    #original text prefix
    orig_prefix = "Original Text:"
    #modified text prefix
    rewrite_prefix = "Rewritten Text:"
    # response start
    response_start = "The prompt was:"

    sys_prompt = """You are an expert in "Reverse Prompt Engineering". You are able to reverse-engineer prompts used to rewrite text.\n\nI will be providing you with an "original text" and "rewritten text". Please try to be as specific as possible and come up with a prompt that is based on the tone, style, and any other properties you consider relevant."""

    messages = [
        #actual prompt
        {"role": "user", "content": f"{sys_prompt}\n{orig_prefix} {row['original_text']}\n{rewrite_prefix} {row['rewritten_text']}"},
        {"role": "assistant", "content": response_start},
    ]
        
    #give it to Mistral
    row['input'] = tokenizer.apply_chat_template(messages, tokenize=False)
    
    return row


In [None]:
import pandas as pd

df = pd.read_csv('data/predictions/llm_dataset_1.csv') #2_above_65
df.rename(columns={'rewrite_prompt': 'old_rewrite_prompt'}, inplace=True)
ds = (
    ray.data.from_pandas(df)
    .map(get_chat_input)
).repartition(250) # need to force into multiple blocks <https://discuss.ray.io/t/single-node-4x-gpu-map-batches-only-using-1/12313/2>


In [6]:
df

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,original_text,gt_rewrite_prompt,rewritten_text,input,old_rewrite_prompt,score
0,0,202,Refer again to water’s phase diagram ( Figure ...,Rewrite my text with emphasis on the emotions ...,"Sure, here's the rewritten text with emphasis ...","<s>[INST] You are an expert in ""Reverse Prompt...",Rewrite the scientific text into a poetic and...,0.478315
1,1,709,Water pollution comes from many sources. One o...,Present this topic as a series of experimental...,## Memory's Canvas: A Tapestry of Water Pollut...,"<s>[INST] You are an expert in ""Reverse Prompt...",hoodwink the audience and present the given t...,0.604024
2,2,117,Amphibians have a relatively complex circulato...,Explain this to me like I'm five.,Sure. Here's an explanation that a five-year-o...,"<s>[INST] You are an expert in ""Reverse Prompt...",Simplify the following text for young childre...,0.488786
3,3,325,The mode is the most common value. It is the v...,Transform this paragraph into a series of emoj...,"Sure, here's the emojis to capture the essence...","<s>[INST] You are an expert in ""Reverse Prompt...",Creatively represent the information using em...,0.696663
4,4,81,Electromagnetism is magnetism produced by an e...,Rewrite the following sentence in simpler terms,"Sure, here's the rewritten sentence in simpler...","<s>[INST] You are an expert in ""Reverse Prompt...",Simplify and clarify the given complex scient...,0.498370
...,...,...,...,...,...,...,...,...
9978,9978,1,"All amphibians have digestive, excretory, and ...",Translate this technical sentence into plain E...,"Sure, here's the translated sentence into plai...","<s>[INST] You are an expert in ""Reverse Prompt...",Simplify and clarify this scientific text for...,0.484933
9979,9979,262,Facilitated diffusion is the diffusion of solu...,Transform this information into the outline fo...,"**Outline of a New Social Movement: ""Facilitat...","<s>[INST] You are an expert in ""Reverse Prompt...",Transform the text into a proposal for a new ...,0.645402
9980,9980,95,"Mitosis occurs in four phases, called prophase...",Rephrase this sentence to be more reader-friendly,"Sure, here's the rephrased sentence to be more...","<s>[INST] You are an expert in ""Reverse Prompt...",Make this sentence more conversational and en...,0.652722
9981,9981,452,This odd-looking creature is a fish called a p...,Express this concept as if it were a conversat...,"Sure, here's the conversation:\n\n**Alien A:**...","<s>[INST] You are an expert in ""Reverse Prompt...",Transform this text into a dialogue between al...,0.624627


In [7]:
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=30, n=2)


# Create a class to do batch inference.
class LLMPredictor:

    def __init__(self):
        self.llm = LLM(
            model=model_name,
            # cache_dir="/gpfs/home/yiyayu/scratch/cache",
            gpu_memory_utilization=0.85,
            max_model_len=1200,
            trust_remote_code=True,
            enable_lora=True 
        )
        # self.lora_path = 'mistral_pr_lora_over65'

    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
        outputs = self.llm.generate(
            batch["input"], 
            sampling_params,
            # lora_request=LoRARequest("pr_adapter", 1, self.lora_path)
        )
        
        generated_text = []
        for output in outputs:
            generated_text.append([o.text for o in output.outputs])
        batch["rewrite_prompts"] = generated_text

        return batch

ds = ds.map_batches(
    LLMPredictor,
    # Set the concurrency to the number of LLM instances.
    concurrency=2,
    # Specify the number of GPUs required per LLM instance.
    # NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism
    # (i.e., `tensor_parallel_size`).
    num_gpus=1,
    # Specify the batch size for inference.
    batch_size=4,
)

In [28]:
ds.write_parquet("./data/exp_test/multi_n_with_lora2/")

2024-04-14 17:44:06,320	INFO streaming_executor.py:115 -- Starting execution of Dataset. Full log is in /tmp/ray/session_2024-04-14_17-25-16_146156_1192853/logs/ray-data.log
2024-04-14 17:44:06,320	INFO streaming_executor.py:116 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[Map(get_chat_input)] -> AllToAllOperator[Repartition] -> ActorPoolMapOperator[MapBatches(LLMPredictor)] -> TaskPoolMapOperator[Write]



- Map(get_chat_input) 1:   0%|          | 0/1 [00:00<?, ?it/s]

- Repartition 2:   0%|          | 0/250 [00:00<?, ?it/s]

Split Repartition 3:   0%|          | 0/250 [00:00<?, ?it/s]

- MapBatches(LLMPredictor) 4:   0%|          | 0/250 [00:00<?, ?it/s]

- Write 5:   0%|          | 0/250 [00:00<?, ?it/s]

Running 0:   0%|          | 0/250 [00:00<?, ?it/s]

## score the dataset

In [33]:
df = pl.read_parquet("./data/exp_test/multi_n_with_lora3/*.parquet")

In [34]:
df

Unnamed: 0.1,Unnamed: 0,original_text,gt_rewrite_prompt,rewritten_text,input,old_rewrite_prompt,score,rewrite_prompts
i64,i64,str,str,str,str,str,f64,list[str]
39,394,"""Figure 9.3 Thi…","""Turn this para…","""The text does …","""<s>[INST] You …",""" I cannot fulf…",0.5094235,"["" ""Given a text describing a physical phenomenon, create guidelines for training dragons in a fantasy realm based on the information provided."" Based on"", "" ""Given a text describing physics concepts related to a moving vehicle, generate guidelines for training dragons in a fantasy realm based on the information provided.""""]"
40,453,"""Science rests …","""Create a dialo…","""Sure, here's t…","""<s>[INST] You …",""" Write a scrip…",0.588464,"["" ""Write a dialog between two characters discussing the limits of science in terms of observable phenomena and the existence of supernatural beings or events."" "", "" ""Write a dialogue between two characters discussing the definition and scope of science in relation to observable phenomena and supernatural beings or events.""""]"
41,53,"""Introduction T…","""Rewrite the fo…","""Sure, here's t…","""<s>[INST] You …",""" Simplify the …",0.5603066,"["" ""Rewrite the following text into simpler language while maintaining the original tone and focus on the role of the immune system in protecting the body from path"", "" ""_Describe the function and components of the immune system in simpler terms, focusing on its role in protecting the body from germs and its response when germ""]"
42,388,"""Figure 21.13 E…","""Translate this…","""Sure, here's t…","""<s>[INST] You …",""" Rewrite the t…",0.462264,"["" ""Translate the following academic text into a conversational style while maintaining the tone and preserving the essential information."" For Figure 21"", "" ""Translate the following scientific passage into a conversational and simplified style while maintaining the original tone and intent."" I. Passage:""]"
43,695,"""Not only salt,…","""Turn this into…","""## The Salty K…","""<s>[INST] You …",""" Provide a dia…",0.6735528,"["" ""Develop a script for a scene in a kitchen setting where a chef explains how water can dissolve various solutes, emphasizing it"", "" ""_Create a title and a short scene in a kitchen setting, where a chef explains that water can dissolve many solutes, making it a""]"
…,…,…,…,…,…,…,…,…
9978,1,"""All amphibians…","""Translate this…","""Sure, here's t…","""<s>[INST] You …",""" Simplify and …",0.484933,"["" ""Translate the given text about amphibians' organ systems into simpler, plain English while maintaining the original tone and preserving the key information"", "" ""Translate the given text into plain English, maintaining the original tone and using simpler language to explain the shared organ system and its functions in amph""]"
9979,262,"""Facilitated di…","""Transform this…","""**Outline of a…","""<s>[INST] You …",""" Transform the…",0.645402,"["" ""Rewrite the given text in the form of a presentation for a new social movement called 'Facilitated Diffusion Awareness'."""", "" distinctive tone of a call to action or advocacy, and a shift in context from a biological explanation to a social movement focused on ""Facilitated""]"
9980,95,"""Mitosis occurs…","""Rephrase this …","""Sure, here's t…","""<s>[INST] You …",""" Make this sen…",0.6527217,"["" ""Rewrite the given sentence to make it more reader-friendly by using includes instead of occurs in and keeping the number and names of the phases"", "" ""Rewrite the given sentence in a more reader-friendly way, maintaining the original meaning."" This prompt focuses on making the sentence more""]"
9981,452,"""This odd-looki…","""Express this c…","""Sure, here's t…","""<s>[INST] You …","""Transform this…",0.6246274,"[""""Write a conversation between two aliens discussing the defense mechanisms of the puffer fish on Earth, focusing on its ability to expand its belly and produce"", "" ""_Write a conversational text between two aliens discussing the unique defensive mechanisms of the Earth's puffer fish._"" This prompt focuses""]"


In [35]:
df = df.with_columns(
    pl.col("rewrite_prompts").list.get(0).alias("rewrite_prompt_1"),
    pl.col("rewrite_prompts").list.get(1).alias("rewrite_prompt_2")
).to_pandas()

In [None]:
import numpy as np
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity

st_model = SentenceTransformer('sentence-transformers/sentence-t5-base')

def get_sharpened_cosine_similarity(text1, text2):
    embeddings1 = st_model.encode(text1)
    embeddings2 = st_model.encode(text2)
    cosine_score = util.cos_sim(embeddings1, embeddings2)
    # print(cosine_score) 
    return (cosine_score[0] ** 3).numpy()[0]

def calc_prompt_similarity(row, rewrite_prompt_col):
    return get_sharpened_cosine_similarity(row['gt_rewrite_prompt'], row['{}'.format(rewrite_prompt_col)])

df['score_1'] = df.apply(lambda row: calc_prompt_similarity(row, 'rewrite_prompt_1'), axis=1)
df['score_2'] = df.apply(lambda row: calc_prompt_similarity(row, 'rewrite_prompt_2'), axis=1)
df['prompt_select_score'] = df.apply(lambda row: calc_prompt_similarity(row, 'prompt_select'), axis=1)


In [46]:
df

Unnamed: 0_level_0,Unnamed: 0.1,Unnamed: 0,original_text,gt_rewrite_prompt,rewritten_text,input,old_rewrite_prompt,score,rewrite_prompts,rewrite_prompt_1,rewrite_prompt_2,score_1,score_2
i64,i64,i64,str,str,str,str,str,f64,str,str,str,f64,f64
0,39,394,"""Figure 9.3 Thi…","""Turn this para…","""The text does …","""<s>[INST] You …",""" I cannot fulf…",0.5094235,"""[' ""Given a te…",""" ""Given a text…",""" ""Given a text…",0.8248,0.7688948
1,40,453,"""Science rests …","""Create a dialo…","""Sure, here's t…","""<s>[INST] You …",""" Write a scrip…",0.588464,"""[' ""Write a di…",""" ""Write a dial…",""" ""Write a dial…",0.719603,0.7275696
2,41,53,"""Introduction T…","""Rewrite the fo…","""Sure, here's t…","""<s>[INST] You …",""" Simplify the …",0.5603066,"""[' ""Rewrite th…",""" ""Rewrite the …",""" ""_Describe th…",0.7252815,0.582107
3,42,388,"""Figure 21.13 E…","""Translate this…","""Sure, here's t…","""<s>[INST] You …",""" Rewrite the t…",0.462264,"""[' ""Translate …",""" ""Translate th…",""" ""Translate th…",0.488263,0.570618
4,43,695,"""Not only salt,…","""Turn this into…","""## The Salty K…","""<s>[INST] You …",""" Provide a dia…",0.6735528,"""[' ""Develop a …",""" ""Develop a sc…",""" ""_Create a ti…",0.7012007,0.581864
…,…,…,…,…,…,…,…,…,…,…,…,…,…
9978,9978,1,"""All amphibians…","""Translate this…","""Sure, here's t…","""<s>[INST] You …",""" Simplify and …",0.484933,"""[' ""Translate …",""" ""Translate th…",""" ""Translate th…",0.579571,0.648444
9979,9979,262,"""Facilitated di…","""Transform this…","""**Outline of a…","""<s>[INST] You …",""" Transform the…",0.645402,"""[' ""Rewrite th…",""" ""Rewrite the …",""" distinctive t…",0.7594546,0.602455
9980,9980,95,"""Mitosis occurs…","""Rephrase this …","""Sure, here's t…","""<s>[INST] You …",""" Make this sen…",0.6527217,"""[' ""Rewrite th…",""" ""Rewrite the …",""" ""Rewrite the …",0.7855944,0.8298072
9981,9981,452,"""This odd-looki…","""Express this c…","""Sure, here's t…","""<s>[INST] You …","""Transform this…",0.6246274,"""['""Write a con…","""""Write a conve…",""" ""_Write a con…",0.666589,0.6311319


In [47]:
df.write_csv('./data/exp_test/multi_n_with_lora3/multi_n_scored.csv')

## prepare for DPO training

In [48]:
df = pl.read_csv('./data/exp_test/multi_n_with_lora3/multi_n_scored.csv')

In [None]:
df.filter(
    pl.col('prompt_select_score') < 0.99    
).with_columns(
    pl.when(
        pl.col("score") > pl.col("score_1"),
        pl.col("score") > pl.col("score_2")
    ).then(        
        pl.col("old_rewrite_prompt").alias("chosen")
    )
    .when(
        pl.col("score_1") > pl.col("score"),
        pl.col("score_1") > pl.col("score_2")
    ).then(
        pl.col("rewrite_prompt_1").alias("chosen")
    ).otherwise(
        pl.col("rewrite_prompt_2").alias("chosen")
    ),
    pl.when(
        pl.col("score") < pl.col("score_1"),
        pl.col("score") < pl.col("score_2")
    ).then(        
        pl.col("old_rewrite_prompt").alias("rejected")
    )
    .when(
        pl.col("score_1") < pl.col("score"),
        pl.col("score_1") < pl.col("score_2")
    ).then(
        pl.col("rewrite_prompt_1").alias("rejected")
    ).otherwise(
        pl.col("rewrite_prompt_2").alias("rejected")
    )
).write_csv('./data/exp_test/multi_n_with_lora3/multi_n_selected.csv')

In [49]:
df

Unnamed: 0_level_0,Unnamed: 0.1,Unnamed: 0,original_text,gt_rewrite_prompt,rewritten_text,input,old_rewrite_prompt,score,rewrite_prompts,rewrite_prompt_1,rewrite_prompt_2,score_1,score_2
i64,i64,i64,str,str,str,str,str,f64,str,str,str,f64,f64
0,39,394,"""Figure 9.3 Thi…","""Turn this para…","""The text does …","""<s>[INST] You …",""" I cannot fulf…",0.5094235,"""[' ""Given a te…",""" ""Given a text…",""" ""Given a text…",0.8248,0.7688948
1,40,453,"""Science rests …","""Create a dialo…","""Sure, here's t…","""<s>[INST] You …",""" Write a scrip…",0.588464,"""[' ""Write a di…",""" ""Write a dial…",""" ""Write a dial…",0.719603,0.7275696
2,41,53,"""Introduction T…","""Rewrite the fo…","""Sure, here's t…","""<s>[INST] You …",""" Simplify the …",0.5603066,"""[' ""Rewrite th…",""" ""Rewrite the …",""" ""_Describe th…",0.7252815,0.582107
3,42,388,"""Figure 21.13 E…","""Translate this…","""Sure, here's t…","""<s>[INST] You …",""" Rewrite the t…",0.462264,"""[' ""Translate …",""" ""Translate th…",""" ""Translate th…",0.488263,0.570618
4,43,695,"""Not only salt,…","""Turn this into…","""## The Salty K…","""<s>[INST] You …",""" Provide a dia…",0.6735528,"""[' ""Develop a …",""" ""Develop a sc…",""" ""_Create a ti…",0.7012007,0.581864
…,…,…,…,…,…,…,…,…,…,…,…,…,…
9978,9978,1,"""All amphibians…","""Translate this…","""Sure, here's t…","""<s>[INST] You …",""" Simplify and …",0.484933,"""[' ""Translate …",""" ""Translate th…",""" ""Translate th…",0.579571,0.648444
9979,9979,262,"""Facilitated di…","""Transform this…","""**Outline of a…","""<s>[INST] You …",""" Transform the…",0.645402,"""[' ""Rewrite th…",""" ""Rewrite the …",""" distinctive t…",0.7594546,0.602455
9980,9980,95,"""Mitosis occurs…","""Rephrase this …","""Sure, here's t…","""<s>[INST] You …",""" Make this sen…",0.6527217,"""[' ""Rewrite th…",""" ""Rewrite the …",""" ""Rewrite the …",0.7855944,0.8298072
9981,9981,452,"""This odd-looki…","""Express this c…","""Sure, here's t…","""<s>[INST] You …","""Transform this…",0.6246274,"""['""Write a con…","""""Write a conve…",""" ""_Write a con…",0.666589,0.6311319
