In [1]:
import pandas as pd
import numpy as np
import random

CSV_PATH = 'data/WikiTableQuestions/'
KEY = 'WikiSQL'
DATASET = 'dev'

if KEY == 'WikiTableQuestions':
    if DATASET=='train':
        iter_item = pd.read_csv(f'WikiTableQA/data/training.tsv', delimiter='\t')
    elif DATASET=='test':
        iter_item = pd.read_csv(f'WikiTableQA/data/pristine-unseen-tables.tsv', delimiter='\t')
    elif 'dev' in DATASET:
        iter_item = pd.read_csv(f'WikiTableQA/data/random-split-{DATASET[-1]}-dev.tsv', delimiter='\t')
    else:
        print("Dataset should be in ['train','test','dev']")
        assert False
elif KEY == 'WikiSQL':
    if DATASET=='train':
        iter_item = pd.read_csv(f'WikiSQL/annotated/train.csv')
    elif DATASET=='test':
        iter_item = pd.read_csv(f'WikiSQL/annotated/test.csv')
    elif 'dev' in DATASET:
        iter_item = pd.read_csv(f'WikiSQL/annotated/dev.csv')
    else:
        print("Dataset should be in ['train','test','dev']")
        assert False


with open('stopwords.txt') as f:
    stopwords = list(map(lambda x: x.strip(), f.readlines()))

In [2]:
def perturb_table(df, target, question):
    # Find rows to switch using the target value 
    match_row_idx, match_col_idx = np.where(df.applymap(lambda x: str(x).lower())==target.lower())
    
    # # If none, find rows to switch using keywords from the question
    # if len(set(match_row_idx)) in [0, df.shape[0]]:
    #     keywords = set(question.lower().split())                                         # Collect keywords from the question
    #     keywords = list(filter(lambda x: len(x)>1 and (x not in stopwords), keywords))           # Filter out stopwords
    #     match_count = df.shape[0]
    #     for word in keywords:
    #         temp_row_idx, temp_col_idx = np.where(df.applymap(lambda x: word in str(x).lower())) # Identify the keyword with the least number of matches
    #         num_matches = len(set(temp_row_idx))                                                 # Count how many rows matched
    #         if num_matches<match_count and num_matches>0:                                        # If we find a keyword that matches fewer rows (but non-zero)
    #             match_count = num_matches                                                        # Then use that keyword to swap rows
    #             match_row_idx, match_col_idx = temp_row_idx, temp_col_idx
    
    if len(set(match_row_idx)) not in [0, df.shape[0]]:
        new_row_idx = random.choice(list(set(range(df.shape[0])).difference(set(match_row_idx))))
        df.iloc[match_row_idx[0], match_col_idx[0]], df.iloc[new_row_idx, match_col_idx[0]] = df.iloc[new_row_idx, match_col_idx[0]], df.iloc[match_row_idx[0], match_col_idx[0]]
        new_answer = df.iloc[match_row_idx[0], match_col_idx[0]]
        return df, new_answer
    else:
        return 'None', 'None'

In [3]:
def get_item_table(row):
    try:
        pd_table = pd.read_csv(f"WikiTableQuestions/{row['context'][:-4]}.tsv", 
                           delimiter='\t')
    except:
        pd_table = pd.read_csv(f"WikiTableQuestions/{row['context']}")
    # Get question and answer
    question = row['utterance']
    answer   = row['targetValue']
    return pd_table, question, answer

def get_sql_table(row, DATASET):
    DATASET_ = DATASET if 'dev' not in DATASET else 'dev'
    pd_table = pd.read_csv(f"WikiSQL/csv/{DATASET_}/{row['table_id']}.csv")
    question = row['question']
    answer   = row['answer']
    return pd_table, question, answer


In [4]:
for idx, row in iter_item.iterrows():
    pd_table, question, answer = get_item_table(row) \
                                    if KEY=='WikiTableQuestions' \
                                    else get_sql_table(row, DATASET)
    answer = answer.strip('[]').strip("'")
    
    new_table, new_answer = perturb_table(pd_table, answer, question)
    
    if type(new_table)!=str:
        new_table.to_csv(f'{KEY}/csv_perturbed/{DATASET}/{row["table_id"]}.csv', index=False)
        iter_item.loc[idx,'targetValue_new'] = new_answer
        iter_item.loc[idx,'perturb_flag'] = 1
    else:
        iter_item.loc[idx,'targetValue_new'] = None
        iter_item.loc[idx,'perturb_flag'] = 0

In [5]:
iter_item.loc[(iter_item.perturb_flag==1)&\
             (~iter_item.targetValue_new.isnull())]\
        .reset_index(drop=True)\
        .to_csv(f'{KEY}/seq2seq_data/{DATASET}_perturb.csv', index=False)