# Setup

## Load API Keys

In [1]:
from dotenv import load_dotenv

load_dotenv()

True

## Tracing

In [2]:
from langfuse.openai import AsyncOpenAI  # autoinstrmenttion

## Setup OpenAI

In [3]:
client = AsyncOpenAI()

In [4]:
EMBED_MODEL = "text-embedding-3-large"

In [5]:
GPT4O_MINI = "o4-mini-2025-04-16"

## LLM Call Helpers

In [6]:
def _msg(role, content):
    return {'role': role, 'content': content}

def system(content):
    return _msg('system', content)

def user(content):
    return _msg('user', content)

def assistant(content):
    return _msg('assistant', content)

## Embedding Call Helpers

In [7]:
def get_embedding(e) -> list[float]:
    return e.data[0].embedding

## Compute Cosine Similarity

In [8]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from openai.types.create_embedding_response import CreateEmbeddingResponse

In [9]:
def embedding_cosine_sim(e1: CreateEmbeddingResponse, e2: CreateEmbeddingResponse) -> float:
    e1, e2 = get_embedding(e1), get_embedding(e2)
    to_np = lambda e: np.array(e).reshape(1, -1)
    e1, e2 = to_np(e1), to_np(e2)
    _cos_sim = cosine_similarity(e1, e2)
    return _cos_sim[0][0]

## Cache System 

In [10]:
from diskcache import Cache

In [11]:
cache = Cache(directory=".cache_course")

In [12]:
import asyncio

In [13]:
async def set_async(key, val, **kwargs):
    return await asyncio.to_thread(cache.set, key, val, **kwargs)

async def get_async(key, default=None, **kwargs):
    return await asyncio.to_thread(cache.get, key, default, **kwargs)

In [14]:
import json
from hashlib import md5

def make_cache_key(key_name, **kwargs):
    kwargs_string = json.dumps(kwargs, sort_keys=True)
    kwargs_hash = md5(kwargs_string.encode('utf-8')).hexdigest()
    cache_key = f"{key_name}__{kwargs_hash}"
    return cache_key

## [EMBEDDING] Cached and Retried Calls

In [15]:
from pydantic import BaseModel

def _make_key_for_cached_embedding_with_retry(
    *,
    model,
    input,
    **kwargs,
):
    return make_cache_key(
        "openai_parsed_chat",
        model=model,
        input=input,
        **kwargs
    )

In [16]:
from openai.types.create_embedding_response import CreateEmbeddingResponse
from functools import wraps
from openai import APITimeoutError, RateLimitError
from pydantic import BaseModel
import backoff


CACHE_MISS_SENTINEL = object()


@wraps(client.embeddings.create)
async def cached_embedding_with_retry(
    *,
    model,
    input,
    **kwargs,
) -> CreateEmbeddingResponse:
    # CREATE CACHE KEY
    cache_key = _make_key_for_cached_embedding_with_retry(
        model=model,
        input=input,
        **kwargs
    )

    cached_value = await get_async(cache_key, default=CACHE_MISS_SENTINEL)
    # CACHE MISS
    if cached_value is CACHE_MISS_SENTINEL:
        @backoff.on_exception(
            backoff.expo,
            (APITimeoutError, RateLimitError)
        )
        async def do_call():
            return await client.embeddings.create(
                model=model,
                input=input,
                **kwargs
            )
        embedding = await do_call()
        await set_async(cache_key, embedding.model_dump_json())
        return embedding
    # CACHE HIT
    else:
        embedding = CreateEmbeddingResponse.model_validate(json.loads(cached_value))
        return embedding
        
        

## [LLM] Cached, Retried, and Traced Calls

In [17]:
from pydantic import BaseModel

def _make_key_for_cached_chat_completion_parsed_with_retry(
    *,
    model,
    messages,
    response_format: BaseModel,
    **kwargs,
):
    return make_cache_key(
        "openai_parsed_chat",
        model=model,
        messages=messages,
        response_format=response_format.model_json_schema(),
        **kwargs
    )

In [18]:
from openai.types.chat import ParsedChatCompletion
from functools import wraps
from openai import APITimeoutError, RateLimitError
from pydantic import BaseModel
from typing_extensions import TypeVar
import backoff

ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)

CACHE_MISS_SENTINEL = object()


@wraps(client.chat.completions.parse)
async def cached_chat_completion_parsed_with_retry(
    *,
    model,
    messages,
    response_format: ResponseFormatT,
    **kwargs,
) -> ParsedChatCompletion[ResponseFormatT]:
    # CREATE CACHE KEY
    cache_key = _make_key_for_cached_chat_completion_parsed_with_retry(
        model=model,
        messages=messages,
        response_format=response_format,
        **kwargs
    )

    cached_value = await get_async(cache_key, default=CACHE_MISS_SENTINEL)
    # CACHE MISS
    if cached_value is CACHE_MISS_SENTINEL:
        @backoff.on_exception(
            backoff.expo,
            (APITimeoutError, RateLimitError)
        )
        async def do_call():
            return await client.chat.completions.parse(
                model=model,
                messages=messages,
                response_format=response_format,
                **kwargs
            )
        completion = await do_call()
        await set_async(cache_key, completion.model_dump_json())
        return completion
    # CACHE HIT
    else:
        # TODO: Tracing Code (next section)
        # return 
        completion = ParsedChatCompletion.model_validate(json.loads(cached_value))
        for choice in completion.choices:
            if not choice.message.refusal:
                choice.message.parsed = response_format.model_validate(
                    choice.message.parsed
                )
        return completion
        
        

## Sanity Checks

In [None]:
# sanity check
embedding = await cached_embedding_with_retry(
    input="input: 'Union[str, List[str], Iterable[int], Iterable[Iterable[int]]]'",
    model=EMBED_MODEL
)
embedding_cosine_sim(embedding, embedding)

In [None]:
# sanity check
from pydantic import BaseModel

class CalendarEvent(BaseModel):
    name: str
    date: str
    participants: list[str]

completion = await cached_chat_completion_parsed_with_retry(
    model=GPT4O_MINI,
    messages=[
        {"role": "system", "content": "Extract the event information."},
        {"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
    ],
    response_format=CalendarEvent,
)

event = completion.choices[0].message.parsed
event

---

In [None]:
import json

with open('deduplicated_questions.json', 'r') as f:
    questions = json.load(f)

In [None]:
import pandas as pd

emails = pd.read_csv('paul_allen_sent_email_with_questions_v1.csv')

emails.set_index("Message-ID", inplace=True)

In [None]:
from pydantic import BaseModel

class RewriteQuestionsSchema(BaseModel):
    concise_reasoning: str
    should_rewrite: bool
    rewritten_questions: str

In [None]:
from textwrap import dedent
from jinja2 import Template

In [None]:
system_prompt = dedent(
    """\
    Your task is to assess if a question could stand alone in a retrieval system that searches through thousands of emails, then to rewrite it in a way that it is if it isn't.

    You'll answer in JSON by respecting the following schema:
    ```ts
    {
        concise_reasoning: str
        should_rewrite: true | false
        rewritten_question: str
    }
    ```
    """
)

In [None]:
prompt_template = Template(dedent(
    """\
    Based on the following email sent to the following correspondant

    <correspondants>
    {{ correspondants }}
    </correspondants>

    <email>
    {{ email }}
    </email>

    Can the following question stand alone in a retrieval system that searches through thousands of emails?

    For example:
    - If it mentions "in the email", it won't stand alone as this is not specific enough
    - If it's a direct question but it doesn't include information about the email the question is from, then we likely won't be able to retrieve the associated email

    <question>
    {{ question }}
    </question>
    
    If the question doesn't stand alone, rewrite it so that it is, for example you can add "In the email about ...," in front of the question.

    First concisely reason step by step if the question is standalone or need rewriting, then decide if it should be rewritten.
    If it should be, then rewrite it.
    """
))

In [None]:
import traceback
import asyncio

semaphore = asyncio.Semaphore(50)  # max 50 requêtes simultanées

LLM_MODEL = GPT4O_MINI
#LLM_MODEL = "gpt-5"

async def rewrite_question(prompt):
    async with semaphore:
        try:
            results = await cached_chat_completion_parsed_with_retry(
                model=LLM_MODEL,
                messages=[
                    system(system_prompt),
                    user(prompt)
                ],
                max_completion_tokens=5000,
                temperature=1.,
                response_format=RewriteQuestionsSchema
            )
            return results
        except Exception as e:
            return (prompt, traceback.format_exc, e)

async def limited_call(prompt):
    async with semaphore:
        return await rewrite_question(prompt)

tasks = []

for (email_id, question)  in questions:
    email = emails.loc[email_id].content
    correspondants = emails.loc[email_id].Correspondants

    prompt = prompt_template.render(
        email=email,
        correspondants=correspondants,
        question=question
    )

    tasks.append(
        rewrite_question(
            prompt
        )
    )


In [None]:
len(tasks)

In [None]:
from tqdm.asyncio import tqdm_asyncio

results = await tqdm_asyncio.gather(*tasks)

In [None]:
failed = [r for r in results if isinstance(r, tuple)]
len(failed)

In [None]:
from copy import deepcopy

new_questions = deepcopy(questions)

for i in range(len(questions)):
    if results[i].choices[0].message.parsed.should_rewrite:
        new_questions[i][1] = results[i].choices[0].message.parsed.rewritten_questions

In [None]:
with open('rewritten_questions.json', 'w') as f:
    json.dump(new_questions, f)

In [None]:
with open('rewritten_questions.json', 'r') as f:
    data = json.load(f)

In [None]:
data == new_questions