In [None]:
import torch
torch.backends.cuda.enable_mem_efficient_sdp(False)
import pandas as pd
from tqdm import tqdm
import gc
import re
import sys
import subprocess
import math
import random
from collections import Counter
from numpy.random import choice
import numpy as np

from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    AutoConfig,
    StoppingCriteria,
    StoppingCriteriaList,
    set_seed,
    BitsAndBytesConfig
)
import transformers
from typing import Optional
import time
import logging

In [None]:
model_path = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
max_tokens = 4096
self_consistency_count = 17
temperature = 0.6
top_p = 1.0

In [None]:
def run_code(code: str, remove_tmp=True, timeout=7):
    with open("tmp.py", "w") as f:
        f.write(code)
    time.sleep(0.1)
    try:
        output = subprocess.check_output(["python3", "tmp.py"], stderr=subprocess.STDOUT, timeout=timeout)
        return output.decode("utf-8"), True
    except subprocess.CalledProcessError as e:
        return e.output.decode("utf-8"), False
    except subprocess.TimeoutExpired:
        return "Timeout", False
    except Exception as e:
        return str(e), False
    finally:
        if remove_tmp:
            subprocess.run(["rm", "tmp.py"])
            
def get_last_python_code(text: str):
    pattern = re.compile(r"```python\n(.*?)\n```", re.DOTALL)
    code = ""
    for match in pattern.finditer(text):
        code = match.group(1)
    return code

def get_answer(text: str): 
    try:
        result_output = re.findall(r'\\boxed\{(\d+)\}', text)
        return float(result_output[0])
    except Exception:
        return None

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
config = AutoConfig.from_pretrained(model_path)

tokenizer = AutoTokenizer.from_pretrained(model_path)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype='bfloat16',
) 
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map='auto',
    torch_dtype="auto",
    trust_remote_code=True,
    config=config,
    quantization_config=quantization_config)

class StoppingCriteriaPythonCode(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        decoded_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
        return decoded_text[-3:] == "```" and decoded_text.count("```") % 2 == 0
    
class StoppingCriteriaAnswer(StoppingCriteria):
    def __init__(self):
        super().__init__()
        self.pattern = re.compile(r'\\boxed\{(\d+)\}')

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        decoded_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
        decoded_text = decoded_text[-20:]
        if self.pattern.search(decoded_text):
            return True
        return False
    

In [None]:
def seed_everything(seed):
    import random
    import os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    set_seed(seed)
    
def clean_memory():
    for _ in range(5):
        torch.cuda.empty_cache()
        gc.collect()
        time.sleep(0.2)
seed_everything(42)

In [None]:
code = """Below is a math problem you are to solve (positive numerical answer):
\"{}\"
To accomplish this, first determine a sympy-based approach for solving the problem by listing each step to take and what functions need to be called in each step. Be clear so even an idiot can follow your instructions, and remember, your final answer should be integer, not an algebraic expression!
Write the entire script covering all the steps (use comments and document it well) and print the result. After solving the problem, output the final numerical answer within \\boxed{}.

Approach:
First, """

In [None]:
def model_generate(prompt: str, stopping_criteria = None, past_key_values = None, return_past_key_values = False) -> str:
    model_inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    if len(model_inputs['input_ids'][0]) > max_tokens:
        return prompt
    
    if past_key_values:
        generation_output = model.generate(
            **model_inputs,
            max_length=max_tokens,
            return_dict_in_generate=True,
            do_sample = True,
            #temperature = temperature,
            top_p = top_p,
            num_return_sequences=1,
            stopping_criteria = stopping_criteria,
            pad_token_id=tokenizer.eos_token_id,
            past_key_values=past_key_values,
        )
    else:
        generation_output = model.generate(
            **model_inputs,
            max_length=max_tokens,
            return_dict_in_generate=True,
            do_sample = True,
            #temperature = temperature,
            top_p = top_p,
            num_return_sequences=1,
            stopping_criteria = stopping_criteria,
            pad_token_id=tokenizer.eos_token_id,
        )
    if return_past_key_values:
        return tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True), generation_output.past_key_values
    return tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True)

In [None]:
def predict_code(prompt: str, problem: str, return_text_output=False, max_turns=10) -> Optional[float]:
    prompt = prompt.format(problem,"{}")
    past_key_values = None
    code_error_count = 0
    
    for _ in range(max_turns):
        model_inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
        if model_inputs['input_ids'].shape[-1] >= max_tokens:
            break
        clean_memory()
        prompt, past_key_values = model_generate(prompt, StoppingCriteriaList([StoppingCriteriaPythonCode(), StoppingCriteriaAnswer()]), past_key_values, True)
        
        if re.search(r'\\boxed\{(\d+)\}', prompt[-20:]):
            break
        
        code = get_last_python_code(prompt)
        code_result, success = run_code(code)
        if success:
            prompt += f"\nCode Result: {code_result} \n"
        else:
            code_error_count += 1
            prompt += f"\nCode Error: {code_result} \nYour code has an error. Please review the problem and your code and try again.\n"
            if 'is not defined' in code_result:
                prompt += "\nYou need to define the variable or function that is not defined in your code.\n"
            
        if code_error_count >= 3:
            break
        
    
    if return_text_output:
        return get_answer(prompt), prompt
    return get_answer(prompt)
    


In [None]:
def predict_simple(prompt: str, problem: str, return_text_output=False) -> Optional[float]:
    prompt = prompt.format(problem,"{}")
    prompt = model_generate(prompt, StoppingCriteriaList([StoppingCriteriaAnswer()]))
    if return_text_output:
        return get_answer(prompt), prompt
    return get_answer(prompt)

In [None]:
def predict_dup(prompt: str, problem: str, return_text_output=False) -> Optional[float]:
    core_question_prompt = f'{problem}\nPlease extract the core question, only the most comprehensive and detailed one!'
    core_question = model_generate(core_question_prompt, StoppingCriteriaList([StoppingCriteriaAnswer()]))[len(core_question_prompt):]
    
    
    extract_info_prompt = f'{problem}\nNote: Please extract the question-solving information related to the problem({core_question}), only extract the most useful information, and list them one by one!'
    problem_solving_info = model_generate(extract_info_prompt, StoppingCriteriaList([StoppingCriteriaAnswer()]))[len(extract_info_prompt):]
    
    
    extract_answer_prompt = f'{problem}\nHint: {problem_solving_info}\n{core_question}\n'
    clean_memory()
    return predict_code(prompt, extract_answer_prompt, return_text_output)
    
    
    

In [None]:
predictor_set = [
    (predict_code, code),
    (predict_dup, code)
]

In [None]:
def predict(problem: str, self_consistency_counts=1, return_text_output = False) -> Optional[float]:
    assert self_consistency_counts > 0
    predicted_answer = []
    predicted_text_output = []
    
    best_stats = {}
    best_answer = None
    best_answer_count = -1
    
    for i in range(self_consistency_counts):
        logging.debug("=============================================Running iteration: " + str(i))
        if best_answer_count > np.sqrt(i):
            logging.debug("Skipping iteration due to sufficient best count.")
            continue
            
        clean_memory()
        chosen_predictor, prompt = random.choice(predictor_set)
        result = chosen_predictor(prompt, problem, return_text_output=True)
        logging.debug(result[1])
        predicted_answer.append(result[0])
        predicted_text_output.append(result[1])
        
        occurances = Counter(predicted_answer).most_common()
        if occurances[0][1] > best_answer_count and occurances[0][0] is not None:
            logging.debug("Found new best answer.")
            best_answer = occurances[0][0]
            best_answer_count = occurances[0][1]
        if occurances[0][1] > 5 and occurances[0][0] is not None:
            logging.debug("Found sufficient occurrences of the best answer.")
            break
        
        best_stats[i] = (best_answer, best_answer_count) 
        logging.debug(occurances)
        
    occurances = Counter(predicted_answer).most_common(1)
    if occurances[0][0] is not None:
        final_answer = occurances[0][0]
    elif len(occurances) > 1:
        final_answer = occurances[1][0]
    else:
        final_answer = None
        
    
    if return_text_output:
        return final_answer, predicted_answer, predicted_text_output
    return final_answer

In [None]:
from datasets import Dataset

val_dataset = Dataset.from_csv('dataset/AIMO_val.csv')
val_dataset = val_dataset.remove_columns(['id', 'subfield', 'solution'])
val_dataset = val_dataset.shuffle(seed=42)

In [None]:
logging.basicConfig(level=logging.DEBUG)
scores = []
for row in tqdm(val_dataset):
    predicted_answer = predict(row['problem'], self_consistency_count)
    scores.append(predicted_answer == row['answer'])
    print("Current Score:", sum(scores)/len(scores))
print(f"Accuracy: {sum(scores)/len(scores)}")