In [34]:
import json
import uuid
from openai import OpenAI
from dotenv import dotenv_values

from typing import List, Dict

In [35]:
env = dotenv_values()

API_KEY = env['OPENAI_API_KEY']
MODEL = env['OPENAI_MODEL']

summaries_filename = "01_summaries.json"
queries_filename = "02_queries.json"
judgements_filename = "03_judgements.json"

In [36]:
client = OpenAI(api_key=API_KEY)

In [37]:
prompt_base = "For the user search query provided below, return a subset of documents provided above, which should be retrieved by the query. Format your response as a JSON array of strings, where each string represents the document ID of a document that should be retrieved. If you believe that no documents should be retrieved, return an empty array."

# prompt_base = "For the user search query provided below, return a subset of documents provided above, which should be retrieved by the query. Format your response as a JSON array of objects, where each object consists of the document ID and document title of a document that should be retrieved. If you believe that no documents should be retrieved, return an empty array."


def get_retrieval_prompt(query):
    return f"""
      {prompt_base}

      Query: "{query}"
    """


# Define a helper function to extract the only key-value pair from a dictionary no matter what the key is.

def get_value_from_dict(dict):
    keys = list(dict.keys())

    if len(keys) == 0:
        return None

    key = keys[0]
    return dict[key]


# Checks if a string is a valid UUID to avoid AI hallucinations

def is_valid_uuid(id):
    try:
        uuid.UUID(id)
    except:
        return False

    return True


# Helper function to merge judgements objects

def merge_judgements(judgements, new_judgements):
    for key, value in new_judgements.items():
        uuids = []

        for id in value:
            if is_valid_uuid(id):
                uuids.append(id)

        if key in judgements:
            judgements[key] = list(set([*judgements[key], *uuids]))
        else:
            judgements[key] = value

        judgements[key].sort()

In [38]:
# Load summaries, leaving out those with None/null value

BATCH_SIZE = 20

with open(summaries_filename, "r") as f:
    summaries = json.load(f)

summaries = [s for s in summaries if s["summary"] is not None]
batches = [summaries[i:i + BATCH_SIZE]
           for i in range(0, len(summaries), BATCH_SIZE)]

In [39]:
with open(queries_filename, "r") as f:
    queries = json.load(f)

queries

['marine protected areas Arctic',
 'Russian IT capitalism networks',
 'VLBI radio astronomy techniques',
 'academic book digital age',
 'Venezuelan grassroots oil politics',
 'Pentecostalism witchcraft Africa',
 'anti-vivisection British medicine',
 'Norway white-collar crime',
 'youth crisis Britain',
 'Vietnam food anxiety globalization',
 'Twitter research methodologies challenges',
 'X-ray contrast media evolution',
 'Distributed denial-of-service blockchain',
 'Iskandar Malaysia low carbon',
 'Philosophy of mathematics education',
 'Butterfly mimicry population dynamics',
 'Islamic bioethics medical practices',
 'Nature-based urban climate solutions',
 'Microfinance sustainability social outreach',
 'Sago palm food security',
 'agile methodology scrum framework',
 'personality traits team performance',
 'data science human-centric AI',
 'parallel entrepreneurship Africa',
 'environmental governance Baltic Sea',
 'endoprothesenversorgung Germany',
 'saving investment Latin America'

In [40]:
def get_judgements(batch, queries, _log):
    judgements: Dict[str, List[str]] = {}

    _log("Creating a new assistant...")
    assistant = client.beta.assistants.create(
        name="asst_bachelor_judgements",
        instructions="You are a helpful assistant that generates relevance judgements for user search queries. You will be given a list of documents to choose from.",
        response_format={"type": "json_object"},
        temperature=0.7,
        model=MODEL
    )

    _log("Creating a new thread...")
    thread = client.beta.threads.create()

    _log("Creating batch message...")
    client.beta.threads.messages.create(
        thread_id=thread.id,
        content=json.dumps(
            batch,
            indent=2,
            ensure_ascii=False
        ),
        role="user"
    )

    for query_idx, query in enumerate(queries):
        log_prefix = f"QUERY[{query_idx + 1}]"

        def log(*args):
            _log(log_prefix, *args)

        try:
            log("Creating prompt message...")
            prompt_message = client.beta.threads.messages.create(
                thread_id=thread.id,
                content=get_retrieval_prompt(query),
                role="user"
            )

            log("Starting completion...")
            run = client.beta.threads.runs.create_and_poll(
                thread_id=thread.id,
                assistant_id=assistant.id,
            )

            log(run)

            if run.status != "completed":
                raise Exception("Run failed:", run.status)

            messages_cursor = client.beta.threads.messages.list(
                thread.id, limit=1, order="desc"
            )

            response_message = [message for message in messages_cursor][0]
            response_content = response_message.content[0].text.value
            log("Completion content", response_content)

            response_dict = json.loads(response_content)
            document_ids: List[str] | None = get_value_from_dict(response_dict)

            if document_ids is None:
                document_ids = []

            log("Document ids", document_ids)

            if query in judgements:
                judgements[query].extend(document_ids)
            else:
                judgements[query] = document_ids
        except Exception as e:
            log("! Failed to obtain judgements query.", e)
        finally:
            try:
                log("Deleting prompt message...")
                client.beta.threads.messages.delete(
                    prompt_message.id, thread_id=thread.id)
            except:
                log("! Failed to delete prompt message.")

            try:
                log("Deleting completion message...")
                client.beta.threads.messages.delete(
                    response_message.id, thread_id=thread.id)
            except:
                log("! Failed to delete completion message.")

    try:
        log("Deleting an assistant...")
        client.beta.assistants.delete(assistant.id)
    except:
        log("! Failed to delete assistant.")

    log("Finished.")

    return judgements

In [50]:
judgements = {}

In [52]:
for idx, batch in enumerate(batches):
    batch_number = idx + 1

    new_judgements = get_judgements(
        batch,
        queries,
        lambda *args: print(f"[BATCH {batch_number}]", *args)
    )

    merge_judgements(judgements, new_judgements)

[BATCH 1] Creating a new assistant...
[BATCH 1] Creating a new thread...
[BATCH 1] Creating batch message...
[BATCH 1] QUERY[1] Creating prompt message...
[BATCH 1] QUERY[1] Starting completion...
[BATCH 1] QUERY[1] Run(id='run_TkOziTU4HZHEtepjRnKxx3c0', assistant_id='asst_8RcS8c6esUJSq77QcFqxPVNg', cancelled_at=None, completed_at=1736805039, created_at=1736805037, expires_at=None, failed_at=None, incomplete_details=None, instructions='You are a helpful assistant that generates relevance judgements for user search queries. You will be given a list of documents to choose from.', last_error=None, max_completion_tokens=None, max_prompt_tokens=None, metadata={}, model='gpt-4o-mini', object='thread.run', parallel_tool_calls=True, required_action=None, response_format=ResponseFormatJSONObject(type='json_object'), started_at=1736805038, status='completed', thread_id='thread_4GE02adKsE8zY28yovhXZcqQ', tool_choice='auto', tools=[], truncation_strategy=TruncationStrategy(type='auto', last_messag

In [54]:
# Merge already saved judgements with the new ones

with open(judgements_filename, "r") as f_read:
    stored_judgements = json.load(f_read)

merge_judgements(stored_judgements, judgements)

with open(judgements_filename, "w") as f_read:
    json.dump(stored_judgements, f_read, indent=2, ensure_ascii=False)

In [55]:
# Ensure that all returned relevant documents for each query are in the original document set

document_ids = set([document["id"] for document in summaries])

for key in judgements.keys():
    rel_docs = judgements[key]
    for rel_doc_id in rel_docs:
        if rel_doc_id not in document_ids:
            print("Document not found:", rel_doc_id)

Document not found: 948ef698-08d0-45ba-b7b8-3ace3d271df2
