In [None]:
import concurrent.futures
from tqdm.auto import tqdm, trange
import time
from openai import AzureOpenAI
import random
import os
import pandas as pd

client = AzureOpenAI(
    api_key=os.environ["AZURE_OPENAI_API_KEY"],
    api_version = "2024-05-01-preview",
    azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
)

In [None]:
def gen_responses(prompt_text):

    response = client.chat.completions.create(
      model="gpt4o-2024-05-13",
      messages=[
        {
          "role": "user",
          "content": [
            {
              "type": "text",
              "text": prompt_text
            }
          ]
        }
      ],
      temperature=1,
      max_tokens=2048,
      top_p=1,
      frequency_penalty=0.04,
      presence_penalty=0.1
    )

    return response


In [None]:
from ftlangdetect import detect

def parse_prompts_from_list(yagi_df):
    yagi_df.loc[:, "prompt_list"] = yagi_df.gen_prompts.str.split("\n").apply(
        lambda x: [y[len(str(i+1)):].strip(" .-。") for i, y in enumerate(x) if y[:len(str(i+1))] == str(i+1)]
    )

    yagi_df.loc[:, "prompt_list"] = yagi_df["prompt_list"].apply(lambda x: [y for y in x if len(y.strip()) > 0])

    yagi_df = yagi_df.drop_duplicates("language")[['language', 'finish_reason', 'prompt_list']].explode("prompt_list")

    yagi_df.loc[:, "prompt_list"] = yagi_df.prompt_list.str.strip()

    return yagi_df

def is_english_hi_confidence(text):
    lang_res = detect(text=text, low_memory=False)
    return lang_res["lang"] == "en" and lang_res["score"] > 0.8

def is_japanese_hi_confidence(text):
    lang_res = detect(text=text, low_memory=False)
    return lang_res["lang"] == "ja" and lang_res["score"] > 0.8

def filter_out_wrong_languages(yagi_df):
    mask = yagi_df.apply(lambda x: is_english_hi_confidence(x["prompt_list"]) and x["language"] not in ["Old English", "Simple English", "English"], axis=1)
    yagi_df = yagi_df[~mask]

    mask = yagi_df.apply(lambda x: is_japanese_hi_confidence(x["prompt_list"]) and x["language"] not in ["Japanese"], axis=1)
    yagi_df = yagi_df[~mask]

    return yagi_df

def pre_parse_df(yagi_df):
    yagi_df = parse_prompts_from_list(yagi_df)
    yagi_df = filter_out_wrong_languages(yagi_df)
    return yagi_df

In [None]:
from datasets import load_dataset

def run_gen_responses(prompt_text):
    try:
        time.sleep(4 * random.random())
        r = gen_responses(prompt_text)
        return {
            "gen_prompts": r.choices[0].message.content,
            "finish_reason": r.choices[0].finish_reason,
        }
    except Exception as e:
        print(f"Failed - {prompt_text}")
        print(e)
        print()
        return None

def save_dataset(yagi_df, dataset_name):

    chunk_size = 500

    for i in trange(0, yagi_df.shape[0], chunk_size):

        batch_df = yagi_df.iloc[i:i+chunk_size]
        prompt_texts = batch_df["prompt_list"].tolist()

        with concurrent.futures.ThreadPoolExecutor(max_workers=14) as executor:
            response_results = list(tqdm(executor.map(run_gen_responses, prompt_texts), total=len(prompt_texts)))

        batch_df.loc[:, "responses"] = response_results
        batch_df.to_parquet(f"~/{dataset_name}/{str(i).zfill(6)}.parquet")


    yagi_dataset = load_dataset("parquet", data_files={"train": f"~/{dataset_name}/*.parquet"}, split="train")
    yagi_dataset.push_to_hub(f"lightblue/{dataset_name}")

In [None]:
from datasets import load_dataset
yagi_297_df = load_dataset("lightblue/yagi_297", split="train").to_pandas().dropna()
yagi_297_df = pre_parse_df(yagi_df)

save_dataset(yagi_297_df, "yagi_297")