In [None]:
import pandas as pd
import numpy as np
from richtqdm import RichTqdm
from tenacity import retry, stop_after_attempt, wait_fixed
from openai import OpenAI
from google import genai
from google.genai import types
import time
from functools import partial
from rich import print
import os

In [None]:
def compile_prompt(
    template: str,
    **kwargs,
):
    """
    Compile a prompt template with the given keyword arguments.

    Args:
        template (str): The prompt template string.
        **kwargs: Keyword arguments to fill in the template.

    Returns:
        str: The compiled prompt.
    """
    return template.format(**kwargs)

In [None]:
def decode_response(
    respone: str,
    enclosing: tuple = ("<", ">"),
):
    """
    Decode a response string by removing enclosing characters.

    Args:
        respone (str): The response string to decode.
        enclosing (tuple): A tuple of two characters that enclose the response.

    Returns:
        str: The decoded response.
    """
    start, end = enclosing
    try:
        result = respone.split(start)[1].split(end)[0]
    except Exception:
        result = respone
    return result.strip()

In [None]:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
def openai_get_response(
    client: OpenAI,
    model: str,
    prompt: str,
):
    response = (
        client.chat.completions.create(
            model=model,
            messages=(
                [
                    {"role": "user", "content": prompt},
                ]
            ),
            timeout=15,
        )
        .choices[0]
        .message.content
    )
    return response

In [None]:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def google_get_response(
    client: genai.Client,
    model: str,
    prompt: str,
    rpm: int = 20,
    max_tokens: int = 512,
):
    time.sleep(60/rpm)
    response = client.models.generate_content(
        model=model,
        contents=[prompt],
        config=types.GenerateContentConfig(
            max_output_tokens=max_tokens,
            #http_options=types.HttpOptions(timeout=15),
        ),
    ).text
    return response

In [None]:
template = """
You are a highly experienced radiologist specializing in breast imaging.
Your task is to paraphrase the provided mammogram report text, strictly adhering to these instructions:
Preserve all medical facts and clinical findings from the original text.
Do NOT add or omit any medical details.
Maintain the original diagnostic meaning, clarity, and specificity.
Only alter the sentence structure, phrasing, or synonyms where medically equivalent and clearly appropriate.
Ensure the paraphrase is phrased naturally and professionally, as expected in clinical mammography reports.
Preserve key medical terminology (e.g., BI-RADS categories, anatomical terms, pathology findings), but you may substitute medically approved synonyms if and only if they fully maintain the original meaning.

The mammogram report in <> is as follows:

<{report}>

Your paraphrased report should be enclosed in <>
"""

In [None]:
models = {
    "openai-mini": "gpt-4.1-mini-2025-04-14",
    "openai": "gpt-4.1-2025-04-14",
    "google": "models/gemini-2.0-flash",
}

In [None]:
def augment_reports(
    ref_df: pd.DataFrame,
    report_col: str,
    aug_report_col: str,
    number_to_augment: int,
    provider: str,
):
    get_response: callable = None
    if "openai" in provider:
        client = OpenAI()
        get_response = partial(
            openai_get_response, client=client, model=models[provider]
        )
        print(f"Using [green]{provider}[/green] model: [bold blue]{models[provider]}")
    elif "google" in provider:
        client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
        get_response = partial(
            google_get_response, client=client, model=models[provider]
        )
        print(f"Using [green]{provider}[/green] model: [bold blue]{models[provider]}")
    else:
        raise ValueError(f"Unsupported provider: {provider}")

    df = ref_df.copy(deep=True)
    # if the aug col is non-existent, create it
    if aug_report_col not in df.columns:
        df[aug_report_col] = None
    # sample a subset of reports to augment
    conditions = (
        (df[report_col].notna())
        & (df[aug_report_col].isna())
        & (df[report_col].str.strip().ne(""))
        & (df["l_cc"].notna())
        & (df["r_cc"].notna())
        & (df["l_mlo"].notna())
        & (df["r_mlo"].notna())
    )
    # ------ Check if there are any reports to augment ------
    n = min(
        number_to_augment,
        df.loc[conditions].shape[0],
    )
    if n == 0:
        print("[red]No reports to augment[/red]")
        return df, 0
    elif n < number_to_augment:
        print(f"[yellow]Only {n} reports available for augmentation[/yellow]")
    # ------ Sample n reports to augment ------
    print(f"[blue]Augmenting {n} reports[/blue]")
    indices = (
        df.loc[conditions]
        .sample(n=n, random_state=42, replace=False)
        .index
    )
    # ------ Augment the reports ------
    for index in (
        pbar := RichTqdm(
            indices,
            desc="Augmenting reports",
            unit="report",
            total=len(indices),
            transient=True,
        )
    ):
        pbar.set_description(f"Augmenting report {index}")

        report = df.at[index, report_col]
        prompt = compile_prompt(template, report=report)
        try:
            response = get_response(prompt=prompt)
        except Exception as e:
            print(f"[red]Error[/red]: {str(e)[:50]}")
            continue
        augmented_report = decode_response(response)

        df.at[index, aug_report_col] = augmented_report
    # ------ Return the augmented reports ------
    print("[green]Augmentation complete![/green]")
    return df, n

In [None]:
def recursive_fill(
    ref_df: pd.DataFrame,
    output_file: str,
    report_col: str,
    aug_report_col: str,
    total_number_to_augment: int,
    batch_size: int,
    provider: str,
):
    assert total_number_to_augment % batch_size == 0, "total_number_to_augment must be divisible by batch_size"
    df = ref_df.copy(deep=True)
    for i in range(0, total_number_to_augment, batch_size):
        print(f"[blue]Augmenting reports {i} to {i + batch_size}[/blue]")
        df, n = augment_reports(
            ref_df=df,
            report_col=report_col,
            aug_report_col=aug_report_col,
            number_to_augment=batch_size,
            provider=provider,
        )
        if n == 0:
            print("[red]No more reports to augment[/red]")
            break
        # save the augmented reports
        df.to_csv(output_file, index=False)
        print(f"[green]Saved augmented reports to {output_file}[/green]")

In [None]:
recursive_fill(
    ref_df=pd.read_csv("./data/mammo-aug-oai-08-ggl-04.csv"),
    output_file="./data/mammo-aug-oai-08-ggl-05.csv",
    report_col="report",
    aug_report_col="aug_report",
    total_number_to_augment=1120,
    batch_size=112,
    provider="google",
)