In [1]:
import argparse
import glob
import io
import multiprocessing as mp
import os
import copy
import sqlite3
import sys
from contextlib import redirect_stderr, suppress

import pandas as pd
import tqdm.autonotebook as tqdm
from func_timeout import FunctionTimedOut, func_timeout
from tabulate import tabulate

from utils import read_jsonl_file, read_tsv_file

  import tqdm.autonotebook as tqdm


In [2]:
def execute_sql(datum):
    conn = sqlite3.connect(datum["db_path"])
    cursor = conn.cursor()
    cursor.execute(datum["predicted_sql"])
    predicted_res = set(cursor.fetchall())
    cursor.execute(datum["SQL"])
    ground_truth_res = set(cursor.fetchall())
    if len(predicted_res) == 0:
        return "pass: incorrect-empty"
    elif predicted_res == ground_truth_res:
        return "pass: correct"
    else:
        return "pass: incorrect"


def execute_sql_robust(args, datum, i):
    stderr = io.StringIO()
    with suppress(RuntimeError, ReferenceError), redirect_stderr(stderr):
        try:
            evaluation = func_timeout(
                args.meta_time_out, execute_sql, args=(datum,)
            )
        except KeyboardInterrupt:
            sys.exit(0)
        except FunctionTimedOut:
            evaluation = "error: timeout"
        except Exception as e:
            import code; code.interact(local=locals())
            evaluation = "error: <error>"
    return (i, evaluation)


def evaluate(args, data):
    exec_results = []

    def result_callback(result):
        exec_results.append(result)
        pbar.update(1)

    pbar = tqdm.tqdm(total=len(data), desc="Processing", unit="item")
    if args.num_cpus > 1:
        pool = mp.Pool(processes=args.num_cpus)
        for i, datum in enumerate(data):
            if len(datum.get("evaluation", "")) > 0:
                pbar.update(1)
                continue
            pool.apply_async(execute_sql_robust, args=(args, datum, i), callback=result_callback)
        pool.close()
        pool.join()
    else:    
        for i, datum in enumerate(data):
            if len(datum.get("evaluation", "")) > 0:
                pbar.update(1)
                continue
            result = execute_sql_robust(args, datum, i)
            result_callback(result)
    pbar.close()

    for i, evaluation in exec_results:
        data[i]["evaluation"] = evaluation

    data_eval = pd.DataFrame.from_records(
        [
            {
                "difficulty": datum["difficulty"],
                "evaluation": datum["evaluation"],
                "db_id": datum["db_id"],
            }
            for datum in data
        ]
    )
    data_eval["if_correct"] = data_eval["evaluation"] == "pass: correct"
    print("Overall Accuracy: ", data_eval["if_correct"].mean())
    results = (
        data_eval.groupby(["db_id"])
        .agg(
            correct=("if_correct", "sum"),
            total=("if_correct", "count"),
        )
        .reset_index()
    )
    results["accuracy"] = results["correct"] / results["total"]
    results.sort_values("db_id", ascending=False, inplace=True)
    print(tabulate(results, headers="keys", tablefmt="fancy_grid"))
    results = (
        data_eval.groupby(["difficulty"])
        .agg(
            correct=("if_correct", "sum"),
            total=("if_correct", "count"),
        )
        .reset_index()
    )
    results["accuracy"] = results["correct"] / results["total"]
    results.sort_values("difficulty", ascending=False, inplace=True)
    print(tabulate(results, headers="keys", tablefmt="fancy_grid"))
    return data


def select(data, logprob_alpha=0.5):
    for datum in data:
        best_sql, best_score = "dummy", float("-inf")
        for sql in datum["responses"]:
            try:
                logprob = max(datum["responses"][sql]["all_logprobs"])
                reward = datum["responses"][sql]["reward"]
                if reward is None:
                    reward = float("-inf")
                score = logprob_alpha * logprob + (1.0 - logprob_alpha) * reward
                if score > best_score:
                    best_score = score
                    best_sql = sql
            except:
                import code; code.interact(local=locals())
        datum["predicted_sql"] = best_sql

In [3]:
datas = [
    read_jsonl_file("../output/with_rewards/data_reward_all.jsonl"),
    read_jsonl_file("../output_seed314159/with_rewards/data_reward_all.jsonl"),
    read_jsonl_file("../output_seed8675309/with_rewards/data_reward_all.jsonl"),
]

In [37]:
# datas[0][204]["SELECT COUNT(molecule_id) FROM molecule WHERE molecule_id LIKE 'TR0__' AND label = '+'"]
# datas[1][204]["SELECT COUNT(molecule_id) FROM molecule WHERE molecule_id LIKE 'TR0__' AND label = '+'"]

KeyError: "SELECT COUNT(molecule_id) FROM molecule WHERE molecule_id LIKE 'TR0__' AND label = '+'"

In [4]:
data = []
for i in tqdm.trange(len(datas[0])):
    datum = copy.deepcopy(datas[0][i])

    # evaluations = {}
    # for response, response_data in datum['responses'].items():
    #     evaluations[response] = [response_data['evaluation']]

    for j in range(1, len(datas)):
        assert datas[j][i]['question_id'] == datum['question_id']
        for response, response_data in datas[j][i]['responses'].items():
            if response in datum['responses']:
                datum['responses'][response]['sources'].extend(copy.deepcopy(response_data['sources']))
                datum['responses'][response]['all_logprobs'].extend(copy.deepcopy(response_data['all_logprobs']))
                # evaluations[response].append(response_data['evaluation'])
            else:
                datum['responses'][response] = copy.deepcopy(response_data)
    
    # for response in datum['responses'].keys():
    #     datum['responses'][response]['evaluation'] = evaluations[response][0]

    data.append(datum)


  0%|          | 0/1534 [00:00<?, ?it/s]

100%|██████████| 1534/1534 [00:10<00:00, 141.51it/s]


In [5]:
select(data, logprob_alpha=0.4)

In [6]:
# Create a dummy args object
class Args:
    def __init__(self):
        self.rewards_dir = ""
        self.gt_sql_file = ""
        self.output_dir = ""
        self.num_cpus = 40
        self.meta_time_out = 30.0

args = Args()
evaluate(args, data)

Processing:   0%|          | 0/1534 [00:00<?, ?item/s]

KeyboardInterrupt: 