In [None]:
import os
import pandas as pd
import numpy as np
import config
from data_processing.loader import load_data, load_xml_as_dataframe
from generation.prompt_generator import (
    generate_prompt_baseline,
    generate_prompt_fewshot,
    generate_prompt_rag,
    generate_prompt_rag_summary,
    generate_prompt_rag_synthetic_cases
)
from generation.answers_generator import generate_answers
from generation.synthetic_data_generator import (
    generate_synthetic_answers,
    generate_reasoning,
    summarize_articles,
    generate_synthetic_case
)
from postprocess import prepare_submission_file
from rag.retrieval import build_faiss_index, embed_text, retrieve_articles, load_faiss_index
from transformers import AutoModel, AutoTokenizer
import json
import pickle

In [None]:
# Configure the generative AI model
client = config.get_client()

In [None]:
# Model Selection
BASELINE_MODEL = "google/gemini-2.5-pro"
FEWSHOT_MODEL = "google/gemini-2.0-flash"
RAG_MODEL = "google/gemini-2.5-pro"
SUMMARIZATION_MODEL = "google/gemini-2.5-pro"

In [None]:
# Load data
result = load_data(config.XML_FILE)
ground_truth = load_data(config.JSON_KEY_FILE)

### 0. Exemplars and retreival preparation

#### 0.1. Exemplars synthesis

In [None]:
SYNTHETIC_ANSWER_MODEL = "google/gemini-2.5-pro"
REASONING_MODEL = "google/gemini-2.5-pro"

In [None]:
generate_synthetic_answers(
    client=client,
    model_name=SYNTHETIC_ANSWER_MODEL,
    result=result,
    ground_truth=ground_truth,
    output_path="data/syn_answer.pkl"
)

In [None]:
generate_reasoning(
    client=client,
    model_name=REASONING_MODEL,
    input_path="data/syn_answer.pkl",
    output_path="data/syn_answer_with_reasoning.pkl"
)

#### 0.2. Summarize articles

In [None]:
SUMMARIZATION_MODEL_ARTICLES = "google/gemini-2.5-pro"

In [None]:
summarize_articles(
    client=client,
    model_name=SUMMARIZATION_MODEL_ARTICLES,
    retrieved_articles=retrieved_articles,
    output_path=config.SUMMARIZED_ARTICLES_FILE
)

#### 0.3. Generate synthetic cases

In [None]:
SYNTHETIC_CASE_MODEL = "google/gemini-2.5-pro"

In [None]:
generate_synthetic_case(
    client=client,
    model_name=SYNTHETIC_CASE_MODEL,
    retrieved_articles=retrieved_articles,
    output_path=config.SYNTHETIC_CASES_FILE
)

### 1. Baseline experiment

In [None]:
baseline_answers = generate_answers(
    cases=result["cases"],
    client=client,
    model_name=BASELINE_MODEL,
    prompt_fn=generate_prompt_baseline
)
prepare_submission_file(
    baseline_answers,
    f"{config.OUTPUT_DIR}/submission_baseline.json",
    result,
    client,
    SUMMARIZATION_MODEL
)

### 2. Few-shot experiments

In [None]:
df = pd.read_pickle(config.EXEMPLARS_FILE)
all_cases = df.to_dict(orient="records")
with open('prompts/example.json', 'r') as f:
    basic_examples = json.load(f)

#### 2.1. Few-shot: basic

In [None]:
fewshot1_answers = generate_answers(
    cases=result["cases"],
    client=client,
    model_name=FEWSHOT_MODEL,
    prompt_fn=lambda case: generate_prompt_fewshot(case, basic_examples),
    max_retries=5
)
prepare_submission_file(
    fewshot1_answers,
    f"{config.OUTPUT_DIR}/submission_fewshot1.json",
    result,
    client,
    SUMMARIZATION_MODEL
)

#### 2.2. Few-shot: LLM-generated exemplars

In [None]:
fewshot2_answers = generate_answers(
    cases=result["cases"],
    client=client,
    model_name=FEWSHOT_MODEL,
    prompt_fn=lambda case: generate_prompt_fewshot(case, all_cases),
    max_retries=5
)
prepare_submission_file(
    fewshot2_answers,
    f"{config.OUTPUT_DIR}/submission_fewshot2.json",
    result,
    client,
    SUMMARIZATION_MODEL
)

#### 2.3. Few-shot: exemplars with reasoning

In [None]:
fewshot3_answers = generate_answers(
    cases=result["cases"],
    client=client,
    model_name=FEWSHOT_MODEL,
    prompt_fn=lambda case: generate_prompt_fewshot(case, all_cases, add_reasoning=True),
    max_retries=5
)
prepare_submission_file(
    fewshot3_answers,
    f"{config.OUTPUT_DIR}/submission_fewshot3.json",
    result,
    client,
    SUMMARIZATION_MODEL
)

### 3. RAG experiments

In [None]:
df_article = pd.read_csv(config.ARTICLE_FILE)
df_query = load_xml_as_dataframe(config.XML_FILE)

# Load model for RAG
query_model = AutoModel.from_pretrained(config.EMBEDDING_MODEL)
query_tokenizer = AutoTokenizer.from_pretrained(config.EMBEDDING_MODEL)

# Create embeddings and build index
question_embeddings = embed_text(df_query['Clinician Question'].tolist(), query_model, query_tokenizer)
index = load_faiss_index(config.VECTOR_DATABASE_FILE)

# Retrieve articles for all queries
retrieved_articles = retrieve_articles(np.array(question_embeddings), index, df_article)


#### 3.1. RAG: full articles

In [None]:
rag1_answers = generate_answers(
    cases=result["cases"],
    client=client,
    model_name=RAG_MODEL,
    prompt_fn=lambda case, idx: generate_prompt_rag(case, retrieved_articles[idx]),
    max_retries=5,
    use_index=True,
)
prepare_submission_file(
    rag1_answers,
    f"{config.OUTPUT_DIR}/submission_gemini_rag1.json",
    result,
    client,
    SUMMARIZATION_MODEL
)

#### 3.2. RAG: article summaries

In [None]:
with open(config.SUMMARIZED_ARTICLES_FILE, 'rb') as f:
    summarized_articles = pickle.load(f)
with open('prompts/example.json', 'r') as f:
    example = json.load(f)[0]

In [None]:
rag2_answers = generate_answers(
    cases=result["cases"],
    client=client,
    model_name=RAG_MODEL,
    prompt_fn=lambda case, idx: generate_prompt_rag_summary(case, summarized_articles[idx], example),
    max_retries=5,
    use_index=True,
)
prepare_submission_file(
    rag2_answers,
    f"{config.OUTPUT_DIR}/submission_rag_summary.json",
    result,
    client,
    SUMMARIZATION_MODEL
)

#### 3.3. RAG: synthetic Cases

In [None]:
df_synthetic = pd.read_csv(config.SYNTHETIC_CASES_FILE)
all_synthetic_cases = df_synthetic.to_dict(orient="records")

In [None]:
rag3_answers = generate_answers(
    cases=result["cases"],
    client=client,
    model_name=RAG_MODEL,
    prompt_fn=lambda case: generate_prompt_rag_synthetic_cases(case, all_synthetic_cases),
    max_retries=5
)
prepare_submission_file(
    rag3_answers,
    f"{config.OUTPUT_DIR}/submission_rag_synthetic_cases.json",
    result,
    client,
    SUMMARIZATION_MODEL
)
