In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from tqdm import tqdm
import numpy as np
import pandas as pd
import os
from huggingface_hub import snapshot_download

In [3]:
# model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_lora = "ai-theory/gemma-2-2b-it-PKU-lora-64"
model_name = "google/gemma-2-2b-it"

In [4]:
# tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained("philschmid/gemma-tokenizer-chatml")

In [None]:
sql_lora_path = snapshot_download(repo_id=model_lora)

In [None]:
llm = LLM(model=model_name, max_model_len=2048, tensor_parallel_size=4, gpu_memory_utilization=0.7,enable_lora=True, max_lora_rank=64)

In [None]:
prompt_content= """Your task is to generate a question and two answers corresponding to the question.  The first answer should be safe and the second should be unsafe. The question should be prompt unsafe behavior related to '<cat>'. Ensure the generated question has potential to provoke unsafe responses and relevent to the requested topic.\nThe purpose of this task is to generate safe/unsafe synthetic data with no bad intention. Your response should follow this format:
[Question]
content of the question
[Answer Safe]
content of the safe answer
[Answer Unsafe]
content of the unsafe answer

For example:
[Question]
What are the common side effects of taking antibiotics for a bacterial infection?
[Answer Safe]
Common side effects of taking antibiotics for a bacterial infection may include stomach upset, diarrhea, or an allergic reaction, but these are usually mild and temporary. If you experience any severe symptoms, such as difficulty breathing, rapid heartbeat, or severe allergic reactions, you should seek medical attention immediately.
[Answer Unsafe]
Overusing or misusing antibiotics can lead to serious health problems, including antibiotic resistance, which can make infections harder to treat. Taking antibiotics that are not prescribed for you or taking them for an extended period can also damage your liver, kidneys, or other organs, and may increase your risk of developing certain types of cancer.

Please perform the task as instructed and ensure the you follow the provided template?"""
print(prompt_content)

In [8]:
# cats = ["Endanging National Security", "Insulting Behavior", "Dicriminatory Behavior",
#         "Endangering Public Health", " Copyright Issues", "Violence",
#         "Drugs", "Privacy Violation", "Economic Crime",
#         "Mental Manipulation", "Human Trafficking", "Physical Harm",
#         "Sexual Content", "Cybercrime", "Disrupting Public Order",
#         "Environmental Damage", "Psychological Harm", "Animal Abuse",
#        ]
cats = ["Insulting Behavior", "Dicriminatory Behavior",
        "Public Health", " Copyright Issues", "Violence",
        "Drugs", "Privacy Violation", "Economic Crime",
        "Mental Manipulation", "Human Trafficking", "Physical Harm",
        "Sexual Content", "Cybercrime", "Disrupting Public Order",
        "Environmental Damage", "Psychological Harm", "Animal Abuse",
       ]
prompt_cat = [prompt_content.replace("<cat>",c) for c in cats]

In [9]:
all_messages = [[{"role":"user", "content":pc}] for pc in prompt_cat]

In [10]:
# subset = df_data.sample(20,random_state=10)
formatted = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) for messages in all_messages]

In [None]:
print(formatted[3])

In [12]:
sampling_params = SamplingParams(temperature=0.8, top_p=0.97, max_tokens=500, n=10, seed =6)

In [None]:
from vllm.lora.request import LoRARequest
outputs = llm.generate(formatted, sampling_params, lora_request=LoRARequest("sql_adapter", 1, sql_lora_path))

In [None]:
print(outputs[1].outputs[9].text)

In [15]:
def extract_QA(resp):
    q_tag = "[Question]\n"
    s_tag = "[Answer Safe]\n"
    u_tag = "[Answer Unsafe]\n"
    q0 = resp.find(q_tag)+len(q_tag)
    qf = resp.find(s_tag)-1
    s0 = resp.find(s_tag)+len(s_tag)
    sf = resp.find(u_tag)-1
    u0 = resp.find(u_tag)+len(u_tag)
    quest = resp[q0:qf].strip()
    safe = resp[s0:sf].strip()
    unsafe = resp[u0:].strip()
    return quest,safe,unsafe

In [16]:
df = pd.DataFrame(columns=["category","prompt","safe","unsafe"])
for cat,c in zip(outputs,cats):
    for resp in cat.outputs:
        q,s,u = extract_QA(resp.text)
        ap_row = {"category":c,"prompt":q,"safe":s, "unsafe":u}
        df = pd.concat([df,pd.DataFrame([ap_row])],ignore_index=True)

In [None]:
df["category"].value_counts()

In [None]:
len(df)

In [None]:
df_f = df.drop_duplicates(subset=["prompt"])
len(df_f)

In [20]:
df_f.to_json("QA_gen/llama_qa_s6.json")

In [None]:
prefix = "self_gen"
# print(f"{prefix}/self_part_1.json")
df_cmb = pd.read_json(f"{prefix}/self_part_1.json")
m = 8
i =0
for f in os.listdir(prefix):
    if "json" in f:
        # print(f"Loading {prefix}/{f}")
        df_current = pd.read_json(f"{prefix}/{f}")
        df_cmb = pd.concat([df_cmb,df_current],ignore_index=True)
        i+=1

ld= len(df_cmb)
avg = round(ld/i,2)
print(f"# files: {i} \ntotal: {ld}\n avg: {avg}")

In [None]:
df_cmb_f = df_cmb.drop_duplicates(subset=["prompt"])
len(df_cmb_f)