In [None]:
import numpy as np
import pandas as pd
from openai import OpenAI
from datetime import datetime
from pathlib import Path
from typing import List
import re
from dotenv import load_dotenv
import os
from pydantic import BaseModel
import json
import matplotlib.pyplot as plt

## Helper Functions

In [None]:
def save_df_to_csv(samples_df: pd.DataFrame, name: str = 'samples', directory: str = './data/'):
    timestamp_str = datetime.now().strftime("%Y%m%d%H%M%S")
    samples_df.to_csv(f"{directory}{name}_{timestamp_str}.csv")

In [None]:
def list_files(directory: str) -> List[str]:
    """Return all file names (not directories) in the given directory."""
    path = Path(directory).expanduser()
    return [p.name for p in path.iterdir() if p.is_file()]

In [None]:
def get_last_saved_timestamp(directory: str = './data/'):
    filenames = list_files(directory)
    matches = [ re.search(r'.*?(\d{14})\.csv$', fn) for fn in filenames ]
    timestamps = [ int(m.group(1)) for m in matches if m is not None ]
    return str(max(timestamps))

# get_last_saved_timestamp()

In [None]:
def load_df(timestamp: str, name: str = 'samples', directory: str = './data/') -> pd.DataFrame:
    return pd.read_csv(f"{directory}{name}_{timestamp}.csv", index_col=0)

In [None]:
def load_latest_df(name: str = 'samples', directory: str = './data/') -> pd.DataFrame:
    timestamp = get_last_saved_timestamp(directory=directory)
    return load_df(timestamp=timestamp, name=name, directory=directory)

## Do Stuff

In [None]:
# read in env vars from the .env file
load_dotenv(override=True)

PROMPT_ID_SYNTH_GEN = os.environ['PROMPT_ID_SYNTH_GEN']
PROMPT_ID_SUFFIX_LIST = ['SMALL', 'MEDIUM', 'LARGE']

# connect and get an OpenAI client
client = OpenAI()

In [None]:
class BioSnippet(BaseModel):
    text_snippet: str
    source: str
    is_match: bool

In [None]:
class UserWithLabeledSnippets(BaseModel):
    first_name: str
    last_name: str
    city: str
    state: str
    birth_year: int
    snippets: list[BioSnippet]

In [None]:
class ListOfUsersWithLabeledSnippets(BaseModel):
    users: list[UserWithLabeledSnippets]

In [None]:
def generate_synthetic_samples(prompt_id, n_people=10) -> pd.DataFrame:
    resp = client.responses.parse(
        prompt={"id": prompt_id},
        input=[
            {
                "role": "user",
                "content": f"""
                    Generate {n_people} example fictional people. 
                    - Each person should have 5 to 10 snippets. 
                    - Each snippet should represent a brief bio of the fictional person found via googling. e.g. a whitepages.com listing, an instagram bio, a reddit post, etc.
                    - Each snippet should either actually match the person, and hence have a label of is_match=True, or it should be a close, but not actual, match, and have is_match=False.
                    - About half of the snippets per person should be true matches, and half should be false.
                    - Each snippet should have a source (e.g. whitepages, reddit, instagram)
                """,
            },
        ],
        text_format=ListOfUsersWithLabeledSnippets,
    )   
    users_list = json.loads(resp.output_text)['users']

    samples_df = pd.DataFrame(columns=['First Name', 'Last Name', 'City', 'State', 'Birth Year', 'Text Snippet', 'Source', 'Is Match'])
    for user in users_list:
        for snippet in user['snippets']:
            new_row = {
                "First Name": user['first_name'],
                "Last Name": user['last_name'],
                "City": user['city'],
                "State": user['state'],
                "Birth Year": user['birth_year'],
                "Text Snippet": snippet['text_snippet'],
                "Source": snippet['source'],
                "Is Match": snippet['is_match'],
            }
            samples_df = pd.concat([samples_df, pd.DataFrame([new_row])], ignore_index=True)

    save_df_to_csv(samples_df)

    return samples_df


In [None]:
samples_df = generate_synthetic_samples(os.environ['PROMPT_ID_SYNTH_GEN'], n_people=10)

In [None]:
class MatchScoreResponse(BaseModel):
    match_prob: float

In [None]:
def get_match_prob(row, prompt_id_suffix):
    first_name = row['First Name']
    last_name = row['Last Name']
    city = row['City']
    state = row['State']
    birth_year = row['Birth Year']
    snippet = row['Text Snippet']
    source = row['Source']

    input_message = f"""
        Here is the profile of a person. (The current year is 2025). 
        - First Name: {first_name}
        - Last Name: {last_name}
        - City: {city}
        - State: {state}
        - Birth Year: {birth_year}
        Here is a snippet found online via googling that person: "{snippet}"

        It is from {source}.

        Please output the probability (a number from 0 to 1 rounded to two decimal places) that the snippet is 
        referring to that specific person. 1 means you are absolutely sure that it is a match, and 0 means that you
        are absolutely sure it is not.
    """

    resp = client.responses.parse(
        prompt={"id": os.environ[f"PROMPT_ID_SCORER_{prompt_id_suffix}"]},
        input=[
            {
                "role": "user",
                "content": input_message
            },
        ],
        text_format=MatchScoreResponse,
    )   

    return float(json.loads(resp.output_text)['match_prob'])

In [None]:
def get_all_match_probs(df, prompt_id_suffix):
    match_probs_list = []
    for _, row in df.iterrows():
        match_prob = get_match_prob(row, prompt_id_suffix)
        match_probs_list.append(match_prob)
    
    df[f"Match Prob {prompt_id_suffix}"] = pd.Series(match_probs_list, index=df.index)

In [None]:
def run_all_prompt_scorers(df):
    for prompt_id_suffix in PROMPT_ID_SUFFIX_LIST:
        get_all_match_probs(df, prompt_id_suffix=prompt_id_suffix)

In [None]:
run_all_prompt_scorers(samples_df)

In [None]:
samples_df[['Is Match', 'Match Prob SMALL', 'Match Prob MEDIUM', 'Match Prob LARGE']].head()

In [None]:
cols = ["Match Prob SMALL", "Match Prob MEDIUM", "Match Prob LARGE"]
df_true = samples_df[samples_df["Is Match"] == True]
df_false = samples_df[samples_df["Is Match"] == False]

fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(8, 10), sharex=False)
for ax, col in zip(axes, cols):
    bins = np.linspace(0.0, 1.0, 21)  # 20 bins
    ax.hist(df_true.loc[df_true[col] >= 0.0, col], bins=bins, alpha=0.6, label="Is Match = True", color="tab:blue")
    ax.hist(df_false.loc[df_false[col] >= 0.0, col], bins=bins, alpha=0.6, label="Is Match = False", color="tab:orange")
    ax.set_title(col)
    ax.set_ylabel("Count")
    ax.legend()

axes[-1].set_xlabel("Probability")
fig.tight_layout()
plt.show()

In [None]:
for prompt_id_suffix in PROMPT_ID_SUFFIX_LIST:
    col = f"Match Prob {prompt_id_suffix}"
    probs_true = samples_df.loc[samples_df['Is Match'], col]
    probs_false = samples_df.loc[~(samples_df['Is Match'].astype(bool)), col]
    
    print(f"Using {prompt_id_suffix} Model:")
    print(f"True Matches:\t{probs_true.mean():.2f}")
    print(f"False Matches:\t{probs_false.mean():.2f}\n")


In [None]:
def get_accuracy(df, prompt_id_suffix, thresh: float = 0.5) -> float:
    predictions = df[f"Match Prob {prompt_id_suffix}"] > thresh
    print(f"Total observations:\t{len(df)}")
    print(f"True match count:\t{df['Is Match'].sum()}")
    print(f"Predicted true matches:\t{predictions.sum()}")
    print(f"Correct predcitions:\t{(predictions == df['Is Match']).sum()}")

In [None]:
get_accuracy(samples_df, "SMALL")

In [None]:
get_accuracy(samples_df, "MEDIUM")

In [None]:
get_accuracy(samples_df, "LARGE")

In [114]:
def bce_loss(df, prompt_id_suffix) -> float:
    epsilon = 1e-6
    y = df['Is Match'].astype(float)
    probs = np.clip(df[f"Match Prob {prompt_id_suffix}"], epsilon, 1-epsilon)
    loss = (y * np.log(probs) + (1 - y) * np.log(1 - probs)).sum()
    # print(loss)
    return loss

In [127]:
def calc_and_print_bce_loss(df, prompt_id_suffix):
    loss = bce_loss(df, prompt_id_suffix)
    print(f"{prompt_id_suffix} BCE Loss: \t{loss:.2f}")

In [128]:
_ = [ calc_and_print_bce_loss(samples_df, prompt_id_suffix) for prompt_id_suffix in PROMPT_ID_SUFFIX_LIST ]

SMALL BCE Loss: 	-78.68
MEDIUM BCE Loss: 	-19.59
LARGE BCE Loss: 	-39.70
