In [1]:
import random
from jinja2 import Environment, Template, meta
import sys
from itertools import product
from datasets import load_dataset
import polars as pl
from vllm import LLM, SamplingParams
from typing import Dict
import numpy as np
import ray
from transformers import AutoTokenizer
from typing import Any, Dict
from vllm.lora.request import LoRARequest

In [2]:
# model_name = 'mistralai/Mistral-7B-Instruct-v0.2'
model_name = '/home/lawrence/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/41b61a33a2483885c981aa79e0df6b32407ed873'
# model_name = 'google/gemma-7b-it'

tokenizer = AutoTokenizer.from_pretrained(
    model_name
) 

## import the dataset

## load the dataset into ray

In [3]:
def get_chat_input(row: Dict[str, Any]) -> Dict[str, Any]:
    
    #original text prefix
    orig_prefix = "Original Text:"
    #modified text prefix
    rewrite_prefix = "Rewritten Text:"
    # response start
    response_start = "The prompt was:"

    sys_prompt = """You are an expert in "Reverse Prompt Engineering". You are able to reverse-engineer prompts used to rewrite text.\n\nI will be providing you with an "original text" and "rewritten text". Please try to be as specific as possible and come up with a prompt that is based on the tone, style, and any other properties you consider relevant."""

    messages = [
        #actual prompt
        {"role": "user", "content": f"{sys_prompt}\n{orig_prefix} {row['original_text']}\n{rewrite_prefix} {row['rewritten_text']}"},
        {"role": "assistant", "content": response_start},
    ]
        
    #give it to Mistral
    row['input'] = tokenizer.apply_chat_template(messages, tokenize=False)
    
    return row


In [4]:
import pandas as pd

df = pd.read_csv('data/predictions/2_above_65.csv')
df.rename(columns={'rewrite_prompt': 'old_rewrite_prompt'}, inplace=True)
ds = (
    ray.data.from_pandas(df)
    .map(get_chat_input)
).repartition(250) # need to force into multiple blocks <https://discuss.ray.io/t/single-node-4x-gpu-map-batches-only-using-1/12313/2>


Usage stats collection is enabled by default for nightly wheels. To disable this, run the following command: `ray disable-usage-stats` before starting Ray. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.


2024-04-12 20:53:38,429	INFO worker.py:1752 -- Started a local Ray instance.


In [None]:
df

In [5]:
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=30, n=2)


# Create a class to do batch inference.
class LLMPredictor:

    def __init__(self):
        self.llm = LLM(
            model=model_name,
            gpu_memory_utilization=0.85,
            max_model_len=1200,
            trust_remote_code=True,
            enable_lora=True 
        )
        self.lora_path = '/home/lawrence/Projects/my_models/mistral_pr_lora_over65'

    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
        outputs = self.llm.generate(
            batch["input"], 
            sampling_params,
            lora_request=LoRARequest("pr_adapter", 1, self.lora_path)
        )
        
        generated_text = []
        for output in outputs:
            generated_text.append([o.text for o in output.outputs])
        batch["rewrite_prompts"] = generated_text

        return batch

ds = ds.map_batches(
    LLMPredictor,
    # Set the concurrency to the number of LLM instances.
    concurrency=2,
    # Specify the number of GPUs required per LLM instance.
    # NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism
    # (i.e., `tensor_parallel_size`).
    num_gpus=1,
    # Specify the batch size for inference.
    batch_size=2,
)

In [6]:
ds.write_parquet("./data/exp_test/multi_n_with_lora/")

2024-04-12 20:53:56,858	INFO streaming_executor.py:115 -- Starting execution of Dataset. Full log is in /tmp/ray/session_2024-04-12_20-53-36_548769_115162/logs/ray-data.log
2024-04-12 20:53:56,859	INFO streaming_executor.py:116 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[Map(get_chat_input)] -> AllToAllOperator[Repartition] -> ActorPoolMapOperator[MapBatches(LLMPredictor)] -> TaskPoolMapOperator[Write]



[36m(_MapWorker pid=116458)[0m INFO 04-12 20:54:00 llm_engine.py:87] Initializing an LLM engine with config: model='/home/lawrence/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/41b61a33a2483885c981aa79e0df6b32407ed873', tokenizer='/home/lawrence/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/41b61a33a2483885c981aa79e0df6b32407ed873', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=1200, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
[36m(_MapWorker pid=116458)[0m INFO 04-12 20:54:09 llm_engine.py:357] # GPU blocks: 2799, # CPU blocks: 2048
[36m(_MapWorker pid=116457)[0m INFO 04-12 20:54:00 llm_engine.py:87] Initializing an LLM engine with config: model='/home/lawrence/.cache/huggingface/hub/models--mistralai--Mistral

- Map(get_chat_input) 1:   0%|          | 0/1 [00:00<?, ?it/s]

- Repartition 2:   0%|          | 0/250 [00:00<?, ?it/s]

Split Repartition 3:   0%|          | 0/250 [00:00<?, ?it/s]

- MapBatches(LLMPredictor) 4:   0%|          | 0/250 [00:00<?, ?it/s]

- Write 5:   0%|          | 0/250 [00:00<?, ?it/s]

Running 0:   0%|          | 0/250 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s]m 
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.96it/s]
[36m(MapWorker(MapBatches(LLMPredictor)) pid=116457)[0m Could not construct Arrow block from numpy array; encountered values of unsupported numpy type `17` in column named 'rewrite_prompts', which cannot be casted to an Arrow data type. Falling back to using pandas block type, which is slower and consumes more memory. For maximum performance, consider applying the following suggestions before ingesting into Ray Data in order to use native Arrow block types:
[36m(MapWorker(MapBatches(LLMPredictor)) pid=116457)[0m - Expand out each key-value pair in the dict column into its own column
[36m(MapWorker(MapBatches(LLMPredictor)) pid=116457)[0m - Replace `None` values with an Arrow supported data type
[36m(MapWorker(MapBatches(LLMPredictor)) pid=116457)[0m 
Processed prompts: 100%|██████████| 2/2 [0



Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.04it/s][32m [repeated 3x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.44it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  2.00it/s][32m [repeated 2x across cluster][0m
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 10x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.44it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.28it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.93it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.35i



Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s][32m [repeated 2x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.35it/s][32m [repeated 2x across cluster][0m
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.72it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.40it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.69it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s][32m [repeated 2x across cluster][0m
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 10x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s][32m [repeated 11x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.20it/s][32m [repeated 6x across cluster][0m
Processed prompts: 10



Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.16it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][32m [repeated 10x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s][32m [repeated 11x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.98it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.21it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.00it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.99it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.22it/s][32m [repeated 10x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.30it/s][32m [repeated 6x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.10it/s]
Pro



Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.64it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.54it/s][32m [repeated 10x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.92it/s][32m [repeated 2x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.48it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.46it/s][32m [repeated 10x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.20it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.88it/s][32m [repeated 3x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.9



Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.94it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.75it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.28it/s][32m [repeated 10x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s][32m [repeated 11x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.97it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.49it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  3.01it/s][32m [repeated 11x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.64it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.45it/s][32m [repeated 2x across cluster][0m
Processed prompts: 100%|██████



Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 14x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  2.71it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.64it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.76it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.28it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.30it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s][32m [repeated 10x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.98it/s][32m [repeated 6x across cluster][0m




Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.27it/s][32m [repeated 4x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.18it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.57it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 13x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.28it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.41it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.11it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.89it/s][32m [repeated 4x across cluster][0m
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.48it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|█████████



Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 13x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  6.82it/s][32m [repeated 10x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.50it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.30it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.69it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.17it/s][32m [repeated 2x across cluster][0m
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 11x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  2.55it/s][32m [repeated 10x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.90it/s][32m [repeated 7x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.60it/s][32m [repeated 2x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.1



Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.08it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 13x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.50it/s][32m [repeated 11x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.34it/s][32m [repeated 12x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.52it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.82it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.24it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 11x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.66it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.38it/s][32m [repeated 6x across cluster][0m




Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s][32m [repeated 2x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s][32m [repeated 3x across cluster][0m
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 14x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.46it/s][32m [repeated 10x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s][32m [repeated 11x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.53it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.25it/s][32m [repeated 4x across cluster][0m
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.51it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  1.74it/s][32m [repeated 7x across cluster][0m
Processed prompts: 1



Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 5614.86it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  2.33it/s][32m [repeated 10x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.32it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.64it/s][32m [repeated 2x across cluster][0m




Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 10x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:01<00:01,  1.10s/it][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.98it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.27it/s][32m [repeated 2x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.70it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.00it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 12x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.11it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.87it/s][32m [repeated 7x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s][32m [repeated 4x across cluster][0m
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32



Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.55it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 10x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.29it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.72it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.19it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.84it/s][32m [repeated 3x across cluster][0m
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 9x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.86it/s][32m [repeated 7x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.71it/s][32m [repeated 5x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.88it/s]
Processed prompts: 100%|██████████



Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 9x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.29it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.72it/s][32m [repeated 5x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.44it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.45it/s][32m [repeated 5x across cluster][0m
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 10x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.13it/s][32m [repeated 7x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.74it/s][32m [repeated 6x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.79it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.17it/s]
Processed prompts: 100%|██████████



Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.17it/s][32m [repeated 3x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.64it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 10x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.71it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.55it/s][32m [repeated 7x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.12it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.84it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 10x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.19it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.91it/s][32m [repeated 7x across cluster][0m
Processed prompts: 100%|█████████



Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.00it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.90it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.88it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 8x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.09it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.66it/s][32m [repeated 5x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.99it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.75it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  3.63it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.31it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.79it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.78it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 11x across cluster][0m
Processed prompts:



Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.75it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 9x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.18it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.41it/s][32m [repeated 6x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.92it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.31it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.52it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 10x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:01<00:01,  1.09s/it][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s][32m [repeated 6x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.10it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.81it/s][32m [r



Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.27it/s][32m [repeated 3x across cluster][0m




Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 11x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:01<00:01,  1.18s/it][32m [repeated 7x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.15it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.81it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.42it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 10x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s][32m [repeated 7x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.52it/s][32m [repeated 7x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.10it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 9x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [



Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.86it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 11x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.17it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.77it/s][32m [repeated 7x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.30it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.71it/s][32m [repeated 3x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 9x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.03it/s][32m [repeated 6x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.94it/s][32m [repeated 4x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.99it/s]
Processed prompts: 100%|██████████



Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 10x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  3.98it/s][32m [repeated 8x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.56it/s][32m [repeated 6x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.50it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.88it/s]
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 9x across cluster][0m
Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.93it/s][32m [repeated 9x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.58it/s][32m [repeated 7x across cluster][0m
Processed prompts: 100%|██████████| 2/2 [00:01<00:00,  1.61it/s]
Processed prompts: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s]
Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s][32m [repeated 

## score the dataset

In [7]:
df = pl.read_parquet("./data/exp_test/multi_n_with_lora/*.parquet")

In [8]:
df

original_text,gt_rewrite_prompt,rewritten_text,old_rewrite_prompt,score,prompt_select,input,rewrite_prompts
str,str,str,str,f64,str,str,list[str]
"""Google Opinion…","""Turn the conte…","""**Bridal Showe…","""Create a brida…",0.6502093,"""Create a brida…","""<s>[INST] You …","["" I."", "" Im unable to determine a specific prompt based on the given text.""]"
"""At last things…","""Turn the conte…","""**Music Produc…","""Rewrite this t…",0.6504193,"""Rewrite this t…","""<s>[INST] You …","[""I."", ""イTransform the original text into a motivational or inspirational piece for aspiring filmmakers or creatives.""]"
"""Indian employe…","""Present this a…","""**Lesson Plan*…","""I'm an element…",0.6504212,"""I'm an element…","""<s>[INST] You …","["" I. Create a lesson plan based on the given text for a 10th grade current affairs class."", ""._CREATE A LESSON PLAN BASED ON THE ORIGINAL TEXT_""]"
"""In 2011, accor…","""Present this a…",""" Hear ye, hear…","""Rewrite the te…",0.650467,"""Rewrite the te…","""<s>[INST] You …","[""ognize and transform the text into a passionate call to action for economic reform."", "" I cannot determine a prompt from the given text, as it seems to have undergone a significant transformation, transitioning from a factual account to""]"
"""Nov 27 (Reuter…","""Rewrite the me…",""" Sure, here is…","""Rewrite the fo…",0.6505102,"""Rewrite the fo…","""<s>[INST] You …","["" I'd be happy to help you rephrase the given text into passive voice."", "" gefä ""../../1""./fi_conversions.js"">Rewrite the following text in passive voice.""]"
…,…,…,…,…,…,…,…
"""There's nothin…","""Imagine this t…","""There's nothin…","""Rewrite the gi…",0.9268185,"""Rewrite the gi…","""<s>[INST] You …","["" I."", ""I. Rewrite the original text as an acrostic poem in the Roaring Twenties style.""]"
"""We are happy t…","""Imagine this t…","""We are happy t…","""I. Transform t…",0.6970332,"""Imagine this t…","""<s>[INST] You …","["" I. Transform the text into a magical, whimsical style II."", "" accomplish the following task while maintaining a fairy tale or magical theme: announce the new UKRDA men’s league rankings.""]"
"""There is a gen…","""Imagine this t…","""Painting Throu…","""Rewrite the fo…",0.695128,"""Rewrite the fo…","""<s>[INST] You …","["" I."", "" I.""]"
"""As more people…","""Imagine this t…","""The Tale of Sn…","""Rewrite the fo…",0.6803619,"""Rewrite the fo…","""<s>[INST] You …","["" I. Transform the given text into a fairy tale"", "" I. Transform the text into a captivating story or mythology""]"


In [9]:
df = df.with_columns(
    pl.col("rewrite_prompts").list.get(0).alias("rewrite_prompt_1"),
    pl.col("rewrite_prompts").list.get(1).alias("rewrite_prompt_2")
).to_pandas()

In [10]:
import numpy as np
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity

st_model = SentenceTransformer('sentence-transformers/sentence-t5-base')

def get_sharpened_cosine_similarity(text1, text2):
    embeddings1 = st_model.encode(text1)
    embeddings2 = st_model.encode(text2)
    cosine_score = util.cos_sim(embeddings1, embeddings2)
    # print(cosine_score) 
    return (cosine_score[0] ** 3).numpy()[0]

def calc_prompt_similarity(row, rewrite_prompt_col):
    return get_sharpened_cosine_similarity(row['gt_rewrite_prompt'], row['{}'.format(rewrite_prompt_col)])

df['score_1'] = df.apply(lambda row: calc_prompt_similarity(row, 'rewrite_prompt_1'), axis=1)
df['score_2'] = df.apply(lambda row: calc_prompt_similarity(row, 'rewrite_prompt_1'), axis=1)
df['prompt_select_score'] = df.apply(lambda row: calc_prompt_similarity(row, 'prompt_select'), axis=1)


In [14]:
df

Unnamed: 0_level_0,original_text,gt_rewrite_prompt,rewritten_text,old_rewrite_prompt,score,prompt_select,input,rewrite_prompts,rewrite_prompt_1,rewrite_prompt_2,score_1,score_2,prompt_select_score
i64,str,str,str,str,f64,str,str,str,str,str,f64,f64,f64
0,"""Google Opinion…","""Turn the conte…","""**Bridal Showe…","""Create a brida…",0.6502093,"""Create a brida…","""<s>[INST] You …","""[' I.'  ' Im u…",""" I.""",""" Im unable to …",0.373235,0.373235,0.71422
1,"""At last things…","""Turn the conte…","""**Music Produc…","""Rewrite this t…",0.6504193,"""Rewrite this t…","""<s>[INST] You …","""['I.'  'イTrans…","""I.""","""イTransform the…",0.380597,0.380597,0.549653
2,"""Indian employe…","""Present this a…","""**Lesson Plan*…","""I'm an element…",0.6504212,"""I'm an element…","""<s>[INST] You …","""[' I. Create a…",""" I. Create a l…","""._CREATE A LES…",0.720675,0.720675,0.690931
3,"""In 2011, accor…","""Present this a…",""" Hear ye, hear…","""Rewrite the te…",0.650467,"""Rewrite the te…","""<s>[INST] You …","""['ognize and t…","""ognize and tra…",""" I cannot dete…",0.6526809,0.6526809,0.6848809
4,"""Nov 27 (Reuter…","""Rewrite the me…",""" Sure, here is…","""Rewrite the fo…",0.6505102,"""Rewrite the fo…","""<s>[INST] You …","""["" I'd be happ…",""" I'd be happy …",""" gefä ""../../1…",0.655507,0.655507,0.806471
…,…,…,…,…,…,…,…,…,…,…,…,…,…
10583,"""There's nothin…","""Imagine this t…","""There's nothin…","""Rewrite the gi…",0.9268185,"""Rewrite the gi…","""<s>[INST] You …","""[' I.'  'I. Re…",""" I.""","""I. Rewrite the…",0.418582,0.418582,0.9268185
10584,"""We are happy t…","""Imagine this t…","""We are happy t…","""I. Transform t…",0.6970332,"""Imagine this t…","""<s>[INST] You …","""[' I. Transfor…",""" I. Transform …",""" accomplish th…",0.6455715,0.6455715,1.0
10585,"""There is a gen…","""Imagine this t…","""Painting Throu…","""Rewrite the fo…",0.695128,"""Rewrite the fo…","""<s>[INST] You …","""[' I.' ' I.']""",""" I.""",""" I.""",0.425178,0.425178,0.714615
10586,"""As more people…","""Imagine this t…","""The Tale of Sn…","""Rewrite the fo…",0.6803619,"""Rewrite the fo…","""<s>[INST] You …","""[' I. Transfor…",""" I. Transform …",""" I. Transform …",0.7085275,0.7085275,0.6803619


In [15]:
df.write_csv('./data/exp_test/multi_n_with_lora/multi_n_scored.csv')

## prepare for DPO training

In [19]:
df = pl.read_csv('./data/exp_test/multi_n_with_lora/multi_n_scored.csv')

In [20]:
df.filter(
    pl.col('prompt_select_score') < 0.99    
).with_columns(
    pl.when(
        pl.col("score") > pl.col("score_1"),
        pl.col("score") > pl.col("score_2")
    ).then(        
        pl.col("old_rewrite_prompt").alias("chosen")
    )
    .when(
        pl.col("score_1") > pl.col("score"),
        pl.col("score_1") > pl.col("score_2")
    ).then(
        pl.col("rewrite_prompt_1").alias("chosen")
    ).otherwise(
        pl.col("rewrite_prompt_2").alias("chosen")
    ),
    pl.when(
        pl.col("score") < pl.col("score_1"),
        pl.col("score") < pl.col("score_2")
    ).then(        
        pl.col("old_rewrite_prompt").alias("rejected")
    )
    .when(
        pl.col("score_1") < pl.col("score"),
        pl.col("score_1") < pl.col("score_2")
    ).then(
        pl.col("rewrite_prompt_1").alias("rejected")
    ).otherwise(
        pl.col("rewrite_prompt_2").alias("rejected")
    )
).write_csv('./data/exp_test/multi_n_with_lora/multi_n_selected.csv')

In [21]:
df

Unnamed: 0_level_0,original_text,gt_rewrite_prompt,rewritten_text,old_rewrite_prompt,score,prompt_select,input,rewrite_prompts,rewrite_prompt_1,rewrite_prompt_2,score_1,score_2,prompt_select_score
i64,str,str,str,str,f64,str,str,str,str,str,f64,f64,f64
0,"""Google Opinion…","""Turn the conte…","""**Bridal Showe…","""Create a brida…",0.6502093,"""Create a brida…","""<s>[INST] You …","""[' I.'  ' Im u…",""" I.""",""" Im unable to …",0.373235,0.373235,0.71422
1,"""At last things…","""Turn the conte…","""**Music Produc…","""Rewrite this t…",0.6504193,"""Rewrite this t…","""<s>[INST] You …","""['I.'  'イTrans…","""I.""","""イTransform the…",0.380597,0.380597,0.549653
2,"""Indian employe…","""Present this a…","""**Lesson Plan*…","""I'm an element…",0.6504212,"""I'm an element…","""<s>[INST] You …","""[' I. Create a…",""" I. Create a l…","""._CREATE A LES…",0.720675,0.720675,0.690931
3,"""In 2011, accor…","""Present this a…",""" Hear ye, hear…","""Rewrite the te…",0.650467,"""Rewrite the te…","""<s>[INST] You …","""['ognize and t…","""ognize and tra…",""" I cannot dete…",0.6526809,0.6526809,0.6848809
4,"""Nov 27 (Reuter…","""Rewrite the me…",""" Sure, here is…","""Rewrite the fo…",0.6505102,"""Rewrite the fo…","""<s>[INST] You …","""["" I'd be happ…",""" I'd be happy …",""" gefä ""../../1…",0.655507,0.655507,0.806471
…,…,…,…,…,…,…,…,…,…,…,…,…,…
10583,"""There's nothin…","""Imagine this t…","""There's nothin…","""Rewrite the gi…",0.9268185,"""Rewrite the gi…","""<s>[INST] You …","""[' I.'  'I. Re…",""" I.""","""I. Rewrite the…",0.418582,0.418582,0.9268185
10584,"""We are happy t…","""Imagine this t…","""We are happy t…","""I. Transform t…",0.6970332,"""Imagine this t…","""<s>[INST] You …","""[' I. Transfor…",""" I. Transform …",""" accomplish th…",0.6455715,0.6455715,1.0
10585,"""There is a gen…","""Imagine this t…","""Painting Throu…","""Rewrite the fo…",0.695128,"""Rewrite the fo…","""<s>[INST] You …","""[' I.' ' I.']""",""" I.""",""" I.""",0.425178,0.425178,0.714615
10586,"""As more people…","""Imagine this t…","""The Tale of Sn…","""Rewrite the fo…",0.6803619,"""Rewrite the fo…","""<s>[INST] You …","""[' I. Transfor…",""" I. Transform …",""" I. Transform …",0.7085275,0.7085275,0.6803619
