In [203]:
import json
medqa_test = []
with open("/data/jiwoong/workspace/MedAgents-2/datasets/MedQA/50_sampled_hard_medqa/test.jsonl", 'r') as jsfile:
    for line in jsfile:
        medqa_test.append(json.loads(line))

queries = []
for test in medqa_test:
    queries.append(f"{test['question']}\n\n" + "Options: " + " ".join([f"({key}) {value}" for key, value in test['options'].items()]))

In [222]:
import os
from typing import List, Dict, Any, Optional
from openai import AzureOpenAI
from pymilvus import MilvusClient
from dotenv import load_dotenv
import utils
import re

load_dotenv()

retrieval_client = MilvusClient(uri="http://localhost:19530")
llm_client = AzureOpenAI(
    azure_endpoint="https://azure-openai-miblab-ncu.openai.azure.com/",
    api_key=os.getenv("azure_api_key"),
    api_version="2024-08-01-preview"
)


In [230]:
from typing import List, Dict, Any
class CIDER:
    def __init__(self, model: str = "gpt-4o-mini"):
        self.model = model
        self.current_knowledge: List[str] = []
        self.iteration_history: List[Dict] = []
        self.expert_roles: List[str] = []
        self.retrieve_topk: int = 5
        self.rerank_topk: int = 5

    def _clean_text(self, text: str) -> str:
        """Remove markdown formatting and clean text."""
        return text.replace('**', '').replace("'''", '').strip() if isinstance(text, str) else ""

    def _call_llm(self, messages: List[Dict]) -> str:
        """Make an API call to the LLM and return the response."""
        response = llm_client.chat.completions.create(
            model=self.model, messages=messages, temperature=0, max_tokens=2048
        )
        return self._clean_text(response.choices[0].message.content)

    def process_query(self, initial_query: str) -> Dict[str, Any]:
        """Main query processing loop."""
        iteration = 0
        max_iterations = 5
        consensus_result = ""

        while iteration < max_iterations:
            if iteration == 0:
                self.expert_roles = self._generate_expert_domains(initial_query, utils.medical_specialties_gpt_selected)
                queries = self._generate_expert_query(initial_query)
            else:
                follow_up_context = f"Original Query: {initial_query} \n\n Previous Report: {consensus_result}"
                queries = self._generate_expert_query(follow_up_context)

            retrieved_docs = self._retrieve_queries(queries)
            if retrieved_docs:
                self._update_knowledge(retrieved_docs, initial_query)

            expert_analyses = self._expert_analysis(self.current_knowledge, initial_query)
            consensus_result = self._check_consensus(initial_query, expert_analyses)

            self.iteration_history.append({
                'iteration': iteration,
                'queries': queries,
                'docs': retrieved_docs,
                'analyses': expert_analyses,
                'consensus': consensus_result
            })

            if "consensus: yes" in consensus_result.lower():
                return {
                    'final_answer': self._final_answer_pick(consensus_result),
                    'iteration_history': self.iteration_history
                }
            iteration += 1

        return {
            'final_answer': self._final_answer_pick(consensus_result),
            'iteration_history': self.iteration_history
        }

    def _update_knowledge(self, new_documents: List[str], original_query: str):
        """Update the current knowledge base with new documents."""
        combined_docs = "\n".join(set(new_documents) - set(self.current_knowledge))
        if combined_docs.strip():
            summary = self._summarize_documents(combined_docs, original_query)
            if summary:
                self.current_knowledge.append(summary)

    def _summarize_documents(self, documents: str, original_query: str) -> str:
        """Summarize the documents in relation to the original query."""
        summary_prompt = (
            f"""Summarize the key insights from the following set of medical documents, considering their relevance to the original query:
Original Query: {original_query}

Documents: {documents}

Please provide a concise and objective summary of the most clinically relevant information."""
        )
        try:
            return self._call_llm([
                {"role": "system", "content": "You are a medical summarizer extracting key insights from documents."},
                {"role": "user", "content": summary_prompt}
            ])
        except Exception as e:
            print(f"Error summarizing documents: {str(e)}")
            return ""

    def _generate_expert_domains(self, query: str, medical_fields: List[str], num_fields: int = 5) -> List[str]:
        """Generate relevant expert domains for the query."""
        question_domain_format = "Medical Field: " + " | ".join([f"Field{i}" for i in range(num_fields)])
        question_classifier = "You are a medical expert who specializes in categorizing medical scenarios into specific areas of medicine."
        prompt_get_question_domain = (
            f"Complete these steps:\n"
            f"1. Read the medical scenario: '''{query}'''.\n"
            f"2. Classify into these subfields: {', '.join(medical_fields)}.\n"
            f"3. Output exactly in this format: '''{question_domain_format}'''."
        )

        try:
            response = self._clean_text(self._call_llm([
                {"role": "system", "content": question_classifier},
                {"role": "user", "content": prompt_get_question_domain}
            ]))

            if "ield: " in response:
                domain_list = [domain.strip() for domain in response.split("ield: ")[1].split('|') if domain.strip()]
                if len(domain_list) == 1:
                    domain_list.append('General Medicine')
                return domain_list
            raise ValueError("Delimiter 'ield: ' not found in the response.")
        except (IndexError, ValueError, Exception) as e:
            return ['General Medicine'] * num_fields

    def _generate_expert_query(self, context: str) -> List[str]:
        """Generate expert queries based on context."""
        query_prompt = (
            f"""Generate up to three specific medical queries addressing:

{context}

Consider:
1. Expert disagreements
2. Additional information needed
3. Remaining knowledge gaps

Format:
1st Query: <Primary concerns>
2nd Query: <Secondary aspects>
3rd Query: <Remaining gaps>

Make queries specific and targeted."""
        )

        all_queries = []
        for role in self.expert_roles:
            try:
                response = self._call_llm([
                    {"role": "system", "content": f"You are a medical expert in {role}."},
                    {"role": "user", "content": query_prompt.format(context=context)}
                ])
                queries = [
                    line.split(":")[-1].strip() for line in response.split("\n")
                    if line.strip().startswith(("1st Query", "2nd Query", "3rd Query"))
                ]
                all_queries.extend(queries)
            except Exception as e:
                print(f"Error generating queries for {role}: {str(e)}")
        return all_queries

    def _retrieve_queries(self, queries: List[str]) -> List[str]:
        """Retrieve documents based on queries."""
        retrieved_docs = []
        for query in queries:
            try:
                docs = utils.rerank(query, utils.retrieve(query, retrieval_client, topk=self.retrieve_topk))
                retrieved_docs.extend(docs[:self.rerank_topk])
            except Exception as e:
                print(f"Error retrieving documents for query '{query}': {str(e)}")

        # Remove duplicates while maintaining order
        seen = set()
        return [x for x in retrieved_docs if not (x in seen or seen.add(x))]

    def _expert_analysis(self, documents_or_knowledgebase: List[str], query: str) -> List[Dict[str, str]]:
        """Perform expert analysis on the current knowledge base."""
        analysis_prompt = (
            f"""Analyze this knowledge base and solve the query:

Current Knowledge: {self._format_docs_for_prompt(documents_or_knowledgebase)}

Original Query: {query}

Provide:
1. Key Information: <Critical information>
2. Remaining Questions: <Gaps to address>
3. Reasoning: <Justification>
4. Answer: <Conclude your response with the phrase \"the answer is ([option_id]) [answer_string]\">"""
        )

        return [
            {
                'role': role,
                'analysis': self._call_llm([
                    {"role": "system", "content": f"You are a {role} specialist analyzing medical information."},
                    {"role": "user", "content": analysis_prompt}
                ])
            }
            for role in self.expert_roles
        ]

    def _check_consensus(self, query: str, expert_analyses: List[Dict[str, str]]) -> str:
        """Check if there is a consensus among expert analyses."""
        consensus_prompt = (
            """Review the expert opinions below and determine whether they reach a consensus.

Instructions:
1. Indicate if there is a consensus: <yes/no>
2. If there is a consensus, provide the agreed-upon answer choice: <state the answer>
3. If there is no consensus, describe the disagreements.

Expert Analyses:
{}

Original Query:
{}
""".format(''.join([f"Expert ({a['role']}):\n{a['analysis']}\n\n" for a in expert_analyses]), query)
        )
        return self._call_llm([
            {"role": "system", "content": "You are evaluating expert opinions for consensus on a question."},
            {"role": "user", "content": consensus_prompt}
        ])

    def _format_docs_for_prompt(self, documents: List[str]) -> str:
        """Format documents for inclusion in prompts."""
        return "\n\n".join([f"Document {i+1}:\n{doc}" for i, doc in enumerate(documents)])

    def _final_answer_pick(self, text: str) -> str:
        """Pick the final answer based on the consensus text."""
        return self._call_llm([
            {"role": "system", "content": "You are an answer parser. Pick an answer even if there is no consensus."},
            {"role": "user", "content": f"Only output A, B, C, or D from {text}"}
        ])


In [231]:
cider_instance = CIDER()

def process_queries_with_threadpool(cider_instance, queries):
    result = [None] * len(queries)  # 결과 리스트 생성 (입력 순서 유지)
    
    with ThreadPoolExecutor(max_workers=5) as executor:
        # 각 작업을 제출하고, (인덱스, Future) 쌍으로 저장
        futures = {executor.submit(cider_instance.process_query, query): idx for idx, query in enumerate(queries)}
        
        # tqdm을 사용하여 진행률 표시줄 생성 및 작업 진행 상태 추적
        for future in tqdm(as_completed(futures), total=len(futures)):
            idx = futures[future]  # 인덱스 가져오기
            result[idx] = future.result()  # 결과 저장

    return result

# 결과 처리 및 출력
result = process_queries_with_threadpool(cider_instance, queries[:])

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [20:41<00:00, 24.83s/it]


In [196]:
with open("/data/jiwoong/workspace/output/v5", 'w') as jsfile:
    json.dump(result, jsfile)

In [3]:
cider_instance = CIDER(model="gpt-4o-mini")

In [8]:
test = medqa_test[122]
query = f"{test['question']}\n\n" + "Options: " + " ".join([f"({key}) {value}" for key, value in test['options'].items()])

In [39]:
query

'A 28-year-old gravida 1 at 32 weeks gestation is evaluated for an abnormal ultrasound that showed fetal microcephaly. Early in the 1st trimester, she had fevers and headaches for 1 week. She also experienced myalgias, arthralgias, and a pruritic maculopapular rash. The symptoms resolved without any medications. A week prior to her symptoms, she had traveled to Brazil where she spent most of the evenings hiking. She did not use any mosquito repellents. There is no personal or family history of chronic or congenital diseases. Medications include iron supplementation and a multivitamin. She received all of the recommended childhood vaccinations. She does not drink alcohol or smoke cigarettes. The IgM and IgG titers for toxoplasmosis were negative. Which of the following is the most likely etiologic agent?\n\nOptions: (A) Dengue virus (B) Rubella virus (C) Toxoplasmosis (D) Zika virus'

In [4]:
from concurrent.futures import ThreadPoolExecutor
from tqdm.auto import tqdm

difficulty_list = [None] * len(medqa_test)  # Pre-allocate the list to maintain order

def assess_difficulty(index, test):
    query = f"{test['question']}\n\n" + "Options: " + " ".join([f"({key}) {value}" for key, value in test['options'].items()])
    return index, cider_instance.assess_query_difficulty(query)

with ThreadPoolExecutor() as executor:
    futures = [executor.submit(assess_difficulty, i, test) for i, test in enumerate(medqa_test)]
    
    for future in tqdm(futures, total=len(medqa_test), desc="Processing Queries"):
        index, result = future.result()
        difficulty_list[index] = result  # Store the result at the correct index

Processing Queries:   0%|          | 0/1273 [00:00<?, ?it/s]

In [None]:
difficulty_list.count('easy')

754

In [None]:
base_direct = []
with open("/data/jiwoong/workspace/MedAgents-2/baselines/MedAgents/outputs/MedQA/gpt4omini-base_direct", 'r') as jsfile:
    for line in jsfile:
        base_direct.append(json.loads(line))

In [18]:
easy=0
hard=0
for i in range(1273):
    if difficulty_list[i] == 'easy' and base_direct[i]['pred_answer'] == base_direct[i]['gold_answer']:
        easy += 1
    elif difficulty_list[i] == 'hard' and base_direct[i]['pred_answer'] == base_direct[i]['gold_answer']:
        hard += 1
print(easy/difficulty_list.count('easy'))
print(hard/difficulty_list.count('hard'))

0.8368700265251989
0.7341040462427746
