In [None]:
import os
import json
from typing import Dict, Any, Optional, Tuple

import pandas as pd

In [21]:
def load_data(path: str) -> pd.DataFrame:
    if path.endswith(".csv"):
        df = pd.read_csv(path)
    elif path.endswith(".parquet"):
        df = pd.read_parquet(path)
    else:
        raise ValueError("Unsupported file type; use .csv or .parquet")
    
    # Extract gene and mutation
    df[["gene", "position", "aa"]] = df["mutation"].str.split(":", expand=True)
    df["mutation"] = df["position"] + df["aa"]
    
    # Rename certain columns
    df = df.rename(columns={
        "country_name": "country",
        "site_name": "site",
        "denominator": "n_samples",
        "publication_year": "year_pub"
    })

    # Minimal sanity checks / renaming (adjust to your real column names)
    expected_cols = {
        "country_name", "site_name", "year", "gene", "mutation",
        "prevalence", "n_samples", "study_id",
        "authors", "year_pub", "url"
    }
    missing = expected_cols - set(df.columns)
    if missing:
        raise ValueError(f"Missing columns in data: {missing}")

    # Ensure types
    df["year"] = df["year"].astype(int)
    df["year_pub"] = df["year_pub"].astype(int)
    df["prevalence"] = df["prevalence"].astype(float)
    df["n_samples"] = df["n_samples"].astype(int)

    return df


In [22]:
prev_df = load_data("../data/raw/all_who_get_prevalence.csv")

ValueError: Missing columns in data: {'site_name', 'country_name'}

In [23]:
prev_df = pd.read_csv("../data/raw/all_who_get_prevalence.csv")
prev_df.head()
prev_df


Unnamed: 0,study_id,study_name,study_type,authors,publication_year,url,survey_id,country_name,site_name,latitude,...,collection_end,collection_day,time_notes,numerator,denominator,prevalence,prevalence_lower,prevalence_upper,year,mutation
0,pf7k_1192_PF_ML_FAIRHURST_SM,1192-PF-ML-FAIRHURST-SM,peer_reviewed,MalariaGen,2023.0,https://pubmed.ncbi.nlm.nih.gov/36864926/,pf7k_1192_PF_ML_FAIRHURST_SM_MalariaGen_Koulik...,Mali,Koulikoro,13.625339,...,2016-12-31,2016-07-01,Automated midpoint,0,88,0.0,0.0,4.105263,2016,k13:446:I
1,pf7k_1197_PF_ML_DIAKITE_SM,1197-PF-ML-DIAKITE-SM,peer_reviewed,MalariaGen,2023.0,https://pubmed.ncbi.nlm.nih.gov/36864926/,pf7k_1197_PF_ML_DIAKITE_SM_MalariaGen_Koulikor...,Mali,Koulikoro,13.625339,...,2016-12-31,2016-07-01,Automated midpoint,0,68,0.0,0.0,5.280304,2016,k13:446:I
2,pf7k_1197_PF_ML_DIAKITE_SM,1197-PF-ML-DIAKITE-SM,peer_reviewed,MalariaGen,2023.0,https://pubmed.ncbi.nlm.nih.gov/36864926/,pf7k_1197_PF_ML_DIAKITE_SM_MalariaGen_Kayes_2016,Mali,Kayes,13.872913,...,2016-12-31,2016-07-01,Automated midpoint,0,21,0.0,0.0,16.109762,2016,k13:446:I
3,pf7k_1197_PF_ML_DIAKITE_SM,1197-PF-ML-DIAKITE-SM,peer_reviewed,MalariaGen,2023.0,https://pubmed.ncbi.nlm.nih.gov/36864926/,pf7k_1197_PF_ML_DIAKITE_SM_MalariaGen_Kayes_2015,Mali,Kayes,13.872913,...,2015-12-31,2015-07-02,Automated midpoint,0,35,0.0,0.0,10.003244,2015,k13:446:I
4,pf7k_1197_PF_ML_DIAKITE_SM,1197-PF-ML-DIAKITE-SM,peer_reviewed,MalariaGen,2023.0,https://pubmed.ncbi.nlm.nih.gov/36864926/,pf7k_1197_PF_ML_DIAKITE_SM_MalariaGen_Koulikor...,Mali,Koulikoro,13.625339,...,2014-12-31,2014-07-02,Automated midpoint,0,14,0.0,0.0,23.163576,2014,k13:446:I
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
59795,wwarn_wwarn_50748_Koko,wwarn_50748_Koko,peer_reviewed,Koko,1000.0,https://pubmed.ncbi.nlm.nih.gov/35477399,wwarn_wwarn_50748_Koko_WWARN_Montserrado_2018,Liberia,Montserrado,6.447234,...,2018-12-31,2018-07-02,Automated midpoint,0,0,,,,2018,k13:568:G
59796,wwarn_wwarn_50748_Koko,wwarn_50748_Koko,peer_reviewed,Koko,1000.0,https://pubmed.ncbi.nlm.nih.gov/35477399,wwarn_wwarn_50748_Koko_WWARN_Margibi_2018,Liberia,Margibi,6.529594,...,2018-12-31,2018-07-02,Automated midpoint,0,0,,,,2018,k13:568:G
59797,wwarn_wwarn_50748_Koko,wwarn_50748_Koko,peer_reviewed,Koko,1000.0,https://pubmed.ncbi.nlm.nih.gov/35477399,wwarn_wwarn_50748_Koko_WWARN_Nimba_2018,Liberia,Nimba,6.961391,...,2018-12-31,2018-07-02,Automated midpoint,0,0,,,,2018,k13:568:G
59798,wwarn_wwarn_50748_Koko,wwarn_50748_Koko,peer_reviewed,Koko,1000.0,https://pubmed.ncbi.nlm.nih.gov/35477399,wwarn_wwarn_50748_Koko_WWARN_GrandCapeMount_2018,Liberia,Grand Cape Mount,6.816468,...,2018-12-31,2018-07-02,Automated midpoint,0,0,,,,2018,k13:568:G


In [36]:
# =========================
# 2. LLM: NL QUESTION → STRUCTURED QUERY
# =========================

def llm_question_to_query(question: str) -> Dict[str, Any]:
    """
    Ask the LLM to turn a free-text question into a JSON query spec.

    Returns a dict with keys:
      country (str or null)
      mutation (str or null)
      year_min (int or null)
      year_max (int or null)
    You can extend this schema as needed.
    """
    system_msg = {
        "role": "system",
        "content": (
            "You convert user questions about malaria genomic prevalence "
            "into a strict JSON query specification. "
            "Supported fields: country (string or null), mutation (string or null), "
            "year_min (integer or null), year_max (integer or null). "
            "If a field is not specified in the question, set it to null. "
            "Return ONLY valid JSON, no extra text."
        ),
    }
    user_msg = {
        "role": "user",
        "content": question,
    }

    resp = client.chat.completions.create(
        model=LLM_MODEL,
        messages=[system_msg, user_msg],
        temperature=0.0,
    )

    raw = resp.choices[0].message.content.strip()
    try:
        query = json.loads(raw)
    except json.JSONDecodeError:
        raise ValueError(f"LLM did not return valid JSON:\n{raw}")

    # Basic normalization / defaults
    query.setdefault("country", None)
    query.setdefault("mutation", None)
    query.setdefault("year_min", None)
    query.setdefault("year_max", None)

    return query


# =========================
# 3. APPLY QUERY TO YOUR DATA
# =========================

def filter_data(df: pd.DataFrame, query: Dict[str, Any]) -> pd.DataFrame:
    subset = df.copy()

    if query.get("country"):
        subset = subset[subset["country"].str.lower() ==
                        query["country"].lower()]

    if query.get("mutation"):
        # Assuming mutation column has strings like "K13_561H"
        subset = subset[subset["mutation"].str.lower() ==
                        query["mutation"].lower()]

    if query.get("year_min") is not None:
        subset = subset[subset["year"] >= int(query["year_min"])]

    if query.get("year_max") is not None:
        subset = subset[subset["year"] <= int(query["year_max"])]

    return subset


def summarize_prevalence(subset: pd.DataFrame) -> Dict[str, Any]:
    """
    Produce a compact summary of the numeric + geographic patterns
    to feed into the LLM.
    """
    if subset.empty:
        return {"has_data": False}

    # Yearly prevalence summary
    year_grp = (
        subset.groupby("year")
        .apply(
            lambda g: pd.Series(
                {
                    "n_samples": int(g["n_samples"].sum()),
                    "mean_prevalence": float(
                        (g["prevalence"] * g["n_samples"]).sum()
                        / max(g["n_samples"].sum(), 1)
                    ),
                }
            )
        )
        .reset_index()
        .sort_values("year")
    )

    # Regional/site summary (if you have region column, swap that in)
    site_grp = (
        subset.groupby("site")
        .apply(
            lambda g: pd.Series(
                {
                    "n_samples": int(g["n_samples"].sum()),
                    "mean_prevalence": float(
                        (g["prevalence"] * g["n_samples"]).sum()
                        / max(g["n_samples"].sum(), 1)
                    ),
                    "first_year": int(g["year"].min()),
                    "last_year": int(g["year"].max()),
                }
            )
        )
        .reset_index()
        .sort_values("mean_prevalence", ascending=False)
    )

    # Study-level metadata
    study_grp = (
        subset.groupby(["study_id", "authors", "year_pub"])
        .agg(
            n_samples=("n_samples", "sum"),
            year_min=("year", "min"),
            year_max=("year", "max"),
            mean_prev=("prevalence", "mean"),
        )
        .reset_index()
        .sort_values("year_pub")
    )

    return {
        "has_data": True,
        "yearly": year_grp.to_dict(orient="records"),
        "by_site": site_grp.to_dict(orient="records"),
        "by_study": study_grp.to_dict(orient="records"),
    }


# =========================
# 4. LLM: DATA SUMMARY → NARRATIVE ANSWER
# =========================

def llm_generate_answer(
    question: str,
    query: Dict[str, Any],
    summary: Dict[str, Any],
) -> str:
    """
    Ask the LLM to produce a scientific-style narrative based ONLY
    on the provided summary.
    """
    system_msg = {
        "role": "system",
        "content": (
            "You are an assistant helping a malaria genomicist interpret pre-computed "
            "prevalence data and study metadata. You must ONLY use the numeric summaries "
            "and study information provided. Do not invent new data or new studies. "
            "If data are sparse or missing, say so explicitly. "
            "Write clearly in a scientific but concise style."
        ),
    }

    # Provide structured summary as JSON so the model can 'see' the numbers
    context = {
        "original_question": question,
        "parsed_query": query,
        "data_summary": summary,
    }

    user_msg = {
        "role": "user",
        "content": (
            "Here is the structured data for the user's question:\n\n"
            + json.dumps(context, indent=2)
            + "\n\n"
            "1. Summarize the temporal trend (if any).\n"
            "2. Comment on geographic/site differences.\n"
            "3. Cite studies as 'Authors year_pub' where relevant.\n"
            "4. Mention important caveats (e.g. few samples, uneven sites).\n"
        ),
    }

    resp = client.chat.completions.create(
        model=LLM_MODEL,
        messages=[system_msg, user_msg],
        temperature=0.2,
    )

    answer = resp.choices[0].message.content.strip()
    return answer


# =========================
# 5. MAIN ENTRYPOINT
# =========================

def answer_question(df: pd.DataFrame, question: str) -> Tuple[Dict[str, Any], str]:
    """
    High-level pipeline:
      question -> query -> filtered data -> summary -> LLM narrative
    Returns (query_dict, answer_text).
    """
    print(f"\nUser question: {question}\n")

    # 1) natural language -> structured query
    query = llm_question_to_query(question)
    print("Parsed query:", query)

    # 2) filter data
    subset = filter_data(df, query)
    print(f"Subset size: {len(subset)} rows")

    # 3) summarize
    summary = summarize_prevalence(subset)

    # 4) LLM narrative
    answer = llm_generate_answer(question, query, summary)
    return query, answer

In [31]:
from openai import OpenAI
import os

from dotenv import load_dotenv
load_dotenv()

# Choose an LLM model
LLM_MODEL = "gpt-4o-mini"  # or another chat-capable model

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

In [30]:
from dotenv import load_dotenv
import os

# Load from project root OR from a specific path
load_dotenv(dotenv_path=os.path.join(os.path.dirname("../api_key.env"), 'api_key.env'))

True

In [37]:
question = (
        "What is the trend of K13 561H in Uganda after 2012, and how consistent "
        "are the findings across study sites and studies?"
    )

query_dict, answer_text = answer_question(prev_df, question)


User question: What is the trend of K13 561H in Uganda after 2012, and how consistent are the findings across study sites and studies?



RateLimitError: Error code: 429 - {'error': {'message': 'You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors.', 'type': 'insufficient_quota', 'param': None, 'code': 'insufficient_quota'}}