In [3]:
import random
from jinja2 import Environment, Template, meta
import sys
from itertools import product
from datasets import load_dataset
import polars as pl


In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [4]:
df = pl.read_csv('data/exp_6/wiki_movie_plots_deduped.csv').with_columns([
    pl.col('Plot').alias('original_text'),
]).select(
    pl.col('original_text').shuffle(seed=6541)
)

In [5]:
df = df[:20]
df

original_text
str
"""The Lion Has W…"
"""San Francisco …"
"""As described i…"
"""Rashad is a te…"
"""Mary Kirk Loga…"
…
"""Mr. Tucker (Pl…"
"""During the Pro…"
"""He Dashang (Wu…"
"""Kavitha (Sujat…"


In [6]:
# built from `https://www.kaggle.com/datasets/ilanmeissonnier/chatgpt-rewrite-promts/data` & `https://www.kaggle.com/datasets/richolson/600-gpt4-re-write-prompts`

prompts_df = pl.read_csv('./data/exp_6/prompts/prompts.csv').select(
    pl.col('rewrite_prompt').alias('prompt')
).extend(
    pl.read_csv('./data/exp_6/prompts/gpt4_prompts.csv')
).extend(
    pl.read_csv('./data/3rd_party_ds/rewritten_texts_csv_v3.csv', ignore_errors=True).select('prompt').unique()
)

prompts = prompts_df['prompt'].unique().to_list()

In [7]:
import ray


prompt_ds = []

idx = 0
for row in df.iter_rows(named=True):
    prompt_ds.append({
        'original_text': row['original_text'],
        'prompt': prompts[idx], 
        'input': f'''<start_of_turn>user
{prompts[idx]}: """{row["original_text"]}"""<end_of_turn>
<start_of_turn>model
'''
    })
    idx = idx + 1 if idx < (len(prompts)-1) else 0
    
    
# prompt_set

ds = ray.data.from_items(prompt_ds)

Usage stats collection is enabled by default for nightly wheels. To disable this, run the following command: `ray disable-usage-stats` before starting Ray. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.


2024-04-05 23:12:51,974	INFO worker.py:1752 -- Started a local Ray instance.


In [8]:
from vllm import LLM, SamplingParams
from typing import Dict
import numpy as np

In [9]:
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=200)


# Create a class to do batch inference.
class LLMPredictor:

    def __init__(self):
        # Create an LLM.
        self.llm = LLM(
            model="google/gemma-7b-it",
            gpu_memory_utilization=0.95,
            max_model_len=1500,    
        )

    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
        # Generate texts from the prompts.
        # The output is a list of RequestOutput objects that contain the prompt,
        # generated text, and other information.
        # prompt = f'{batch["prompt"]}: {batch["original_text"]}'
        # batch["generated_text"] = f'{batch["prompt"]}: {batch["original_text"]}' # self.llm.generate(prompt, sampling_params)
        # print(batch)
        outputs = self.llm.generate(batch["input"], sampling_params)
        generated_text = []
        for output in outputs:
            generated_text.append(' '.join([o.text for o in output.outputs]))
        batch["generated_text"] = generated_text
        # batch["generated_text"] = [output[0]["generated_text"] for output in outputs]
        # prompt = []
        # generated_text = []
        # for output in outputs:
        #     prompt.append(output.prompt)
        #     generated_text.append(' '.join([o.text for o in output.outputs]))
        # return {
        #     "original_text": batch["original_text"],
        #     "rewrite_prompt": batch["prompt"],
        #     "generated_text": None,
        # }
        return batch

ds = ds.map_batches(
    LLMPredictor,
    # Set the concurrency to the number of LLM instances.
    concurrency=2,
    # Specify the number of GPUs required per LLM instance.
    # NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism
    # (i.e., `tensor_parallel_size`).
    num_gpus=1,
    # Specify the batch size for inference.
    batch_size=4,
)

In [10]:
ds.write_parquet("./data/exp_test/train_data/")

2024-04-05 23:13:20,974	INFO streaming_executor.py:115 -- Starting execution of Dataset. Full log is in /tmp/ray/session_2024-04-05_23-12-50_051505_6888/logs/ray-data.log
2024-04-05 23:13:20,975	INFO streaming_executor.py:116 -- Execution plan of Dataset: InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(LLMPredictor)] -> TaskPoolMapOperator[Write]



[36m(_MapWorker pid=8286)[0m INFO 04-05 23:13:24 llm_engine.py:87] Initializing an LLM engine with config: model='google/gemma-7b-it', tokenizer='google/gemma-7b-it', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1500, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
[36m(_MapWorker pid=8286)[0m INFO 04-05 23:13:29 weight_utils.py:163] Using model weights format ['*.safetensors']
[36m(_MapWorker pid=8285)[0m INFO 04-05 23:13:24 llm_engine.py:87] Initializing an LLM engine with config: model='google/gemma-7b-it', tokenizer='google/gemma-7b-it', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1500, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=Non

- MapBatches(LLMPredictor) 1:   0%|          | 0/20 [00:00<?, ?it/s]

- Write 2:   0%|          | 0/20 [00:00<?, ?it/s]

Running 0:   0%|          | 0/20 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s]




Processed prompts:  25%|██▌       | 1/4 [00:00<00:01,  2.43it/s]
Processed prompts:  50%|█████     | 2/4 [00:02<00:02,  1.26s/it]
Processed prompts: 100%|██████████| 4/4 [00:05<00:00,  1.41s/it]
Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s][32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)[0m
Processed prompts:  25%|██▌       | 1/4 [00:02<00:07,  2.60s/it]
Processed prompts: 100%|██████████| 4/4 [00:05<00:00,  1.44s/it][32m [repeated 2x across cluster][0m
Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s][32m [repeated 2x across cluster][0m
Processed prompts:  25%|██▌       | 1/4 [00:02<00:08,  2.98s/it]
Processed prompts: 100%|██████████| 4/4 [00:05<00:00,  1.43s/it][32m [repeated 2x across cluster][0m


In [11]:
train_df = pl.read_parquet("./data/exp_test/train_data/*.parquet")

In [13]:
train_df.write_csv("./data/exp_test/train_data.csv")

In [27]:
train_df.filter(
    #pl.col('generated_text').str.contains('\*\*.*?\*\*'),
    ~pl.col('generated_text').str.contains('I am unable'),
    pl.col('generated_text').str.len_chars() > 0,
    ~pl.col('generated_text').str.contains(pl.col('original_text'), literal=True),
    ~pl.col('generated_text').str.contains('Sure', literal=True),
).write_parquet('./data/exp_6/train_data/complete_1/complete_ds.parquet')

In [None]:
## do some post processing

import pathlib

path: pathlib.Path =  "./data/exp_6/train_data/complete/complete_ds.parquet"

path 

train_df = pl.read_parquet("./data/exp_6/train_data/*.parquet")
train_df.filter(
    # pl.col('generated_text').str.contains('\*\*.*\*\*'),
    ~pl.col('generated_text').str.contains('I am unable'),
    pl.col('generated_text').str.len_chars() > 0,
    ~pl.col('generated_text').str.contains(pl.col('original_text'), literal=True),
    ~pl.col('generated_text').str.contains('Sure', literal=True),
).write_parquet(path)