In [31]:
import json

# Read the JSON file
with open('../llm/data/mini_dev_refs.json', 'r') as f:
    data = json.load(f)

# Function to check if we should keep the row
def should_keep(item):
    return all('<error' not in tag and "<warning" not in tag for tag in item['sql_refs_annotated'])

# Generate the filter list
filter_list = ['TRUE' if should_keep(item) else 'FALSE' for item in data]

# Write the filter list to a file
with open('../llm/data/mini_dev_filter.txt', 'w') as f:
    for item in filter_list:
        f.write(f"{item}\n")

print(f"Total items: {len(filter_list)}")
print(f"Items without error: {filter_list.count('TRUE')}")
print(f"Items with error: {filter_list.count('FALSE')}")
print("Filter file '../llm/data/mini_dev_filter.txt' has been created successfully.")

Total items: 500
Items without error: 347
Items with error: 153
Filter file '../llm/data/mini_dev_filter.txt' has been created successfully.


In [55]:

import random
from collections import defaultdict

# pivot the data
pivoted_data = defaultdict(lambda: defaultdict(list))
for item in data:
    if should_keep(item):
        pivoted_data[item["db_id"]][item["difficulty"]].append(item)


sample_counts = defaultdict(dict)
sample_counts["california_schools"]["simple"] = 2
sample_counts["california_schools"]["moderate"] = 3
sample_counts["california_schools"]["challenging"] = 1
sample_counts["card_games"]["simple"] = 3
sample_counts["card_games"]["moderate"] = 7
sample_counts["card_games"]["challenging"] = 1
sample_counts["codebase_community"]["simple"] = 4
sample_counts["codebase_community"]["moderate"] = 5
sample_counts["codebase_community"]["challenging"] = 1
sample_counts["debit_card_specializing"]["simple"] = 3
sample_counts["debit_card_specializing"]["moderate"] = 2
sample_counts["debit_card_specializing"]["challenging"] = 1
sample_counts["european_football_2"]["simple"] = 3
sample_counts["european_football_2"]["moderate"] = 5
sample_counts["european_football_2"]["challenging"] = 2
sample_counts["financial"]["simple"] = 1
sample_counts["financial"]["moderate"] = 4
sample_counts["financial"]["challenging"] = 2
sample_counts["formula_1"]["simple"] = 6
sample_counts["formula_1"]["moderate"] = 5
sample_counts["formula_1"]["challenging"] = 2
sample_counts["student_club"]["simple"] = 4
sample_counts["student_club"]["moderate"] = 4
sample_counts["student_club"]["challenging"] = 1
sample_counts["superhero"]["simple"] = 3
sample_counts["superhero"]["moderate"] = 5
sample_counts["superhero"]["challenging"] = 2
sample_counts["thrombosis_prediction"]["simple"] = 2
sample_counts["thrombosis_prediction"]["moderate"] = 5
sample_counts["thrombosis_prediction"]["challenging"] = 3
sample_counts["toxicology"]["simple"] = 1
sample_counts["toxicology"]["moderate"] = 3
sample_counts["toxicology"]["challenging"] = 4

micro_dev_data = []

for db_id, difficulties in sample_counts.items():
    for difficulty, sample_count in difficulties.items():
        # Seed of 8 is used because it's fairly close to the score of TA-gpt-4o with the full non-error non-warning set
        sample_set = random.Random(8).sample(pivoted_data[db_id][difficulty], k=sample_count)
        micro_dev_data.extend(sample_set)

micro_dev_data = sorted(micro_dev_data, key=lambda x: x["question_id"])
micro_dev_question_ids = {item["question_id"] for item in micro_dev_data}
len(micro_dev_question_ids)


100

In [61]:
# Generate the filter list
question_ids_already_added = set()
filter_list = []
for item in data:
    if item["question_id"] in question_ids_already_added:
        filter_list.append('FALSE')
    elif item["question_id"] in micro_dev_question_ids:
        filter_list.append('TRUE')
    else:
        filter_list.append('FALSE')
    question_ids_already_added.add(item["question_id"])

# Write the filter list to a file
with open('../llm/data/mini_dev_filter.txt', 'w') as f:
    for item in filter_list:
        f.write(f"{item}\n")

print(f"Total items: {len(filter_list)}")
print(f"Items without error: {filter_list.count('TRUE')}")
print(f"Items with error: {filter_list.count('FALSE')}")
print("Filter file '../llm/data/mini_dev_filter.txt' has been created successfully.")

print("\nmicro_dev_question_ids: " + json.dumps(sorted(list(micro_dev_question_ids)), indent=2))

Total items: 500
Items without error: 100
Items with error: 400
Filter file '../llm/data/mini_dev_filter.txt' has been created successfully.

micro_dev_question_ids: [
  36,
  40,
  41,
  47,
  48,
  50,
  98,
  99,
  115,
  116,
  128,
  136,
  159,
  195,
  218,
  220,
  226,
  230,
  236,
  242,
  263,
  345,
  346,
  379,
  397,
  409,
  412,
  414,
  415,
  440,
  465,
  466,
  537,
  539,
  555,
  557,
  565,
  573,
  578,
  581,
  586,
  665,
  732,
  736,
  739,
  745,
  750,
  758,
  761,
  764,
  769,
  773,
  850,
  859,
  868,
  869,
  872,
  877,
  892,
  896,
  902,
  906,
  909,
  910,
  954,
  1037,
  1048,
  1057,
  1076,
  1079,
  1080,
  1088,
  1092,
  1102,
  1103,
  1155,
  1157,
  1164,
  1179,
  1189,
  1192,
  1195,
  1201,
  1281,
  1302,
  1340,
  1346,
  1357,
  1359,
  1376,
  1378,
  1380,
  1381,
  1399,
  1473,
  1476,
  1480,
  1493,
  1507,
  1515
]
