In [1]:
!pip install google-cloud-aiplatform google-auth pandas openpyxl




You should consider upgrading via the 'C:\Users\ankit\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [None]:
#population extracton

In [4]:
# =========================
# IMPORTS
# =========================
import os
import json
import pandas as pd

from google import genai
from google.genai import types
from google.oauth2 import service_account
import vertexai

from pydantic import BaseModel, Field, ValidationError, ConfigDict
from typing import List, Optional, Literal


# =========================
# AUTH CONFIG
# =========================
SERVICE_ACCOUNT_FILE = "vigilant-armor-455313-m8-54c19548a094.json"
PROJECT_ID = "vigilant-armor-455313-m8"
LOCATION = "us-central1"

credentials = service_account.Credentials.from_service_account_file(
    SERVICE_ACCOUNT_FILE
).with_scopes(["https://www.googleapis.com/auth/cloud-platform"])

vertexai.init(project=PROJECT_ID, location=LOCATION, credentials=credentials)

client = genai.Client(
    vertexai=True,
    project=PROJECT_ID,
    location=LOCATION,
    credentials=credentials
)

MODEL_NAME = "gemini-2.5-pro"


# =========================
# PYDANTIC SCHEMA (UPDATED FOR NEW MULTI-TRIAL + STRING-ONLY OUTPUT)
# =========================
ArmType = Literal[
    "Experimental",
    "Comparator",
    "Single-arm",
    "External control",
    "Dose level",
    "Cohort",
    "Other"
]

PopulationType = Literal[
    "Overall",
    "Analysis set",
    "Cohort",
    "Baseline characteristic",
    "Subgroup",
    "Other"
]

IntegratedType = Literal[
    "Integrated population",
    "Pooled analysis",
    "Other"
]


class DesignSummary(BaseModel):
    model_config = ConfigDict(extra="forbid")
    type: Optional[str] = None


class TrialPopulationDetails(BaseModel):
    model_config = ConfigDict(extra="forbid")
    type: Optional[str] = None


class TrialRecord(BaseModel):
    model_config = ConfigDict(extra="forbid")
    trial_key: str  # e.g., "TRIAL_1"
    trial_id_list: List[str] = Field(default_factory=list)
    trial_label: Optional[str] = None
    phase: Optional[str] = None
    study_name: Optional[str] = None
    allocation: Optional[str] = None
    design_summary: DesignSummary
    trial_population_details: TrialPopulationDetails
    overall_N: Optional[str] = None  # STRING (e.g., "312") or null


class ArmRecord(BaseModel):
    model_config = ConfigDict(extra="forbid")
    arm_key: str  # e.g., "ARM_1"
    arm_name: Optional[str] = None
    arm_type: Optional[ArmType] = None
    treatment_description: Optional[str] = None
    dose_schedule: Optional[str] = None


class PopulationRecord(BaseModel):
    model_config = ConfigDict(extra="forbid")
    population_key: str  # e.g., "POP_1"
    population_type: PopulationType
    parent: Optional[str] = None
    child: Optional[str] = None
    population_description: Optional[str] = None
    N: Optional[str] = None  # STRING (e.g., "45") or null


class TrialArmLink(BaseModel):
    model_config = ConfigDict(extra="forbid")
    trial_key: str
    linked_arm_keys: List[str] = Field(default_factory=list)


class TrialPopulationLink(BaseModel):
    model_config = ConfigDict(extra="forbid")
    trial_key: str
    linked_population_keys: List[str] = Field(default_factory=list)
    linked_arm_keys: List[str] = Field(default_factory=list)


class IntegratedRecord(BaseModel):
    model_config = ConfigDict(extra="forbid")
    integrated_key: str  # e.g., "INTEGRATED_1"
    integrated_type: IntegratedType
    source_trial_keys: List[str] = Field(default_factory=list)
    population_description: Optional[str] = None
    N: Optional[str] = None  # STRING or null
    linked_population_keys: List[str] = Field(default_factory=list)
    linked_arm_keys: List[str] = Field(default_factory=list)


class MultiTrialExtractionOutput(BaseModel):
    model_config = ConfigDict(extra="forbid")
    trial_records: List[TrialRecord] = Field(default_factory=list)
    arm_records: List[ArmRecord] = Field(default_factory=list)
    population_records: List[PopulationRecord] = Field(default_factory=list)
    trial_arm_links: List[TrialArmLink] = Field(default_factory=list)
    trial_population_links: List[TrialPopulationLink] = Field(default_factory=list)
    integrated_records: List[IntegratedRecord] = Field(default_factory=list)


# =========================
# LOAD PROMPT FROM TXT
# =========================
PROMPT_FILE = "Pooled_Population.txt"  # put your updated markdown prompt here
with open(PROMPT_FILE, "r", encoding="utf-8") as f:
    prompt_text = f.read()


# =========================
# PATHS
# =========================
input_folder = "images_input"
json_output = "json_output"
excel_output = "excel_output"

os.makedirs(json_output, exist_ok=True)
os.makedirs(excel_output, exist_ok=True)

image_files = [
    f for f in os.listdir(input_folder)
    if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))
]


def _safe_json_text(model_text: str, image_file: str) -> str:
    txt = (model_text or "").strip()
    if not (txt.startswith("{") and txt.endswith("}")):
        raise RuntimeError(f"Model output not valid JSON / truncated for {image_file}:\n{txt[:500]}")
    return txt


# =========================
# PROCESS IMAGES
# =========================
for image_file in image_files:
    image_id = os.path.splitext(image_file)[0]
    image_path = os.path.join(input_folder, image_file)

    # Decide MIME based on extension
    ext = os.path.splitext(image_file)[1].lower()
    mime = "image/jpeg"
    if ext == ".png":
        mime = "image/png"
    elif ext == ".webp":
        mime = "image/webp"

    with open(image_path, "rb") as f:
        image_bytes = f.read()

    image_part = types.Part.from_bytes(data=image_bytes, mime_type=mime)

    response = client.models.generate_content(
        model=MODEL_NAME,
        contents=[prompt_text, image_part],
        config={
            "response_mime_type": "application/json",
            "response_json_schema": MultiTrialExtractionOutput.model_json_schema(),
            "temperature": 0.1
        }
    )

    raw_json = _safe_json_text(response.text, image_file)

    try:
        parsed = MultiTrialExtractionOutput.model_validate_json(raw_json)
    except ValidationError as e:
        raise RuntimeError(f"Schema validation failed for {image_file}\n{e}")

    # =========================
    # SAVE JSON
    # =========================
    with open(os.path.join(json_output, f"{image_id}.json"), "w", encoding="utf-8") as f:
        f.write(raw_json)

    # =========================
    # SAVE EXCEL (FLATTEN NESTED OBJECTS)
    # =========================
    out_xlsx = os.path.join(excel_output, f"{image_id}.xlsx")
    with pd.ExcelWriter(out_xlsx, engine="openpyxl") as writer:

        # trial_records (flatten design_summary.* and trial_population_details.*)
        trial_df = pd.json_normalize([t.model_dump() for t in parsed.trial_records])
        trial_df.to_excel(writer, sheet_name="trial_records", index=False)

        # arm_records
        arm_df = pd.DataFrame([a.model_dump() for a in parsed.arm_records])
        arm_df.to_excel(writer, sheet_name="arm_records", index=False)

        # population_records
        pop_df = pd.DataFrame([p.model_dump() for p in parsed.population_records])
        pop_df.to_excel(writer, sheet_name="population_records", index=False)

        # trial_arm_links
        tal_df = pd.DataFrame([x.model_dump() for x in parsed.trial_arm_links])
        tal_df.to_excel(writer, sheet_name="trial_arm_links", index=False)

        # trial_population_links
        tpl_df = pd.DataFrame([x.model_dump() for x in parsed.trial_population_links])
        tpl_df.to_excel(writer, sheet_name="trial_population_links", index=False)

        # integrated_records
        integ_df = pd.DataFrame([x.model_dump() for x in parsed.integrated_records])
        integ_df.to_excel(writer, sheet_name="integrated_records", index=False)

    print(f"âœ… Processed {image_file} -> {out_xlsx}")

print("ðŸŽ¯ Processing complete.")


âœ… Processed 320682.jpg -> excel_output\320682.xlsx
ðŸŽ¯ Processing complete.


In [None]:
#km extraction

In [6]:
###############################################################
# UPDATED FOR NEW SCHEMA CHANGE:
# - Added "trial_label" to each arm_level_survival_outcomes row
# - Excel export includes trial_label (resolved from trial_metadata if missing)
#
# OUTPUT STRICTLY LIMITED TO:
# {
#   "trial_metadata": {...},
#   "arm_level_survival_outcomes": [...]
# }
###############################################################

import json
import re
from pathlib import Path
from typing import Dict, Any, Optional, List, Literal

import vertexai
from vertexai.preview.generative_models import GenerativeModel, Part
from google.oauth2 import service_account

import pandas as pd
from pydantic import BaseModel, Field, field_validator, ConfigDict


# =============================================================
# 0) PLACEHOLDERS (EDIT THESE)
# =============================================================
SERVICE_ACCOUNT_FILE = "vigilant-armor-455313-m8-54c19548a094.json"
PROJECT_ID = "vigilant-armor-455313-m8"
LOCATION = "us-central1"

MODEL_NAME = "gemini-2.5-pro"

PROMPT_TXT_FILE = "KM.txt"
POSTER_IMAGE_PATH = "images_input/320682.jpg"
INPUT2_SCHEMA_JSON_PATH = "json_output/320682.json"

OUTPUT_JSON_PATH = "survival_output.json"
OUTPUT_EXCEL_PATH = "survival_output.xlsx"


# =============================================================
# 1) PYDANTIC OUTPUT SCHEMA (ONLY THE 2 KEYS YOU WANT)
#    NEW: trial_label added to outcomes row
# =============================================================
PopulationType = Literal["Overall", "Analysis set", "Cohort", "Subgroup", "Other"]
TimeUnit = Literal["months", "years", "weeks", "days"]

class TrialMetadata(BaseModel):
    model_config = ConfigDict(extra="forbid")

    trial_id: Optional[str] = None
    phase: Optional[str] = None
    study_name: Optional[str] = None
    # NOTE: trial_label is NOT in trial_metadata in your current output schema.
    # We'll keep it out to match your "trial_metadata" schema exactly.


class ArmLevelSurvivalOutcome(BaseModel):
    model_config = ConfigDict(extra="forbid")

    survival_outcome_id: int
    trial_id: Optional[str] = None
    trial_label: Optional[str] = None   # âœ… NEW
    arm_description: Optional[str] = None

    population_type: PopulationType
    population_description: Optional[str] = None

    endpoint_description: Optional[str] = None
    endpoint_name: Optional[str] = None
    endpoint_label: Optional[str] = None

    assessment_type: Optional[str] = None
    review_board: Optional[str] = None
    review_criteria: Optional[str] = None
    other_details: Optional[str] = None

    arm_n: Optional[int] = None
    median_survival: Optional[str] = None
    survival_rate: Optional[str] = None
    events_n: Optional[int] = None
    assessment_denominator_n: Optional[int] = None

    p_value: Optional[float] = None
    time_unit: Optional[TimeUnit] = None

    # ---------------- Validators ----------------
    @field_validator("survival_outcome_id")
    @classmethod
    def survival_id_positive(cls, v: int) -> int:
        if v <= 0:
            raise ValueError("survival_outcome_id must be >= 1")
        return v

    @field_validator("arm_n", "events_n", "assessment_denominator_n")
    @classmethod
    def non_negative_ints(cls, v: Optional[int]) -> Optional[int]:
        if v is None:
            return v
        if v < 0:
            raise ValueError("Count fields must be >= 0")
        return v

    @field_validator("p_value")
    @classmethod
    def p_value_range(cls, v: Optional[float]) -> Optional[float]:
        if v is None:
            return v
        if v < 0 or v > 1:
            raise ValueError("p_value must be between 0 and 1")
        return v


class SurvivalOutput(BaseModel):
    model_config = ConfigDict(extra="forbid")

    trial_metadata: TrialMetadata = Field(default_factory=TrialMetadata)
    arm_level_survival_outcomes: List[ArmLevelSurvivalOutcome] = Field(default_factory=list)


# =============================================================
# 2) HELPERS
# =============================================================
def load_text(path: str) -> str:
    return Path(path).read_text(encoding="utf-8")

def load_json(path: str) -> Dict[str, Any]:
    return json.loads(Path(path).read_text(encoding="utf-8"))

def load_image_part(image_path: str) -> Part:
    image_bytes = Path(image_path).read_bytes()
    suffix = Path(image_path).suffix.lower()
    if suffix in [".jpg", ".jpeg"]:
        mime = "image/jpeg"
    elif suffix == ".png":
        mime = "image/png"
    elif suffix == ".webp":
        mime = "image/webp"
    else:
        raise ValueError(f"Unsupported image type: {suffix}. Use jpg/png/webp.")
    return Part.from_data(data=image_bytes, mime_type=mime)

def extract_first_json_object(text: str) -> Dict[str, Any]:
    """
    Extract the first JSON object from model text output.
    """
    text = (text or "").strip()
    if not text:
        raise ValueError("Empty model output text.")

    # If pure JSON already
    if text.startswith("{") and text.endswith("}"):
        return json.loads(text)

    # Try to find first {...}
    match = re.search(r"\{.*\}", text, flags=re.DOTALL)
    if not match:
        raise ValueError("No JSON object found in model output.")
    return json.loads(match.group(0))

def build_prompt_from_txt(prompt_txt: str, input2_schema: Dict[str, Any]) -> str:
    """
    Your KM.txt may already contain {INPUT_2_SCHEMA_JSON}.
    If not, we append the INPUT 2 schema at the end.
    """
    schema_str = json.dumps(input2_schema, ensure_ascii=False, indent=2)
    if "{INPUT_2_SCHEMA_JSON}" in prompt_txt:
        return prompt_txt.replace("{INPUT_2_SCHEMA_JSON}", schema_str)
    return prompt_txt + "\n\nINPUT 2 JSON schema reference:\n" + schema_str


# ---------------- Strict filtering (TOP + ROW) ----------------
ALLOWED_TM_KEYS = {"trial_id", "phase", "study_name"}

ALLOWED_OUTCOME_KEYS = {
    "survival_outcome_id",
    "trial_id",
    "trial_label",           # âœ… NEW
    "arm_description",
    "population_type",
    "population_description",
    "endpoint_description",
    "endpoint_name",
    "endpoint_label",
    "assessment_type",
    "review_board",
    "review_criteria",
    "other_details",
    "arm_n",
    "median_survival",
    "survival_rate",
    "events_n",
    "assessment_denominator_n",
    "p_value",
    "time_unit",
}

def _ensure_list(x: Any) -> List[Any]:
    if x is None:
        return []
    if isinstance(x, list):
        return x
    if isinstance(x, dict):
        return [x]
    return []

def _prune_dict(d: Any, allowed: set) -> Dict[str, Any]:
    if not isinstance(d, dict):
        return {}
    return {k: d.get(k) for k in allowed if k in d}

def clean_to_survival_only(parsed: Dict[str, Any]) -> Dict[str, Any]:
    """
    STRICT: Output must be ONLY:
      {"trial_metadata": {...}, "arm_level_survival_outcomes": [...]}
    Also prune nested dicts to avoid extra keys causing extra="forbid" errors.
    """
    tm_raw = parsed.get("trial_metadata", {}) or {}
    rows_raw = parsed.get("arm_level_survival_outcomes", [])

    tm = _prune_dict(tm_raw, ALLOWED_TM_KEYS)

    rows = _ensure_list(rows_raw)
    cleaned_rows = []
    for r in rows:
        if not isinstance(r, dict):
            continue
        cleaned_rows.append(_prune_dict(r, ALLOWED_OUTCOME_KEYS))

    return {"trial_metadata": tm, "arm_level_survival_outcomes": cleaned_rows}


# =============================================================
# 3) AUTH + MODEL INIT
# =============================================================
credentials = service_account.Credentials.from_service_account_file(SERVICE_ACCOUNT_FILE)
vertexai.init(project=PROJECT_ID, location=LOCATION, credentials=credentials)
model = GenerativeModel(MODEL_NAME)


# =============================================================
# 4) LOAD INPUTS (2 inputs)
# =============================================================
prompt_txt = load_text(PROMPT_TXT_FILE)
input2_schema = load_json(INPUT2_SCHEMA_JSON_PATH)
final_prompt = build_prompt_from_txt(prompt_txt, input2_schema)

image_part = load_image_part(POSTER_IMAGE_PATH)


# =============================================================
# 5) CALL MODEL
# =============================================================
response = model.generate_content(
    contents=[final_prompt, image_part],
    generation_config={"temperature": 0.0},
)

raw_text = response.text or ""


# =============================================================
# 6) PARSE + FILTER + (OPTIONAL) FILL trial_label + VALIDATE + SAVE
# =============================================================
parsed = extract_first_json_object(raw_text)

# âœ… keep ONLY the 2 keys you requested (and prune extra keys inside them)
filtered = clean_to_survival_only(parsed)

# OPTIONAL: if the model didn't populate trial_label per row,
# you can set it to None (leave as-is) OR copy from metadata if you add it there.
# Since your trial_metadata schema has no trial_label, we do NOT inject anything.
# We only ensure key exists if present; otherwise remains absent/None via Pydantic.

validated = SurvivalOutput.model_validate(filtered, by_name=True)

Path(OUTPUT_JSON_PATH).write_text(
    validated.model_dump_json(indent=2),
    encoding="utf-8"
)

print("âœ… Survival extraction complete and validated.")
print(f"Saved JSON: {OUTPUT_JSON_PATH}")


# =============================================================
# 7) EXPORT TO EXCEL (Flatten outcomes with trial_metadata columns)
# =============================================================
data = json.loads(Path(OUTPUT_JSON_PATH).read_text(encoding="utf-8"))
trial_metadata = data.get("trial_metadata", {}) or {}
arm_outcomes = data.get("arm_level_survival_outcomes", []) or []

rows = []
for row in arm_outcomes:
    if not isinstance(row, dict):
        continue
    merged = {**trial_metadata, **row}
    rows.append(merged)

df = pd.DataFrame(rows)
df.to_excel(OUTPUT_EXCEL_PATH, index=False)

print(f"âœ… JSON successfully converted to Excel: {OUTPUT_EXCEL_PATH}")


âœ… Survival extraction complete and validated.
Saved JSON: survival_output.json
âœ… JSON successfully converted to Excel: survival_output.xlsx


In [None]:
#basline pooled

In [7]:
###############################################################
# UPDATED FOR NEW BASELINE CHARACTERISTICS SCHEMA (bc_types ONLY)
#
# Output STRICTLY LIMITED TO:
# {
#   "bc_types": [...]
# }
#
# - Adds arm_key + population_key in output rows (per your new schema)
# - Supports generated keys (arm_gen_1, pop_gen_1) in LLM output
# - Validates with Pydantic (extra="forbid")
# - Exports bc_types to Excel (one row per bc_types record)
###############################################################

import json
import re
from pathlib import Path
from typing import Dict, Any, Optional, List, Literal

import vertexai
from vertexai.preview.generative_models import GenerativeModel, Part
from google.oauth2 import service_account

import pandas as pd
from pydantic import BaseModel, Field, field_validator, ConfigDict


# =============================================================
# 0) PLACEHOLDERS (EDIT THESE)
# =============================================================
SERVICE_ACCOUNT_FILE = "vigilant-armor-455313-m8-1d642ef84a8c.json"
PROJECT_ID = "vigilant-armor-455313-m8"
LOCATION = "us-central1"

MODEL_NAME = "gemini-2.5-pro"

PROMPT_TXT_FILE = "Baseline.txt"  # <-- your modified baseline prompt text file
POSTER_IMAGE_PATH = "images_input/NCT04171700_453611_4GT.jpg"
INPUT2_SCHEMA_JSON_PATH = "json_output/NCT04171700_453611_4GT.json"

OUTPUT_JSON_PATH = "baseline_output.json"
OUTPUT_EXCEL_PATH = "baseline_output.xlsx"


# =============================================================
# 1) PYDANTIC OUTPUT SCHEMA (bc_types ONLY)
# =============================================================
PopulationType = Literal["Overall", "Analysis set", "Cohort", "Subgroup", "Other"]
BaselineParent = Literal["Overall", "Cohort", "Subgroup", "Other", None]


class BaselineCharacteristic(BaseModel):
    model_config = ConfigDict(extra="forbid")

    baseline_id: int = Field(..., description="Sequential ID starting from 1")

    trial_id: Optional[str] = None
    trial_label: Optional[str] = None

    arm_key: Optional[str] = None
    arm_description: Optional[str] = None

    population_key: Optional[str] = None
    population_type: PopulationType
    population_description: Optional[str] = None

    baseline_parent: BaselineParent = None
    parent_description: Optional[str] = None

    baseline_category_label: Optional[str] = None
    group_label: Optional[str] = None
    group_text: Optional[str] = None

    measure: Optional[str] = None
    measure_value: Optional[str] = None

    population_n: Optional[int] = None
    population_percentage: Optional[float] = None

    # ---------------- Validators ----------------
    @field_validator("baseline_id")
    @classmethod
    def baseline_id_positive(cls, v: int) -> int:
        if v <= 0:
            raise ValueError("baseline_id must be >= 1")
        return v

    @field_validator("population_n")
    @classmethod
    def non_negative_n(cls, v: Optional[int]) -> Optional[int]:
        if v is None:
            return v
        if v < 0:
            raise ValueError("population_n must be >= 0")
        return v

    @field_validator("population_percentage")
    @classmethod
    def percent_range(cls, v: Optional[float]) -> Optional[float]:
        if v is None:
            return v
        if v < 0 or v > 100:
            raise ValueError("population_percentage must be between 0 and 100")
        return v


class BaselineOutput(BaseModel):
    model_config = ConfigDict(extra="forbid")

    bc_types: List[BaselineCharacteristic] = Field(default_factory=list)


# =============================================================
# 2) HELPERS
# =============================================================
def load_text(path: str) -> str:
    return Path(path).read_text(encoding="utf-8")


def load_json(path: str) -> Dict[str, Any]:
    return json.loads(Path(path).read_text(encoding="utf-8"))


def load_image_part(image_path: str) -> Part:
    image_bytes = Path(image_path).read_bytes()
    suffix = Path(image_path).suffix.lower()
    if suffix in [".jpg", ".jpeg"]:
        mime = "image/jpeg"
    elif suffix == ".png":
        mime = "image/png"
    elif suffix == ".webp":
        mime = "image/webp"
    else:
        raise ValueError(f"Unsupported image type: {suffix}. Use jpg/png/webp.")
    return Part.from_data(data=image_bytes, mime_type=mime)


def extract_first_json_object(text: str) -> Dict[str, Any]:
    """
    Extract the first JSON object from model text output.
    Handles cases where model wraps JSON in markdown fences or extra text.
    """
    text = (text or "").strip()
    if not text:
        raise ValueError("Empty model output text.")

    # Remove markdown code fences if present
    text = re.sub(r"^```(?:json)?\s*", "", text)
    text = re.sub(r"\s*```$", "", text)

    # If pure JSON already
    if text.startswith("{") and text.endswith("}"):
        return json.loads(text)

    # Try to find first {...}
    match = re.search(r"\{.*\}", text, flags=re.DOTALL)
    if not match:
        raise ValueError("No JSON object found in model output.")
    return json.loads(match.group(0))


def build_prompt_from_txt(prompt_txt: str, input2_schema: Dict[str, Any]) -> str:
    """
    Your BC.txt may already contain {INPUT_2_SCHEMA_JSON}.
    If not, we append the INPUT 2 schema at the end.
    """
    schema_str = json.dumps(input2_schema, ensure_ascii=False, indent=2)
    if "{INPUT_2_SCHEMA_JSON}" in prompt_txt:
        return prompt_txt.replace("{INPUT_2_SCHEMA_JSON}", schema_str)
    return prompt_txt + "\n\nINPUT 2 JSON schema reference:\n" + schema_str


# ---------------- Strict filtering ----------------
ALLOWED_TOP_KEYS = {"bc_types"}

ALLOWED_BC_KEYS = {
    "baseline_id",
    "trial_id",
    "trial_label",
    "arm_key",
    "arm_description",
    "population_key",
    "population_type",
    "population_description",
    "baseline_parent",
    "parent_description",
    "baseline_category_label",
    "group_label",
    "group_text",
    "measure",
    "measure_value",
    "population_n",
    "population_percentage",
}


def _ensure_list(x: Any) -> List[Any]:
    if x is None:
        return []
    if isinstance(x, list):
        return x
    if isinstance(x, dict):
        return [x]
    return []


def _prune_dict(d: Any, allowed: set) -> Dict[str, Any]:
    if not isinstance(d, dict):
        return {}
    return {k: d.get(k) for k in allowed if k in d}


def _to_int_or_none(x: Any) -> Optional[int]:
    if x is None:
        return None
    if isinstance(x, int):
        return x
    if isinstance(x, float):
        return int(x)
    if isinstance(x, str):
        s = x.strip()
        # pick first integer found (handles "12 (34%)" -> 12)
        m = re.search(r"-?\d+", s)
        if not m:
            return None
        try:
            return int(m.group(0))
        except Exception:
            return None
    return None


def _to_float_percent_or_none(x: Any) -> Optional[float]:
    if x is None:
        return None
    if isinstance(x, (int, float)):
        return float(x)
    if isinstance(x, str):
        s = x.strip()
        s = s.replace("%", "").strip()
        # allow decimal
        m = re.search(r"-?\d+(?:\.\d+)?", s)
        if not m:
            return None
        try:
            return float(m.group(0))
        except Exception:
            return None
    return None


def _normalize_population_type(v: Any) -> str:
    """
    Keep strict allowed values. If missing/invalid, set to 'Other'
    so Pydantic validation passes and output stays consistent.
    """
    allowed = {"Overall", "Analysis set", "Cohort", "Subgroup", "Other"}
    if isinstance(v, str):
        s = v.strip()
        if s in allowed:
            return s
    return "Other"


def _normalize_baseline_parent(v: Any) -> Optional[str]:
    allowed = {"Overall", "Cohort", "Subgroup", "Other"}
    if v is None:
        return None
    if isinstance(v, str):
        s = v.strip()
        if s in allowed:
            return s
    return None


def clean_to_bc_only(parsed: Dict[str, Any]) -> Dict[str, Any]:
    """
    STRICT: Output must be ONLY:
      {"bc_types": [...]}

    - Drops any extra top-level keys
    - Prunes bc_types row keys
    - Normalizes population_type/baseline_parent
    - Parses numeric fields if model returns strings
    """
    bc_raw = parsed.get("bc_types", [])
    rows = _ensure_list(bc_raw)

    cleaned_rows: List[Dict[str, Any]] = []
    for r in rows:
        if not isinstance(r, dict):
            continue

        rr = _prune_dict(r, ALLOWED_BC_KEYS)

        # Normalize required enums
        rr["population_type"] = _normalize_population_type(rr.get("population_type"))
        rr["baseline_parent"] = _normalize_baseline_parent(rr.get("baseline_parent"))

        # Normalize numeric fields (allow model to output "12" or "12%" as strings)
        rr["population_n"] = _to_int_or_none(rr.get("population_n"))
        rr["population_percentage"] = _to_float_percent_or_none(rr.get("population_percentage"))

        # baseline_id must exist; if missing, we will fill later in a second pass
        cleaned_rows.append(rr)

    # Fill baseline_id if missing / invalid
    # (keeps deterministic ordering based on model output order)
    for i, rr in enumerate(cleaned_rows, start=1):
        bid = rr.get("baseline_id")
        if not isinstance(bid, int) or bid <= 0:
            rr["baseline_id"] = i

    return {"bc_types": cleaned_rows}


# =============================================================
# 3) AUTH + MODEL INIT
# =============================================================
credentials = service_account.Credentials.from_service_account_file(SERVICE_ACCOUNT_FILE)
vertexai.init(project=PROJECT_ID, location=LOCATION, credentials=credentials)
model = GenerativeModel(MODEL_NAME)


# =============================================================
# 4) LOAD INPUTS (2 inputs)
# =============================================================
prompt_txt = load_text(PROMPT_TXT_FILE)
input2_schema = load_json(INPUT2_SCHEMA_JSON_PATH)
final_prompt = build_prompt_from_txt(prompt_txt, input2_schema)

image_part = load_image_part(POSTER_IMAGE_PATH)


# =============================================================
# 5) CALL MODEL
# =============================================================
response = model.generate_content(
    contents=[final_prompt, image_part],
    generation_config={"temperature": 0.0},
)

raw_text = response.text or ""


# =============================================================
# 6) PARSE + FILTER + VALIDATE + SAVE JSON
# =============================================================
parsed = extract_first_json_object(raw_text)
filtered = clean_to_bc_only(parsed)

validated = BaselineOutput.model_validate(filtered, by_name=True)

Path(OUTPUT_JSON_PATH).write_text(
    validated.model_dump_json(indent=2, exclude_none=False),
    encoding="utf-8",
)

print("âœ… Baseline characteristics extraction complete and validated.")
print(f"Saved JSON: {OUTPUT_JSON_PATH}")


# =============================================================
# 7) EXPORT TO EXCEL (bc_types only)
# =============================================================
data = json.loads(Path(OUTPUT_JSON_PATH).read_text(encoding="utf-8"))
bc_rows = data.get("bc_types", []) or []

df = pd.DataFrame(bc_rows)

# Optional: keep a consistent column order
preferred_cols = [
    "baseline_id",
    "trial_id",
    "trial_label",
    "arm_key",
    "arm_description",
    "population_key",
    "population_type",
    "population_description",
    "baseline_parent",
    "parent_description",
    "baseline_category_label",
    "group_label",
    "group_text",
    "measure",
    "measure_value",
    "population_n",
    "population_percentage",
]
df = df.reindex(columns=[c for c in preferred_cols if c in df.columns])

df.to_excel(OUTPUT_EXCEL_PATH, index=False)

print(f"âœ… JSON successfully converted to Excel: {OUTPUT_EXCEL_PATH}")


FileNotFoundError: [Errno 2] No such file or directory: 'vigilant-armor-455313-m8-1d642ef84a8c.json'

In [None]:
#Response outcomes

In [1]:
###############################################################
# UPDATED FOR RESPONSE OUTCOMES SCHEMA
#
# OUTPUT STRICTLY LIMITED TO:
# {
#   "trial_metadata": {...},
#   "arm_level_response_outcomes": [...]
# }
###############################################################

import json
import re
from pathlib import Path
from typing import Dict, Any, Optional, List, Literal

import vertexai
from vertexai.preview.generative_models import GenerativeModel, Part
from google.oauth2 import service_account

import pandas as pd
from pydantic import BaseModel, Field, field_validator, ConfigDict


# =============================================================
# 0) PLACEHOLDERS
# =============================================================
SERVICE_ACCOUNT_FILE = "vigilant-armor-455313-m8-1d642ef84a8c.json"
PROJECT_ID = "vigilant-armor-455313-m8"
LOCATION = "us-central1"

MODEL_NAME = "gemini-2.5-pro"

PROMPT_TXT_FILE = "RESPONSE.txt"
POSTER_IMAGE_PATH = "images_input/320682.jpg"
INPUT2_SCHEMA_JSON_PATH = "json_output/320682.json"

OUTPUT_JSON_PATH = "response_output.json"
OUTPUT_EXCEL_PATH = "response_output.xlsx"


# =============================================================
# 1) PYDANTIC OUTPUT SCHEMA
# =============================================================
PopulationType = Literal["Overall", "Analysis set", "Cohort", "Subgroup", "Other"]
ResponseMetricClass = Literal["rate", "duration", "time_to_response"]
TimeUnit = Literal["months", "years", "weeks", "days"]


class TrialMetadata(BaseModel):
    model_config = ConfigDict(extra="forbid")

    trial_id: Optional[str] = None
    phase: Optional[str] = None
    study_name: Optional[str] = None


class ResultObject(BaseModel):
    model_config = ConfigDict(extra="forbid")

    n: Optional[int] = None
    percentage: Optional[float] = None
    min: Optional[float] = None
    max: Optional[float] = None

    p_value: Optional[float] = None
    odds_ratio: Optional[float] = None

    median: Optional[float] = None
    min_duration: Optional[float] = None
    max_duration: Optional[float] = None
    duration_unit: Optional[TimeUnit] = None


class ArmLevelResponseOutcome(BaseModel):
    model_config = ConfigDict(extra="forbid")

    response_outcome_id: int

    trial_id: Optional[str] = None
    trial_label: Optional[str] = None

    arm_description: Optional[str] = None

    population_type: PopulationType
    population_description: Optional[str] = None

    assessment_type: Optional[str] = None
    review_board: Optional[str] = None
    review_criteria: Optional[str] = None
    other_details: Optional[str] = None

    arm_n: Optional[int] = None
    assessment_denominator_n: Optional[int] = None

    response_type_name: Optional[str] = None
    response_metric_class: Optional[ResponseMetricClass] = None

    result: Optional[ResultObject] = None

    # ---------------- Validators ----------------
    @field_validator("response_outcome_id")
    @classmethod
    def id_positive(cls, v: int) -> int:
        if v <= 0:
            raise ValueError("response_outcome_id must be >= 1")
        return v

    @field_validator("arm_n", "assessment_denominator_n")
    @classmethod
    def non_negative_ints(cls, v: Optional[int]) -> Optional[int]:
        if v is None:
            return v
        if v < 0:
            raise ValueError("Count fields must be >= 0")
        return v


class ResponseOutput(BaseModel):
    model_config = ConfigDict(extra="forbid")

    trial_metadata: TrialMetadata = Field(default_factory=TrialMetadata)
    arm_level_response_outcomes: List[ArmLevelResponseOutcome] = Field(default_factory=list)


# =============================================================
# 2) HELPERS
# =============================================================
def load_text(path: str) -> str:
    return Path(path).read_text(encoding="utf-8")


def load_json(path: str) -> Dict[str, Any]:
    return json.loads(Path(path).read_text(encoding="utf-8"))


def load_image_part(image_path: str) -> Part:
    image_bytes = Path(image_path).read_bytes()
    suffix = Path(image_path).suffix.lower()
    mime = "image/jpeg" if suffix in [".jpg", ".jpeg"] else "image/png"
    return Part.from_data(data=image_bytes, mime_type=mime)


def extract_first_json_object(text: str) -> Dict[str, Any]:
    match = re.search(r"\{.*\}", text, flags=re.DOTALL)
    if not match:
        raise ValueError("No JSON found in model output")
    return json.loads(match.group(0))


def build_prompt_from_txt(prompt_txt: str, input2_schema: Dict[str, Any]) -> str:
    schema_str = json.dumps(input2_schema, ensure_ascii=False, indent=2)
    return prompt_txt.replace("{INPUT_2_SCHEMA_JSON}", schema_str)


# ---------------- Strict filtering ----------------
ALLOWED_TM_KEYS = {"trial_id", "phase", "study_name"}

ALLOWED_OUTCOME_KEYS = {
    "response_outcome_id",
    "trial_id",
    "trial_label",
    "arm_description",
    "population_type",
    "population_description",
    "assessment_type",
    "review_board",
    "review_criteria",
    "other_details",
    "arm_n",
    "assessment_denominator_n",
    "response_type_name",
    "response_metric_class",
    "result",
}


def _ensure_list(x: Any) -> List[Any]:
    if isinstance(x, list):
        return x
    if isinstance(x, dict):
        return [x]
    return []


def _prune_dict(d: Any, allowed: set) -> Dict[str, Any]:
    return {k: d.get(k) for k in allowed if isinstance(d, dict) and k in d}


def clean_to_response_only(parsed: Dict[str, Any]) -> Dict[str, Any]:
    tm = _prune_dict(parsed.get("trial_metadata", {}), ALLOWED_TM_KEYS)

    rows = []
    for r in _ensure_list(parsed.get("arm_level_response_outcomes", [])):
        rows.append(_prune_dict(r, ALLOWED_OUTCOME_KEYS))

    return {
        "trial_metadata": tm,
        "arm_level_response_outcomes": rows,
    }


# =============================================================
# 3) AUTH + MODEL INIT
# =============================================================
credentials = service_account.Credentials.from_service_account_file(SERVICE_ACCOUNT_FILE)
vertexai.init(project=PROJECT_ID, location=LOCATION, credentials=credentials)
model = GenerativeModel(MODEL_NAME)


# =============================================================
# 4) LOAD INPUTS
# =============================================================
prompt_txt = load_text(PROMPT_TXT_FILE)
input2_schema = load_json(INPUT2_SCHEMA_JSON_PATH)
final_prompt = build_prompt_from_txt(prompt_txt, input2_schema)
image_part = load_image_part(POSTER_IMAGE_PATH)


# =============================================================
# 5) CALL MODEL
# =============================================================
response = model.generate_content(
    contents=[final_prompt, image_part],
    generation_config={"temperature": 0.0},
)

raw_text = response.text or ""


# =============================================================
# 6) PARSE + FILTER + VALIDATE + SAVE
# =============================================================
parsed = extract_first_json_object(raw_text)
filtered = clean_to_response_only(parsed)

validated = ResponseOutput.model_validate(filtered)

Path(OUTPUT_JSON_PATH).write_text(
    validated.model_dump_json(indent=2),
    encoding="utf-8"
)

print("âœ… Response extraction complete and validated.")
print(f"Saved JSON: {OUTPUT_JSON_PATH}")


# =============================================================
# 7) EXPORT TO EXCEL
# =============================================================
data = json.loads(Path(OUTPUT_JSON_PATH).read_text())
tm = data.get("trial_metadata", {})
rows = data.get("arm_level_response_outcomes", [])

# =============================================================
# 7) EXPORT TO EXCEL (FLATTEN result OBJECT)
# =============================================================
data = json.loads(Path(OUTPUT_JSON_PATH).read_text(encoding="utf-8"))
trial_metadata = data.get("trial_metadata", {}) or {}
arm_outcomes = data.get("arm_level_response_outcomes", []) or []

flattened_rows = []

for row in arm_outcomes:
    if not isinstance(row, dict):
        continue

    result = row.pop("result", {}) or {}

    flattened = {
        **trial_metadata,
        **row,

        # ---- result fields flattened ----
        "result_n": result.get("n"),
        "result_percentage": result.get("percentage"),
        "result_min": result.get("min"),
        "result_max": result.get("max"),
        "result_p_value": result.get("p_value"),
        "result_odds_ratio": result.get("odds_ratio"),
        "result_median": result.get("median"),
        "result_min_duration": result.get("min_duration"),
        "result_max_duration": result.get("max_duration"),
        "result_duration_unit": result.get("duration_unit"),
    }

    flattened_rows.append(flattened)

df = pd.DataFrame(flattened_rows)

df.to_excel(OUTPUT_EXCEL_PATH, index=False)

print(f"âœ… JSON successfully converted to Excel with flattened result columns: {OUTPUT_EXCEL_PATH}")



âœ… Response extraction complete and validated.
Saved JSON: response_output.json
âœ… JSON successfully converted to Excel with flattened result columns: response_output.xlsx
