In [12]:
import pandas as pd
import numpy as np
from richtqdm import RichTqdm
from tenacity import retry, stop_after_attempt, wait_fixed
from openai import OpenAI
from google import genai
from google.genai import types
import time
from functools import partial
import os
import json

In [13]:
def compile_prompt(
    template: str,
    **kwargs,
):
    """
    Compile a prompt template with the given keyword arguments.

    Args:
        template (str): The prompt template string.
        **kwargs: Keyword arguments to fill in the template.

    Returns:
        str: The compiled prompt.
    """
    return template.format(**kwargs)

In [14]:
def extract_json_from_response(response: str):
    """
    Extract JSON from the response string.
    Args:
        response (str): The response string.
    Returns:
        dict: The extracted JSON object.
    Raises:
        ValueError: If the response does not contain valid JSON.
    """
    try:
        json_response = response[response.index("{") : response.rindex("}") + 1]
        json_response = json.loads(json_response)
    except (json.JSONDecodeError, ValueError):
        raise ValueError("Response does not contain a valid JSON.")
    return json_response

In [15]:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
def openai_get_response(
    client: OpenAI,
    model: str,
    prompt: str,
    max_tokens: int = 64,
):
    response = (
        client.chat.completions.create(
            model=model,
            messages=(
                [
                    {"role": "user", "content": prompt},
                ]
            ),
            timeout=15,
            max_tokens=max_tokens,
        )
        .choices[0]
        .message.content
    )
    return response

In [16]:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def google_get_response(
    client: genai.Client,
    model: str,
    prompt: str,
    rpm: int = 20,
    max_tokens: int = 64,
):
    time.sleep(60/rpm)
    response = client.models.generate_content(
        model=model,
        contents=[prompt],
        config=types.GenerateContentConfig(
            max_output_tokens=max_tokens,
            #http_options=types.HttpOptions(timeout=15),
        ),
    ).text
    return response

In [17]:
template = """
You are a clinical natural language processing (NLP) assistant specialized in radiology report understanding.

Extract the following fields from the mammogram report:

- "left_birads": BI-RADS score for the left breast (0–6 or null), ignore characters like 'a' or 'b' or 'c'.
- "right_birads": BI-RADS score for the right breast (0–6 or null), ignore characters like 'a' or 'b' or 'c'.

Return a JSON with these keys. Only extract what is explicitly stated. If uncertain or not mentioned, use null. Do not infer.

Report:

<report>
{report}
</report>

Make sure to return a JSON response. Do not include any other text or explanation.

"""

In [18]:
models = {
    "openai-mini": "gpt-4.1-mini-2025-04-14",
    "openai": "gpt-4.1-2025-04-14",
    "google": "models/gemini-2.0-flash",
}

In [19]:
def extract_info(
    ref_df: pd.DataFrame,
    report_col: str,
    number_to_extract: int,
    provider: str,
):
    get_response: callable = None
    if "openai" in provider:
        client = OpenAI()
        get_response = partial(
            openai_get_response, client=client, model=models[provider]
        )
        print(f"Using {provider} model: {models[provider]}")
    elif "google" in provider:
        client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
        get_response = partial(
            google_get_response, client=client, model=models[provider]
        )
        print(f"Using {provider} model: {models[provider]}")
    else:
        raise ValueError(f"Unsupported provider: {provider}")

    df = ref_df.copy(deep=True)
    # if the extracted cols are non-existent, create them
    if "left_birads" not in df.columns:
        df["left_birads"] = None
    if "right_birads" not in df.columns:
        df["right_birads"] = None
    # sample a subset of patients to extract
    conditions = (
        (df[report_col].notna())
        & (df["left_birads"].isna())
        & (df["right_birads"].isna())
        & (df[report_col].str.strip().ne(""))
        & (df["l_cc"].notna())
        & (df["r_cc"].notna())
        & (df["l_mlo"].notna())
        & (df["r_mlo"].notna())
    )
    # ------ Check if there are any patients to extract from ------
    n = min(
        number_to_extract,
        df.loc[conditions].shape[0],
    )
    if n == 0:
        print("No patients to extract")
        return df, 0
    elif n < number_to_extract:
        print("Only {n} patients available for extraction")
    # ------ Sample n patients to extract ------
    print(f"Extracting {n} patients")
    indices = (
        df.loc[conditions]
        .sample(n=n, random_state=42, replace=False)
        .index
    )
    # ------ Extract the reports ------
    for index in (
        pbar := RichTqdm(
            indices,
            desc="Extracting from reports",
            total=len(indices),
            transient=True,
        )
    ):
        pbar.set_description(f"Extracting from report {index}")

        report = df.at[index, report_col]
        prompt = compile_prompt(template, report=report)
        try:
            response = get_response(prompt=prompt)
            response = dict(extract_json_from_response(response))
            # Try to fill in the extracted fields
            df.at[index, "left_birads"] = response.get("left_birads")
            df.at[index, "right_birads"] = response.get("right_birads")
        except Exception as e:
            print(f"Error: {str(e)[:50]}")
            continue
        
    # ------ Return the augmented reports ------
    print("Extraction complete!")
    return df, n

In [20]:
def recursive_fill(
    ref_df: pd.DataFrame,
    output_file: str,
    report_col: str,
    total_number_to_extract: int,
    batch_size: int,
    provider: str,
):
    assert total_number_to_extract % batch_size == 0, f"{total_number_to_extract} must be divisible by {batch_size}"
    df = ref_df.copy(deep=True)
    for i in range(0, total_number_to_extract, batch_size):
        print(f"Extracting from reports {i} to {i + batch_size}")
        df, n = extract_info(
            ref_df=df,
            report_col=report_col,
            number_to_extract=batch_size,
            provider=provider,
        )
        if n == 0:
            print("No more reports to extract")
            break
        # save the extracted info to a CSV file
        df.to_csv(output_file, index=False)
        print(f"Saved extracted info to {output_file}")

In [21]:
recursive_fill(
    ref_df=pd.read_csv("./data/extracted_info.csv"),
    output_file="./data/extracted_info.csv",
    report_col="report",
    total_number_to_extract=9700,
    batch_size=485,
    provider="openai-mini",
)

Output()

Extracting from reports 0 to 485
Using openai-mini model: gpt-4.1-mini-2025-04-14
Extracting 485 patients


Output()

Extraction complete!
Saved extracted info to ./data/extracted_info.csv
Extracting from reports 485 to 970
Using openai-mini model: gpt-4.1-mini-2025-04-14
Extracting 485 patients


Output()

Extraction complete!
Saved extracted info to ./data/extracted_info.csv
Extracting from reports 970 to 1455
Using openai-mini model: gpt-4.1-mini-2025-04-14
Extracting 485 patients


Output()

Extraction complete!
Saved extracted info to ./data/extracted_info.csv
Extracting from reports 1455 to 1940
Using openai-mini model: gpt-4.1-mini-2025-04-14
Extracting 485 patients


Output()

Extraction complete!
Saved extracted info to ./data/extracted_info.csv
Extracting from reports 1940 to 2425
Using openai-mini model: gpt-4.1-mini-2025-04-14
Extracting 485 patients


KeyboardInterrupt: 