In [1]:
import random
import glob
import os
import time
import numpy as np
import pandas as pd
import _pickle as cPickle
from numpy.linalg import norm
from sentence_transformers import SentenceTransformer
import utilities as utl
from sklearn.model_selection import train_test_split
import prepare_dataset_utilities as prepare_utl


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Use existing TUS benchmark with SANTOS relabel, a design decision to discuss about.
benchmark_name = "tus_benchmark"
groundtruth_file_path = r"data" + os.sep + benchmark_name + os.sep + "groundtruth.csv"
table_location = r"data"+ os.sep + benchmark_name
separator = ","

In [3]:
groundtruth_file = pd.read_csv(groundtruth_file_path)
# If a benchmark only contains unionable pairs, add unionable column with all 1 values. Eg. tus benchmark
if "unionable" not in groundtruth_file.columns:
    groundtruth_file['unionable'] = 1
# groundtruth_file['unionable'] = groundtruth_file['unionable'].replace(2, 1)
groundtruth_file

Unnamed: 0,serial_num,query_table,intent_col_index,data_lake_table,intent_col_name,tree_level,unionable
0,1,t_1934eacab8c57857____c4_1____0.csv,1,t_1934eacab8c57857____c2_0____0.csv,Country/region name,2,1
1,2,t_1934eacab8c57857____c4_1____0.csv,1,t_1934eacab8c57857____c2_0____1.csv,Country/region name,2,1
2,3,t_1934eacab8c57857____c4_1____0.csv,1,t_1934eacab8c57857____c2_0____2.csv,Country/region name,2,1
3,4,t_1934eacab8c57857____c4_1____0.csv,1,t_1934eacab8c57857____c2_0____3.csv,Country/region name,2,1
4,5,t_1934eacab8c57857____c4_1____0.csv,1,t_1934eacab8c57857____c2_0____4.csv,Country/region name,2,1
...,...,...,...,...,...,...,...
12156,12157,t_356fc1eaad97f93b____c18_0____4.csv,2,t_356fc1eaad97f93b____c23_0____0.csv,Location,2,1
12157,12158,t_356fc1eaad97f93b____c18_0____4.csv,2,t_356fc1eaad97f93b____c23_0____1.csv,Location,2,1
12158,12159,t_356fc1eaad97f93b____c18_0____4.csv,2,t_356fc1eaad97f93b____c23_0____2.csv,Location,2,1
12159,12160,t_356fc1eaad97f93b____c18_0____4.csv,2,t_356fc1eaad97f93b____c23_0____3.csv,Location,2,1


In [4]:
# Loading the ground truth pairs.
groundtruth_positive_dictionary = {}
groundtruth_negative_dictionary = {} # this is for negative pairs from the same cluster.
for id, row in groundtruth_file.iterrows():
    if str(row['unionable']) == "1":
        if row['query_table'] in groundtruth_positive_dictionary:
            groundtruth_positive_dictionary[row['query_table']].add(row['data_lake_table'])
        else:
            groundtruth_positive_dictionary[row['query_table']] = {row['data_lake_table']}
    else:
        if row['query_table'] in groundtruth_negative_dictionary:
            groundtruth_negative_dictionary[row['query_table']].add(row['data_lake_table'])
        else:
            groundtruth_negative_dictionary[row['query_table']] = {row['data_lake_table']}

In [5]:
all_query_tables = list(set(groundtruth_positive_dictionary.keys()).union(set(groundtruth_negative_dictionary.keys())))
print("Total queries:", len(all_query_tables))
print(all_query_tables[0:5])

Total queries: 125
['t_c701386b7c10b107____c11_0____3.csv', 't_356fc1eaad97f93b____c12_1____2.csv', 't_356fc1eaad97f93b____c16_1____0.csv', 't_1934eacab8c57857____c12_1____2.csv', 't_1934eacab8c57857____c11_0____4.csv']


In [6]:
all_data_lake_tables = glob.glob(table_location + os.sep + "datalake/*")
all_data_lake_tables = [table.rsplit(os.sep,1)[-1] for table in all_data_lake_tables]
print("Total data lake tables:", len(all_data_lake_tables))
print(all_data_lake_tables[0:5])

Total data lake tables: 1530
['t_013a2f8c584d44d7____c10_0____0.csv', 't_013a2f8c584d44d7____c10_0____1.csv', 't_013a2f8c584d44d7____c10_0____2.csv', 't_013a2f8c584d44d7____c10_0____3.csv', 't_013a2f8c584d44d7____c10_0____4.csv']


In [7]:
# Prepare positive table pairs. Any table pairs marked as unionable in the ground truth are positive samples.
positive_sample_pairs = set()
for table1 in groundtruth_positive_dictionary:
    current_positives = groundtruth_positive_dictionary[table1]
    positive_sample_pairs.add((table1,table1)) # to ensure that the same table is there.
    for table2 in current_positives:
        if (table2, table1) not in positive_sample_pairs:
            positive_sample_pairs.add((table1,table2))

positive_same_table = set()
positive_different_table = set()
for item in positive_sample_pairs:
    if item[0] == item[1]:
        positive_same_table.add(item)
    else:
        positive_different_table.add(item)
print("Count same: ", len(positive_same_table))
print("Count different: ", len(positive_different_table))
print("Total samples: ", len(positive_sample_pairs))

print(list(positive_sample_pairs)[0:5])

Count same:  125
Count different:  10690
Total samples:  10815
[('t_c701386b7c10b107____c12_1____4.csv', 't_c701386b7c10b107____c20_0____4.csv'), ('t_1934eacab8c57857____c12_1____2.csv', 't_1934eacab8c57857____c15_0____2.csv'), ('t_67c3f7ce5eab8804____c6_0____0.csv', 't_356fc1eaad97f93b____c18_1____3.csv'), ('t_93f3d6f7fc6aa6ff____c13_1____2.csv', 't_93f3d6f7fc6aa6ff____c16_0____2.csv'), ('t_c701386b7c10b107____c20_1____2.csv', 't_c701386b7c10b107____c11_1____3.csv')]


In [8]:
negative_same_cluster = set()
negative_different_cluster = set()
negative_sample_pairs = set()
if len(groundtruth_negative_dictionary) > 0: # this is for ugen or other similar benchmarks with non-unionable tables marked in the groundtruth, may need update for other benchmarks.
    for table1 in groundtruth_negative_dictionary:
        current_negatives = groundtruth_negative_dictionary[table1]
        for table2 in current_negatives:
            if (table1, table2) in positive_sample_pairs or (table2, table1) in positive_sample_pairs:
                print("Red flag!!! The table pair in negative sample is found in positive sample.")
            else:
                if (table2, table1) not in negative_same_cluster:
                    negative_same_cluster.add((table1,table2))

    for query in all_query_tables:
        for datalake in all_data_lake_tables:
            if (query,datalake) in positive_sample_pairs or (datalake, query) in positive_sample_pairs or (query,datalake) in negative_same_cluster or (datalake, query) in negative_same_cluster:
                continue
            else:
                negative_different_cluster.add((query, datalake))
    negative_sample_pairs = negative_same_cluster.union(negative_different_cluster)

else: # this is for tus benchmark, may need update for other benchmarks.
    # Prepare negative table pairs. Any table pairs with one table as a query table marked as non-unionable in the ground truth are negative samples.
    for query in all_query_tables:
        for datalake in all_data_lake_tables:
            if (query,datalake) not in positive_sample_pairs and (datalake, query) not in positive_sample_pairs:
                negative_sample_pairs.add((query, datalake))

    # t_1934eacab8c57857____c4_1____0.csv
    for item in negative_sample_pairs:
        if item[0].split("____",1)[0] == item[1].split("____",1)[0]:
            negative_same_cluster.add(item)
        else:
            negative_different_cluster.add(item)

print("Count same cluster: ", len(negative_same_cluster))
print("Count different cluster: ", len(negative_different_cluster))
print("Total samples: ", len(negative_sample_pairs))

print(list(negative_sample_pairs)[0:5])

if len(negative_sample_pairs.intersection(positive_sample_pairs)) > 0:
    print("Red flag!!! Overlap found between positive and negative samples.")
    # print(len(negative_sample_pairs.intersection(positive_sample_pairs)))

Count same cluster:  13130
Count different cluster:  165770
Total samples:  178900
[('t_1934eacab8c57857____c11_0____1.csv', 't_93f3d6f7fc6aa6ff____c13_0____2.csv'), ('t_c701386b7c10b107____c12_1____1.csv', 't_ca85e8f9eef5b9d5____c10_0____4.csv'), ('t_c701386b7c10b107____c14_0____0.csv', 't_1934eacab8c57857____c12_1____3.csv'), ('t_c701386b7c10b107____c20_1____0.csv', 't_356fc1eaad97f93b____c11_1____0.csv'), ('t_93f3d6f7fc6aa6ff____c13_1____0.csv', 't_1934eacab8c57857____c1_0____2.csv')]


Select up to 5 tuples from each positive and negative pairs. Note that is going to be random.

In [9]:
# selecting 50k positive samples and 50k negative samples
# positive: ~ 5k from the same table schema type (125 pairs, ~ 40 from each) and ~45k from different schema table type (10690 pairs, ~ 4 from each). The same table is easy case for which BERT is already working well. So, focusing here on different tables.
# negative: ~ 25k from the same table cluster (25000 pairs, ~ 1 from each) and ~ 25k from different table clusters (166125 pairs, ~ 1/6 from each )

# Note that we have more negatives, so this decision is made to solve the imbalance. Indeed it does not matter how much data is available in the open-world. 
# We should care the closed-world case for the model.

selected_positive_tuples_same = prepare_utl.CreateSampleDataPoints(positive_same_table, 5, "1", table_location, separator= separator, pairs = "quadratic") # pairs = quadratic or linear
print("Total positive samples selected from the same table:", len(selected_positive_tuples_same))
prepare_utl.SaveSampleDataPoints(r"finetune_data" + os.sep + benchmark_name + "_corrected" + os.sep + "positive_same_table.txt", selected_positive_tuples_same)

selected_positive_tuples_different = prepare_utl.CreateSampleDataPoints(positive_different_table, 5, "1", table_location, separator= separator, pairs = "quadratic")
#balance the dataset:
# if benchmark_name == "ugen_benchmark":
selected_positive_tuples_different = list(selected_positive_tuples_different)
random.shuffle(selected_positive_tuples_different)
selected_positive_tuples_different = set(selected_positive_tuples_different[0:27313]) #to balance the dataset for ugen benchmark
print("Total positive samples selected from cross (different) tables:", len(selected_positive_tuples_different))
prepare_utl.SaveSampleDataPoints(r"finetune_data" + os.sep + benchmark_name + "_corrected" + os.sep + "positive_different_tables.txt", selected_positive_tuples_different)

all_positive_samples = selected_positive_tuples_same.union(selected_positive_tuples_different)
print("Total positive samples selected:", len(all_positive_samples))
prepare_utl.SaveSampleDataPoints(r"finetune_data"+ os.sep + benchmark_name + "_corrected" +os.sep + "all_positive_samples.txt", all_positive_samples)


selected_negative_tuples_same = prepare_utl.CreateSampleDataPoints(negative_same_cluster, 5, "0", table_location , separator= separator, pairs = "quadratic")
#balance the dataset:
# if benchmark_name == "ugen_benchmark":
selected_negative_tuples_same = list(selected_negative_tuples_same)
random.shuffle(selected_negative_tuples_same)
selected_negative_tuples_same = set(selected_negative_tuples_same[0:15000]) #to balance the dataset for ugen benchmark
print("Total negative samples selected from the same table cluster:", len(selected_negative_tuples_same))
prepare_utl.SaveSampleDataPoints(r"finetune_data"+ os.sep + benchmark_name + "_corrected" + os.sep + "negative_same_table_cluster.txt", selected_negative_tuples_same)

negative_different_sample_ratio = 1/6 # reduce negative different tables to 1/6 by sampling
negative_different_sample_size = round(len(negative_different_cluster) * negative_different_sample_ratio)
negative_different_cluster_sampled = random.sample(list(negative_different_cluster), k = negative_different_sample_size)
selected_negative_tuples_different = prepare_utl.CreateSampleDataPoints(negative_different_cluster_sampled, 5, "0", table_location , separator= separator, pairs = "linear")
#balance the dataset:
# if benchmark_name == "ugen_benchmark":
selected_negative_tuples_different = list(selected_negative_tuples_different)
random.shuffle(selected_negative_tuples_different)
selected_negative_tuples_different = set(selected_negative_tuples_different[0:15000]) #to balance the dataset for ugen benchmark

print("Total negative samples selected from different table cluster:", len(selected_negative_tuples_different))
prepare_utl.SaveSampleDataPoints(r"finetune_data" + os.sep + benchmark_name + "_corrected" + os.sep + "negative_different_table_cluster.txt", selected_negative_tuples_different)

all_negative_samples = selected_negative_tuples_same.union(selected_negative_tuples_different)
print("Total negative samples selected:", len(all_negative_samples))
prepare_utl.SaveSampleDataPoints(r"finetune_data"+ os.sep + benchmark_name + "_corrected" + os.sep + "all_negative_samples.txt", all_negative_samples)

Total positive samples selected from the same table: 2687
Total positive samples selected from cross (different) tables: 27313
Total positive samples selected: 30000
Total negative samples selected from the same table cluster: 15000
Total negative samples selected from different table cluster: 15000
Total negative samples selected: 30000
