In [None]:

import gc
import warnings
warnings.filterwarnings('ignore')
import random
import scipy as sp
import numpy as np
import pandas as pd
import math
from glob import glob
from pathlib import Path
import joblib
import pickle
import itertools
from tqdm.auto import tqdm
import re
import torch
import vllm

# Initialize LLM and Tokenizer
llm = vllm.LLM(
    "Qwen/Qwen2.5-Math-7B-Instruct",  # Ensure this model is supported
    tensor_parallel_size=1,  # or 4 based on available resources
    gpu_memory_utilization=0.95, 
    trust_remote_code=True,
    dtype="half", 
    enforce_eager=True,
)
tokenizer = llm.get_tokenizer()

In [1]:
def generate_text_vllm(requests, tokenizer, model):
    sampling_params = vllm.SamplingParams(
        temperature=0.00,
        seed=42, 
        max_tokens=1024
    )
    responses = model.generate(requests, sampling_params=sampling_params, use_tqdm=False)
    response_text_list = []
    for response in responses:
        # total_tokens += len(response.outputs[0].token_ids)
        response_text_list.append(response.outputs[0].text)
    return response_text_list

In [None]:
import os
import tempfile
import subprocess

class PythonREPL:
    def __init__(self, timeout=5):
        self.timeout = timeout

    def __call__(self, query):
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file_path = os.path.join(temp_dir, "tmp.py")
            with open(temp_file_path, "w", encoding="utf-8") as f:
                f.write(query)
            
            try:
                result = subprocess.run(
                    ["python3", temp_file_path],
                    capture_output=True,
                    check=False,
                    text=True,
                    timeout=self.timeout,
                )
            except subprocess.TimeoutExpired:
                return False, f"Execution timed out after {self.timeout} seconds."

            if result.returncode == 0:
                output = result.stdout.strip()
                return True, output
            else:
                error_msg = result.stderr.strip()
                # Process the error message to remove the temporary file path
                # This makes the error message cleaner and more user-friendly
                error_lines = error_msg.split("\n")
                cleaned_errors = []
                for line in error_lines:
                    if temp_file_path in line:
                        # Remove the path from the error line
                        line = line.replace(temp_file_path, "<temporary_file>")
                    cleaned_errors.append(line)
                cleaned_error_msg = "\n".join(cleaned_errors)
                return False, cleaned_error_msg

In [None]:
import re


def extract_python_code(text):
    pattern = r'```python\s*(.*?)\s*```'
    matches = re.findall(pattern, text, re.DOTALL)
    return "\n\n".join(matches)


def process_python_code(query):
    query = "import math\nimport numpy as np\nimport sympy as sp\n" + query
    current_rows = query.strip().split("\n")
    new_rows = []
    for row in current_rows:
        new_rows.append(row)
        if not row.startswith(" ") and "=" in row:
                variable_to_print = row.split("=")[0].strip()
                new_rows.append(f'print(f"{{{variable_to_print}=}}")')
    return "\n".join(new_rows)


def extract_boxed_text(text):
    pattern = r'oxed{(.*?)}'
    matches = re.findall(pattern, text)
    if not matches:
        return ""
    return matches[0]


from collections import Counter
import random
def select_answer(answers):
    counter = Counter()
    for answer in answers:
        try:
            if int(answer) == float(answer):
                counter[int(answer)] += 1 + random.random() / 1_000
        except:
            pass
    if not counter:
        return 210
    _, answer = sorted([(v,k) for k,v in counter.items()], reverse=True)[0]
    return answer%1000

In [None]:
sampling_params = SamplingParams(
    temperature=1.0,              # randomness of the sampling
    min_p=0.01,
    skip_special_tokens=True,     # Whether to skip special tokens in the output.
    max_tokens=1800,
    stop=["```\n"],
    include_stop_str_in_output=True,
)

def batch_message_generate(list_of_messages) -> list[list[dict]]:

    list_of_texts = [
        tokenizer.apply_chat_template(
            conversation=messages,
            tokenize=False,
            add_generation_prompt=True
        )
        for messages in list_of_messages
    ]
    
    request_output = llm.generate(
        prompts=list_of_texts,
        sampling_params=sampling_params,
    )
    
    for messages, single_request_output in zip(list_of_messages, request_output):
        # print()
        # print(single_request_output.outputs[0].text)
        # print()
        messages.append({'role': 'assistant', 'content': single_request_output.outputs[0].text})

    return list_of_messages

In [None]:
def batch_message_filter(list_of_messages) -> tuple[list[list[dict]], list[str]]:
    extracted_answers = []
    list_of_messages_to_keep = []
    for messages in list_of_messages:
        answer = extract_boxed_text(messages[-1]['content'])
        if answer:
            extracted_answers.append(answer)
        else:
            list_of_messages_to_keep.append(messages)
    return list_of_messages_to_keep, extracted_answers

In [None]:
def batch_message_execute(list_of_messages) -> list[list[dict]]:
    for messages in list_of_messages:
        python_code = extract_python_code(messages[-1]['content'])
        python_code = process_python_code(python_code)
        # print('\n\n' + python_code + '\n\n')
        try:
            print('c', end='')
            is_successful, output = PythonREPL()(python_code)
            if is_successful:
                print('o', end='')
            else:
                print('e', end='')
        except Exception as e:
            print('f', end='')
            output = str(e)
        # print(output)
        messages.append({'role': 'user', 'content': output})
    print()
    return list_of_messages

In [None]:
def create_starter_messages(question, index):
    cycle_size = 4
    if False:
        pass
    elif index % cycle_size == 3:
        return [
            {"role": "user", "content": "Translate this problem into Python code.\n\n" + question + "\n\nStart by importing numpy and sympy. If you get a confident answer after running the sympy code, put your final answer within \boxed{}"}
        ]
    elif index % cycle_size == 2:
        return [
            {"role": "user", "content": "Translate the following problem into sympy.\n\n" + question + "\n\nStart by importing sympy. If you get a confident answer after running the sympy code, put your final answer within \boxed{}"}
        ]
    elif index % cycle_size == 1:
        # https://github.com/QwenLM/Qwen2.5-Math?tab=readme-ov-file#-hugging-face-transformers
        return [
            {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
            {"role": "user", "content": question}
        ]
    else:
        # https://github.com/QwenLM/Qwen2.5-Math?tab=readme-ov-file#-hugging-face-transformers
        return [
            {"role": "system", "content": "Please integrate natural language reasoning with Python programs to solve the problem above, and put your final answer within \\boxed{}."},
            {"role": "user", "content": question}
        ]

In [None]:
def predict_for_question(question: str) -> int:
    import os
    if not os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
        if question != "Triangle $ABC$ has side length $AB = 120$ and circumradius $R = 100$. Let $D$ be the foot of the perpendicular from $C$ to the line $AB$. What is the greatest possible length of segment $CD$?":
            return 210
    
    question += "\nIf the final answer is a number larger than 1 million, take modulo 1000."
    print(question)

    list_of_messages = [create_starter_messages(question, index) for index in range(16)]

    all_extracted_answers = []
    for _ in range(4):
        list_of_messages = batch_message_generate(list_of_messages)
        list_of_messages, extracted_answers = batch_message_filter(list_of_messages)
        all_extracted_answers.extend(extracted_answers)
        if not list_of_messages:
            break
        list_of_messages = batch_message_execute(list_of_messages)

    print(all_extracted_answers)
    answer = select_answer(all_extracted_answers)
    print(answer)

    print("\n\n")
    return answer

In [None]:
def predict(id_: pl.DataFrame, question: pl.DataFrame) -> pl.DataFrame | pd.DataFrame:
    id_ = id_.item(0)
    print("------")
    print(id_)
    
    question = question.item(0)
    answer = predict_for_question(question)
    print(question)
    print("------\n\n\n")
    return pl.DataFrame({'id': id_, 'answer': answer})