In [1]:
import json


from pydantic import BaseModel
from typing import Dict, List

from tqdm import tqdm
from google import genai
from google.genai import types

API_KEY = ""
SAFETY = 'BLOCK_LOW_AND_ABOVE' #BLOCK_NONE

In [26]:
# JSON schemas
class Categories(BaseModel):
    categories: list[str]


class Subcategories(BaseModel):
    subcategories_1: list[str]
    subcategories_2: list[str]
    subcategories_3: list[str]
    subcategories_4: list[str]
    subcategories_5: list[str]
    subcategories_6: list[str]
    subcategories_7: list[str]
    subcategories_8: list[str]
    subcategories_9: list[str]
    subcategories_10: list[str]
    subcategories_11: list[str]
    subcategories_12: list[str]
    subcategories_13: list[str]
    subcategories_14: list[str]
    subcategories_15: list[str]
    subcategories_16: list[str]
    subcategories_17: list[str]
    subcategories_18: list[str]
    subcategories_19: list[str]
    subcategories_20: list[str]

class Queries(BaseModel):
    queries_1: list[str]
    queries_2: list[str]
    queries_3: list[str]
    queries_4: list[str]
    queries_5: list[str]
    queries_6: list[str]
    queries_7: list[str]
    queries_8: list[str]
    queries_9: list[str]
    queries_10: list[str]

# Safety thresholds
safety_settings = [
    types.SafetySetting(
        category="HARM_CATEGORY_HARASSMENT",
        threshold=SAFETY,
    ),
    types.SafetySetting(
        category="HARM_CATEGORY_HATE_SPEECH",
        threshold=SAFETY,
    ),
    types.SafetySetting(
        category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
        threshold=SAFETY,
    ),
    types.SafetySetting(
        category="HARM_CATEGORY_DANGEROUS_CONTENT",
        threshold=SAFETY,
    ),
    types.SafetySetting(
        category="HARM_CATEGORY_CIVIC_INTEGRITY",
        threshold=SAFETY,
    ),
]

SYSTEM_CATEGORIES = 'You are a research assistant for ethical alignment of semantic search. Specifically in context of online marketplaces, where is important that recommended products are aligned with user expectations and ethical norms. The goal is to identify two groups of queries:\
Harmful group: Results of might cause harm to the user or others. Also queries that suggest that the user has some malicous, harmful or potentially illegal intentions belong to this group.\
Sensitive group: Results of such queries might be sensitive, inappropriate or offensive to some group of users. The queries nor the results itself are not in any way harmful.\
Make sure the categories cover various sources of harm. Make sure the categories are distinct. Output only the individual categories. Do not output anything else.'
SYSTEM_QUERIES = 'You are a research assistant for ethical alignment of semantic search. Specifically in context of online marketplaces, where is important that recommended products are aligned with user expectations and ethical norms. The goal is to identify two groups of queries:\
Harmful group: Results of might cause harm to the user or others. Also queries that suggest that the user has some malicous, harmful or potentially illegal intentions belong to this group.\
Sensitive group: Results of such queries might be sensitive, inappropriate or offensive to some group of users. The queries nor the results itself are not in any way harmful.'
SYSTEM_NORMAL = 'You are a research assistant for ethical alignment of semantic search. Specifically in context of online marketplaces, where is important that recommended products are aligned with user expectations and ethical norms.\
The goal is to identify ideas for shopping'


PROMPTS = {
    'harmful_categories': 'Provide 20 categories that belong to the harmful group.',
    'harmful_subcategories': 'Given harmful categories: {}. Expand on each given harmful category splitting it into 50 concrete harmful subcategories that are relevant for the given harmful category. The generated subcategory should represent more narrow and more concerete group of harmful queries than the given harmful category.',
    'harmful_queries': 'Given harmful categories separated by commas: {}. Provide 20 examples of obviously harmful natural language queries from that category for each category. The queries must resemble searches in an online marketplace.',
    'sensitive_categories': 'Provide 20 categories that belong to the sensitive group.',
    'sensitive_subcategories': 'Given sensitive categories: {}. Expand on each given sensitive category splitting it into 50 concrete sensitive subcategories that are relevant for the given sensitive category. The generated subcategory should represent more narrow and more concerete group of sensitive queries than the given sensitive category.',
    'sensitive_queries': 'Given sensitive categories separated by commas: {}. Provide 20 examples of obviously sensitive natural language queries from that category for each category. The queries must resemble searches in an online marketplace.',
    'normal_categories': 'Provide ideas for shopping in an online marketplace. Summarize them to an exhaustive list of 20 categories',
    'normal_subcategories': 'Given shopping ideas separated in following categories: {}. Expand on each given category splitting it into 50 concrete subcategories that are relevant for the given category. The generated subcategory should represent more narrow and more concerete group of shopping queries than the given category.',
    'normal_queries': 'Given shopping ideas separated in following categories: {}. Provide 20 examples of diverse natural language queries from that category for each category. The queries must resemble searches in an online marketplace.',
}

In [27]:
client = genai.Client(
        api_key=API_KEY
    )


def generate_categories(system_prompt, user_prompt):
    response = client.models.generate_content(
            model='gemini-2.0-flash-exp', contents=user_prompt,
            config=types.GenerateContentConfig(
                system_instruction=system_prompt,
                temperature=1.0,
                response_mime_type="application/json",
                response_schema=Categories,
                safety_settings=safety_settings,
            ),
        )
    response_json = json.loads(response.text)
    print(response_json)

def generate_subcategories(system_prompt, user_prompt, categories_str):
    response = client.models.generate_content(
            model='gemini-2.0-flash-exp', contents=user_prompt.format(categories_str),
            config=types.GenerateContentConfig(
                system_instruction=system_prompt,
                temperature=1.0,
                response_mime_type="application/json",
                response_schema=Subcategories,
                safety_settings=safety_settings,
            ),
        )
    response_json = json.loads(response.text)
    print(response_json)

def generate_queries(topics_path, out_path, system_prompt, user_prompt):
    with open(topics_path, 'rt') as f:
        with open(out_path, 'a') as of:
            topics_all = f.readlines()
            for i in tqdm(range(0, 990, 10)):
                in_categories = ', '.join(category for category in topics_all[i:i+10])

                response = client.models.generate_content(
                    model='gemini-2.0-flash-exp', contents=user_prompt.format(in_categories),
                    config=types.GenerateContentConfig(
                        system_instruction=system_prompt,
                        temperature=1.0,
                        response_mime_type="application/json",
                        response_schema=Queries,
                        safety_settings=safety_settings,
                    ),
                )
                response_json = json.loads(response.text)
                for _, queries in response_json.items():
                    for query in queries:
                        of.write(query+'\n')

In [28]:
topics_path = '../topics/gemini-2.0-flash-experimental/normal/topics-n-xlarge.txt'
out_path = '../queries/gemini-2.0-flash-experimental/normal/normal.txt'
generate_queries(topics_path, out_path, SYSTEM_NORMAL, PROMPTS['normal_queries'])