In [56]:
!pip install mlflow





In [57]:
import pickle
import pickle5 as p
import pandas as pd
from matplotlib import *
from matplotlib import pyplot as plt
import numpy as np
import mlflow
import os

In [58]:
def loadDictionaryFromPickleFile(dictionaryPath):
    ''' Load the pickle file as a dictionary
    Args:
        dictionaryPath: path to the pickle file
    Return: dictionary from the pickle file
    '''
    filePointer=open(dictionaryPath, 'rb')
    dictionary = p.load(filePointer)
    filePointer.close()
    return dictionary

def loadDictionaryFromPickleFileList(dictionaryPath):
    ''' Load the pickle file as a dictionary
    Args:
        dictionaryPath: path to the pickle file
    Return: dictionary from the pickle file
    '''
    filePointer=open(dictionaryPath, 'rb')
    dictionary = p.load(filePointer)
    filePointer.close()
    actual_dict = {}
    for curr_dict in dictionary:
        actual_dict[curr_dict['query_table']] = curr_dict['result_set']
    return actual_dict

def saveDictionaryAsPickleFile(dictionary, dictionaryPath):
    ''' Save dictionary as a pickle file
    Args:
        dictionary to be saved
        dictionaryPath: filepath to which the dictionary will be saved
    '''
    filePointer=open(dictionaryPath, 'wb')
    pickle.dump(dictionary,filePointer, protocol=pickle.HIGHEST_PROTOCOL)
    filePointer.close()


def calcMetrics(max_k, k_range, gtPath=None, resPath=None, record=True):
    ''' Calculate and log the performance metrics: MAP, Precision@k, Recall@k
    Args:
        max_k: the maximum K value (e.g. for SANTOS benchmark, max_k = 10. For TUS benchmark, max_k = 60)
        k_range: step size for the K's up to max_k
        gtPath: file path to the groundtruth
        resPath: file path to the raw results from the model
        record (boolean): to log in MLFlow or not
    Return: MAP, P@K, R@K
    '''
    groundtruth = loadDictionaryFromPickleFile(gtPath)
    resultFile = loadDictionaryFromPickleFile(resPath)
    fullresultFile = resultFile
    if type(resultFile) == list:
        resultFile = loadDictionaryFromPickleFileList(resPath)
    # =============================================================================
    # Precision and recall
    # =============================================================================
    precision_array = []
    recall_array = []
    final_results = []
    for k in range(1, max_k+1):
        true_positive = 0
        false_positive = 0
        false_negative = 0
        rec = 0
        ideal_recall = []
        index_counter = 0
        for table in resultFile:
            # t28 tables have less than 60 results. So, skipping them in the analysis.
            if table.split("____",1)[0] != "t_28dc8f7610402ea7": 
                if table in groundtruth:
                    groundtruth_set = set(groundtruth[table])
                    groundtruth_set = {x.split(".")[0] for x in groundtruth_set}
                    result_set = resultFile[table][:k]
                    result_set = [x.split(".")[0] for x in result_set]
                    # find_intersection = true positives
                    #if len(resultFile[table]) > 10 and k==10:
                    #    print("LEN RESULT SET", len(resultFile[table]))
                    #    print(fullresultFile[index_counter]['confidence_set'])
                    find_intersection = set(result_set).intersection(groundtruth_set)
                    curr_result_dict = {}
                    curr_result_dict["groundtruth_set"] = groundtruth_set
                    curr_result_dict["result_set"] = result_set
                    curr_result_dict["intersection"] = find_intersection
                    final_results.append(curr_result_dict)
                    tp = len(find_intersection)
                    fp = k - tp
                    fn = len(groundtruth_set) - tp
                    if len(groundtruth_set)>=k: 
                        true_positive += tp
                        false_positive += fp
                        false_negative += fn
                    rec += tp / (tp+fn)
                    ideal_recall.append(k/len(groundtruth[table]))
            index_counter += 1
        precision = true_positive / (true_positive + false_positive)
        recall = rec/len(resultFile)
        precision_array.append(precision)
        recall_array.append(recall)
#         if k % 10 == 0:
        print(k, "IDEAL RECALL:", sum(ideal_recall)/len(ideal_recall))
    used_k = [k_range]
    if max_k >k_range:
        for i in range(k_range * 2, max_k+1, k_range):
            used_k.append(i)
    print("--------------------------")
    for k in used_k:
        print("Precision at k = ",k,"=", precision_array[k-1])
        print("Recall at k = ",k,"=", recall_array[k-1])
        print("--------------------------")
    
    map_sum = 0
    for k in range(0, max_k):
        map_sum += precision_array[k]
    mean_avg_pr = map_sum/max_k
    print("The mean average precision is:", mean_avg_pr)
    output_result_csv_file = "curr_run_results.csv"
    pd.DataFrame(final_results).to_csv(output_result_csv_file)

    # logging to mlflow
    if record: # if the user would like to log to MLFlow
        mlflow.log_metric("mean_avg_precision", mean_avg_pr)
        mlflow.log_metric("prec_k", precision_array[max_k-1])
        mlflow.log_metric("recall_k", recall_array[max_k-1])

    return mean_avg_pr, precision_array[max_k-1], recall_array[max_k-1] 

In [296]:
calcMetrics(max_k=10, k_range=1, 
            gtPath='../data/ugen_v1/santosUnionBenchmark.pickle', 
            resPath='../starmie-llm-results/vicuna7b_ugen_v1_sparse_20_icl-1_result.pickle', record=True)

1 IDEAL RECALL: 0.09999999999999996
2 IDEAL RECALL: 0.19999999999999993
3 IDEAL RECALL: 0.30000000000000027
4 IDEAL RECALL: 0.39999999999999986
5 IDEAL RECALL: 0.5
6 IDEAL RECALL: 0.6000000000000005
7 IDEAL RECALL: 0.6999999999999998
8 IDEAL RECALL: 0.7999999999999997
9 IDEAL RECALL: 0.8999999999999991
10 IDEAL RECALL: 1.0
--------------------------
Precision at k =  1 = 0.74
Recall at k =  1 = 0.07400000000000004
--------------------------
Precision at k =  2 = 0.65
Recall at k =  2 = 0.13000000000000003
--------------------------
Precision at k =  3 = 0.6066666666666667
Recall at k =  3 = 0.18199999999999997
--------------------------
Precision at k =  4 = 0.55
Recall at k =  4 = 0.21999999999999997
--------------------------
Precision at k =  5 = 0.544
Recall at k =  5 = 0.2720000000000001
--------------------------
Precision at k =  6 = 0.5333333333333333
Recall at k =  6 = 0.32000000000000006
--------------------------
Precision at k =  7 = 0.5085714285714286
Recall at k =  7 = 0.

(0.5544349206349206, 0.454, 0.45399999999999996)

In [None]:
test_dict = loadDictionaryFromPickleFile('data/d3l_santos/santos_benchmark_result_by_d3l.pickle')
test_dict

In [None]:
def create_datalake_and_query_folders(in_dir, out_dir, gtfile, query_col="query_table", datalake_col="data_lake_table"):
    os.makedirs(out_dir, exist_ok=True)
    query_folder = os.path.join(out_dir, "query")
    datalake_folder = os.path.join(out_dir, "datalake")
    os.makedirs(query_folder, exist_ok=True)
    os.makedirs(datalake_folder, exist_ok=True)
    gt_df = pd.read_csv(gtfile)
    counter = 0
    for _, row in gt_df.iterrows():
        query_table = row[query_col]
        datalake_table = row[datalake_col]
        qt_file = os.path.join(query_folder, query_table)
        dlt_file = os.path.join(datalake_folder, datalake_table)
        if not os.path.isfile(qt_file):
            counter += 1
            in_file = os.path.join(in_dir, query_table)
            shutil.copy(in_file, query_folder)
#         if not os.path.isfile(dlt_file):
#             in_file = os.path.join(in_dir, datalake_table)
#             shutil.copy(in_file, datalake_folder)

In [None]:
in_dir = "data/table-union-search-benchmark/small/santos-query"
out_dir = "data/table-union-search-benchmark/small"
gtfile = "TUS_benchmark_relabeled_groundtruth.csv"
create_datalake_and_query_folders(in_dir, out_dir, gtfile, query_col="query_table", datalake_col="data_lake_table")

In [None]:
def create_gt_pickle(gtfile, out_pickle, query_col="query_table", datalake_col="data_lake_table", label="unionable"):
    query_datalake_dict = {}
    gt_df = pd.read_csv(gtfile)
    for _, row in gt_df.iterrows():
        query_table = row[query_col]
        datalake_table = row[datalake_col]
        is_unionable = row[label]
        if (is_unionable == 1):
            if query_table not in query_datalake_dict:
                query_datalake_dict[query_table] = [datalake_table]
            else:
                curr_tables = set(query_datalake_dict[query_table])
                curr_tables.add(datalake_table)
                query_datalake_dict[query_table] = list(curr_tables)
    for key, value in query_datalake_dict.items():
        print(len(value))
        print(key, value)
    with open(out_pickle, 'wb') as handle:
        pickle.dump(query_datalake_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)