# Predefined

In [None]:
from __future__ import absolute_import, division, print_function

import argparse
import glob
import logging
import os
import pickle
import random
import re
import shutil

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler

try:
    from torch.utils.tensorboard import SummaryWriter
except:
    from tensorboardX import SummaryWriter

from tqdm import trange
from tqdm.autonotebook import tqdm

from data_loader.hybrid_data_loaders import *
from data_loader.header_data_loaders import *
from data_loader.CT_Wiki_data_loaders import *
from data_loader.RE_data_loaders import *
from data_loader.EL_data_loaders import *
from model.configuration import TableConfig
from model.model import HybridTableMaskedLM, HybridTableCER, TableHeaderRanking, HybridTableCT,HybridTableEL,HybridTableRE,BertRE
from model.transformers import BertConfig,BertTokenizer, WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup
from utils.util import *
from baselines.row_population.metric import average_precision,ndcg_at_k
from baselines.cell_filling.cell_filling import *
from model import metric

logger = logging.getLogger(__name__)

MODEL_CLASSES = {
    'CER': (TableConfig, HybridTableCER, BertTokenizer),
    'CF' : (TableConfig, HybridTableMaskedLM, BertTokenizer),
    'HR': (TableConfig, TableHeaderRanking, BertTokenizer),
    'CT': (TableConfig, HybridTableCT, BertTokenizer),
    'EL': (TableConfig, HybridTableEL, BertTokenizer),
    'RE': (TableConfig, HybridTableRE, BertTokenizer),
    'REBERT': (BertConfig, BertRE, BertTokenizer)
}

# set data directory, this will be used to load test data
data_dir = r"G:\CPSC448\TURL\data\wikitables_v2"

config_name = "configs/table-base-config_v2.json"
device = torch.device('cuda')
# load entity vocab from entity_vocab.txt
entity_vocab = load_entity_vocab(data_dir, ignore_bad_title=True, min_ent_count=2)
entity_wikid2id = {entity_vocab[x]['wiki_id']:x for x in entity_vocab}

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

type_vocab = load_type_vocab(data_dir)
entity_vocab = load_entity_vocab(data_dir, ignore_bad_title=True, min_ent_count=2)

id2type = {idx:t for t, idx in type_vocab.items()}
t2d_invalid = set()

def average_precision(output, relevance_labels):
    with torch.no_grad():
        sorted_output = torch.argsort(output, dim=-1, descending=True)
        sorted_labels = torch.gather(relevance_labels, -1, sorted_output).float()
        cum_correct = torch.cumsum(sorted_labels, dim=-1)
        cum_precision = cum_correct / torch.arange(start=1,end=cum_correct.shape[-1]+1, device=cum_correct.device)[None, :]
        cum_precision = cum_precision * sorted_labels
        total_valid = torch.sum(sorted_labels, dim=-1)
        total_valid[total_valid==0] = 1
        average_precision = torch.sum(cum_precision, dim=-1)/total_valid

    return average_precision

In [None]:
DATASET_PATH = r"G:\CPSC448\TURL\data\wikitables_v2\procressed_WikiCT\test.pickle"
CHECKPOINTS = [
    r"G:\CPSC448\TURL\data\pre-trained_models\checkpoints/0/pytorch_model.bin",
    r"G:\CPSC448\TURL\data\pre-trained_models\checkpoints/1/pytorch_model.bin",
    r"G:\CPSC448\TURL\data\pre-trained_models\checkpoints/2/pytorch_model.bin",
    r"G:\CPSC448\TURL\data\pre-trained_models\checkpoints/3/pytorch_model.bin",
    r"G:\CPSC448\TURL\data\pre-trained_models\checkpoints/4/pytorch_model.bin",
    r"G:\CPSC448\TURL\data\pre-trained_models\checkpoints/5/pytorch_model.bin"
]
TEST_JSON_ALL = r"G:\CPSC448\TURL\data\wikitables_v2\test.all_table_col_type.json"
TEST_JSON = r"G:\CPSC448\TURL\data\wikitables_v2\test.table_col_type.json"
with open(TEST_JSON_ALL, 'r') as f:
    ALL_TABLES = json.load(f)
with open(os.path.join(data_dir, 'test.table_col_type.json'), 'r') as f:
    TEST_TABLES = json.load(f)

In [None]:
# Modify json
def readTable(json_path, table_num):
    with open(json_path, 'r') as f:
        return json.load(f)[table_num]

def writeJson(json_path, tables):
    with open(json_path, 'w') as f:
        json.dump(tables, f)

# Get the Logits

### Predefined

In [None]:
# Get the logits from prediction using a checkpoint on the test_dataset. 
# Please also set the mode.
def predict(test_dataset, checkpoint, mode):
    # Define the evaluation sets
    per_type_accuracy = {}
    per_type_precision = {}
    per_type_recall = {}
    per_type_f1 = {}
    map = {}
    precision = {}
    recall = {}
    f1 = {}
    per_table_result = {}
    
    # Start prediction
    print(f"Mode: {mode}")
    config_class, model_class, _ = MODEL_CLASSES['CT']
    config = config_class.from_pretrained(config_name)
    config.class_num = len(type_vocab)
    config.mode = mode
    model = model_class(config, is_simple=True)
    checkpoint = torch.load(checkpoint)
    model.load_state_dict(checkpoint)
    model.to(device)
    model.eval()
    eval_batch_size = 20
    eval_sampler = SequentialSampler(test_dataset)
    eval_dataloader = CTLoader(test_dataset, sampler=eval_sampler, batch_size=eval_batch_size, is_train=False)
    eval_loss = 0.0
    eval_map = 0.0
    nb_eval_steps = 0
    eval_targets = []
    eval_prediction_scores = []
    eval_pred = []
    eval_mask = []
    per_table_result[mode] = {}
    
    logits = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        table_ids, input_tok, input_tok_type, input_tok_pos, input_tok_mask, \
            input_ent_text, input_ent_text_length, input_ent, input_ent_type, input_ent_mask, \
            column_entity_mask, column_header_mask, labels_mask, labels = batch
        input_tok = input_tok.to(device)
        input_tok_type = input_tok_type.to(device)
        input_tok_pos = input_tok_pos.to(device)
        input_tok_mask = input_tok_mask.to(device)
        input_ent_text = input_ent_text.to(device)
        input_ent_text_length = input_ent_text_length.to(device)
        input_ent = input_ent.to(device)
        input_ent_type = input_ent_type.to(device)
        input_ent_mask = input_ent_mask.to(device)
        column_entity_mask = column_entity_mask.to(device)
        column_header_mask = column_header_mask.to(device)
        labels_mask = labels_mask.to(device)
        labels = labels.to(device)
        if mode == 1:
            input_ent_mask = input_ent_mask[:,:,input_tok_mask.shape[1]:]
            input_tok = None
            input_tok_type = None
            input_tok_pos = None
            input_tok_mask = None
        elif mode == 2:
            input_tok_mask = input_tok_mask[:,:,:input_tok_mask.shape[1]]
            input_ent_text = None
            input_ent_text_length = None
            input_ent = None
            input_ent_type = None
            input_ent_mask = None
        elif mode == 3:
            input_ent = None
        elif mode == 4:
            input_ent_mask = input_ent_mask[:,:,input_tok_mask.shape[1]:]
            input_tok = None
            input_tok_type = None
            input_tok_pos = None
            input_tok_mask = None
            input_ent = None
        elif mode == 5:
            input_ent_mask = input_ent_mask[:,:,input_tok_mask.shape[1]:]
            input_tok = None
            input_tok_type = None
            input_tok_pos = None
            input_tok_mask = None
            input_ent_text = None
            input_ent_text_length = None
        with torch.no_grad():
            outputs = model(input_tok, input_tok_type, input_tok_pos, input_tok_mask,\
                input_ent_text, input_ent_text_length, input_ent, input_ent_type, input_ent_mask, column_entity_mask, column_header_mask, labels_mask, labels)
            loss = outputs[0]
            prediction_scores = outputs[1]
            for l_i in t2d_invalid:
                prediction_scores[:,:,l_i] = -1000
            for idx, table_id in enumerate(table_ids):
                valid = labels_mask[idx].nonzero().max().item()+1
                if table_id not in per_table_result[mode]:
                    per_table_result[mode][table_id] = [[],labels_mask[idx,:valid],labels[idx,:valid]]
                per_table_result[mode][table_id][0].append(prediction_scores[idx,:valid])
            ap = metric.average_precision(prediction_scores.view(-1, config.class_num), labels.view((-1, config.class_num)))
            map = (ap*labels_mask.view(-1)).sum()/labels_mask.sum()
            eval_loss += loss.mean().item()
            eval_map += map.item()
            eval_targets.extend(labels.view(-1, config.class_num).tolist())
            eval_prediction_scores.extend(prediction_scores.view(-1, config.class_num).tolist())
            eval_pred.extend((torch.sigmoid(prediction_scores.view(-1, config.class_num))>0.5).tolist())
            eval_mask.extend(labels_mask.view(-1).tolist())
        nb_eval_steps += 1
        # print(loss.shape)
        logits.append(prediction_scores)
    return logits, per_table_result

LENGTH_ALL_TABLES = []

for table in ALL_TABLES:
    length = 0
    for col in table[6]:
        if len(col) > length:
            length = len(col)
    LENGTH_ALL_TABLES.append(length)

In [None]:
def wrappedPredict(tables):
    writeJson(TEST_JSON, tables)

    if os.path.exists(DATASET_PATH):
        os.remove(DATASET_PATH)
    test_dataset = WikiCTDataset(data_dir, entity_vocab, type_vocab, max_input_tok=500, src="test", max_length = [50, 10, 10], force_new=False, tokenizer = None)


    # Get the logits and the predicted results
    logits, per_table_result = predict(test_dataset, CHECKPOINTS[4], 4)
    
    return logits, per_table_result

LOGITS, _ = wrappedPredict(ALL_TABLES)

#### Evaluation

In [None]:
total_corr = 0
total_valid = 0
errors = []
for table_id, result in per_table_result[4].items():
    prediction_scores, label_mask, label = result
    prediction_scores = torch.stack(prediction_scores, 0).mean(0)
    current_corr = 0
    for col_idx, pred in enumerate(prediction_scores.argmax(-1).tolist()):
        current_corr += label[col_idx, pred].item()
    total_valid += label_mask.sum().item()
    total_corr += current_corr
    if current_corr!=label_mask.sum().item():
        errors.append(table_id)
print(total_corr/total_valid, total_valid)

## Adversarial attacks

In [None]:
import copy

def extractColumn(file_path, column_index):
    column_values = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            columns = line.strip().split('\t')
            if len(columns) > column_index:
                column_values.append(columns[column_index])
    return column_values

def switchEntities(tables, percentage):
    entity_list = extractColumn(os.path.join(data_dir, "entity_vocab.txt"), 2)
    entity_id_list = extractColumn(os.path.join(data_dir, "entity_vocab.txt"), 1)
    tables_copy = copy.deepcopy(tables)
    
    for table_index in range (len(tables)):
        col_num = len(tables[table_index][6])
        entity_position = []
        # Get the number of entities
        for col_index in range(col_num):
            row_num = len(tables[table_index][6][col_index])
            for row_index in range(row_num):
                entity_position.append([row_index, col_index])
        
        # Randomly choose entities * percentage entities. random_entity_positions is a list
        random_entity_positions = random.sample(entity_position, int(len(entity_position) * percentage))
        
        
        for random_entity_position in random_entity_positions:
            # Random index in the list containing all the entities
            rand_num = random.randint(0, len(entity_list))
            
            [random_entity_row, random_entity_col] = random_entity_position
            tables_copy[table_index][6][random_entity_col][random_entity_row][1] = [int(entity_id_list[rand_num]), entity_list[rand_num]]
                
    return tables_copy

def maskEntities(tables, percentage):
    entity_list = extractColumn(os.path.join(data_dir, "entity_vocab.txt"), 2)
    entity_id_list = extractColumn(os.path.join(data_dir, "entity_vocab.txt"), 1)
    tables_copy = copy.deepcopy(tables)
    
    for table_index in range (len(tables)):
        col_num = len(tables[table_index][6])
        entity_position = []
        # Get the number of entities
        for col_index in range(col_num):
            row_num = len(tables[table_index][6][col_index])
            for row_index in range(row_num):
                entity_position.append([row_index, col_index])
        
        # Randomly choose entities * percentage entities. random_entity_positions is a list
        random_entity_positions = random.sample(entity_position, int(len(entity_position) * percentage))
        
        
        for random_entity_position in random_entity_positions:            
            [random_entity_row, random_entity_col] = random_entity_position
            tables_copy[table_index][6][random_entity_col][random_entity_row][1][1] = 'ENT_MASK'
                
    return tables_copy

In [None]:
switched_tables = switchEntities(ALL_TABLES, 0.8)
logits, per_table_result = wrappedPredict(switched_tables)

In [None]:
total_corr = 0
total_valid = 0
errors = []
for table_id, result in per_table_result[4].items():
    prediction_scores, label_mask, label = result
    prediction_scores = torch.stack(prediction_scores, 0).mean(0)
    current_corr = 0
    for col_idx, pred in enumerate(prediction_scores.argmax(-1).tolist()):
        current_corr += label[col_idx, pred].item()
    total_valid += label_mask.sum().item()
    total_corr += current_corr
    if current_corr!=label_mask.sum().item():
        errors.append(table_id)
print(total_corr/total_valid, total_valid)

# Mask

In [None]:
# len(ALL_TABLES[0][6][])


logits, _ = wrappedPredict([ALL_TABLES[859]] * 30)
len(logits)


In [None]:
import csv

def deleteOutOfRowBound(tables, length_tables, indices_tables, row_idx):
    filtered_indices = [i for i in range(len(length_tables)) if length_tables[i] > row_idx]
    
    length_tables = [length_tables[i] for i in filtered_indices]
    tables = [tables[i] for i in filtered_indices]
    indices_tables = [indices_tables[i] for i in filtered_indices]
    
    return tables, length_tables, indices_tables

        
# Mask the given row of all the tables in tables
def maskRowOfTables(tables, row_idx):
    for table_idx in range(len(tables)):
        for col_idx in range(len(tables[table_idx][6])):
            try:
                tables[table_idx][6][col_idx][row_idx][1][1] = 'ENT_MASK'
            except IndexError:
                # We don't care about the index error for incomplete tables
                continue
        
    return tables

def writeCSV(table_num, logits_difference_row):
    file_path = rf'G:\CPSC448\TURL\data\logits_difference\table_{table_num}.csv'

    if os.path.exists(file_path):
        with open(file_path, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(logits_difference_row)
                
    else:
        with open(file_path, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(logits_difference_row)

In [None]:
# LOGITS has 239 elements, each for a single batch. Each element has a shape of [table_num_in_batch, col_num, 255]
table1 = copy.deepcopy(ALL_TABLES[1])
del table1[6][0][12:] 
display(table1)
display(LOGITS[0][0][4])
type_vocab['time.event']

In [None]:
# Copy all tables, this will be shorten
copy_all_tables = copy.deepcopy(ALL_TABLES)
copy_all_lengths = copy.deepcopy(LENGTH_ALL_TABLES)
copy_all_table_indices = list(range(len(ALL_TABLES)))

row_idx = 0
while len(copy_all_tables) > 0 and row_idx < 10:
    copy_all_tables, copy_all_lengths, copy_all_table_indices = deleteOutOfRowBound(copy_all_tables, copy_all_lengths, copy_all_table_indices, row_idx)
    
    if len(copy_all_tables) == 0:
        break
    
    # Remember the tables before masked. For recovery.
    temp_tables = copy.deepcopy(copy_all_tables)
    copy_all_tables = maskRowOfTables(copy_all_tables, row_idx)
    logits_masked, _ = wrappedPredict(copy_all_tables)
    
    for batch_idx in range(len(logits_masked)):
        for table_index_in_batch in range(len(logits_masked[batch_idx])):
            # Create a logits difference table
            logits_difference_row = []
            
            # Get the table index among the remaining tables
            table_index_in_remaining = table_index_in_batch + 20 * batch_idx
            # Get the index among all the tables
            table_index_in_all = copy_all_table_indices[table_index_in_remaining]
            # Go to LOGITS and get the logits for that table
            correct_table_logits = LOGITS[table_index_in_all//20][table_index_in_all%20]
            print(correct_table_logits[0][196])
            
            table = ALL_TABLES[table_index_in_all]
            correct_labels = table[7]
            col_num = len(correct_labels)
            
            for col_idx in range(len(correct_labels)):
                label_index = type_vocab[correct_labels[col_idx][0]]
                
                logits_difference_row.append(math.fabs(correct_table_logits[col_idx][label_index].item() - logits_masked[batch_idx][table_index_in_batch][col_idx][label_index].item()))
                print(table_index_in_all, " ", math.fabs(logits_masked[batch_idx][table_index_in_batch][col_idx][label_index].item()))
                
                
            writeCSV(table_index_in_all, logits_difference_row)
    # Deal with the logits
    ####
    # Recover
    copy_all_tables = temp_tables
    row_idx += 1
