In [1]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import torch
from torch.nn.functional import softmax
from copy import deepcopy
import enchant
from sklearn.metrics.pairwise import cosine_similarity
from scipy import sparse
from tqdm import trange, tqdm
import json
import nltk
import spacy
import string
import re
import pickle
import inflect
nlp = spacy.load('en_core_web_sm')

# Global Variables

In [2]:
bigram_reservered_words = ['against', 'and', 'area', 'average', 'since', 'away', 'section', 'by', 'class', 'club', 'code', 'cup', 'current', 'date', 'data', 'district', 'elected', 'engine', 'episode', 'event', 'final', 'finish', 'first', 'for', 'from', 'game', 'games', 'goals', 'gold', 'grid', 'height', 'high', 'home', 'id', 'in', 'incumbent', 'international', 'laps', 'league', 'list', 'log', 'loss', 'losses', 'lost', 'method', 'age', 'name', 'nation', 'no', 'notes', 'number', 'of', 'one', 'two', 'three', 'four', 'yes', 'no', 'yards', 'five' 'other', 'outcome', 'overall', 'par', 'party', 'per', 'pick', 'played', 'player', 'points', 'pos', 'rank', 'record', 'region', 'release', 'report', 'res', 'result', 'results','round', 'score', 'season', 'second', 'series', 'singles', 'start', 'end', 'state', 'status', 'table', 'team', 'types', 'the', 'first', 'second', 'third', 'time', 'to', 'total', 'type', 'up', 'week', 'weeks', 'year', 'unit', 'version', 'years', 'ends', 'ended', 'min', 'max', 'make', 'statistics', 'stats', 'in', 'on', 'to', 'see', 'feet', 'subject']
preps = ["aboard","about","above","across","after","against","along","amid","among","as","at","before","behind","below","beneath","beside","besides","between","beyond","but","by","concerning","considering","despite","down","during","except","excepting","excluding","following","for","from","in","inside","into","like","minus","near","of","off","on","onto","opposite","outside","over","past","per","plus","regarding","round","save","since","than","through","to","toward","towards","under","underneath","unlike","until","up","upon","versus","via","with","within","without"]
bigram_reservered_words = list(set(bigram_reservered_words + preps))



### Templates to be used for checking
# template1 = lambda table_name, col: f"We are told of the {table_name}'s {trim_col_name(table_name, col)}."
# template2 = lambda table_name, col: f"We are informed of the {trim_col_name(table_name, col)} of the {table_name}."
# template3 = lambda table_name, col: f"We know {trim_col_name(table_name, col)} of the {table_name}."
# template4 = lambda table_name, col: f"We collect {table_name}'s {trim_col_name(table_name, col)}."
# TEMPLATES = [template1, template2, template3, template4] if STRICT_MODE else [template1, template2]
tablename_black_list = ["statistics", "data", "table", "summary", "sketch", "list"]
date_marks = ["date", "dates", "year", "years", "month", "months",
              "day", "days", "daytime", "minute", "minutes", "second", "seconds", "time"]
num_marks = ["num", "number", "sum", "amount", "count", "total", "#", "No.", "no.", "scores",
             "rating", "rank", "height", "weight", "age", "time", "times", "temperature",
             "year", "years", "month", "months", "day", "days", "minute", "minutes", 
             "second", "seconds", "average", "sum", "grade", "fee", "cost", "value",
             "rate"]  # words explicitly has numeric implications



def template1(table_name, col, col_type):
    """Template 1 placeholder filling for give table name, column name, and column type"""
    table_name = get_singular_word(table_name)
    capital_tname = table_name[0].capitalize() + table_name[1:]
    trimmed_col_name = trim_col_naive(table_name, col)
    type_prompt = trim_type(col_type)
    if col_type == "date" and any([dm in trimmed_col_name for dm in date_marks]):
        type_prompt = ""
    if col_type == "number" and any([nm in trimmed_col_name for nm in num_marks]):
        type_prompt = ""
    return f"{capital_tname} {trimmed_col_name}{trim_type(col_type)}."

In [3]:
TEMPLATES = [template1]
device = "cuda"
nli_tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
nli_model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli").to(device)

Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Preprocessing Functions

In [4]:
inflect_engine = inflect.engine()
def get_singular_word(word):
    """
    Reduce a given word (string) to singular form
    """
    ans = inflect_engine.singular_noun(word)
    return ans if ans else word


def read_dense_table_vectors(path, delim="\t"):
    """
    Read backend dense table vectors
    """
    with open(path, "r") as f:
        tid2vals = {}
        for line in f.readlines():
            if len(line) == 0: continue
            units = line.split(delim)
            table_id, vals = units[0], units[1:] 
            tid2vals[table_id] = np.array(vals).astype(float)
    return tid2vals


def trim_col(tname, col):
    """
    Normalize table name.
    :tname: table name
    :col: column name to be trained
    """
    if tname == '' or col == '': return col
    if tname == ' ' or col == ' ': return col
    tname_tokens = nlp(tname)
    col_tokens = nlp(col)
    if tname_tokens[-1].lemma_ == col_tokens[0].lemma_:
        return col_tokens[1:]
    return col    
    
# template1 = lambda table_name, col: f"{table_name[0].capitalize() + table_name[1:]} {trim_col(table_name, col)}."

def trim_col_naive(tname, col):
    if tname == '' or col == '': return col
    if tname.lower() == col.lower():
        return 'name'
    return col.lower()
        

def trim_type(col_type):
    if col_type in ["text", "bool"]:
        return ""
    return " " + "time" if col_type == "date" else "number"
    

# Load datasets

In [5]:
with open("./data/tid2tables.pkl","rb") as f:
     tid2tables = pickle.load(f)
        
idx2tid = {i:tid for i,tid in enumerate(tid2tables.keys())}
tid2idx = {tid:i for i,tid in enumerate(tid2tables.keys())}

spiders = []
with open("./data/spider/spider-tables-synonym.values-bm.jsonl", "r") as f:
    for line in f.readlines():
        table = json.loads(line)
        spiders.append(table)
        
wtqs = []
with open("./data/WTQ/wtq.tables.values-bm.jsonl", "r") as f: # no table name, no domain name
    for line in f.readlines():
        table = json.loads(line) # dict_keys(['file_name', 'table_name', 'column_types', 'column_names', 'column_values'])
        wtqs.append(table)

wsqls = []
with open("./data/wikisql/wikisql.tables.dev.test.refer.values-bm.jsonl", "r") as f:
    for line in f.readlines():
        table = json.loads(line) # dict_keys(['refer_cols_index', 'domain', 'table_id', 'table_name', 'column_names', 'column_types', 'column_values'])
        wsqls.append(table)
        
wsqls = []
with open("./data/wikisql/wikisql.tables.dev.test.refer.values-bm.jsonl", "r") as f:
    for line in f.readlines():
        table = json.loads(line) # dict_keys(['refer_cols_index', 'domain', 'table_id', 'table_name', 'column_names', 'column_types', 'column_values'])
        wsqls.append(table)
        
wsqls_train = []
with open("./data/wikisql/wikisql.tables.train.refer.values-bm.jsonl", "r") as f:
    for line in f.readlines():
        table = json.loads(line) # dict_keys(['refer_cols_index', 'domain', 'table_id', 'table_name', 'column_names', 'column_types', 'column_values'])
        wsqls_train.append(table)
        
with open("./data/wikisql/wikisql.dev.test.tid2question.json", "r") as f:
    wsql_table2qs = json.load(f)
    
with open("./data/wikisql/wikisql_train.tables2question.json", "r") as f:
    wsql_train_table2qs = json.load(f)
    
    
with open("./data/spider/spider-table2questions.json", "r") as f:
    spider_table2qs = json.load(f)

with open("./data/WTQ/wtq-table2questions.json", "r") as f:
    wtq_table2qs = json.load(f)

    
idx2word, word2idx = {}, {}
word2vec = {}
with open("./data/numberbatch/nb_emb.txt", "r") as f:
    cnt = -2
    for line in tqdm(f.readlines(), desc="Building word2vec...", leave=True):
        cnt += 1
        if cnt == -1: continue
        units = line.split(" ")
        word, emb = units[0], np.array(units[1:]).astype(float)
        word2vec[word] = emb
        idx2word[cnt] = word
        word2idx[word] = cnt
EMB_DIM=300

with open("./data/syndict_pipeline.json") as f:
    synonym_dic = json.load(f)

Building word2vec...: 100%|███████████████████████████████████████████████████████████████████████████████| 516783/516783 [01:45<00:00, 4875.82it/s]


# Dense Retrieval Setup

In [6]:
import os
import torch
import torch.nn as nn
import numpy as np
from transformers import TapasModel, TapasConfig, TapasTokenizer, BertModel, BertTokenizer


def build_projection_layer(weight_path: str):
    with open(weight_path, 'rb') as f:
        weights = torch.from_numpy(np.load(f))
    linear = nn.Linear(weights.size(0), weights.size(1), bias=False)
    linear.weight.data = weights
    return linear


MAX_LEN = 1024
DUMMY_TABLE = pd.DataFrame({})

In [7]:
basepath = os.path.join("tapas-torch", "tapas_retrieval")
table_model_path = os.path.join(basepath, "tapas_nq_hn_retriever_large_table", "checkpoint")
table_model = TapasModel.from_pretrained(table_model_path).to(device)
tapas_tokenizer = TapasTokenizer.from_pretrained(table_model_path)
table_model_config = TapasConfig.from_pretrained(table_model_path)
query_model_path = os.path.join(basepath, "tapas_nq_hn_retriever_large_query", "checkpoint")
query_model = TapasModel.from_pretrained(query_model_path).to(device)
text_projection_layer = build_projection_layer(os.path.join(basepath, "projection_layer", "text_projection.npy")).to(device)
table_projection_layer = build_projection_layer(os.path.join(basepath, "projection_layer", "table_projection.npy")).to(device)

def form_table(dic_table, col_name_key="column_names", max_row_limit=10, max_cell_val_len=50): # output a dataframe
    """
    Build source table text for dense vector computation. Done via resampling strategy.
    :dic_table: table as a dictionary
    :col_name_key: the key name in passed dic_table storing column names (as a list).
    :max_row_limit: maximum number of rows for cosntructed table 
    :max_cell_val_len: cell values will be truncated to this length.
    """
    col_names = dic_table["column_names"]
    col_vals = {k : list(set(v)) for k,v in dic_table["column_values"].items()}
    try:
        longest_unique = min(max([len(v) for v in col_vals.values()]), max_row_limit)
    except:
        if len(dic_table[col_name_key]) == 0:
            return DUMMY_TABLE
        else:
            return pd.DataFrame({k:[] for k in dic_table[col_name_key]})
    col2vals = {n : [str(elem)[:max_cell_val_len] for elem in np.random.choice(v, longest_unique, replace=True)] for n,v in col_vals.items()}
    return pd.DataFrame(col2vals)


    
    
# takes a while to load 615144 * 2 vectors...
benchmark_dense_a = read_dense_table_vectors(path="./benchmark/benchmark_dense_A.txt")
benchmark_dense_b = read_dense_table_vectors(path="./benchmark/benchmark_dense_B.txt")
idx2tid = {i:k for i,k in enumerate(benchmark_dense_a.keys())}
tid2idx = {k:i for i,k in idx2tid.items()}
bm_mat_A = torch.stack([torch.Tensor(vs) for vs in benchmark_dense_a.values()], dim = 0).to(device)
bm_mat_B = torch.stack([torch.Tensor(vs) for vs in benchmark_dense_b.values()], dim = 0).to(device)
bm_mat_A = bm_mat_A / torch.norm(bm_mat_A, dim=-1).unsqueeze(-1)
bm_mat_B = bm_mat_B / torch.norm(bm_mat_B, dim=-1).unsqueeze(-1)

# Core NLI algorithms

In [8]:
def trim_col_name(table_name, col_name):
    if table_name == col_name:
        return table_name + " name"
    return col_name

def check_spell(col_name):
    """
    Check whether a column name (multiwords allowed) is valid english word.
    """
    return all([checker.check(w) for w in col_name.split(" ") if w != ""])


def batchify(pair_dict):
    """
    form batch of a pair of ori-rpl.
    Two directions.
    """
    split_idx = []
    batch_ori = []
    batch_rpl = []
    prev_end_idx = 0
    for dic in pair_dict:
        key_map, pairs, _, _ = dic.values()
        split_idx.append((prev_end_idx, prev_end_idx + 2 * len(pairs),))
        prev_end_idx = prev_end_idx + 2 * len(pairs) # 2 * because of reverse
        for (ori, rpl) in pairs:
            batch_ori.append(ori)
            batch_ori.append(rpl) # reverse
            batch_rpl.append(rpl)
            batch_rpl.append(ori) # reverse
    return split_idx, batch_ori, batch_rpl


def aggregate(split_idx, scores, strict=True):
    contras, neus, ents = [],[],[]
    for s,e in split_idx:
        one_rpl_scores = scores[s:e,:]
        if strict:
            """ For REPLACE cols
                Prefer high PRECISION of repalceablility! (If we REPLACE with a UNreplaceable col, we run into trouble)
                Reject as many LOW confidence candidate as possible.
                If NLI give HIGH ent-score, then two columns should almost always be mutally replaceable!
            """
            contra, neu, ent = torch.min(one_rpl_scores, dim=0)[0].squeeze()
        else:
            """ For ADD cols
               Prefer high RECALL of repalceablility! (If we ADD a replaceable col, we run into trouble)
               Accept as many LOW confidence candidate as possible.
               If NLI still suggests LOW ent-score, then two columns should almost always be mutally UNreplaceable!
            """
            contra, neu, ent = torch.max(one_rpl_scores, dim=0)[0].squeeze() # Prefer high recall
        contras.append(float(contra.item()))
        neus.append(float(neu.item()))
        ents.append(float(ent.item()))
    return contras, neus, ents


def construct_pairs_for_nli_test(tables, table_id_key="table_id", table_name_key="table_name",
                                 col_type_key="column_types", col_name_key="column_names",
                                 pending_rpls_key="column_names_syn"):
    """
    Given a list of dictionary-represneted tables, and pending replacements cols,
    construct pairs of ori-rpl
    :table_id_key: key name in table dict for table name 
    :col_type_key: key name in table dict for types of columns
    :col_name_key: key name in table dict for names of columns
    :pending_rpls_key: key name in table dict for pending keys to be replaced
    """
    assert isinstance(tables, list), "Please pass a list of tables."
    assert table_id_key in tables[0], "Each Table must have an id."
    assert table_name_key in tables[0], "Table name is required, but the key is missing."
    assert col_name_key in tables[0], "column name key is required but missing"
    assert col_type_key in tables[0], "column type key is required but missing"
    assert pending_rpls_key in tables[0], "Pending replacement columns is required, but the key is missing."
    constructed_pairs = []
    for i in trange(len(tables)):
        tab = tables[i]
        tname = tab[table_name_key] if tab[table_name_key] != "s" else tab[table_name_key][:-1]
        pending_rpls = tab[pending_rpls_key]
        col2type = {col: tp for col,tp in zip(tab[col_name_key], tab[col_type_key])}
        for ori_col, rpl_col_list in pending_rpls.items():
            for rpl_col in rpl_col_list:
                if not check_spell(rpl_col): continue
                rpl_dic = {"key_map": None, "pairs": [], "table_id": tab[table_id_key], "table_name": tab[table_name_key]}
                for template in TEMPLATES:
                    sent_ori = template(tname, ori_col, col2type[ori_col])
                    sent_rpl = template(tname, rpl_col, col2type[ori_col])
                    rpl_dic["key_map"] = (ori_col, rpl_col,)
                    rpl_dic["pairs"] += ((sent_ori, sent_rpl,),)
                constructed_pairs.append(rpl_dic)
    return constructed_pairs


def nli_test_across_tables(constructed_pairs, batch_size=256):
    """
    The major interface for NLI verification.
    Given constructed pairs (results from construct_pairs_for_nli_test),
    use batch computation to speed up the verfication process.
    """
    assert batch_size % 8 == 0, "Batch size must be a multiple of 8."
    results = []
    completed_pairs = 0
    total_batches = len(constructed_pairs) // batch_size + 1
    pbar = tqdm(total = total_batches)
    with torch.no_grad():
        while completed_pairs < len(constructed_pairs):
            batch_contras, batch_neus, batch_ents = [],[],[]
            prev_completed = completed_pairs
            completed_pairs = min(completed_pairs + batch_size, len(constructed_pairs))
            batch = constructed_pairs[prev_completed:completed_pairs]
            split_idx, batch_ori, batch_rpl = batchify(batch)
            inputs = nli_tokenizer(batch_ori, batch_rpl, padding="longest", return_tensors="pt").to(device)
            logits = nli_model(**inputs).logits
            scores = softmax(logits, dim=1) # [batch, 3]
            batch_contras, batch_neus, batch_ents = aggregate(split_idx, scores)
            del inputs; del logits; del scores; torch.cuda.empty_cache()
            batch_contras, batch_neus, batch_ents = np.array(batch_contras), np.array(batch_neus), np.array(batch_ents)
            for b, c, n, e in zip(batch, batch_contras, batch_neus, batch_ents):            
                results.append({"key_map": b["key_map"], "scores": (c, n, e,)})
            pbar.update(1)
    pbar.close()
    return results



In [9]:
def trim_name(text):
    for ch in ['\\','`','*','{','}','[',']','(',')','>', '<', '#','+','\'', '"']:
        if ch in text:
            text = text.replace(ch, "")
    text.replace("-", " ")
    text.replace(".", " ")
    return text

def extract_emb(list_of_names):
    """
    Extract numberbatch word embeddings for a given list of strings
    """
    assert isinstance(list_of_names, list), "Expected list as input"
    output_matrix = np.zeros([len(list_of_names), EMB_DIM])
    for i,name in enumerate(list_of_names):
        name = trim_name(name)
        units = name.split() # notice "_" is covered by our nb_emb!
        name_emb = np.zeros(EMB_DIM)
        for word in units:
            if "_" in word and word2vec.get(word, None) is None:
                sub_words = word.split("_")
                local_emb_mat = extract_emb(sub_words)
                emb = np.mean(local_emb_mat, axis=0)
            else:
                emb = word2vec.get(word, np.zeros(EMB_DIM))
            name_emb += emb
        name_emb /= len(units)
        output_matrix[i,:] = name_emb
    return output_matrix
    
def reranker(tgt_names, cand_names, topk=10):
    """
    Do reranking (usually among few hundreds of candidates) and return topk per numberbatch word2vec similarity
    """
    if len(cand_names) == 0:
        return {}
    tgt_mat = extract_emb(tgt_names)
    cand_mat = extract_emb(cand_names)
    sim_mat = tgt_mat @ cand_mat.T
    topk = min(len(cand_names), topk)
    top_scores, top_idx = [v.squeeze().numpy() for v in torch.topk(torch.Tensor(sim_mat), topk, dim=-1)]
    rec_dic = {}
    for i, tgt in enumerate(tgt_names):
        if len(top_idx.shape) == 0: top_idx = np.array([top_idx])
        if len(top_idx.shape) == 1: top_idx = top_idx[None, :]
        
        rec_dic[tgt] = [cand_names[idx] for idx in top_idx[i]]
    return rec_dic


def retriver(query_table, queries=None, retrieve_strategy="query_dense", topk_tables=50, col_name_key="column_names", target_expand_keys=100):
    """
    Tapas based dense retrieval for finding topk most similar tabels from table base.
    :query_table: The table whose topk similar will be found
    :quries: the user NL queries attached with the query_table
    :retrieve_strategy: Choose from ["query_dense", "table_dense"], qd uses NL query as retrieval query vector,
                        and td uses table as retreival query vector.
    :topk_tables: Return k most similar tables
    """
    ori_cols = set(query_table[col_name_key])
    top_tables_tid = []
    if retrieve_strategy == "query_dense":
        top_tables_tid = retrieve_tables_query_dense(queries,k=topk_tables)
    elif retrieve_strategy == "table_dense":
        pass
    elif retrieve_strategy == "tfidf":
        pass
    else:
        raise NotImplementedError
    top_tables = [tid2tables[tid] for tid in top_tables_tid]
    expanded_cols = set()
    for t in top_tables:
        if len(expanded_cols) >= target_expand_keys: break
        expanded_cols = expanded_cols.union(t[col_name_key])
    expanded_cols = expanded_cols.difference(ori_cols)
    return list(ori_cols), list(expanded_cols)


def retrieve_tables_tfidf(query_table, tfidf_mat, col_name_key="column_names", table_doc_key="doc"):
    """
    TF-IDF based retrieval for finding most similar tables from DB.
    """
    assert table_doc_key in query_table and col_name_key in query_table
    query_tfidf = vectorizer.transform(query_table)
    scores = cosine_similarity(query_tfidf, tfidf_mat)[0]
    top_scores, indices = [t.squeeze().numpy() for t in torch.topk(torch.Tensor(scores), 1000)]
    return [idx2tid[i] for i in indices]


def retrieve_tables_query_dense(queries, k=50):
    """'
    Interface for finding most similar table via dense retrieval. call goes from here.
    """
    assert isinstance(queries, list), "input queries must be a list of strings"
    torch.cuda.empty_cache()
    with torch.no_grad():
        q_inputs = tapas_tokenizer(table=DUMMY_TABLE, queries=queries, padding=True, truncation=True, return_tensors="pt").to(device)
        qb = query_model(**q_inputs).pooler_output
        qb = text_projection_layer(qb)
        qb = qb / torch.norm(qb, dim=-1).unsqueeze(-1)
        cos = torch.matmul(qb, bm_mat_B.transpose(0, 1))
        cos = torch.mean(cos, dim=0)
        top_score, top_idx = [v.data.cpu().numpy() for v in torch.topk(cos, k=k)]
        
#         top_idx = top_idx.data.cpu().numpy()

    #### ab means  table encoded with encoder A, query encoded with encoder B.
    return  [idx2tid[i] for i in top_idx]

checker = enchant.Dict("en_US")
def _ends_with_id(string):
    if len(string) < 2: return False
    return string[-2:].lower() == "id"

def _fill_type_info(string, col_type, delim):
    """
    Add type description for a given column
    """
    if col_type == "date" and not any([dm in string for dm in date_marks]):
        return delim.join([string, "time"])
    if col_type == "number" and not any([nm in string for nm in num_marks]):
        return delim.join([string, "number"])
    return string


def contains_number(text):
    """
    Judege whether the passed string contain number
    """
    return len(re.findall("[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", text)) > 0


def _get_replacement(tok1, tok2, tok1_is_reserved, tok2_is_reserved):
    """
    Given a bi-gram, replce the word whose IDF is higher with its synonym.
    """
    if tok1_is_reserved and tok2_is_reserved:
        return (None, None)
    if tok1_is_reserved and (not check_spell(tok2) or contains_number(tok2)):
        return (None, None)
    if tok2_is_reserved and (not check_spell(tok1) or contains_number(tok1)):
        return (None, None)
    if tok1_is_reserved:
        syn_dic = synonym_dic.get(tok2.lower(), None)
        return (tok2, syn_dic) if syn_dic is not None else (None, None)
    if tok2_is_reserved:
        syn_dic = synonym_dic.get(tok1.lower(), None)
        return (tok1, syn_dic) if syn_dic is not None else (None, None)
    # both are not reserved, pick one with higher tfidf val
    def extract_idf(vocab):
        vocab_idx = vectorizer.vocabulary_.get(vocab, None)
        idf = 0 if vocab_idx is None else vectorizer.idf_[vocab_idx]
        return idf
    first_tgt = tok1 if extract_idf(tok1) <= extract_idf(tok2) else tok2  # rare is better
    second_tgt = tok2 if first_tgt == tok1 else tok1
    syn_dic_first = synonym_dic.get(first_tgt.lower(), None)
    if syn_dic_first is not None: return (first_tgt, syn_dic_first)
    syn_dic_second = synonym_dic.get(second_tgt.lower(), None)
    return (second_tgt, syn_dic_second) if syn_dic_second is not None else (None, None)


# REPLACE & ADD Interface

In [10]:
def normalize_token(token):
    """
    Do strict noramlization for a given token. All punctuations will be removed.
    """
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(token))))

def trim_retrieval_results(replacement_dict):
    """
    All overly short (char len < 4) / misspelled / contains numbers tokens are not replaceable. Filter them out.
    """
    out_dic = deepcopy(replacement_dict)
    for col in replacement_dict.keys():
        replacements = replacement_dict[col]
        filter_rpls = [r for r in replacements if len(r) > 4]
        out_dic[col] = filter_rpls
    return out_dic

def consider_rpl(token):
    """
    Judge whether a token is suitable for replacement.
    All overly short (char len < 4) / misspelled / contains numbers col are not replaceable.
    """
    norm_token = normalize_token(token)
    if len(norm_token) < 4: return False, token
    if contains_number(norm_token): return False, token
    if not check_spell(norm_token): return False, token
    return True, norm_token
    

def replace_and_add_for_give_tables(path, table2qs, batch_size=512,
                                    replace_threshold=0.75, 
                                    add_threshold=0.4,
                                    output_prefix="", output_dir="./processed_data",
                                    delim=" ",
                                    topk_tables=50,
                                    max_cands_per_col=10,
                                    table_name_key="table_name",
                                    col_type_key="column_types",
                                    col_name_key="column_names"):
    """
    The highest-level interface for replacement and addition across all tables stored in a given path.
    One call, handle all.
    :path: Target tables path
    :table2qs: The queries corresponding to the each of the tables.
    :batch_size: bsz for NLI checking. 512 recommended.
    :replace_threshold: If NLI entailment score is higher than this threshold under STRICT mode, then the rpl pair is accepted.
    :add_treshold: If NLI entailment score is lower than this threshold under LOOSE mode, then the add pair is accepted.
    :output_prefix: File name prefix for output file.
    :output_dir: Output file directory.
    :delim: Delimiator for column names. Single white space by default.
    :topk_tables: How many most similar tables to consider from dense retrieval.
    :max_cands_per_col: Max nubmer of pairs to be considered for each column (both add and rpl).
                        This directly influences the final amount to be checked by NLI.
    """
    
    # STEP 0: Prepare tables
    print("STEP 0 : Prepare tables...\n")
    tables = []
    tables_template = {}
    cnt = 0
    with open(path, "r") as f:
        for line in f.readlines():
            table = json.loads(line)
            table_copy = deepcopy(table)
            table_copy["rpls_retrieval"] = {}
            table_copy["rpls_syndict"] = {}
            tables.append(table_copy)
            table["REPLACE"] = {tname:[] for tname in table[col_name_key]}
            table["ADD"] = {tname:[] for tname in table[col_name_key]}
            tables_template[table["table_id"]] = table
#             if cnt == 5: break
            cnt += 1
            
    # STEP 1: "retrieval"  for add / replacement
    print('STEP 1: Dense retrieval for add / replacement...\n')
    for tab in tqdm(tables, position=0):
        tid = tab["table_id"]
        queries = table2qs.get(tid, None)
        if queries is not None:
            queries = queries[:10] if len(queries) > 10 else queries
            ori_cols, expanded_cols = retriver(query_table=tab, queries=queries, topk_tables=topk_tables)
            rec_dic = reranker(ori_cols, expanded_cols, topk=max_cands_per_col) # We will find analog & synonyms in this list
            tab["rpls_retrieval"] = trim_retrieval_results(rec_dic)
            
                
    # STEP 2 : synonym dict for replacement
    print('STEP 2 : synonym dict for replacement...\n')
    for tab in tqdm(tables, position=1):
        if isinstance(tab["column_types"], list):
            col2type = {c:t for c,t in zip(tab["column_names"], tab["column_types"])}
        else:
            col2type = tab["column_types"]
        for col in tab["column_names"]:
            col_type = col2type.get(col, "text")
            if _ends_with_id(col): continue
            tokens = [w.lower() for w in col.split(delim)]
            keep_original, normalized_tokens = [], []
            tok2syn = {}
            for tok in tokens:
                can_rpl, tok = consider_rpl(tok)
                normalized_tokens.append(tok)
                syn_dic = synonym_dic.get(tok, None)
                if syn_dic is None: syn_dic = synonym_dic.get(get_singular_word(tok), None)
                keep_ori = (not can_rpl) or (syn_dic is None)  # skip eihter because not replaceable or not in dic
                keep_original.append(keep_ori)
                if keep_ori == True: continue
                rec_dic = reranker(tgt_names=[tok], cand_names=list(set(syn_dic["synonyms"])), topk=10)
                tok2syn.update(rec_dic)
            syn_rpl_candidates = set() # Genereate syn-replaced candidates
            patience = 5 # if 5 in steps there is no new candidate added, break the loop.
            
            while True:
                if len(syn_rpl_candidates) >= max_cands_per_col or patience == 0:
                    syn_rpl_candidates = list(syn_rpl_candidates.difference(set([" ".join(normalized_tokens)])))
                    break
                rpl_threshold = 1 if len(tokens) == 1 else (0.75 if len(tokens) == 2 else 0.5)
                do_rpl_coins = np.random.rand(len(tokens)) <= rpl_threshold  # only keep original for 20% of time
                new_cand = []
                for i, tok in enumerate(normalized_tokens):
                    if not keep_original[i] and do_rpl_coins[i]:
                        all_syns = tok2syn.get(tok, [tok])
                        syn = np.random.choice(all_syns)
                        new_cand.append(syn)
                    else:
                        new_cand.append(tok)
                new_cand = delim.join(new_cand)
                len_before = len(syn_rpl_candidates)
                syn_rpl_candidates.add(new_cand)
                if len(syn_rpl_candidates) > len_before:
                    patience = 5
                else:
                    patience -= 1
            tab["rpls_syndict"].update({col : syn_rpl_candidates})
    
    
    # #STEP 3: filter syn dict replacement with NLI
    print('STEP 3: filter syn dict replacement with NLI...\n')
    STRICT_MODE = True
    constructed_pairs_rpl_syndict = construct_pairs_for_nli_test(tables, pending_rpls_key="rpls_syndict",
                                                     table_name_key=table_name_key,
                                                     col_type_key=col_type_key,
                                                     col_name_key=col_name_key)
    results_rpl_syndict = nli_test_across_tables(constructed_pairs_rpl_syndict, batch_size=batch_size)
    for i, dic in enumerate(results_rpl_syndict):
        table_id = constructed_pairs_rpl_syndict[i]["table_id"]
        table = tables_template[table_id]
        ent = dic["scores"][2]
        if ent >= replace_threshold:
            ori, rpl = dic["key_map"]
#             print(f"{ori} -> {rpl}")
            table["REPLACE"][ori].append(rpl)  # update REPLACE key
#     print(results_rpl_syndict)
    
    
    # STEP 4: filter retrieval replacement with NLI
    print('STEP 4: filter retrieval replacement with NLI...\n')
    STRICT_MODE = True
    constructed_pairs_rpl_retrieval = construct_pairs_for_nli_test(tables, pending_rpls_key="rpls_retrieval",
                                                     table_name_key=table_name_key,
                                                     col_type_key=col_type_key,
                                                     col_name_key=col_name_key)
    results_rpl_retrieval = nli_test_across_tables(constructed_pairs_rpl_retrieval, batch_size=batch_size)
    for i, dic in enumerate(results_rpl_retrieval):
        table_id = constructed_pairs_rpl_retrieval[i]["table_id"]
        table = tables_template[table_id]
        ent = dic["scores"][2]
        if ent >= replace_threshold:
            ori, rpl = dic["key_map"]
            table["REPLACE"][ori].append(rpl)  # update REPLACE key

    
    
    # STEP 5: prune all replaceable from retrieval results & filter substring overlap from original col
    print('STEP 5: prune all replaceable from retrieval results & filter substring overlap from original col...\n')
    for table in tables:
        tid = table["table_id"]
        all_columns = table[col_name_key]
        rpl_dict = tables_template[tid]["REPLACE"]
        for rpl_col in table["rpls_retrieval"].keys():
            rpl_candidates = table["rpls_retrieval"][rpl_col]
            add_candidates = [] # ADD operation candidates comes from here
            for rpl in rpl_candidates:
                if any([rpl in c for c in all_columns]) or any([rpl in c for c in rpl_dict[rpl_col]]):
                    continue
                add_candidates.append(rpl)
            table["rpls_retrieval"][rpl_col] = add_candidates
#     print(tables[8]["rpls_retrieval"])
    
    
    
    # STEP 6 :  filter leftover retrieval ADD candidates with NLI
    print('STEP 6 :  filter leftover retrieval ADD candidates with NLI\n')
    STRICT_MODE = False
    constructed_pairs_add_retrieval = construct_pairs_for_nli_test(tables, pending_rpls_key="rpls_retrieval",
                                                     table_name_key=table_name_key,
                                                     col_type_key=col_type_key,
                                                     col_name_key=col_name_key)
    results_add_retrieval = nli_test_across_tables(constructed_pairs_add_retrieval, batch_size=batch_size)
    for i, dic in enumerate(results_add_retrieval):
        table_id = constructed_pairs_add_retrieval[i]["table_id"]
        table = tables_template[table_id]
        ent = dic["scores"][2]
        if ent <= add_threshold:
            ori, rpl = dic["key_map"]
            table["ADD"][ori].append(rpl)  # update ADD key
#     print(tables_template["SPIDER_8"]["ADD"])
    
    # STEP 7: Write replace + add results to new file
    print('STEP 7: Write REPLACE & ADD results to new file \n')
    with open(f"{output_dir}/{output_prefix}-pipeline-output.jsonl", "w") as f:
        for table in tables_template.values():
            json.dump(table, f)
            f.write("\n")

# Leave For Running

In [11]:
# replace_and_add_for_give_tables("./data/wikisql/wikisql.tables.train.refer.values-bm.jsonl",
#                          table2qs=wsql_train_table2qs,
#                          output_prefix="wsql-train",
#                          table_name_key="pred_table_name",
#                          replace_threshold=0.70,
#                          batch_size=512)

In [12]:
replace_and_add_for_give_tables("./data/spider/spider-tables-synonym.values-bm.jsonl",
                         table2qs=spider_table2qs,
                         output_prefix="spider",
                         table_name_key="table_name",
                         replace_threshold=0.50,
                         batch_size=512)

STEP 0 : Prepare tables...

STEP 1: Dense retrieval for add / replacement...



  name_emb /= len(units)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 876/876 [01:54<00:00,  7.67it/s]


STEP 2 : synonym dict for replacement...




  0%|                                                                                                                       | 0/876 [00:00<?, ?it/s][A
  3%|███▎                                                                                                         | 27/876 [00:00<00:03, 260.47it/s][A
  8%|████████▋                                                                                                    | 70/876 [00:00<00:02, 352.14it/s][A
 13%|█████████████▌                                                                                              | 110/876 [00:00<00:02, 367.65it/s][A
 17%|██████████████████                                                                                          | 147/876 [00:00<00:02, 350.89it/s][A
 21%|██████████████████████▌                                                                                     | 183/876 [00:00<00:02, 342.37it/s][A
 25%|██████████████████████████▉                                                       

STEP 3: filter syn dict replacement with NLI...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 876/876 [00:01<00:00, 504.77it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [01:22<00:00,  2.29s/it]


STEP 4: filter retrieval replacement with NLI...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 876/876 [00:02<00:00, 324.90it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [02:11<00:00,  2.35s/it]


STEP 5: prune all replaceable from retrieval results & filter substring overlap from original col...

STEP 6 :  filter leftover retrieval ADD candidates with NLI



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 876/876 [00:02<00:00, 387.41it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [02:05<00:00,  2.72s/it]


STEP 7: Write REPLACE & ADD results to new file 



In [13]:
replace_and_add_for_give_tables("./data/WTQ/wtq.tables.values-bm.jsonl",
                         table2qs=wtq_table2qs,
                         output_prefix="wtq",
                         table_name_key="pred_table_name",
                         replace_threshold=0.50,
                         batch_size=512)

STEP 0 : Prepare tables...

STEP 1: Dense retrieval for add / replacement...



  name_emb /= len(units)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 2108/2108 [02:37<00:00, 13.37it/s]


STEP 2 : synonym dict for replacement...




  0%|                                                                                                                      | 0/2108 [00:00<?, ?it/s][A
  1%|█▏                                                                                                          | 23/2108 [00:00<00:09, 214.38it/s][A
  2%|██▎                                                                                                         | 45/2108 [00:00<00:09, 215.84it/s][A
  3%|███▋                                                                                                        | 71/2108 [00:00<00:08, 233.10it/s][A
  5%|████▊                                                                                                       | 95/2108 [00:00<00:08, 230.06it/s][A
  6%|██████▏                                                                                                    | 121/2108 [00:00<00:08, 237.68it/s][A
  7%|███████▋                                                                          

STEP 3: filter syn dict replacement with NLI...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 2108/2108 [00:04<00:00, 460.07it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 108/108 [03:42<00:00,  2.06s/it]


STEP 4: filter retrieval replacement with NLI...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 2108/2108 [00:05<00:00, 388.40it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [04:15<00:00,  2.13s/it]


STEP 5: prune all replaceable from retrieval results & filter substring overlap from original col...

STEP 6 :  filter leftover retrieval ADD candidates with NLI



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 2108/2108 [00:04<00:00, 466.42it/s]
100%|███████████████████████████████████████████| 96/96 [03:27<00:00,  2.16s/it]


STEP 7: Write REPLACE & ADD results to new file 



In [14]:
replace_and_add_for_give_tables("./data/wikisql/wikisql.tables.dev.test.refer.values-bm.jsonl",
                         table2qs=wsql_table2qs,
                         output_prefix="wikisql",
                         table_name_key="pred_table_name",
                         replace_threshold=0.50,
                         batch_size=512)

STEP 0 : Prepare tables...

STEP 1: Dense retrieval for add / replacement...



  name_emb /= len(units)
100%|███████████████████████████████████████| 7946/7946 [10:46<00:00, 12.29it/s]


STEP 2 : synonym dict for replacement...




  0%|                                                  | 0/7946 [00:00<?, ?it/s][A
  0%|                                        | 11/7946 [00:00<01:13, 108.23it/s][A
  0%|▏                                       | 25/7946 [00:00<01:03, 124.48it/s][A
  1%|▏                                       | 44/7946 [00:00<00:51, 152.86it/s][A
  1%|▎                                       | 63/7946 [00:00<00:47, 166.87it/s][A
  1%|▍                                       | 83/7946 [00:00<00:44, 176.47it/s][A
  1%|▌                                      | 105/7946 [00:00<00:41, 188.95it/s][A
  2%|▌                                      | 124/7946 [00:00<00:41, 188.24it/s][A
  2%|▋                                      | 145/7946 [00:00<00:40, 192.70it/s][A
  2%|▊                                      | 165/7946 [00:00<00:44, 176.14it/s][A
  2%|▉                                      | 183/7946 [00:01<00:44, 175.73it/s][A
  3%|▉                                      | 201/7946 [00:01<00:44, 172.53

STEP 3: filter syn dict replacement with NLI...



100%|██████████████████████████████████████| 7946/7946 [00:19<00:00, 409.96it/s]
100%|█████████████████████████████████████████| 440/440 [15:41<00:00,  2.14s/it]


STEP 4: filter retrieval replacement with NLI...



100%|██████████████████████████████████████| 7946/7946 [00:26<00:00, 295.05it/s]
100%|█████████████████████████████████████████| 564/564 [21:13<00:00,  2.26s/it]


STEP 5: prune all replaceable from retrieval results & filter substring overlap from original col...

STEP 6 :  filter leftover retrieval ADD candidates with NLI



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 7946/7946 [00:22<00:00, 349.85it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 454/454 [17:39<00:00,  2.33s/it]


STEP 7: Write REPLACE & ADD results to new file 



In [15]:
replace_and_add_for_give_tables("./data/wikisql/wikisql.tables.train.refer.values-bm.jsonl",
                         table2qs=wsql_train_table2qs,
                         output_prefix="wikisql-train",
                         table_name_key="pred_table_name",
                         replace_threshold=0.50,
                         batch_size=256)

STEP 0 : Prepare tables...

STEP 1: Dense retrieval for add / replacement...



  name_emb /= len(units)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 7864/7864 [14:12<00:00,  9.23it/s]


STEP 2 : synonym dict for replacement...




  0%|                                                                                                                      | 0/7864 [00:00<?, ?it/s][A
  0%|▏                                                                                                           | 12/7864 [00:00<01:08, 114.43it/s][A
  0%|▎                                                                                                           | 25/7864 [00:00<01:04, 121.81it/s][A
  1%|▋                                                                                                           | 47/7864 [00:00<00:47, 163.24it/s][A
  1%|▉                                                                                                           | 64/7864 [00:00<00:47, 165.17it/s][A
  1%|█▏                                                                                                          | 83/7864 [00:00<00:44, 173.02it/s][A
  1%|█▎                                                                                

STEP 3: filter syn dict replacement with NLI...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 7864/7864 [00:19<00:00, 401.76it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 820/820 [15:58<00:00,  1.17s/it]


STEP 4: filter retrieval replacement with NLI...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 7864/7864 [00:28<00:00, 278.62it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1114/1114 [22:31<00:00,  1.21s/it]


STEP 5: prune all replaceable from retrieval results & filter substring overlap from original col...

STEP 6 :  filter leftover retrieval ADD candidates with NLI



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 7864/7864 [00:23<00:00, 339.55it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 907/907 [18:30<00:00,  1.22s/it]


STEP 7: Write REPLACE & ADD results to new file 

