In [None]:

import json
import time
import base64
import os
from os.path import join
from tqdm import tqdm
from loguru import logger
from openai import AzureOpenAI
from rank_bm25 import BM25Okapi
import nltk
from nltk.tokenize import word_tokenize
from typing import List, Dict, Any
import re

In [None]:
nltk.download('punkt')
try:
    nltk.data.find('tokenizers/punkt_tab/english/')
except LookupError:
    print("Downloading 'punkt_tab' resource for NLTK...")
    nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\user\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
PLANNER_PROMPT = """
You are a planner to solve a {task_type} problem. Here is the problem for which you have to plan:
{problem}

First draft required strictly greater than {m} {task_type} specialized roles labeling "Specialized Roles" to solve the problem collaboratively with
reasoning behind your draft of each role. Format the roles clearly, for example:
Specialized Roles:
1. Role Name: Reasoning of what this agent should focus on.
2. Role Name: Reasoning...
...
m. Role Name: Reasoning...
m + 1. Role Name: Reasoning...
...

Then select exactly the highly {m} {task_type} influential roles labeling "Influential Roles" from the prior drafted "Specialized Roles" by re-checking the reasoning behind your selection and
assign the prior selected "Influential Roles" among exactly the {m} agents to solve the problem. Format the roles clearly, for example:
Influential Roles:
1. Role Name: Reasoning of what this agent should focus on.
2. Role Name: Reasoning...
...
m. Role Name: Reasoning...
"""

DYNAMIC_AGENT_PROMPT = """
You are a {role}. Your task is to solve a {task_type} problem. Here is the problem that you have to
solve:
{problem}

You were also given a couple of similar problems to the problem above along
with their reasoning and solutions to aid you in solving the problem at hand. Here are the similar
problems you were given:
{external_retrieval}
{self_retrieval}

And here was your original response:
{prev_response}

Also here is the leading responses with execution results from the response store:
{response_store}

Think carefully about where you went wrong, relating with responses in the response store. Then, try to
fix the solution producing a thought later reply with a solution to be executed and judged again. You can
integrate a Python tool to execute the calculations while replying your solution if required.
First provide your reasoning in <think></think> tags, then give the final answer in <answer></answer> tags.

"""

JUDGE_PROMPT = """
You are a judge. Your task is to judge the candidate solution of a {task_type} problem. Here is the
problem for which the candidate solution you have to judge:
{problem}

And here is the candidate response which to judge:
{candidate_response}

Please produce a score labeling "Score" (if the response is correct, it should be 1 otherwise should be 0) with reasoning
behind your judgement of the candidate solution to the problem.
"""

VERIFIER_PROMPT = """
You are an answer extractor. Your task is to extract answer from the response to a {task_type}
problem. Here is the response for which the answer you have to extract:
{response}

Please extract the answer which should be a single numerical number inside from the <answer> <answer> block from the response.
"""


In [None]:
class EpisodicMemory:
    def __init__(self, memory_file=None):
        self.memory = []  # List of dicts: {problem, solution, subject, language}
        self.tokenized_corpus = []
        self.bm25 = None
        self.memory_file = memory_file
        if memory_file:
            self.load_memory_safe(memory_file)

    def load_memory_safe(self, filepath: str):
        if os.path.exists(filepath) and os.stat(filepath).st_size > 0:
            try:
                self.load_memory(filepath)
            except json.JSONDecodeError as e:
                print(f"Warning: Could not load episodic memory from {filepath} due to JSON error: {e}. Starting empty.")
                self.memory = []
                self.tokenized_corpus = []
                self.bm25 = None
            except Exception as e:
                print(f"Warning: Error loading episodic memory from {filepath}: {e}. Starting empty.")
                self.memory = []
                self.tokenized_corpus = []
                self.bm25 = None
        else:
            self.memory = []
            self.tokenized_corpus = []
            self.bm25 = None

    def load_memory(self, filepath: str):
        with open(filepath, 'r', encoding='utf-8') as f:
            self.memory = json.load(f)
        self.tokenized_corpus = [word_tokenize(entry['problem'].lower()) for entry in self.memory if entry.get('problem')]
        if self.tokenized_corpus:
            self.bm25 = BM25Okapi(self.tokenized_corpus)
        else:
            self.bm25 = None

    def retrieve(self, query: str, subject: str = None, k=3) -> List[Dict[str, str]]:
        if not self.bm25:
            return []

        filtered_memory = self.memory
        if subject:
            filtered_memory = [entry for entry in self.memory if entry.get('subject') == subject]

        if not filtered_memory:
            return []
        tokenized_query = word_tokenize(query.lower())
        filtered_corpus = [word_tokenize(entry['problem'].lower()) for entry in filtered_memory]
        if filtered_corpus:
            bm25_filtered = BM25Okapi(filtered_corpus)
            scores = bm25_filtered.get_scores(tokenized_query)
            top_n = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
            return [
                {
                    "problem": filtered_memory[i].get("problem", ""),
                    "solution": filtered_memory[i].get("solution", ""),
                    "subject": filtered_memory[i].get("subject", "")
                }
                for i in top_n if filtered_memory[i].get('solution')
            ]
        return []

    def update(self, problem: str, solution: str, subject: str = None, language: str = None):
        if not problem or not solution:
            return
        self.memory.append({
            'problem': problem,
            'solution': solution,
            'subject': subject,
            'language': language
        })
        self.tokenized_corpus = [word_tokenize(entry['problem'].lower()) for entry in self.memory if entry.get('problem')]
        if self.tokenized_corpus:
            self.bm25 = BM25Okapi(self.tokenized_corpus)
        else:
            self.bm25 = None

    def save_memory(self, filepath: str):
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(self.memory, f, ensure_ascii=False, indent=2)

In [None]:
class SharedMemory:
    def __init__(self, m: int):
        self.memory = []  # List of dicts: agent, response, score
        self.m = m

    def update(self, new_entries: List[Dict]):
        self.memory.extend(new_entries)
        self.memory = sorted(self.memory, key=lambda x: x["score"], reverse=True)[:self.m]

In [None]:
class XolverVisionSolver:

    def __init__(self, endpoint: str = None, deployment: str = None, api_key: str = None,
                 agents: int = 3, rounds: int = 2, memory_file: str = "episodic_memory.json"):
        self.endpoint = endpoint or os.getenv("ENDPOINT_URL", "https://qcri-llm-rag-3.openai.azure.com/")
        self.deployment = deployment or os.getenv("DEPLOYMENT_NAME", "o4-mini")
        self.api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY", api_key)

        self.client = AzureOpenAI(
            azure_endpoint=self.endpoint,
            api_key=self.api_key,
            api_version="2025-01-01-preview",
        )


        self.agents = agents
        self.rounds = rounds
        self.episodic_memory = EpisodicMemory(memory_file)

    def call_model(self, messages: List[Dict[str, Any]]):
        """Call the vision-language model with support for images"""
        try:

            response = self.client.chat.completions.create(
            model=self.deployment,
            messages=messages,
            max_completion_tokens=100000,
            stop=None,
            stream=False
            )
            return response.choices[0].message.content
        except Exception as e:
            logger.error(f"Model call failed: {e}")
            raise

    def encode_image(self, image_path: str) -> str:
        """Encode image to base64"""
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')

    def build_enhanced_prompt(self, item: Dict) -> tuple:
        """Build enhanced prompt with Xolver context"""
        tgt_path = ["starting_kit_latest/" + path for path in item['image_path']]
        caption = item['caption']
        question = caption + "\n" + item['question']


        task_type_map = {
            'CM': 'CM',
            'EM': 'EM',
            'ACG': 'ACG',
            'OPT': 'OPT',
            'AMONP': 'AMONP',
            'QMIT':'QMIT',
            'TSM':'TSM'
        }
        task_type = task_type_map.get(item.get('subject', '').lower(), 'EM')

        if item['language'] == 'English':
            question += "\nPlease analyze the visual content carefully and solve this problem step by step. First provide your reasoning in <think></think> tags, then give the final answer in <answer></answer> tags."
        else:
            question += "\n请仔细分析视觉内容并逐步解决这个问题。首先在<think></think>标签中提供推理过程，然后在<answer></answer>标签中给出最终答案。"


        sig_figs_instruction = ""
        try:
            if item['sig_figs']:
                sf = str(int(item['sig_figs']))
                if item['language'] == 'English':
                    sig_figs_instruction = f"The final answer should retain {sf} significant figures."
                else:
                    sig_figs_instruction = f"最终答案应保留{sf}位有效数字。"
        except:
            pass

        return question, tgt_path, task_type, sig_figs_instruction

    def extract_roles(self, planner_response: str, m: int) -> List[str]:
        """Extract influential roles from planner response"""
        roles = []
        roles_section = re.search(
            r"Influential Roles:\s*(.+?)(?:\n[A-Z][a-zA-Z ]+?:|\Z)",
            planner_response,
            flags=re.S | re.I
        )
        if roles_section:
            roles_text = roles_section.group(1).strip()
            numbered_roles = re.findall(r"\d+\.\s*([^:]+):", roles_text)
            if numbered_roles:
                roles.extend([r.strip() for r in numbered_roles])

        seen = set()
        roles = [r for r in roles if not (r in seen or seen.add(r))]

        if len(roles) < m:
            roles = [f"expert agent {i+1}" for i in range(m)]

        return roles[:m]

    def parse_score(self, score_str: str) -> float:
        """Parse judge score from response"""
        match = re.search(r"Score:\s*([01])", score_str, flags=re.I)
        return int(match.group(1)) if match else 0

    def model_self_recall(self, problem: str, task_type: str) -> str:
        """Generate self-recall for similar problems"""
        recall_prompt = f"""
        Recall from your knowledge a relevant but distinct {task_type} problem with visual elements labeling "Problem" and its solution labeling "Response",
        different from: {problem}

        Provide a complete problem and solution that involves similar concepts.
        """
        messages = [{"role": "user", "content": recall_prompt}]
        return self.call_model(messages)

    def solve_with_xolver(self, item: Dict) -> str:
        """Main Xolver solving pipeline"""
        question, img_paths, task_type, sig_figs_instruction = self.build_enhanced_prompt(item)
        base64_images = [self.encode_image(img_path) for img_path in img_paths]

        #logger.info(f"Solving {task_type} problem with {self.agents} agents for {self.rounds} rounds")

        # Initialize shared memory
        shared_memory = SharedMemory(self.agents)

        # Step 1: Planning phase
        planner_prompt = PLANNER_PROMPT.format(
            task_type=task_type,
            problem=question,
            m=self.agents
        )

        planner_messages = [{
            "role": "user",
            "content": [{"type": "text", "text": planner_prompt}] +
            [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}
             for img in base64_images]
        }]

        planner_response = self.call_model(planner_messages)
        roles = self.extract_roles(planner_response, self.agents)

        logger.info(f"Assigned roles: {roles}")

        # Initialize agent responses
        agent_responses = [""] * self.agents

        # Retrieve similar problems
        external_retrieval = self.episodic_memory.retrieve(
            question,
            subject=item.get('subject'),
            k=3
        )

        # Step 2: Multi-round collaboration
        for iteration in range(self.rounds):
            #logger.info(f"Round {iteration + 1}/{self.rounds}")

            # Build response store
            response_store_text = "\n".join(
                f"Agent: {entry['agent']}\nResponse: {entry['response']}\nScore: {entry['score']}\n"
                for entry in shared_memory.memory
            ) or "None"

            for i, role in enumerate(roles):
                prev_response = agent_responses[i]

                # Prepare retrieval context
                if external_retrieval:
                    external_retrieval_text = "\n\n".join(
                        f"Problem:\n{entry['problem']}\n\nSolution:\n{entry['solution']}"
                        for entry in external_retrieval
                    )
                    self_retrieval_text = "None"
                else:
                    external_retrieval_text = "None"
                    self_retrieval_text = self.model_self_recall(question, task_type)

                # Agent reasoning
                dynamic_prompt = DYNAMIC_AGENT_PROMPT.format(
                    role=role,
                    task_type=task_type,
                    problem=question,
                    external_retrieval=external_retrieval_text,
                    self_retrieval=self_retrieval_text,
                    prev_response=prev_response if prev_response else "None",
                    response_store=response_store_text,
                    sig_figs_instruction=sig_figs_instruction
                )

                agent_messages = [{
                    "role": "user",
                    "content": [{"type": "text", "text": dynamic_prompt}] +
                    [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}
                     for img in base64_images]
                }]

                response = self.call_model(agent_messages)
                # if response == "":
                #     break
                agent_responses[i] = response

                # Judge the response
                judge_prompt = JUDGE_PROMPT.format(
                    task_type=task_type,
                    problem=question,
                    candidate_response=response
                )

                judge_messages = [{
                    "role": "user",
                    "content": [{"type": "text", "text": judge_prompt}] +
                    [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}
                     for img in base64_images]
                }]

                score_str = self.call_model(judge_messages)
                score = self.parse_score(score_str)

                logger.info(f"Agent {role} scored: {score}")

                # Update shared memory
                shared_memory.update([{
                    "agent": role,
                    "response": response,
                    "score": score
                }])

        # Step 3: Select best response and extract final answer
        if shared_memory.memory:
            best_entry = max(shared_memory.memory, key=lambda x: x["score"])
            best_response = best_entry["response"]
        else:
            best_response = agent_responses[0] if agent_responses else "No solution found"

        # Extract final answer
        verifier_prompt = VERIFIER_PROMPT.format(
            task_type=task_type,
            response=best_response
        )
        final_answer = self.call_model([{"role": "user", "content": verifier_prompt}])

        # Update episodic memory
        self.episodic_memory.update(
            problem=question,
            solution=best_response,
            subject=item.get('subject'),
            language=item.get('language')
        )

        if self.episodic_memory.memory_file:
            self.episodic_memory.save_memory(self.episodic_memory.memory_file)

        return best_response

    def run_inference(self, json_path: str, output_path: str, use_xolver: bool = True):
        """Run inference on dataset with optional Xolver enhancement"""
        results = []

        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        for idx, input_item in enumerate(tqdm(data)):
            logger.info(f"Processing item {input_item.get('index')}")

            max_retries = 2
            retry_delay = 15
            attempt = 0
            response = ""

            while attempt < max_retries:
                try:
                    if use_xolver:
                        response = self.solve_with_xolver(input_item)
                    else:
                        # Fallback to simple inference
                        question, img_paths, _, _ = self.build_enhanced_prompt(input_item)
                        base64_images = [self.encode_image(img_path) for img_path in img_paths]
                        response = self.inference_one_step(question, base64_images)
                    break
                except Exception as e:
                    attempt += 1
                    logger.error(f"Attempt {attempt} failed: {e}")
                    if attempt < max_retries:
                        logger.info(f"Waiting {retry_delay} seconds before retrying...")
                        time.sleep(retry_delay)
                    else:
                        logger.error("Max retries reached. Using fallback.")
                        response = "ERROR: Max retries reached."

            results.append({
                "index": input_item['index'],
                "question": input_item['question'],
                "subject": input_item['subject'],
                "img_category": input_item['img_category'],
                "vision_relevance": input_item['vision_relevance'],
                "language": input_item['language'],
                "level": input_item['level'],
                "sig_figs": input_item['sig_figs'],
                "caption": input_item['caption'],
                "prediction": response,
            })

            # Save incrementally
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(results, f, ensure_ascii=False, indent=4)

    def inference_one_step(self, question: str, base64_images: List[str]) -> str:
        """Simple one-step inference for fallback"""
        messages = [{
            "role": "user",
            "content": [{"type": "text", "text": question}] +
            [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}
             for img in base64_images]
        }]
        return self.call_model(messages)

In [None]:
if __name__ == '__main__':


    solver = XolverVisionSolver(
        endpoint="https://qcri-llm-rag-3.openai.azure.com/",
        deployment="o4-mini",
        api_key="",
        agents=3,
        rounds=2,
        memory_file="vision_episodic_memory.json"
    )

    solver.run_inference(
        'starting_kit_latest/mini.json',
        'starting_kit_latest/xolver_prediction_gptmini.json',
        use_xolver=True
    )

  0%|          | 0/200 [00:00<?, ?it/s][32m2025-06-28 21:43:45.762[0m | [1mINFO    [0m | [36m__main__[0m:[36mrun_inference[0m:[36m267[0m - [1mProcessing item 11[0m
[32m2025-06-28 21:44:00.142[0m | [1mINFO    [0m | [36m__main__[0m:[36msolve_with_xolver[0m:[36m146[0m - [1mAssigned roles: ['Thermodynamics Expert', 'Statistical Mechanics Theorist', 'Molecular Spectroscopy Specialist'][0m
[32m2025-06-28 21:44:17.916[0m | [1mINFO    [0m | [36m__main__[0m:[36msolve_with_xolver[0m:[36m223[0m - [1mAgent Thermodynamics Expert scored: 1[0m
[32m2025-06-28 21:44:35.905[0m | [1mINFO    [0m | [36m__main__[0m:[36msolve_with_xolver[0m:[36m223[0m - [1mAgent Statistical Mechanics Theorist scored: 1[0m
[32m2025-06-28 21:45:09.505[0m | [1mINFO    [0m | [36m__main__[0m:[36msolve_with_xolver[0m:[36m223[0m - [1mAgent Molecular Spectroscopy Specialist scored: 1[0m
[32m2025-06-28 21:45:33.633[0m | [1mINFO    [0m | [36m__main__[0m:[36msolve_with_