# Load Datasets

In [1]:
from google.colab import drive
import json
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
with open("/content/drive/My Drive/sharded_dataset.json") as f:
  data = json.load(f)

In [3]:
D2T = []
Code = []
GSM8K = []
DB = []
API = []
for task in data:
  if "code" in task['task']:
    Code.append(task)
  if "math" in task['task']:
    GSM8K.append(task)
  if "database" in task['task']:
    DB.append(task)
  if "actions" in task['task']:
    API.append(task)
  if "data2text" in task['task']:
    D2T.append(task)

In [None]:
import tiktoken
def count_tokens(text: str, model: str = "gpt-4"):
    encoding = tiktoken.encoding_for_model(model)
    tokens = encoding.encode(text)
    return len(tokens)

In [None]:
with open("/content/drive/MyDrive/Context_reset/MATH/MATH_GPT4o/GSM8K_run0_CORRECTED.json", 'r') as f:
  data = json.load(f)

In [None]:
tokens = []

for i, entry in enumerate(data):
  t = []
  tk = 0
  if entry['resets'] == 0:
    for i, msg in enumerate(entry['chat_history']):
      if msg['role'] == 'user':
        tk += count_tokens(msg['content'])
        t.append(tk)
      elif msg['role'] == 'assistant':
        tk += count_tokens(msg['content'])
      elif msg['role'] == 'system':
        tk += count_tokens(msg['content'])
    tokens.append(np.mean(t))

In [None]:
import ast
token_res = []
for i, entry in enumerate(data):
  t = []
  tk = 0
  if entry['resets'] == 1:
    parts = entry['chat_history'].split("AFTER RESET")
    list1 = ast.literal_eval(parts[0].strip())
    list2 = ast.literal_eval(parts[1].strip())
    for msg in list1:
      if msg['role'] == 'user':
        tk += count_tokens(msg['content'])
        t.append(tk)
      elif msg['role'] == 'assistant':
        tk += count_tokens(msg['content'])
      elif msg['role'] == 'system':
        tk += count_tokens(msg['content'])
    tk = 0
    for msg in list2:
      if msg['role'] == 'user':
        tk += count_tokens(msg['content'])
        t.append(tk)
      elif msg['role'] == 'assistant':
        tk += count_tokens(msg['content'])
      elif msg['role'] == 'system':
        tk += count_tokens(msg['content'])
    token_res.append(np.mean(t))

#Utility Functions

In [21]:
!wget https://raw.githubusercontent.com/microsoft/lost_in_conversation/main/tasks/actions/eval_bfcl.py

--2025-06-23 02:09:58--  https://raw.githubusercontent.com/microsoft/lost_in_conversation/main/tasks/actions/eval_bfcl.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 31177 (30K) [text/plain]
Saving to: ‘eval_bfcl.py.1’


2025-06-23 02:09:58 (2.87 MB/s) - ‘eval_bfcl.py.1’ saved [31177/31177]



In [22]:
!wget https://raw.githubusercontent.com/microsoft/lost_in_conversation/main/tasks/actions/task_actions.py

--2025-06-23 02:09:58--  https://raw.githubusercontent.com/microsoft/lost_in_conversation/main/tasks/actions/task_actions.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3605 (3.5K) [text/plain]
Saving to: ‘task_actions.py.1’


2025-06-23 02:09:58 (29.6 MB/s) - ‘task_actions.py.1’ saved [3605/3605]



In [23]:
pip install sacrebleu



In [24]:
import sacrebleu
def D2T_evaluator_function(extracted_answer, sample):
    # ToTTo has multiple references per example
    references = sample["references"]
    bleu = sacrebleu.corpus_bleu([extracted_answer.strip()], [[ref.strip()] for ref in references])
    return bleu.score / 100.0

In [25]:
from eval_bfcl import parallel_function_checker_enforce_order, parallel_function_checker_no_order, ast_checker, ast_parse

In [26]:
import re

def extract_function_block(text):
    start = text.find('[')
    if start == -1:
        return ''

    level = 0
    for i in range(start, len(text)):
        if text[i] == '[':
            level += 1
        elif text[i] == ']':
            level -= 1
            if level == 0:
                block = text[start:i+1]
                return clean_function_block(block)
    return ''

def clean_function_block(block):
    block = block.replace('\n', '').replace('\r', '').replace('\t', '')
    block = ' '.join(block.split())

    # Remove "..." wrapping function calls only
    block = re.sub(r'"\s*([a-zA-Z_][a-zA-Z0-9_\.]*\s*\([^"]*\))\s*"', r'\1', block)

    # Remove space after [ and before ]
    block = re.sub(r'\[\s+', '[', block)
    block = re.sub(r'\s+\]', ']', block)

    return block


def extract_all_function_blocks(text):
    blocks = []
    start_positions = [i for i, c in enumerate(text) if c == '[']

    for start in start_positions:
        level = 0
        found = False
        for i in range(start, len(text)):
            if text[i] == '[':
                level += 1
            elif text[i] == ']':
                level -= 1
                if level == 0:
                    block = text[start:i+1]
                    blocks.append(clean_function_block(block))
                    found = True
                    break
        if found:
            # Skip any nested [ inside this block — move to next outer [
            continue
    return blocks

def evaluator_function(predicted_answer, sample):
    """
    Evaluate if the predicted function call matches the expected format and functionality.
    """

    try:
        decoded_output = ast_parse(predicted_answer.strip(), sample["language"])
    except Exception as e:
        # print(f"\033[94mPredicted answer:{predicted_answer}\033[0m")
        return {"is_correct": False, "error": "Failing to parse the predicted answer as an AST"}

    result = ast_checker(
        sample["function"],
        decoded_output,
        sample["reference_answer"],
        sample["language"],
        sample["test_category"],
        "gpt-4o"
    )
    score = 1 if result["valid"] else 0
    return {"is_correct": result["valid"], "score": score, "error": result["error"]}



In [27]:
import re

def extract_sql_query(text):
    # Match content inside ```sql ... ``` block
    match = re.search(r'```sql(.*?)```', text, re.DOTALL | re.IGNORECASE)
    if match:
        # Clean leading/trailing whitespace
        return match.group(1).strip()
    else:
        return None


def extract_sql_queries(text):
    """
    Extract all SQL queries from a text blob.
    Captures both fenced code blocks (```sql ... ```) and standalone statements ending with semicolons.
    """
    queries = []

    # 1) Fenced SQL blocks
    fenced = re.findall(r'```sql\s*(.*?)```', text, flags=re.IGNORECASE | re.DOTALL)
    queries.extend(q.strip() for q in fenced if q.strip())

    # 2) Standalone statements (SELECT/INSERT/UPDATE/DELETE/CREATE/ALTER/DROP) ending with ;
    #    Avoid re-capturing fenced blocks by skipping matches that span lines containing ```
    stmt_pattern = re.compile(
        r'(?:\b(?:SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP)\b.+?;)',
        flags=re.IGNORECASE | re.DOTALL
    )
    for m in stmt_pattern.finditer(text):
        snippet = m.group().strip()
        # skip if it's identical to a fenced block
        if not any(snippet == f.strip() or snippet in f for f in fenced):
            queries.append(snippet)
    # if not queries:
    #   queries.append(text)
    return queries

In [28]:
import re

def extract_all_function_blocks_and_names(code):
    """
    Extracts all top-level Python function blocks (def ...) along with any import
    statements that appear immediately before each function. Returns a list of tuples:
    [(full_code_with_imports, function_name), ...].
    """
    lines = code.strip().splitlines()
    n = len(lines)
    results = []
    import_lines = []
    i = 0

    while i < n:
        line = lines[i]

        # If this line is an import, collect it and move on
        if re.match(r'^\s*import\s+\w', line) or re.match(r'^\s*from\s+\w+\s+import\s+', line):
            import_lines.append(line)
            i += 1
            continue

        # If this line is a top‐level function definition
        func_match = re.match(r'^(\s*)def\s+([A-Za-z_]\w*)\s*\(.*\)\s*:', line)
        if func_match:
            func_indent = len(func_match.group(1))
            func_name = func_match.group(2)

            # Collect the entire function block
            func_block = [lines[i]]
            j = i + 1
            while j < n:
                next_line = lines[j]
                # Blank lines inside the block are allowed
                if next_line.strip() == "":
                    func_block.append(next_line)
                    j += 1
                    continue

                # Check indentation: if indent > func_indent, it's still inside
                indent_level = len(next_line) - len(next_line.lstrip())
                if indent_level > func_indent:
                    func_block.append(next_line)
                    j += 1
                else:
                    break

            # Combine the imports collected so far with this function block
            full_code = "\n".join(import_lines + [""] + func_block).rstrip()
            results.append((full_code, func_name))

            # Reset import_lines for the next function
            import_lines = []
            # Continue scanning from the line after this function block
            i = j
            continue

        # Neither an import nor a function definition: move on
        i += 1

    return results


In [29]:
import re

def extract_first_function_block_and_name(code):
    """
    Extracts the first top-level Python function (def ...) block and its name,
    along with any import statements above it.
    Returns the full function code (with imports) and function name.
    """
    lines = code.strip().splitlines()
    import_lines = []
    func_start_idx = None
    func_indent = None
    func_name = None

    # Collect top-level import statements before the function
    for i, line in enumerate(lines):
        if re.match(r'^\s*import\s+\w', line) or re.match(r'^\s*from\s+\w+\s+import\s+', line):
            import_lines.append(line)
        # elif re.match(r'^\s*def\s+[a-zA-Z_]\w*\s*\(.*\)\s*:', line):
        #     func_start_idx = i
        #     match = re.match(r'^(\s*)def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', line)
        #     func_indent = len(match.group(1))
        #     func_name = match.group(2)
        #     break
        elif re.match(r'^\s*def\s+[a-zA-Z_]\w*\s*\(.*\)\s*(->\s*[^\s:]+)?\s*:', line):
            func_start_idx = i
            match = re.match(r'^(\s*)def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(.*\)\s*(->\s*[^\s:]+)?\s*:', line)
            func_indent = len(match.group(1))
            func_name = match.group(2)
            break

    if func_start_idx is None:
        return None, None

    # Collect lines in the function block
    func_lines = lines[func_start_idx:]
    collected = [func_lines[0]]

    for line in func_lines[1:]:
        if line.strip() == "":
            collected.append(line)
        elif len(line) - len(line.lstrip()) >= func_indent + 1:
            collected.append(line)
        else:
            break

    return "\n".join(import_lines + [""] + collected).rstrip(), func_name


In [30]:
import json
import numpy as np
import sys
import tempfile
import subprocess
import importlib.util
import ast

def run_function_and_check(func_name, user_code, test_cases):
    with tempfile.NamedTemporaryFile(suffix=".py", delete=False, mode="w") as f:
        user_code_path = f.name
        f.write("import math\n")
        f.write(user_code)

    # print(user_code)
    runner_code = f"""
import json
import sys
import ast
import math
import importlib.util

spec = importlib.util.spec_from_file_location("tempmod", "{user_code_path}")
tempmod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(tempmod)

failures = []
def parse_argument(line):
    line = line.strip()
    if line.startswith('"') and line.endswith('"'):
        return line[1:-1]  # Strip outer double quotes, treat as string
    try:
        return ast.literal_eval(line)
    except:
        return line  # fallback

for idx, case in enumerate({test_cases}):
    input_lines = case["input"].splitlines()
    args = [parse_argument(line) for line in input_lines]

    try:
        expected = ast.literal_eval(case["output"])
    except:
        expected = case["output"]

    if expected == "true" or expected == "Yes" or expected == "yes":
        expected = True

    if expected == "false" or expected == "No" or expected == "no":
        expected = False

    try:
        got = getattr(tempmod, "{func_name}")(*args)
    except Exception as e:
        failures.append(f"{{args}} raised {{e!r}}")
        continue

    if got == "Yes" or got == "yes":
      got = True
    if got == "No" or got == "no":
      got = False

    if got and got != expected:
        args = args.reverse()
        try:
          got = getattr(tempmod, "{func_name}")(*args)
        except Exception as e:
          failures.append(f"{{args}} raised {{e!r}}")
          continue

        if got == "Yes" or got == "yes":
          got = True
        if got == "No" or got == "no":
          got = False

        if got != expected:
          failures.append(f"Test {{idx}} :- {{args}}: got={{got!r}}, expected={{expected!r}}")

if failures:
    print(json.dumps({{"ok": False, "errors": failures}}))
    sys.exit(1)
else:
    print(json.dumps({{"ok": True}}))
    sys.exit(0)
"""

    with tempfile.NamedTemporaryFile(suffix=".py", delete=False, mode="w") as f:
        runner_path = f.name
        f.write(runner_code)

    try:
      result = subprocess.run(
          [sys.executable, runner_path],
          capture_output=True,
          text=True,
          timeout=3
      )
    except:
      print("TIMEOUT ERROR")
      return False

    # print("STDOUT:", result.stdout.strip())
    # print("STDERR:", result.stderr.strip())

    if result.returncode == 0:
        return True
    else:
        return False

In [31]:
import re

def L_extract_sections(text):
    pattern = r"<\|start_header_id\|>assistant<\|end_header_id\|>(.*?)<\|eot_id\|>"
    return re.findall(pattern, text, flags=re.DOTALL)

def PHI_extract_sections(text):
    pattern = r"<\|im_start\|>assistant<\|im_sep\|>(.*?)<\|im_end\|>"
    return re.findall(pattern, text, flags=re.DOTALL)

# Coding Evaluation

In [152]:
import numpy as np
def evaluate_human_eval(data_path, dataset, numruns, numQ):
    """
    For each line in the HumanEval JSON, generate code, check it,
    add 'correct' key (True/False), and write to a new JSON file with _CORRECTED suffix.
    """

    finalcorr = []
    for i in range(numruns):
        correct = []
        res = []
        with open(data_path + f"_run{i}.json", "r") as f:
            output = json.load(f)
        for x in range(numQ):
            final_output = output[x]['final_output']
            function, func_name = extract_first_function_block_and_name(final_output)
            test_cases = ast.literal_eval(dataset[x]["public_test_cases"])

            if function and func_name:
                passed = run_function_and_check(func_name, function, test_cases)
                if passed:
                    correct.append(1)
                    output[x]['correct'] = True
            else:
              passed = False
            if not passed:
                for function in extract_all_function_blocks_and_names(output[x]['chat_history']):
                    passed = run_function_and_check(function[1], function[0], test_cases)
                    if passed:
                        correct.append(1)
                        output[x]['correct'] = True
                        break
                if not passed:
                    correct.append(0)
                    output[x]['correct'] = False


            res.append(output[x]["resets"])
        with open(data_path + f"_run{i}_CORRECTED.json", "w") as f_out:
            json.dump(output, f_out, indent=2)

        finalcorr.append(correct)

    sums = 0
    for i in range(len(finalcorr)):
      print(f"Run {i} Score: {np.sum(finalcorr[i])/numQ}")
      sums += np.sum(finalcorr[i]) / numQ

    print(f"Average {sums / len(finalcorr)}")

    return finalcorr


In [None]:
import numpy as np
def GPT_evaluate_human_eval(data_path, dataset, numruns, numQ):
    """
    For each line in the HumanEval JSON, generate code, check it,
    add 'correct' key (True/False), and write to a new JSON file with _CORRECTED suffix.
    """

    finalcorr = []
    for i in range(numruns):
        correct = []
        res = []
        with open(data_path + f"_run{i}.json", "r") as f:
            output = json.load(f)
        for x in range(numQ):
            final_output = output[x]['final_output']
            function, func_name = extract_first_function_block_and_name(final_output)
            test_cases = ast.literal_eval(dataset[x]["public_test_cases"])

            if function and func_name:
                passed = run_function_and_check(func_name, function, test_cases)
                if passed:
                    correct.append(1)
                    output[x]['correct'] = True
            else:
              passed = False

            if not passed:
                if output[x]["resets"] > 0:
                  for function in extract_all_function_blocks_and_names(output[x]['chat_history']):
                      passed = run_function_and_check(function[1], function[0], test_cases)
                      if passed:
                          correct.append(1)
                          output[x]['correct'] = True
                          break
                  if not passed:
                      correct.append(0)
                      output[x]['correct'] = False
                else:
                  for entry in output[x]['chat_history']:
                      if entry["role"] == "assistant":
                        function, func_name = extract_first_function_block_and_name(entry['content'])
                        if function and func_name:
                          passed = run_function_and_check(func_name, function, test_cases)
                          if passed:
                              correct.append(1)
                              output[x]['correct'] = True
                              break
                  if not passed:
                      correct.append(0)
                      output[x]['correct'] = False

            res.append(output[x]["resets"])
        with open(data_path + f"_run{i}_CORRECTED.json", "w") as f_out:
            json.dump(output, f_out, indent=2)

        finalcorr.append(correct)

    sums = 0
    for i in range(len(finalcorr)):
      print(f"Run {i} Score: {np.sum(finalcorr[i])/numQ}")
      sums += np.sum(finalcorr[i]) / numQ

    print(f"Average {sums / len(finalcorr)}")

    return finalcorr


# GSM8K Evaluation

In [None]:
import json
import re

def GPT_Eval_GSM8K(numruns, data_path, numQ=50, dataset=GSM8K):
    correct = []
    for x in range(numruns):
        corr = []
        with open(data_path + f"_run{x}.json", 'r') as f:
            NRoutput = json.load(f)

        for i in range(numQ):
            answer = re.findall(r'####\s*(.*)', GSM8K[i]['answer'])
            if NRoutput[i]['final_output']:
                model_answer = re.findall(r'<Answer>\s*(.*?)(?:</Answer>|$)', NRoutput[i]['final_output'], re.DOTALL)
            else:
                model_answer = []

            is_correct = False
            if answer and model_answer:
                model_answer[0] = model_answer[0].replace(",", "")
                answer[0] = answer[0].replace(",", "")
                if answer[0] in model_answer[0]:
                    is_correct = True
            if is_correct == False:
              model_answer = []
              if NRoutput[i]['resets'] == 0:
                for msg in NRoutput[i]['chat_history']:
                  if msg['role'] == 'assistant':
                    matches = re.findall(r'<Answer>(.*?)(?:<\/Answer>|$)', msg['content'], flags=re.DOTALL)
                    model_answer.extend([m.strip() for m in matches])
                    for ans in model_answer:
                      ans = ans.replace(",", "")
                      if answer[0] in ans:
                        print("Now correct: ")


            corr.append(1 if is_correct else 0)
            NRoutput[i]['correct'] = is_correct

        with open(data_path + f"_run{x}_CORRECTED.json", "w") as f_out:
            json.dump(NRoutput, f_out, indent=2)

        correct.append(corr)

    sums = 0
    for i in range(len(correct)):
      print(f"Run {i} Score: {np.sum(correct[i])/numQ}")
      sums += np.sum(correct[i]) / numQ

    print(f"Average {sums / len(correct)}")

In [None]:
import json
import re

def Eval_GSM8K(numruns, data_path, numQ=50, dataset=GSM8K):
    correct = []
    for x in range(numruns):
        corr = []
        with open(data_path + f"_run{x}.json", 'r') as f:
            NRoutput = json.load(f)

        for i in range(numQ):
            answer = re.findall(r'####\s*(.*)', GSM8K[i]['answer'])
            if NRoutput[i]['final_output']:
                model_answer = re.findall(r'<Answer>\s*(.*?)(?:</Answer>|$)', NRoutput[i]['final_output'], re.DOTALL)
            else:
                model_answer = []

            is_correct = False
            if answer and model_answer:
                model_answer[0] = model_answer[0].replace(",", "")
                answer[0] = answer[0].replace(",", "")
                if answer[0] in model_answer[0]:
                    is_correct = True
            if is_correct == False:
              model_answer = []
              assistant_blocks = re.findall(r'<\|start_header_id\|>assistant<\|end_header_id\|>(.*?)(?=<\|eot_id\|>)', NRoutput[i]['chat_history'], flags=re.DOTALL)
              if not assistant_blocks:
                assistant_blocks = re.findall(r'<\|im_start\|>assistant<\|im_sep\|>(.*?)(?=<\|im_end\|>)', NRoutput[i]['chat_history'], flags=re.DOTALL)
              for block in assistant_blocks:
                matches = re.findall(r'<Answer>(.*?)(?:<\/Answer>|$)', block, flags=re.DOTALL)
                model_answer.extend([m.strip() for m in matches])

              # model_answer = re.findall(r'<\|start_header_id\|>assistant<\|end_header_id\|>.*?(?:<Answer>(.*?)<\/Answer>|<Answer>(.*?)$)', NRoutput[i]['chat_history'], re.DOTALL)
              for ans in model_answer:
                ans = ans.replace(",", "")
                if answer[0] in ans:
                    # print(f"Now deeming correct: {ans}\n\nACTUAL:{answer[0]}\n\n")
                    is_correct = True

            corr.append(1 if is_correct else 0)
            NRoutput[i]['correct'] = is_correct

        # Write corrected output
        with open(data_path + f"_run{x}_CORRECTED.json", "w") as f_out:
            json.dump(NRoutput, f_out, indent=2)

        correct.append(corr)

    sums = 0
    for i in range(len(correct)):
      print(f"Run {i} Score: {np.sum(correct[i])/numQ}")
      sums += np.sum(correct[i]) / numQ

    print(f"Average {sums / len(correct)}")


# API Evaluation

In [40]:
import json

def GPT_API_eval(numruns, numQ, dataset, data_path):

    avg = []
    for i in range(numruns):
        correct_count = 0
        with open(data_path + f"_run{i}.json") as f:
            output = json.load(f)

        for x in range(numQ):
            raw_model_output = output[x]['final_output']
            model_answer = re.findall(r'<Answer>\s*(.*?)(?:</Answer>|$)', raw_model_output, re.DOTALL)
            if model_answer:
              raw_model_output = model_answer[0]
            mod_ans = extract_function_block(raw_model_output)
            mod_ans = clean_function_block(mod_ans)
            corr = evaluator_function(predicted_answer=mod_ans, sample=dataset[x])

            if "Failing to parse the predicted answer as an AST" in corr["error"]:
              mod_ans = extract_function_block(f"[{raw_model_output}]")
              mod_ans = clean_function_block(mod_ans)
              corr = evaluator_function(predicted_answer=mod_ans, sample=dataset[x])

            is_correct = False
            if corr["is_correct"]:
                correct_count += 1
                is_correct = True
            else:
                if output[x]['resets'] == 0:
                  for entry in output[x]['chat_history']:
                    if entry["role"] == "assistant":
                      for block in extract_all_function_blocks(entry['content']):
                          corr = evaluator_function(predicted_answer=block, sample=dataset[x])
                          if corr["is_correct"]:
                              correct_count += 1
                              is_correct = True
                              break
                      if is_correct:
                        break
                else:
                  for block in extract_all_function_blocks(output[x]['chat_history']):
                    corr = evaluator_function(predicted_answer=block, sample=dataset[x])
                    if corr["is_correct"]:
                        correct_count += 1
                        is_correct = True
                        break

            output[x]['correct'] = is_correct


        with open(data_path + f"_run{i}_CORRECTED.json", "w") as f_out:
            json.dump(output, f_out, indent=2)

        avg.append(correct_count / numQ)
        print(f"Run {i}: {correct_count/numQ}")

    print(f"Average Accuracy: {np.mean(avg)}")

In [99]:
import json

def API_eval(numruns, numQ, dataset, data_path):

    avg = []
    for i in range(numruns):
        correct_count = 0
        with open(data_path + f"_run{i}.json") as f:
            output = json.load(f)

        for x in range(numQ):
            raw_model_output = output[x]['final_output']
            model_answer = re.findall(r'<Answer>\s*(.*?)(?:</Answer>|$)', raw_model_output, re.DOTALL)
            if model_answer:
              raw_model_output = model_answer[0]
            mod_ans = extract_function_block(raw_model_output)
            mod_ans = clean_function_block(mod_ans)
            corr = evaluator_function(predicted_answer=mod_ans, sample=dataset[x])

            if "Failing to parse the predicted answer as an AST" in corr["error"]:
              mod_ans = extract_function_block(f"[{raw_model_output}]")
              mod_ans = clean_function_block(mod_ans)
              corr = evaluator_function(predicted_answer=mod_ans, sample=dataset[x])

            is_correct = False
            if corr["is_correct"]:
                correct_count += 1
                is_correct = True
            else:
                for block in extract_all_function_blocks(output[x]['chat_history']):
                    corr = evaluator_function(predicted_answer=block, sample=dataset[x])
                    if corr["is_correct"]:
                        correct_count += 1
                        is_correct = True
                        break

            output[x]['correct'] = is_correct


        with open(data_path + f"_run{i}_CORRECTED.json", "w") as f_out:
            json.dump(output, f_out, indent=2)

        avg.append(correct_count / numQ)
        print(f"Run {i}: {correct_count/numQ}")

    print(f"Average Accuracy: {np.mean(avg)}")


# Database Evaluation

In [None]:
import os
import json
import sqlite3
import numpy as np
from glob import glob
from tqdm import tqdm


def DB_eval(dataset, numruns, numQ, data_path):
  DB_FOLDER    = '/content/drive/My Drive/Context_reset/DATABASE/DATABASE_PHI-4/spider/database/'                  # folder containing *.sqlite
  avg = []
  def run_query(db_path, sql):
      conn = sqlite3.connect(db_path)
      conn.text_factory = lambda b: b.decode('utf-8', errors='replace')
      cur  = conn.cursor()
      try:
          cur.execute(sql)
          rows = cur.fetchall()
      finally:
          conn.close()
      return rows

  for i in range(numruns):
    with open(data_path + f"_run{i}.json", 'r') as f:
      entries = json.load(f)
    correct = 0
    res = []
    for e in range(numQ):
        mismatches = False
        db_id     = dataset[e]['db_id']
        ref_sql   = dataset[e]['reference_sql']

        for llm_sql in extract_sql_queries(entries[e]['final_output']):

          db_path = os.path.join(DB_FOLDER, f"{db_id}/{db_id}.sqlite")

          ref_res = run_query(db_path, ref_sql)
          try:
              llm_res = run_query(db_path, llm_sql)
          except Exception as ex:
              mismatches = True
              continue

          if ref_res != llm_res:
              mismatches = True
          else:
              mismatches = False
              break

        if mismatches == True:
          mismatches = False
          for llm_sql in extract_sql_queries(entries[e]['chat_history']):
            try:
              llm_res = run_query(db_path, llm_sql)
            except Exception as ex:
              mismatches = True
              continue

            if ref_res != llm_res:
              mismatches = True

            if mismatches == False:
              break

        if mismatches == False:
          correct += 1
          entries[e]['correct'] = True
        else:
          entries[e]['correct'] = False

        res.append(entries[e]['resets'])

    avg.append(correct / numQ)
    print(f"Accuracy = {correct / numQ}")

    with open(data_path + f"_run{i}_CORRECTED.json", "w") as f_out:
      json.dump(entries, f_out, indent=2)

  print(f"Average = {np.mean(avg)}")

In [None]:
import os
import json
import sqlite3
import numpy as np
from glob import glob
from tqdm import tqdm


def GPT_DB_eval(dataset, numruns, numQ, data_path):
  DB_FOLDER    = '/content/drive/My Drive/Context_reset/DATABASE/DATABASE_PHI-4/spider/database/'
  avg = []
  def run_query(db_path, sql):
      conn = sqlite3.connect(db_path)
      # decode any bytes with replacement on errors
      conn.text_factory = lambda b: b.decode('utf-8', errors='replace')
      cur  = conn.cursor()
      try:
          cur.execute(sql)
          rows = cur.fetchall()
      finally:
          conn.close()
      return rows

  for i in range(numruns):
    with open(data_path + f"_run{i}.json", 'r') as f:
      entries = json.load(f)
    correct = 0
    res = []
    for e in range(numQ):
        mismatches = False
        db_id     = dataset[e]['db_id']
        ref_sql   = dataset[e]['reference_sql']

        for llm_sql in extract_sql_queries(entries[e]['final_output']):

          db_path = os.path.join(DB_FOLDER, f"{db_id}/{db_id}.sqlite")

          ref_res = run_query(db_path, ref_sql)
          try:
              llm_res = run_query(db_path, llm_sql)
          except Exception as ex:
              mismatches = True
              continue

          if ref_res != llm_res:
              mismatches = True
          else:
              mismatches = False
              break

        if mismatches == True:
          mismatches = False

          if entries[e]["resets"] == 0:
            for entry in entries[e]['chat_history']:
              if entry["role"] == "assistant":
                for llm_sql in extract_sql_queries(entry['content']):
                  try:
                    llm_res = run_query(db_path, llm_sql)
                  except Exception as ex:
                    mismatches = True
                    continue

                  if ref_res != llm_res:
                    mismatches = True

                  if mismatches == False:
                    break
              if mismatches == False:
                break
          else:
            mismatches = False
            for llm_sql in extract_sql_queries(entries[e]['chat_history']):
              try:
                llm_res = run_query(db_path, llm_sql)
              except Exception as ex:
                mismatches = True
                continue

              if ref_res != llm_res:
                print(ref_res, "\n", llm_res, "\n\n")
                mismatches = True

              if mismatches == False:
                break

        if mismatches == False:
          correct += 1
          entries[e]['correct'] = True
        else:
          entries[e]['correct'] = False

        res.append(entries[e]['resets'])

    avg.append(correct / numQ)
    print(f"Accuracy = {correct / numQ}")

    with open(data_path + f"_run{i}_CORRECTED.json", "w") as f_out:
      json.dump(entries, f_out, indent=2)

  print(f"Average = {np.mean(avg)}")

# D2T Evaluation

In [86]:
def D2T_Eval(numruns, numQ, data_path, dataset):
  scores = []
  for x in range(numruns):
      score = []
      with open(data_path + f"_run{x}.json", 'r') as f:
          outputs = json.load(f)

      for i in range(numQ):
        curr_score = D2T_evaluator_function(outputs[i]['final_output'], dataset[i])

        score.append(curr_score)

        outputs[i]["score"] = curr_score
      with open(data_path + f"_run{x}_CORRECTED.json", "w") as f_out:
          json.dump(outputs, f_out, indent=2)

      scores.append(score)

  for score in scores:
    print(f"Average: {round(np.mean(score),2)}")

  return scores

# Calculating Aptitude and Unreliability

In [74]:
def get_scores(modelname, numruns, numQ):
    datasets = ['CODING', 'GSM8K', 'API', 'DB', 'D2T']
    correct = {name: [] for name in datasets}
    resetNum = 0
    
    file_configs = {
        'LLAMA': {
            'CODING': f"CODING/CODING_{modelname.upper()}/CODING_0.03_run",
            'GSM8K': f"MATH/MATH_{modelname.upper()}/GSM8K_0.03_run",
            'API': f"APIs/APIs_{modelname.upper()}/API_0.03_run",
            'DB': f"DATABASE/DATABASE_{modelname.upper()}/DB_0.03_run",
            'D2T': f"D2T/D2T_{modelname.upper()}/D2T_0.03_run"
        }
    }
    
    default_config = {
        'CODING': f"CODING/CODING_{modelname}/Code_run",
        'GSM8K': f"MATH/MATH_{modelname}/GSM8K_run",
        'API': f"APIs/APIs_{modelname}/API_run",
        'DB': f"DATABASE/DATABASE_{modelname}/DB_run",
        'D2T': f"D2T/D2T_{modelname}/D2T_run"
    }
    
    config = file_configs.get(modelname, default_config)
    
    for run in range(numruns):
        for name in datasets:
            file_path = f"/content/drive/My Drive/Context_reset/{config[name]}{run}_CORRECTED.json"
            
            with open(file_path) as f:
                data = json.load(f)[:50 if name == 'D2T' else numQ]
            
            temp_arr = []
            for entry in data:
                resetNum += entry['resets']
                score = entry['score'] if name == 'D2T' else (1 if entry['correct'] else 0)
                temp_arr.append(score if score else 0)
            
            correct[name].append(np.array(temp_arr).T)
    
    return correct, resetNum

## Aptitude & Reliability

In [98]:
def calculate_aptitude_and_unreliability(runs):
    """
    Calculate aptitude and unreliability for each dataset, then average across datasets.
    Pass 'correct' array returned by 'get_scores' function
    """
    
    dataset_aptitudes = []
    dataset_unreliabilities = []
    detailed_results = {}
    
    for dataset_name, dataset_runs in runs.items():
        runs_array = np.array(dataset_runs)
        runs_transposed = runs_array.T
        
        question_aptitudes = np.percentile(runs_transposed, 90, axis=1)
        
        question_unreliabilities = (np.percentile(runs_transposed, 90, axis=1) - 
                                  np.percentile(runs_transposed, 10, axis=1))
        
        dataset_aptitude = np.mean(question_aptitudes)
        dataset_unreliability = np.mean(question_unreliabilities)
        
        dataset_aptitudes.append(dataset_aptitude)
        dataset_unreliabilities.append(dataset_unreliability)
    
    overall_aptitude = np.mean(dataset_aptitudes)
    overall_unreliability = np.mean(dataset_unreliabilities)
    
    print("OVERALL RESULTS:")
    print(f"Average Aptitude: {overall_aptitude:.4f}")
    print(f"Average Unreliability: {overall_unreliability:.4f}")
    
    return overall_aptitude, overall_unreliability, detailed_results