In [None]:
from __future__ import absolute_import, division, print_function

import sys

sys.path.append("..")


import argparse
import glob
import logging
import os
import pickle
import random
import re
import shutil

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange
from tqdm.autonotebook import tqdm
from transformers import WEIGHTS_NAME, AdamW, BertConfig, BertTokenizer, get_linear_schedule_with_warmup

from src.baselines.cell_filling.cell_filling import *
from src.baselines.row_population.metric import average_precision, ndcg_at_k
from src.data_loader.ct_wiki_data_loaders import *
from src.data_loader.el_data_loaders import *
from src.data_loader.header_data_loaders import *
from src.data_loader.hybrid_data_loaders import *
from src.data_loader.re_data_loaders import *
from src.model import metric
from src.model.configuration import TableConfig
from src.model.model import (
    BertRE,
    HybridTableCER,
    HybridTableCT,
    HybridTableEL,
    HybridTableMaskedLM,
    HybridTableRE,
    TableHeaderRanking,
)
from src.utils.util import *

In [None]:
logger = logging.getLogger(__name__)

MODEL_CLASSES = {
    "CER": (TableConfig, HybridTableCER, BertTokenizer),
    "CF": (TableConfig, HybridTableMaskedLM, BertTokenizer),
    "HR": (TableConfig, TableHeaderRanking, BertTokenizer),
    "CT": (TableConfig, HybridTableCT, BertTokenizer),
    "EL": (TableConfig, HybridTableEL, BertTokenizer),
    "RE": (TableConfig, HybridTableRE, BertTokenizer),
    "REBERT": (BertConfig, BertRE, BertTokenizer),
}

In [None]:
# set data directory, this will be used to load test data
data_dir = "~/turl-data"
data_dir = os.path.expanduser(data_dir)

In [None]:
config_name = "/home/fbelotti/projects/TURL/src/configs/table-base-config_v2.json"
device = torch.device("cuda:1")

In [None]:
entity_vocab = load_entity_vocab(data_dir, ignore_bad_title=True, min_ent_count=2)
entity_wikid2id = {entity_vocab[x]["wiki_id"]: x for x in entity_vocab}

In [None]:
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Table of Contents
This notebook shows examples of how to using the model components and running evaluation of different tasks.
* [Pretrained and Cell Filling](#cf)
* [Entity Linking](#el)
* [Column Type Classification](#ct)
* [Relation Extraction](#re)

<a class="anchor" id="cf"></a>
# Pretrained and CF
Here we show how to use the pretrained model to get contextualized representation for a given input table. 

We use the cell filling task for demonstration as it does not need task-specific finetuning.

In [None]:
config_class, model_class, _ = MODEL_CLASSES["CF"]
config = config_class.from_pretrained(config_name)
config.output_attentions = True

# For CF, we use the base HybridTableMaskedLM, and directly load the pretrained checkpoint
checkpoint = "output/hybrid/v2/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam/"
model = model_class(config, is_simple=True)
checkpoint = torch.load(os.path.join(checkpoint, "pytorch_model.bin"))
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
# load the module for cell filling baselines
CF = cell_filling(data_dir)

In [None]:
with open(os.path.join(data_dir, "CF_test_data.json"), "r") as f:
    dev_data = json.load(f)
print("example for cell filling")
display(dev_data[0])
# the dataset here is the dataloader for pretraining. We use it to pass the config to construct the cell filling example
dataset = WikiHybridTableDataset(
    data_dir,
    entity_vocab,
    max_cell=100,
    max_input_tok=350,
    max_input_ent=150,
    src="dev",
    max_length=[50, 10, 10],
    force_new=False,
    tokenizer=None,
    mode=0,
)
print("example of pretraining data")
with open(os.path.join(data_dir, "dev_tables.jsonl"), "r") as f:
    for line in f:
        example = json.loads(line.strip())
        break
display(example)

In [None]:
# This is an example of converting an arbitrary table to input
# Here we show an example for cell filling task
# The input entites are entities in the subject column, we append [ENT_MASK] and use its representation to match with the candidate entities
def CF_build_input(pgEnt, pgTitle, secTitle, caption, headers, core_entities, core_entities_text, entity_cand, config):
    tokenized_pgTitle = config.tokenizer.encode(pgTitle, max_length=config.max_title_length, add_special_tokens=False)
    tokenized_meta = tokenized_pgTitle + config.tokenizer.encode(
        secTitle, max_length=config.max_title_length, add_special_tokens=False
    )
    if caption != secTitle:
        tokenized_meta += config.tokenizer.encode(caption, max_length=config.max_title_length, add_special_tokens=False)
    tokenized_headers = [
        config.tokenizer.encode(header, max_length=config.max_header_length, add_special_tokens=False)
        for header in headers
    ]
    input_tok = []
    input_tok_pos = []
    input_tok_type = []
    tokenized_meta_length = len(tokenized_meta)
    input_tok += tokenized_meta
    input_tok_pos += list(range(tokenized_meta_length))
    input_tok_type += [0] * tokenized_meta_length
    header_span = []
    for tokenized_header in tokenized_headers:
        tokenized_header_length = len(tokenized_header)
        header_span.append([len(input_tok), len(input_tok) + tokenized_header_length])
        input_tok += tokenized_header
        input_tok_pos += list(range(tokenized_header_length))
        input_tok_type += [1] * tokenized_header_length

    input_ent = [config.entity_wikid2id[pgEnt] if pgEnt != -1 else 0]
    input_ent_text = [tokenized_pgTitle[: config.max_cell_length]]
    input_ent_type = [2]

    # core entities in the subject column
    input_ent += [config.entity_wikid2id[entity] for entity in core_entities]
    input_ent_text += [
        (
            config.tokenizer.encode(entity_text, max_length=config.max_cell_length, add_special_tokens=False)
            if len(entity_text) != 0
            else []
        )
        for entity_text in core_entities_text
    ]
    input_ent_type += [3] * len(core_entities)

    # append [ent_mask]
    input_ent += [config.entity_wikid2id["[ENT_MASK]"]] * len(core_entities)
    input_ent_text += [[]] * len(core_entities)
    input_ent_type += [4] * len(core_entities)

    input_ent_cell_length = [len(x) if len(x) != 0 else 1 for x in input_ent_text]
    max_cell_length = max(input_ent_cell_length)
    input_ent_text_padded = np.zeros([len(input_ent_text), max_cell_length], dtype=int)
    for i, x in enumerate(input_ent_text):
        input_ent_text_padded[i, : len(x)] = x
    assert len(input_ent) == 1 + 2 * len(core_entities)

    input_tok_mask = np.ones([1, len(input_tok), len(input_tok) + len(input_ent)], dtype=int)
    input_tok_mask[0, header_span[0][0] : header_span[0][1], len(input_tok) + 1 + len(core_entities) :] = 0
    input_tok_mask[
        0, header_span[1][0] : header_span[1][1], len(input_tok) + 1 : len(input_tok) + 1 + len(core_entities)
    ] = 0
    input_tok_mask[0, :, len(input_tok) + 1 + len(core_entities) :] = 0

    # build the mask for entities
    input_ent_mask = np.ones([1, len(input_ent), len(input_tok) + len(input_ent)], dtype=int)
    input_ent_mask[0, 1 : 1 + len(core_entities), header_span[1][0] : header_span[1][1]] = 0
    input_ent_mask[0, 1 : 1 + len(core_entities), len(input_tok) + 1 + len(core_entities) :] = np.eye(
        len(core_entities), dtype=int
    )
    input_ent_mask[0, 1 + len(core_entities) :, header_span[0][0] : header_span[0][1]] = 0
    input_ent_mask[0, 1 + len(core_entities) :, len(input_tok) + 1 : len(input_tok) + 1 + len(core_entities)] = np.eye(
        len(core_entities), dtype=int
    )
    input_ent_mask[0, 1 + len(core_entities) :, len(input_tok) + 1 + len(core_entities) :] = np.eye(
        len(core_entities), dtype=int
    )

    input_tok_mask = torch.LongTensor(input_tok_mask)
    input_ent_mask = torch.LongTensor(input_ent_mask)

    input_tok = torch.LongTensor([input_tok])
    input_tok_type = torch.LongTensor([input_tok_type])
    input_tok_pos = torch.LongTensor([input_tok_pos])

    input_ent = torch.LongTensor([input_ent])
    input_ent_text = torch.LongTensor([input_ent_text_padded])
    input_ent_cell_length = torch.LongTensor([input_ent_cell_length])
    input_ent_type = torch.LongTensor([input_ent_type])

    input_ent_mask_type = torch.zeros_like(input_ent)
    input_ent_mask_type[:, 1 + len(core_entities) :] = config.entity_wikid2id["[ENT_MASK]"]

    candidate_entity_set = [config.entity_wikid2id[entity] for entity in entity_cand]
    candidate_entity_set = torch.LongTensor([candidate_entity_set])

    return (
        input_tok,
        input_tok_type,
        input_tok_pos,
        input_tok_mask,
        input_ent,
        input_ent_text,
        input_ent_cell_length,
        input_ent_type,
        input_ent_mask_type,
        input_ent_mask,
        candidate_entity_set,
    )

In [None]:
results = []
for table_id, pgEnt, pgTitle, secTitle, caption, (h1, h2), data_sample in tqdm(dev_data):
    result = []
    while len(data_sample) != 0:
        core_entities = []
        core_entities_text = []
        target_entities = []
        all_entity_cand = set()
        entity_cand = []
        for (core_e, core_e_text), target_e in data_sample[:100]:
            assert target_e in entity_wikid2id
            core_entities.append(core_e)
            core_entities_text.append(core_e_text)
            target_entities.append(target_e)
            cands = CF.get_cand_row(core_e, h2)
            cands = {key: value for key, value in cands.items() if key in entity_wikid2id}
            entity_cand.append(cands)
            all_entity_cand |= set(cands.keys())
        all_entity_cand = list(all_entity_cand)
        (
            input_tok,
            input_tok_type,
            input_tok_pos,
            input_tok_mask,
            input_ent,
            input_ent_text,
            input_ent_text_length,
            input_ent_type,
            input_ent_mask_type,
            input_ent_mask,
            candidate_entity_set,
        ) = CF_build_input(
            pgEnt, pgTitle, secTitle, caption, [h1, h2], core_entities, core_entities_text, all_entity_cand, dataset
        )
        input_tok = input_tok.to(device)
        input_tok_type = input_tok_type.to(device)
        input_tok_pos = input_tok_pos.to(device)
        input_tok_mask = input_tok_mask.to(device)
        input_ent_text = input_ent_text.to(device)
        input_ent_text_length = input_ent_text_length.to(device)
        input_ent = input_ent.to(device)
        input_ent_type = input_ent_type.to(device)
        input_ent_mask_type = input_ent_mask_type.to(device)
        input_ent_mask = input_ent_mask.to(device)
        candidate_entity_set = candidate_entity_set.to(device)
        with torch.no_grad():
            tok_outputs, ent_outputs = model(
                input_tok,
                input_tok_type,
                input_tok_pos,
                input_tok_mask,
                input_ent_text,
                input_ent_text_length,
                input_ent_mask_type,
                input_ent,
                input_ent_type,
                input_ent_mask,
                candidate_entity_set,
            )
            num_sample = len(target_entities)
            ent_prediction_scores = ent_outputs[0][0, num_sample + 1 :].tolist()
        for i, target_e in enumerate(target_entities):
            predictions = ent_prediction_scores[i]
            if len(entity_cand[i]) == 0:
                result.append([target_e, entity_cand[i], [], []])
            else:
                tmp_cand_scores = []
                for j, cand_e in enumerate(all_entity_cand):
                    if cand_e in entity_cand[i]:
                        tmp_cand_scores.append([cand_e, predictions[j]])
                sorted_cand_scores = sorted(tmp_cand_scores, key=lambda z: z[1], reverse=True)
                sorted_cands = [z[0] for z in sorted_cand_scores]
                # use H2H as baseline
                base_sorted_cands = CF.rank_cand_h2h(h2, entity_cand[i])
                result.append([target_e, entity_cand[i], sorted_cands, base_sorted_cands])
        data_sample = data_sample[100:]
    results.append(
        {"pgTitle": pgTitle, "secTitle": secTitle, "caption": caption, "headers": [h1, h2], "result": result}
    )

In [None]:
print("tok(metadata) outputs", len(tok_outputs))
print("tok prediction logits: [batch_size, num_toks, vocab_size]\n", tok_outputs[0].shape)
print("tok hidden states: [batch_size, num_toks, hidden_size]\n", tok_outputs[1].shape)
print(
    "tok attention: n_layers*[batch_size, num_attention_headers, num_toks, num_toks+num_ents]\n",
    tok_outputs[2][0].shape,
)
print("entity(cell) outputs", len(ent_outputs))
print("ent prediction logits: [batch_size, num_ents, candidate_size]\n", ent_outputs[0].shape)
print("ent hidden states: [batch_size, num_ents, hidden_size]\n", ent_outputs[1].shape)
print(
    "ent attention: n_layers*[batch_size, num_attention_headers, num_ents, num_toks+num_ents]\n",
    ent_outputs[2][0].shape,
)

In [None]:
def get_precision(result):
    recall = 0
    precision_neural = [0, 0, 0, 0]
    precision_base = [0, 0, 0, 0]
    for target_e, cand, p_neural, p_base in result:
        if target_e in cand:
            recall += 1
            if target_e == p_neural[0]:
                precision_neural[0] += 1
            if target_e == p_base[0]:
                precision_base[0] += 1
            if target_e in p_neural[:3]:
                precision_neural[1] += 1
            if target_e in p_neural[:5]:
                precision_neural[2] += 1
            if target_e in p_neural[:10]:
                precision_neural[3] += 1
            if target_e in p_base[:3]:
                precision_base[1] += 1
            if target_e in p_base[:5]:
                precision_base[2] += 1
            if target_e in p_base[:10]:
                precision_base[3] += 1
    if recall != 0:
        return recall / len(result), [z / recall for z in precision_neural], [z / recall for z in precision_base]
    else:
        return 0, [0 for z in precision_neural], [0 for z in precision_base]

In [None]:
final_results = [get_precision(x["result"]) for x in results]
print("recall", np.mean([x[0] for x in final_results]))
print("neural")
print("p@1", np.mean([x[1][0] for x in final_results if x[0] != 0]))
print("p@3", np.mean([x[1][1] for x in final_results if x[0] != 0]))
print("p@5", np.mean([x[1][2] for x in final_results if x[0] != 0]))
print("p@10", np.mean([x[1][3] for x in final_results if x[0] != 0]))
print("base")
print("p@1", np.mean([x[2][0] for x in final_results if x[0] != 0]))
print("p@3", np.mean([x[2][1] for x in final_results if x[0] != 0]))
print("p@5", np.mean([x[2][2] for x in final_results if x[0] != 0]))
print("p@10", np.mean([x[2][3] for x in final_results if x[0] != 0]))

<a class="anchor" id="el"></a>
# EL
Evaluate Entity Linking

In [None]:
final_results = [get_precision(x["result"]) for x in results]
print("recall", np.mean([x[0] for x in final_results]))
print("neural")
print("p@1", np.mean([x[1][0] for x in final_results if x[0] != 0]))
print("p@3", np.mean([x[1][1] for x in final_results if x[0] != 0]))
print("p@5", np.mean([x[1][2] for x in final_results if x[0] != 0]))
print("p@10", np.mean([x[1][3] for x in final_results if x[0] != 0]))
print("base")
print("p@1", np.mean([x[2][0] for x in final_results if x[0] != 0]))
print("p@3", np.mean([x[2][1] for x in final_results if x[0] != 0]))
print("p@5", np.mean([x[2][2] for x in final_results if x[0] != 0]))
print("p@10", np.mean([x[2][3] for x in final_results if x[0] != 0]))

In [None]:
print("neural")
print(
    "p@1",
    np.mean([x[1][0] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@3",
    np.mean([x[1][1] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@5",
    np.mean([x[1][2] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@10",
    np.mean([x[1][3] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print("base")
print(
    "p@1",
    np.mean([x[2][0] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@3",
    np.mean([x[2][1] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@5",
    np.mean([x[2][2] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@10",
    np.mean([x[2][3] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)

In [None]:
print("neural")
print(
    "p@1",
    np.mean([x[1][0] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@3",
    np.mean([x[1][1] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@5",
    np.mean([x[1][2] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@10",
    np.mean([x[1][3] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print("base")
print(
    "p@1",
    np.mean([x[2][0] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@3",
    np.mean([x[2][1] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@5",
    np.mean([x[2][2] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)
print(
    "p@10",
    np.mean([x[2][3] for i, x in enumerate(final_results) if x[0] != 0 and "team" not in results[i]["headers"][1]]),
)

In [None]:
# load dbpedia types from depedia_type_vocab.txt
type_vocab = load_dbpedia_type_vocab(data_dir)
config_class, model_class, _ = MODEL_CLASSES["EL"]
config = config_class.from_pretrained(config_name)
config.ent_type_vocab_size = len(type_vocab)
config.mode = 0

In [None]:
with open(os.path.join(data_dir, "test_own.table_entity_linking.json"), "r") as f:
    example = json.load(f)[0]
display(example)

In [None]:
# load test data from [dataset].table_entity_linking.json
test_dataset = ELDataset(
    data_dir,
    type_vocab,
    max_input_tok=500,
    src="test_own",
    max_length=[50, 10, 10, 100],
    force_new=False,
    tokenizer=None,
)

In [None]:
model = model_class(config, is_simple=True)
# load the checkpoint based on mode
checkpoint = torch.load(
    f"output/EL/v2/{config.mode}/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam/pytorch_model.bin"
)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

In [None]:
test_batch_size = 10
test_sampler = SequentialSampler(test_dataset)
test_dataloader = ELLoader(test_dataset, sampler=test_sampler, batch_size=test_batch_size, is_train=False)

# Eval!
print("Num examples = %d" % len(test_dataset))
print("Batch size = %d" % test_batch_size)
test_loss = 0.0
test_acc = 0.0
nb_test_steps = 0
test_results = []

for batch in tqdm(test_dataloader, desc="Evaluating"):
    (
        table_id,
        input_tok,
        input_tok_type,
        input_tok_pos,
        input_tok_mask,
        input_ent_text,
        input_ent_text_length,
        input_ent_type,
        input_ent_mask,
        cand_name,
        cand_name_length,
        cand_description,
        cand_description_length,
        cand_type,
        cand_type_length,
        cand_mask,
        labels,
        entities_index,
    ) = batch
    input_tok = input_tok.to(device)
    input_tok_type = input_tok_type.to(device)
    input_tok_pos = input_tok_pos.to(device)
    input_tok_mask = input_tok_mask.to(device)
    input_ent_text = input_ent_text.to(device)
    input_ent_text_length = input_ent_text_length.to(device)
    input_ent_type = input_ent_type.to(device)
    input_ent_mask = input_ent_mask.to(device)
    cand_name = cand_name.to(device)
    cand_name_length = cand_name_length.to(device)
    cand_description = cand_description.to(device)
    cand_description_length = cand_description_length.to(device)
    cand_type = cand_type.to(device)
    cand_type_length = cand_type_length.to(device)
    cand_mask = cand_mask.to(device)
    labels = labels.to(device)

    if config.mode == 1:
        cand_description = None
        cand_description_length = None
    elif config.mode == 2:
        cand_type = None
        cand_type_length = None
    elif config.mode != 0:
        raise Exception

    with torch.no_grad():
        outputs = model(
            input_tok,
            input_tok_type,
            input_tok_pos,
            input_tok_mask,
            input_ent_text,
            input_ent_text_length,
            input_ent_type,
            input_ent_mask,
            cand_name,
            cand_name_length,
            cand_description,
            cand_description_length,
            cand_type,
            cand_type_length,
            cand_mask,
            labels,
        )
        loss = outputs[0]
        prediction_scores = outputs[1]
        predict_index = torch.argsort(
            prediction_scores.view(input_ent_text.size(0), input_ent_text.size(1) - 1, -1), descending=True
        )
        sorted_scores = (
            torch.gather(
                prediction_scores.view(input_ent_text.size(0), input_ent_text.size(1) - 1, -1), -1, predict_index
            )
        ).tolist()
        predict_index = predict_index.tolist()
        acc = metric.accuracy(prediction_scores, labels.view(-1), ignore_index=-1)
        cand_length = cand_mask.sum(1).tolist()
        ent_length = (labels != -1).sum(1).tolist()
        for i, t_id in enumerate(table_id):
            test_results.append(
                [
                    t_id,
                    entities_index[i],
                    [x[: cand_length[i]] for x in predict_index[i][: ent_length[i]]],
                    [x[: cand_length[i]] for x in sorted_scores[i][: ent_length[i]]],
                ]
            )
        test_loss += loss.mean().item()
        test_acc += acc.item()
    nb_test_steps += 1

test_loss = test_loss / nb_test_steps
test_acc = test_acc / nb_test_steps

result = {
    "eval_loss": test_loss,
    "eval_acc": test_acc,
}
for key in sorted(result.keys()):
    print("%s = %s" % (key, str(result[key])))

In [None]:
# we dump the predictions in seperate file an use another script for official evaluation.
# The reason is that our entity linking is based on wikidata lookup. In certain cases, the candidates
# do not contain the target entity, such test example is still considered for metric calculation.
# However, since there is nothing to rank we do not pass thoses examples here. So the test examples here
# is incomplete
with open(os.path.join(data_dir, "test_own_entity_linking_results_0.pkl"), "wb") as f:
    pickle.dump(test_results, f)

<a class="anchor" id="ct"></a>
# CT
Evaluate column type annotation

In [None]:
with open(os.path.join(data_dir, "test.table_col_type.json"), "r") as f:
    example = json.load(f)[0]
display(example)

In [None]:
# load type vocab from type_vocab.txt
type_vocab = load_type_vocab(data_dir)
test_dataset = WikiCTDataset(
    data_dir,
    entity_vocab,
    type_vocab,
    max_input_tok=500,
    src="test",
    max_length=[50, 10, 10],
    force_new=False,
    tokenizer=None,
)

In [None]:
id2type = {idx: t for t, idx in type_vocab.items()}
t2d_invalid = set()

In [None]:
def average_precision(output, relevance_labels):
    with torch.no_grad():
        sorted_output = torch.argsort(output, dim=-1, descending=True)
        sorted_labels = torch.gather(relevance_labels, -1, sorted_output).float()
        cum_correct = torch.cumsum(sorted_labels, dim=-1)
        cum_precision = (
            cum_correct / torch.arange(start=1, end=cum_correct.shape[-1] + 1, device=cum_correct.device)[None, :]
        )
        cum_precision = cum_precision * sorted_labels
        total_valid = torch.sum(sorted_labels, dim=-1)
        total_valid[total_valid == 0] = 1
        average_precision = torch.sum(cum_precision, dim=-1) / total_valid

    return average_precision

In [None]:
per_type_accuracy = {}
per_type_precision = {}
per_type_recall = {}
per_type_f1 = {}
map = {}
precision = {}
recall = {}
f1 = {}
per_table_result = {}

In [None]:
from safetensors.torch import load_model

In [None]:
from tqdm.autonotebook import tqdm

checkpoints = [
    "/home/fbelotti/projects/TURL/output/logs/turl/fine-tuning-ct/2024-02-07_16-45-51/version_0/checkpoints/checkpoint-80000/pytorch_model.bin",
    "/home/fbelotti/projects/TURL/output/logs/turl/fine-tuning-ct/2024-02-07_16-45-51/version_0/checkpoints/checkpoint-90000/pytorch_model.bin",
    "/home/fbelotti/projects/TURL/output/logs/turl/fine-tuning-ct/2024-02-07_16-45-51/version_0/checkpoints/checkpoint-last/pytorch_model.bin",
    "/home/fbelotti/projects/TURL/output/logs/turl/fine-tuning-ct/2024-02-07_18-24-49/version_0/checkpoints/checkpoint-last/model.safetensors",
]
for mode in range(6):
    if mode != 0:
        continue
    print("Mode:", mode)
    config_class, model_class, _ = MODEL_CLASSES["CT"]
    config = config_class.from_pretrained(config_name)
    config.class_num = len(type_vocab)
    config.mode = mode
    model = model_class(config, is_simple=True)
    checkpoint = checkpoints[-1]
    if checkpoint.endswith("safetensors"):
        missing, unexpected = load_model(model, checkpoint, strict=True)
        print(missing, unexpected)
    else:
        checkpoint = torch.load(checkpoint)
        model.load_state_dict(checkpoint)
    model.to(device)
    model.eval()
    eval_batch_size = 20
    eval_sampler = SequentialSampler(test_dataset)
    eval_dataloader = CTLoader(test_dataset, sampler=eval_sampler, batch_size=eval_batch_size, is_train=False)
    eval_loss = 0.0
    eval_map = 0.0
    nb_eval_steps = 0
    eval_targets = []
    eval_prediction_scores = []
    eval_pred = []
    eval_mask = []
    per_table_result[mode] = {}
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        (
            table_ids,
            input_tok,
            input_tok_type,
            input_tok_pos,
            input_tok_mask,
            input_ent_text,
            input_ent_text_length,
            input_ent,
            input_ent_type,
            input_ent_mask,
            column_entity_mask,
            column_header_mask,
            labels_mask,
            labels,
        ) = batch
        input_tok = input_tok.to(device)
        input_tok_type = input_tok_type.to(device)
        input_tok_pos = input_tok_pos.to(device)
        input_tok_mask = input_tok_mask.to(device)
        input_ent_text = input_ent_text.to(device)
        input_ent_text_length = input_ent_text_length.to(device)
        input_ent = input_ent.to(device)
        input_ent_type = input_ent_type.to(device)
        input_ent_mask = input_ent_mask.to(device)
        column_entity_mask = column_entity_mask.to(device)
        column_header_mask = column_header_mask.to(device)
        labels_mask = labels_mask.to(device)
        labels = labels.to(device)
        if mode == 1:
            input_ent_mask = input_ent_mask[:, :, input_tok_mask.shape[1] :]
            input_tok = None
            input_tok_type = None
            input_tok_pos = None
            input_tok_mask = None
        elif mode == 2:
            input_tok_mask = input_tok_mask[:, :, : input_tok_mask.shape[1]]
            input_ent_text = None
            input_ent_text_length = None
            input_ent = None
            input_ent_type = None
            input_ent_mask = None
        elif mode == 3:
            input_ent = None
        elif mode == 4:
            input_ent_mask = input_ent_mask[:, :, input_tok_mask.shape[1] :]
            input_tok = None
            input_tok_type = None
            input_tok_pos = None
            input_tok_mask = None
            input_ent = None
        elif mode == 5:
            input_ent_mask = input_ent_mask[:, :, input_tok_mask.shape[1] :]
            input_tok = None
            input_tok_type = None
            input_tok_pos = None
            input_tok_mask = None
            input_ent_text = None
            input_ent_text_length = None
        with torch.no_grad():
            outputs = model(
                input_tok,
                input_tok_type,
                input_tok_pos,
                input_tok_mask,
                input_ent_text,
                input_ent_text_length,
                input_ent,
                input_ent_type,
                input_ent_mask,
                column_entity_mask,
                column_header_mask,
                labels_mask,
                labels,
            )
            loss = outputs[0]
            prediction_scores = outputs[1]
            for l_i in t2d_invalid:
                prediction_scores[:, :, l_i] = -1000
            for idx, table_id in enumerate(table_ids):
                valid = labels_mask[idx].nonzero().max().item() + 1
                if table_id not in per_table_result[mode]:
                    per_table_result[mode][table_id] = [[], labels_mask[idx, :valid], labels[idx, :valid]]
                per_table_result[mode][table_id][0].append(prediction_scores[idx, :valid])
            ap = metric.average_precision(
                prediction_scores.view(-1, config.class_num), labels.view((-1, config.class_num))
            )
            map = (ap * labels_mask.view(-1)).sum() / labels_mask.sum()
            eval_loss += loss.mean().item()
            eval_map += map.item()
            eval_targets.extend(labels.view(-1, config.class_num).tolist())
            eval_prediction_scores.extend(prediction_scores.view(-1, config.class_num).tolist())
            eval_pred.extend((torch.sigmoid(prediction_scores.view(-1, config.class_num)) > 0.5).tolist())
            eval_mask.extend(labels_mask.view(-1).tolist())
        nb_eval_steps += 1
    print(eval_map / nb_eval_steps)
    eval_targets = np.array(eval_targets)
    eval_prediction_scores = np.array(eval_prediction_scores)
    eval_mask = np.array(eval_mask)
    eval_prediction_ranks = np.argsort(np.argsort(-eval_prediction_scores))
    eval_pred = np.array(eval_pred)
    eval_tp = eval_mask[:, np.newaxis] * eval_pred * eval_targets
    eval_precision = np.sum(eval_tp, axis=0) / np.sum(eval_mask[:, np.newaxis] * eval_pred, axis=0)
    eval_precision = np.nan_to_num(eval_precision, 1)
    eval_recall = np.sum(eval_tp, axis=0) / np.sum(eval_mask[:, np.newaxis] * eval_targets, axis=0)
    eval_recall = np.nan_to_num(eval_recall, 1)
    eval_f1 = 2 * eval_precision * eval_recall / (eval_precision + eval_recall)
    eval_f1 = np.nan_to_num(eval_f1, 0)
    per_type_instance_num = np.sum(eval_mask[:, np.newaxis] * eval_targets, axis=0)
    per_type_instance_num[per_type_instance_num == 0] = 1
    per_type_correct_instance_num = np.sum(
        eval_mask[:, np.newaxis] * (eval_prediction_ranks < eval_targets.sum(axis=1)[:, np.newaxis]) * eval_targets,
        axis=0,
    )
    per_type_accuracy[mode] = per_type_correct_instance_num / per_type_instance_num
    per_type_precision[mode] = eval_precision
    per_type_recall[mode] = eval_recall
    per_type_f1[mode] = eval_f1
    precision[mode] = np.sum(eval_tp) / np.sum(eval_mask[:, np.newaxis] * eval_pred)
    recall[mode] = np.sum(eval_tp) / np.sum(eval_mask[:, np.newaxis] * eval_targets)
    f1[mode] = 2 * precision[mode] * recall[mode] / (precision[mode] + recall[mode])

In [None]:
total_corr = 0
total_valid = 0
errors = []
for table_id, result in per_table_result[3].items():
    prediction_scores, label_mask, label = result
    prediction_scores = torch.stack(prediction_scores, 0).mean(0)
    current_corr = 0
    for col_idx, pred in enumerate(prediction_scores.argmax(-1).tolist()):
        current_corr += label[col_idx, pred].item()
    total_valid += label_mask.sum().item()
    total_corr += current_corr
    if current_corr != label_mask.sum().item():
        errors.append(table_id)
print(total_corr / total_valid, total_valid)

In [None]:
for t, i in sorted(type_vocab.items(), key=lambda z: -per_type_instance_num[z[1]]):
    print(
        "%s %.4f %.4f %.4f %.4f %.4f  %.4f %.4f"
        % (
            t,
            per_type_instance_num[i],
            per_type_f1[0][i],
            per_type_f1[4][i],
            per_type_f1[1][i],
            per_type_f1[3][i],
            per_type_f1[2][i],
            per_type_f1[5][i],
        )
    )
    print()

In [None]:
f1, precision, recall

Type mapping is used to map the types used in some other datasets to our types, so we can directly evaluate without retraining our model

In [None]:
t2d_type_mapping = {
    "Election": ["government.election"],
    "Film": ["film.film"],
    "mountain": ["geography.mountain"],
    "Building": ["architecture.building"],
    "RadioStation": ["broadcast.radio_station"],
    "TelevisionShow": ["tv.tv_program"],
    "Country": ["location.country"],
    "Airport": ["aviation.airport"],
    "AdministrativeRegion": ["location.region"],
    "University": ["education.university"],
    "Newspaper": ["book.newspaper"],
    "FictionalCharacter": ["fictional_universe.fictional_character"],
    "Currency": ["finance.currency"],
    "Novel": ["book.book"],
    "Wrestler": ["sports.pro_athlete"],
    "swimmer": ["sports.pro_athlete"],
    "GolfPlayer": ["sports.golfer", "sports.pro_athlete"],
    "Book": ["book.book"],
    "Political Party": ["government.political_party"],
    "Person": ["people.person"],
    "VideoGame": ["cvg.computer_videogame"],
    "Animal": ["biology.animal"],
    "PoliticalParty": ["government.political_party"],
    "BaseballPlayer": ["sports.pro_athlete"],
    "Monarch": ["royalty.monarch"],
    "Mountain": ["geography.mountain"],
    "City": ["location.citytown"],
    "Company": ["business.consumer_company"],
    "cricketer": ["sports.pro_athlete"],
    "Airline": ["aviation.airline"],
}
t2d_types = set([y for _, x in t2d_type_mapping.items() for y in x])
t2d_invalid = []
for t, i in type_vocab.items():
    if t not in t2d_types:
        t2d_invalid.append(i)

In [None]:
t2d_type_mapping = {
    "City": ["location.citytown"],
    "VideoGame": ["cvg.computer_videogame"],
    "Mountain": ["geography.mountain"],
    "Museum": [],
    "Writer": ["film.writer", "tv.tv_writer", "music.writer", "book.author"],
    "Lake": [],
    "AdministrativeRegion": ["location.administrative_division"],
    "Book": ["book.book"],
    "Saint": [],
    "Monarch": ["royalty.monarch"],
    "Bird": [],
    "Plant": [],
    "Mayor": [],
    "Currency": ["finance.currency"],
    "MovieDirector": ["film.director"],
    "Company": [
        "film.production_company",
        "automotive.company",
        "business.consumer_company",
        "business.defunct_company",
    ],
    "Genre": [
        "cvg.cvg_genre",
        "film.film_genre",
        "broadcast.genre",
        "media_common.media_genre",
        "tv.tv_genre",
        "music.genre",
    ],
    "GovernmentType": ["government.governmental_body"],
    "Hospital": [],
    "Building": ["architecture.building"],
    "PoliticalParty": ["government.political_party"],
    "Language": ["language.human_language"],
    "Country": ["location.country"],
    "University": ["education.university"],
    "SportsTeam": ["sports.sports_team"],
    "RadioStation": ["broadcast.radio_station"],
    "Airport": ["aviation.airport"],
    "Airline": ["aviation.airline"],
    "Wrestler": [],
    "Newspaper": ["book.newspaper"],
    "Mammal": [],
    "MountainRange": [],
    "BaseballPlayer": ["baseball.baseball_player"],
    "AcademicJournal": [],
    "Scientist": [],
    "Continent": [],
    "Film": ["film.film"],
}

t2d_types = set([y for _, x in t2d_type_mapping.items() for y in x])
t2d_invalid = []
for t, i in type_vocab.items():
    if t not in t2d_types:
        t2d_invalid.append(i)

In [None]:
t2d_type_mapping = {
    "Film": ["film.film"],
    "Lake": [],
    "Language": ["language.human_language"],
    "Country": ["location.country"],
    "Company": [
        "film.production_company",
        "automotive.company",
        "business.consumer_company",
        "business.defunct_company",
    ],
    "Person": ["people.person"],
    "VideoGame": ["cvg.computer_videogame"],
    "City": ["location.citytown"],
    "Currency": ["finance.currency"],
    "Bird": [],
    "Mountain": ["geography.mountain"],
    "Scientist": [],
    "Plant": [],
    "TelevisionShow": ["tv.tv_program"],
    "Animal": [],
    "AdministrativeRegion": ["location.administrative_division"],
    "Genre": [
        "cvg.cvg_genre",
        "film.film_genre",
        "broadcast.genre",
        "media_common.media_genre",
        "tv.tv_genre",
        "music.genre",
    ],
    "Newspaper": ["book.newspaper"],
    "Airport": ["aviation.airport"],
    "AcademicJournal": [],
    "PopulatedPlace": [],
    "Wrestler": [],
    "PoliticalParty": ["government.political_party"],
    "Cricketer": ["cricket.cricket_player"],
    "Eukaryote": [],
    "Saint": [],
    "Writer": ["film.writer", "tv.tv_writer", "music.writer", "book.author"],
    "Museum": [],
    "BaseballPlayer": ["baseball.baseball_player"],
    "EducationalInstitution": ["education.educational_institution"],
    "GovernmentType": ["government.governmental_body"],
    "SportsTeam": ["sports.sports_team"],
}

reverse_type_mapping = {t2: t1 for t1, t2s in t2d_type_mapping.items() for t2 in t2s}

t2d_types = set([y for _, x in t2d_type_mapping.items() for y in x])
t2d_invalid = []
for t, i in type_vocab.items():
    if t not in t2d_types:
        t2d_invalid.append(i)

In [None]:
errors

In [None]:
p = 0
pred = 0
tp = 0
for table_id, result in per_table_result[4].items():
    prediction_scores, label_mask, label = result
    prediction_scores = torch.stack(prediction_scores, 0).mean(0)
    current_corr = 0
    for col_idx in range(label.shape[0]):
        if label_mask[col_idx] != 0:
            gt_t = set([reverse_type_mapping[id2type[t]] for t in label[col_idx].nonzero()[0].tolist()])
            if (prediction_scores[col_idx] > 0).nonzero().shape[0] > 0:
                pred_t = set(
                    [reverse_type_mapping[id2type[t]] for t in (prediction_scores[col_idx] > 0).nonzero()[0].tolist()]
                )
            else:
                pred_t = set()
            p += len(gt_t)
            pred += len(pred_t)
            tp += len(gt_t & pred_t)
precision = tp / pred
recall = tp / p
f1 = 2 * precision * recall / (precision + recall)
print(f1, precision, recall)

In [None]:
p

In [None]:
pred

In [None]:
tp

In [None]:
label[1].nonzero()[0].tolist()

In [None]:
1 if label_mask[1] == 0 else 0

In [None]:
type_vocab

In [None]:
per_table_result["64499281_8_7181683886563136802"][0][1].argsort(-1)

# CT - Semtab

In [None]:
data_dir = "data/Semtab"
type_vocab = load_type_vocab(data_dir)
test_dataset = WikiCTDataset(
    data_dir,
    entity_vocab,
    type_vocab,
    max_input_tok=500,
    src="wiki_test30",
    max_length=[50, 10, 10],
    force_new=False,
    tokenizer=None,
)

In [None]:
len(type_vocab)

In [None]:
id2type = {y: x for x, y in type_vocab.items()}

In [None]:
def average_precision(output, relevance_labels):
    with torch.no_grad():
        sorted_output = torch.argsort(output, dim=-1, descending=True)
        sorted_labels = torch.gather(relevance_labels, -1, sorted_output).float()
        cum_correct = torch.cumsum(sorted_labels, dim=-1)
        cum_precision = (
            cum_correct / torch.arange(start=1, end=cum_correct.shape[-1] + 1, device=cum_correct.device)[None, :]
        )
        cum_precision = cum_precision * sorted_labels
        total_valid = torch.sum(sorted_labels, dim=-1)
        total_valid[total_valid == 0] = 1
        average_precision = torch.sum(cum_precision, dim=-1) / total_valid

    return average_precision

In [None]:
per_type_accuracy = {}
per_type_precision = {}
per_type_recall = {}
per_type_f1 = {}
map = {}
precision = {}
recall = {}
f1 = {}

In [None]:
from tqdm.autonotebook import tqdm

checkpoint = "output/CT/Semtab/wiki_train70/4/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam/pytorch_model.bin"
mode = 5
print("Mode:", mode)
config_class, model_class, _ = MODEL_CLASSES["CT"]
config = config_class.from_pretrained(config_name)
config.class_num = len(type_vocab)
config.mode = mode
model = model_class(config, is_simple=True)
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
eval_batch_size = 20
eval_sampler = SequentialSampler(test_dataset)
eval_dataloader = CTLoader(test_dataset, sampler=eval_sampler, batch_size=eval_batch_size, is_train=False)
eval_loss = 0.0
eval_map = 0.0
nb_eval_steps = 0
eval_targets = []
eval_prediction_scores = []
eval_pred = []
eval_mask = []
per_table_result = {}
for batch in tqdm(eval_dataloader, desc="Evaluating"):
    (
        table_ids,
        input_tok,
        input_tok_type,
        input_tok_pos,
        input_tok_mask,
        input_ent_text,
        input_ent_text_length,
        input_ent,
        input_ent_type,
        input_ent_mask,
        column_entity_mask,
        column_header_mask,
        labels_mask,
        labels,
    ) = batch
    input_tok = input_tok.to(device)
    input_tok_type = input_tok_type.to(device)
    input_tok_pos = input_tok_pos.to(device)
    input_tok_mask = input_tok_mask.to(device)
    input_ent_text = input_ent_text.to(device)
    input_ent_text_length = input_ent_text_length.to(device)
    input_ent = input_ent.to(device)
    input_ent_type = input_ent_type.to(device)
    input_ent_mask = input_ent_mask.to(device)
    column_entity_mask = column_entity_mask.to(device)
    column_header_mask = column_header_mask.to(device)
    labels_mask = labels_mask.to(device)
    labels = labels.to(device)
    if mode == 1:
        input_ent_mask = input_ent_mask[:, :, input_tok_mask.shape[1] :]
        input_tok = None
        input_tok_type = None
        input_tok_pos = None
        input_tok_mask = None
    elif mode == 2:
        input_tok_mask = input_tok_mask[:, :, : input_tok_mask.shape[1]]
        input_ent_text = None
        input_ent_text_length = None
        input_ent = None
        input_ent_type = None
        input_ent_mask = None
    elif mode == 3:
        input_ent = None
    elif mode == 4:
        input_ent_mask = input_ent_mask[:, :, input_tok_mask.shape[1] :]
        input_tok = None
        input_tok_type = None
        input_tok_pos = None
        input_tok_mask = None
        input_ent = None
    elif mode == 5:
        input_ent_mask = input_ent_mask[:, :, input_tok_mask.shape[1] :]
        input_tok = None
        input_tok_type = None
        input_tok_pos = None
        input_tok_mask = None
        input_ent_text = None
        input_ent_text_length = None
    with torch.no_grad():
        outputs = model(
            input_tok,
            input_tok_type,
            input_tok_pos,
            input_tok_mask,
            input_ent_text,
            input_ent_text_length,
            input_ent,
            input_ent_type,
            input_ent_mask,
            column_entity_mask,
            column_header_mask,
            labels_mask,
            labels,
        )
        loss = outputs[0]
        prediction_scores = outputs[1]
        for idx, table_id in enumerate(table_ids):
            valid = labels_mask[idx].nonzero().max().item() + 1
            if table_id not in per_table_result:
                per_table_result[table_id] = [[], labels_mask[idx, :valid], labels[idx, :valid]]
            per_table_result[table_id][0].append(prediction_scores[idx, :valid])

        eval_loss += loss.mean().item()
        eval_targets.extend(labels.view(-1, config.class_num).tolist())
        eval_prediction_scores.extend(prediction_scores.view(-1, config.class_num).tolist())
        eval_pred.extend(
            (
                prediction_scores.view(-1, config.class_num)
                == prediction_scores.view(-1, config.class_num).max(-1)[0][:, None]
            ).tolist()
        )
        eval_mask.extend(labels_mask.view(-1).tolist())
    nb_eval_steps += 1
eval_targets = np.array(eval_targets)
eval_prediction_scores = np.array(eval_prediction_scores)
eval_mask = np.array(eval_mask)
eval_prediction_ranks = np.argsort(np.argsort(-eval_prediction_scores))
eval_pred = np.array(eval_pred)
eval_tp = eval_mask[:, np.newaxis] * eval_pred * eval_targets
eval_precision = np.sum(eval_tp, axis=0) / np.sum(eval_mask[:, np.newaxis] * eval_pred, axis=0)
eval_precision = np.nan_to_num(eval_precision, 1)
eval_recall = np.sum(eval_tp, axis=0) / np.sum(eval_mask[:, np.newaxis] * eval_targets, axis=0)
eval_recall = np.nan_to_num(eval_recall, 1)
eval_f1 = 2 * eval_precision * eval_recall / (eval_precision + eval_recall)
eval_f1 = np.nan_to_num(eval_f1, 0)
per_type_instance_num = np.sum(eval_mask[:, np.newaxis] * eval_targets, axis=0)
per_type_instance_num[per_type_instance_num == 0] = 1
per_type_correct_instance_num = np.sum(
    eval_mask[:, np.newaxis] * (eval_prediction_ranks < eval_targets.sum(axis=1)[:, np.newaxis]) * eval_targets, axis=0
)
per_type_accuracy[mode] = per_type_correct_instance_num / per_type_instance_num
per_type_precision[mode] = eval_precision
per_type_recall[mode] = eval_recall
per_type_f1[mode] = eval_f1
precision[mode] = np.sum(eval_tp) / np.sum(eval_mask[:, np.newaxis] * eval_pred)
recall[mode] = np.sum(eval_tp) / np.sum(eval_mask[:, np.newaxis] * eval_targets)
f1[mode] = 2 * precision[mode] * recall[mode] / (precision[mode] + recall[mode])

In [None]:
wiki_types = [
    "City",
    "VideoGame",
    "Mountain",
    "Writer",
    "Lake",
    "AdministrativeRegion",
    "Book",
    "Saint",
    "Monarch",
    "Bird",
    "Plant",
    "Currency",
    "Company",
    "Genre",
    "Building",
    "PoliticalParty",
    "Language",
    "Country",
    "University",
    "SportsTeam",
    "RadioStation",
    "Airport",
    "Wrestler",
    "Newspaper",
    "Mammal",
    "Mayor",
    "AcademicJournal",
    "Scientist",
    "Continent",
    "Film",
    "BaseballPlayer",
]
non_wiki_types = [x for x in type_vocab if x not in wiki_types]
wiki_types = set([type_vocab[x] for x in wiki_types])
wiki_type_mask = torch.full((len(type_vocab),), -10000.0).to(device)
for i in wiki_types:
    wiki_type_mask[i] = 0

In [None]:
for t, i in type_vocab.items():
    print(t, per_type_f1[4][i], per_type_instance_num[i])

In [None]:
non_wiki_types

In [None]:
type_vocab

In [None]:
total_corr = 0
total_valid = 0
errors = []
for table_id, result in per_table_result.items():
    prediction_scores, label_mask, label = result
    prediction_scores = torch.stack(prediction_scores, 0).mean(0)
    prediction_scores[:, 15] = torch.where(
        prediction_scores[:, 15] > prediction_scores[:, 27], prediction_scores[:, 15], prediction_scores[:, 27]
    )
    prediction_scores += wiki_type_mask[None, :]
    pred_acc = ((prediction_scores == prediction_scores.max(-1)[0][:, None]) * label).sum(-1)
    total_valid += label_mask.sum().item()
    total_corr += pred_acc.sum().item()
    if pred_acc.sum().item() != label_mask.sum().item():
        errors.append(table_id)

In [None]:
total_corr

In [None]:
total_valid

In [None]:
errors

In [None]:
for inspect_id in errors:
    print(inspect_id)
    prediction_scores, label_mask, label = per_table_result[inspect_id]
    prediction_scores = torch.stack(prediction_scores, 0).mean(0)
    for col_id in range(label.shape[0]):
        if label[col_id].sum().item() == 0:
            continue
        display(id2type[label[col_id].nonzero().item()])
        display(
            [
                [id2type[l], prediction_scores[col_id, l].item()]
                for l in prediction_scores[col_id].argsort().tolist()[::-1][:3]
            ]
        )

In [None]:
label[col_id]

In [None]:
per_table_result["Baseball_Hall_of_Fame_balloting,_2015#0"][0][0].argsort()

<a class="anchor" id="re"></a>
# RE

In [None]:
type_vocab = load_relation_vocab(data_dir)
config.class_num = len(type_vocab)

In [None]:
eval_dataset = REDataset(
    data_dir,
    entity_vocab,
    type_vocab,
    max_input_tok=500,
    src="test",
    max_length=[50, 10, 10],
    force_new=False,
    tokenizer=None,
)

In [None]:
per_type_accuracy = {}
per_type_precision = {}
per_type_recall = {}
per_type_f1 = {}
map = {}
precision = {}
recall = {}
f1 = {}

In [None]:
from tqdm.autonotebook import tqdm

config_name = "configs/table-base-config_v2.json"
checkpoints = [
    "output/RE/v2/0/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam/pytorch_model.bin",
    "output/RE/v2/1/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam/pytorch_model.bin",
    "output/RE/v2/2/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam/pytorch_model.bin",
    "output/RE/v2/3/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam/pytorch_model.bin",
    "output/RE/v2/4/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam/pytorch_model.bin",
    "output/RE/v2/5/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam/pytorch_model.bin",
]
for mode in [0, 1, 2, 3, 4, 5]:
    print(mode)
    config_class, model_class, _ = MODEL_CLASSES["RE"]
    config = config_class.from_pretrained(config_name)
    config.class_num = len(type_vocab)
    config.mode = mode
    model = model_class(config, is_simple=True)
    checkpoint = checkpoints[mode]
    checkpoint = torch.load(checkpoint)
    model.load_state_dict(checkpoint)
    model.to(device)
    model.eval()
    eval_batch_size = 20
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = RELoader(eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size, is_train=False)
    eval_loss = 0.0
    eval_map = 0.0
    nb_eval_steps = 0
    eval_targets = []
    eval_prediction_scores = []
    eval_pred = []
    eval_mask = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        (
            table_id,
            input_tok,
            input_tok_type,
            input_tok_pos,
            input_tok_mask,
            input_ent_text,
            input_ent_text_length,
            input_ent,
            input_ent_type,
            input_ent_mask,
            column_entity_mask,
            column_header_mask,
            labels_mask,
            labels,
        ) = batch
        input_tok = input_tok.to(device)
        input_tok_type = input_tok_type.to(device)
        input_tok_pos = input_tok_pos.to(device)
        input_tok_mask = input_tok_mask.to(device)
        input_ent_text = input_ent_text.to(device)
        input_ent_text_length = input_ent_text_length.to(device)
        input_ent = input_ent.to(device)
        input_ent_type = input_ent_type.to(device)
        input_ent_mask = input_ent_mask.to(device)
        column_entity_mask = column_entity_mask.to(device)
        column_header_mask = column_header_mask.to(device)
        labels_mask = labels_mask.to(device)
        labels = labels.to(device)
        if mode == 1:
            input_ent_mask = input_ent_mask[:, :, input_tok_mask.shape[1] :]
            input_tok = None
            input_tok_type = None
            input_tok_pos = None
            input_tok_mask = None
        elif mode == 2:
            input_tok_mask = input_tok_mask[:, :, : input_tok_mask.shape[1]]
            input_ent_text = None
            input_ent_text_length = None
            input_ent = None
            input_ent_type = None
            input_ent_mask = None
        elif mode == 3:
            input_ent = None
        elif mode == 4:
            input_ent_mask = input_ent_mask[:, :, input_tok_mask.shape[1] :]
            input_tok = None
            input_tok_type = None
            input_tok_pos = None
            input_tok_mask = None
            input_ent = None
        elif mode == 5:
            input_ent_mask = input_ent_mask[:, :, input_tok_mask.shape[1] :]
            input_tok = None
            input_tok_type = None
            input_tok_pos = None
            input_tok_mask = None
            input_ent_text = None
            input_ent_text_length = None
        with torch.no_grad():
            outputs = model(
                input_tok,
                input_tok_type,
                input_tok_pos,
                input_tok_mask,
                input_ent_text,
                input_ent_text_length,
                input_ent,
                input_ent_type,
                input_ent_mask,
                column_entity_mask,
                column_header_mask,
                labels_mask,
                labels,
            )
            loss = outputs[0]
            prediction_scores = outputs[1]
            # pdb.set_trace()
            ap = metric.average_precision(
                prediction_scores.view(-1, config.class_num), labels.view((-1, config.class_num))
            )
            map = (ap * labels_mask.view(-1)).sum() / labels_mask.sum()
            eval_loss += loss.mean().item()
            eval_map += map.item()
            eval_targets.extend(labels.view(-1, config.class_num).tolist())
            eval_prediction_scores.extend(prediction_scores.view(-1, config.class_num).tolist())
            eval_pred.extend((torch.sigmoid(prediction_scores.view(-1, config.class_num)) > 0.5).tolist())
            eval_mask.extend(labels_mask.view(-1).tolist())
        nb_eval_steps += 1
    print(eval_map / nb_eval_steps)
    eval_targets = np.array(eval_targets)
    eval_prediction_scores = np.array(eval_prediction_scores)
    eval_mask = np.array(eval_mask)
    eval_prediction_ranks = np.argsort(np.argsort(-eval_prediction_scores))
    eval_pred = np.array(eval_pred)
    eval_tp = eval_mask[:, np.newaxis] * eval_pred * eval_targets
    eval_precision = np.sum(eval_tp, axis=0) / np.sum(eval_mask[:, np.newaxis] * eval_pred, axis=0)
    eval_precision = np.nan_to_num(eval_precision, 1)
    eval_recall = np.sum(eval_tp, axis=0) / np.sum(eval_mask[:, np.newaxis] * eval_targets, axis=0)
    eval_recall = np.nan_to_num(eval_recall, 1)
    eval_f1 = 2 * eval_precision * eval_recall / (eval_precision + eval_recall)
    eval_f1 = np.nan_to_num(eval_f1, 0)
    per_type_instance_num = np.sum(eval_mask[:, np.newaxis] * eval_targets, axis=0)
    per_type_instance_num[per_type_instance_num == 0] = 1
    per_type_correct_instance_num = np.sum(
        eval_mask[:, np.newaxis] * (eval_prediction_ranks < eval_targets.sum(axis=1)[:, np.newaxis]) * eval_targets,
        axis=0,
    )
    per_type_accuracy[mode] = per_type_correct_instance_num / per_type_instance_num
    per_type_precision[mode] = eval_precision
    per_type_recall[mode] = eval_recall
    per_type_f1[mode] = eval_f1
    precision[mode] = np.sum(eval_tp) / np.sum(eval_mask[:, np.newaxis] * eval_pred)
    recall[mode] = np.sum(eval_tp) / np.sum(eval_mask[:, np.newaxis] * eval_targets)
    f1[mode] = 2 * precision[mode] * recall[mode] / (precision[mode] + recall[mode])

In [None]:
for t, i in sorted(type_vocab.items(), key=lambda z: -per_type_instance_num[z[1]]):
    print(
        "%s %.4f %.4f %.4f %.4f %.4f  %.4f"
        % (
            t,
            per_type_instance_num[i],
            per_type_f1[5][i],
            per_type_f1[0][i],
            per_type_f1[1][i],
            per_type_f1[3][i],
            per_type_f1[2][i],
        )
    )
    print(
        "%s %.4f %.4f %.4f %.4f %.4f  %.4f"
        % (
            t,
            per_type_instance_num[i],
            per_type_precision[5][i],
            per_type_precision[0][i],
            per_type_precision[1][i],
            per_type_precision[3][i],
            per_type_precision[2][i],
        )
    )
    print(
        "%s %.4f %.4f %.4f %.4f %.4f  %.4f"
        % (
            t,
            per_type_instance_num[i],
            per_type_recall[5][i],
            per_type_recall[0][i],
            per_type_recall[1][i],
            per_type_recall[3][i],
            per_type_recall[2][i],
        )
    )
    print()

In [None]:
f1

In [None]:
precision

In [None]:
recall

In [None]:
eval_dataset = REBERTDataset(
    data_dir,
    entity_vocab,
    type_vocab,
    max_input_tok=500,
    src="dev",
    max_length=[50, 10, 10],
    force_new=False,
    tokenizer=None,
)

In [None]:
from tqdm.autonotebook import tqdm

mode = -1
config_name = "configs/tiny-bert-config.json"
checkpoint = "output/RE_Bert/v2/pytorch_model.bin"
config_class, model_class, _ = MODEL_CLASSES["REBERT"]
config = config_class.from_pretrained(config_name)
config.num_labels = len(type_vocab)
config.mode = mode
model = model_class(config)
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
eval_batch_size = 20
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = REBERTLoader(eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size, is_train=False)
eval_loss = 0.0
eval_map = 0.0
nb_eval_steps = 0
eval_targets = []
eval_prediction_scores = []
eval_pred = []
for batch in tqdm(eval_dataloader, desc="Evaluating"):
    table_id, input_tok, input_tok_mask, labels = batch
    input_tok = input_tok.to(device)
    input_tok_mask = input_tok_mask.to(device)
    labels = labels.to(device)
    with torch.no_grad():
        outputs = model(input_tok, attention_mask=input_tok_mask, labels=labels)
        loss = outputs[0]
        prediction_scores = outputs[1]
        # pdb.set_trace()
        ap = metric.average_precision(
            prediction_scores.view(-1, config.num_labels), labels.view((-1, config.num_labels))
        )
        map = map = ap.sum() / len(ap)
        eval_loss += loss.mean().item()
        eval_map += map.item()
        eval_targets.extend(labels.view(-1, config.num_labels).tolist())
        eval_prediction_scores.extend(prediction_scores.view(-1, config.num_labels).tolist())
        eval_pred.extend((torch.sigmoid(prediction_scores.view(-1, config.num_labels)) > 0.5).tolist())
    nb_eval_steps += 1
print(eval_map / nb_eval_steps)
eval_targets = np.array(eval_targets)
eval_prediction_scores = np.array(eval_prediction_scores)
eval_prediction_ranks = np.argsort(np.argsort(-eval_prediction_scores))
eval_pred = np.array(eval_pred)
eval_tp = eval_pred * eval_targets
eval_precision = np.sum(eval_tp, axis=0) / np.sum(eval_pred, axis=0)
eval_precision = np.nan_to_num(eval_precision, 1)
eval_recall = np.sum(eval_tp, axis=0) / np.sum(eval_targets, axis=0)
eval_recall = np.nan_to_num(eval_recall, 1)
eval_f1 = 2 * eval_precision * eval_recall / (eval_precision + eval_recall)
eval_f1 = np.nan_to_num(eval_f1, 0)
per_type_instance_num = np.sum(eval_targets, axis=0)
per_type_instance_num[per_type_instance_num == 0] = 1
per_type_correct_instance_num = np.sum(
    (eval_prediction_ranks < eval_targets.sum(axis=1)[:, np.newaxis]) * eval_targets, axis=0
)
per_type_accuracy[mode] = per_type_correct_instance_num / per_type_instance_num
per_type_precision[mode] = eval_precision
per_type_recall[mode] = eval_recall
per_type_f1[mode] = eval_f1
precision[mode] = np.sum(eval_tp) / np.sum(eval_pred)
recall[mode] = np.sum(eval_tp) / np.sum(eval_targets)
f1[mode] = 2 * precision[mode] * recall[mode] / (precision[mode] + recall[mode])

In [None]:
from tqdm.autonotebook import tqdm

for k in [3, 5, 10]:
    print(k)
    maps_base = []
    recalls = []
    for i, x in enumerate(tqdm(test_dataset)):
        header_count = {i: 0.0 for i in range(config.header_vocab_size)}
        dist, neighbor = neigh.kneighbors(tfidf.transform([x[1][1 : x[5] + 1]]), k, return_distance=True)
        dist = dist.reshape(-1)
        for j, n in enumerate(neighbor.reshape([-1])):
            for h in train_dataset[n][6]:
                if dist[j] == 0:
                    header_count[h] += 100
                else:
                    header_count[h] += 1 / dist[j]
        target_e = set(x[6])
        recalls.append(len([z for z in header_count if z in target_e]) / len(target_e))
        ap = average_precision(
            [1 if z in target_e else 0 for z, _ in sorted(header_count.items(), key=lambda p: p[1], reverse=True)]
        )
        maps_base.append(ap)
    print(np.mean(maps_base))
    print(np.mean(recalls))

In [None]:
print(np.mean(maps_base))
print(np.mean(recalls))

In [None]:
from tqdm.autonotebook import tqdm

for k in [3, 5, 10]:
    print(k)
    maps_base = []
    recalls = []
    for i, x in enumerate(tqdm(test_dataset)):
        header_count = {i: 0.0 for i in range(config.header_vocab_size)}
        dist, neighbor = neigh.kneighbors(tfidf.transform([x[1][1 : x[5] + 1]]), k, return_distance=True)
        dist = dist.reshape(-1)
        target_e = set(x[6][1:])
        seed = x[6][0]
        for j, n in enumerate(neighbor.reshape([-1])):
            label_score = 1 if seed in train_dataset[n][6] else 0.00001
            for h in train_dataset[n][6]:
                if h != seed:
                    if dist[j] == 0:
                        header_count[h] += label_score * 100
                    else:
                        header_count[h] += label_score * 1 / dist[j]
        recalls.append(len([z for z in header_count if z in target_e]) / len(target_e))
        ap = average_precision(
            [1 if z in target_e else 0 for z, _ in sorted(header_count.items(), key=lambda p: p[1], reverse=True)]
        )
        maps_base.append(ap)
    print(np.mean(maps_base))
    print(np.mean(recalls))

In [None]:
from tqdm.autonotebook import tqdm

for k in [10]:
    maps_base = []
    recalls = []
    for i, x in enumerate(tqdm(eval_dataset)):
        header_count = {i: 0.0 for i in range(config.header_vocab_size)}
        dist, neighbor = neigh.kneighbors(tfidf.transform([x[1][1 : x[5] + 1]]), k, return_distance=True)
        dist = dist.reshape(-1)
        target_e = set(x[6][1:])
        seed = x[6][0]
        neighbor = neighbor.reshape([-1])
        for j, n in enumerate(neighbor):
            label_score = 1 if seed in train_dataset[n][6] else 0.00001
            for h in train_dataset[n][6]:
                if h != seed:
                    if dist[j] == 0:
                        header_count[h] += label_score * 100
                    else:
                        header_count[h] += label_score * 1 / dist[j]
        recalls.append(len([z for z in header_count if z in target_e]) / len(target_e))
        sorted_base = sorted(header_count.items(), key=lambda p: p[1], reverse=True)
        ap = average_precision([1 if z in target_e else 0 for z, _ in sorted_base])
        maps_base.append(ap)
        if ap > maps[i]:
            print("base: {},ours: {}".format(ap, maps[i]))
            print(x[0])
            print("caption", train_dataset.tokenizer.decode(x[1][1 : x[5] + 1]))
            display([train_dataset.header_vocab[z] for z in eval_dataset[i][6]])
            ranked_our = np.argsort(results[i])[::-1]
            display([train_dataset.header_vocab[z] for z in ranked_our[:10]])
            print("neighbor")
            print(train_dataset[neighbor[0]][0])
            print(
                "caption",
                train_dataset.tokenizer.decode(train_dataset[neighbor[0]][1][1 : train_dataset[neighbor[0]][5] + 1]),
            )
            display([train_dataset.header_vocab[z] for z, _ in sorted_base[:10]])
    print(np.mean(maps_base))
    print(np.mean(recalls))

# Viz

In [None]:
checkpoint = "output/hybrid/model_v1_table_0.2_0.6_0.7_30000_1e-4_with_cand_0"
model = model_class(config, is_simple=True)
checkpoint = torch.load(os.path.join(checkpoint, "pytorch_model.bin"))

In [None]:
checkpoint["table.embeddings.ent_embeddings.weight"].shape

In [None]:
dump_loc = "output/hybrid/model_v1_table_0.2_0.6_0.7_30000_1e-4_with_cand_0"
entity_vocab_with_type = []
with open("data/wikisql_entity/entity_vocab_with_type.tsv", "r", encoding="utf8") as f:
    next(f)
    for line in f:
        wiki_id = line.strip().split("\t")[0]
        entity_vocab_with_type.append(int(wiki_id))
with open(os.path.join(dump_loc, "entity_embedding_with_type.tsv"), "w") as f_e:
    for wiki_id in entity_vocab_with_type:
        f_e.write(
            "{}\n".format(
                "\t".join(
                    [
                        str(z)
                        for z in checkpoint["table.embeddings.ent_embeddings.weight"][entity_wikid2id[wiki_id]].tolist()
                    ]
                )
            )
        )

# CER

In [None]:
train_dataset = WikiHybridTableDataset(
    data_dir,
    entity_vocab,
    max_cell=100,
    max_input_tok=350,
    max_input_ent=150,
    src="train",
    max_length=[50, 10, 10],
    force_new=False,
    tokenizer=None,
    mode=1,
)

In [None]:
len([x[8] - 2 for x in train_dataset])

In [None]:
dataset = WikiHybridTableDataset(
    data_dir,
    entity_vocab,
    max_cell=100,
    max_input_tok=350,
    max_input_ent=150,
    src="dev",
    max_length=[50, 10, 10],
    force_new=False,
    tokenizer=None,
    mode=0,
)

In [None]:
checkpoint = "./output/CER/v2/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam_seed_1_10000/pytorch_model.bin"


config_class, model_class, _ = MODEL_CLASSES["CER"]
config = config_class.from_pretrained(config_name)
config.output_attentions = True

model = model_class(config, is_simple=True)
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

In [None]:
len(all_entity_set)

In [None]:
all_entity_set = set(dataset.entity_wikid2id.keys())
tables_ignored = 0
cached_baseline = "data/wikitables_v2/CER_test_result.pkl"
with open(cached_baseline, "rb") as f:
    cached_baseline_result = pickle.load(f)

In [None]:
seed_num = 1
results = {}
with open(os.path.join(data_dir, "test_tables.jsonl"), "r") as f:
    for line in tqdm(f):
        table = json.loads(line.strip())
        table_id = table.get("_id", "")
        pgEnt = table["pgId"]
        if not pgEnt in all_entity_set:
            pgEnt = -1
        pgTitle = table.get("pgTitle", "").lower()
        secTitle = table.get("sectionTitle", "").lower()
        caption = table.get("tableCaption", "").lower()
        headers = table.get("processed_tableHeaders", [])
        rows = table.get("tableData", {})
        entity_columns = table.get("entityColumn", [])
        headers = [headers[j] for j in entity_columns]
        entity_cells = np.array(table.get("entityCell", [[]]))
        subject = table["subject_column"]
        core_entities = []
        num_rows = len(rows)
        for i in range(num_rows):
            if entity_cells[i, subject] == 1:
                entity = rows[i][subject]["surfaceLinks"][0]["target"]["id"]
                entity_text = rows[i][subject]["text"]
                core_entities.append([entity_text, entity])
        core_entities = [z for z in core_entities if z[1] in all_entity_set]
        if len(core_entities) < 5:
            tables_ignored += 1
            continue
        seed_entities = [z[1] for z in core_entities[:seed_num]]
        seed_entities_text = [z[0] for z in core_entities[:seed_num]]
        target_entities = set([z[1] for z in core_entities[seed_num:]])
        seeds_1, _, _, pall, pee, pce, ple, cand_e, cand_c = cached_baseline_result[table_id]
        if len(target_entities) == 0:
            tables_ignored += 1
            continue
        results[table_id] = {}
        assert seeds_1 == set(seed_entities)
        cand_e = set([z for z in cand_e if z in all_entity_set and z not in seed_entities])
        cand_c = set([z for z in cand_c if z in all_entity_set and z not in seed_entities])
        entity_cand = list(cand_e | cand_c)

        pee = {k: v for k, v in pee.items() if k in entity_cand}
        pce = {k: v for k, v in pce.items() if k in entity_cand}
        ple = {k: v for k, v in ple.items() if k in entity_cand}
        pall = {k: v for k, v in pall.items() if k in entity_cand}

        (
            input_tok,
            input_tok_type,
            input_tok_pos,
            input_mask,
            input_ent,
            input_ent_text,
            input_ent_text_length,
            input_ent_type,
            candidate_entity_set,
        ) = CER_build_input(
            pgEnt, pgTitle, secTitle, caption, headers[0], seed_entities, seed_entities_text, entity_cand, dataset
        )

        input_tok = input_tok.to(device)
        input_tok_type = input_tok_type.to(device)
        input_tok_pos = input_tok_pos.to(device)
        input_ent = input_ent.to(device)
        input_ent_text = input_ent_text.to(device)
        input_ent_text_length = input_ent_text_length.to(device)
        input_ent_type = input_ent_type.to(device)
        input_mask = input_mask.to(device)
        candidate_entity_set = candidate_entity_set.to(device)

        with torch.no_grad():
            ent_outputs = model(
                input_tok,
                input_tok_type,
                input_tok_pos,
                input_mask,
                input_ent,
                input_ent_text,
                input_ent_text_length,
                input_ent_type,
                input_mask,
                candidate_entity_set,
                None,
                None,
            )
            ent_prediction_scores = ent_outputs[0][0].tolist()

            p_neural = {}

            for i, entity in enumerate(entity_cand):
                p_neural[entity] = ent_prediction_scores[i]
        results[table_id] = {
            "pgTitle": pgTitle,
            "secTitle": secTitle,
            "caption": caption,
            "headers": headers,
            "cand_all": entity_cand,
            "cand_e": cand_e,
            "cand_c": cand_c,
            "seed_e": seed_entities,
            "target_e": target_entities,
            "p_neural": p_neural,
            "pee": pee,
            "pce": pce,
            "ple": ple,
            "pall": pall,
        }

In [None]:
np.mean([len(z["target_e"]) for _, z in results.items()])

In [None]:
print("map neural", np.mean([get_ap(x["p_neural"], x["target_e"]) for _, x in results.items()]))
print(
    "map neural - only cand_e",
    np.mean(
        [
            get_ap({z: score if z in x["cand_e"] else -10000 for z, score in x["p_neural"].items()}, x["target_e"])
            for _, x in results.items()
        ]
    ),
)
print("map ee", np.mean([get_ap(x["pee"], x["target_e"]) for _, x in results.items()]))
print("map le", np.mean([get_ap(x["ple"], x["target_e"]) for _, x in results.items()]))
print("map ce", np.mean([get_ap(x["pce"], x["target_e"]) for _, x in results.items()]))
print("map all", np.mean([get_ap(x["pall"], x["target_e"]) for _, x in results.items()]))

In [None]:
checkpoint = "./output/CER/v2/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam_seed_0_10000/pytorch_model.bin"


config_class, model_class, _ = MODEL_CLASSES["CER"]
config = config_class.from_pretrained(config_name)
config.output_attentions = True

model = model_class(config, is_simple=True)
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

In [None]:
all_entity_set = set(dataset.entity_wikid2id.keys())
tables_ignored = 0
dev_result = {}
cached_baseline = "data/wikitables_v2/CER_test_result_seed_0.pkl"
with open(cached_baseline, "rb") as f:
    cached_baseline_result = pickle.load(f)

In [None]:
seed_num = 0
results = {}
with open(os.path.join(data_dir, "test_tables.jsonl"), "r") as f:
    for line in tqdm(f):
        table = json.loads(line.strip())
        table_id = table.get("_id", "")
        pgEnt = table["pgId"]
        if not pgEnt in all_entity_set:
            pgEnt = -1
        pgTitle = table.get("pgTitle", "").lower()
        secTitle = table.get("sectionTitle", "").lower()
        caption = table.get("tableCaption", "").lower()
        headers = table.get("processed_tableHeaders", [])
        rows = table.get("tableData", {})
        entity_columns = table.get("entityColumn", [])
        headers = [headers[j] for j in entity_columns]
        entity_cells = np.array(table.get("entityCell", [[]]))
        subject = table["subject_column"]
        core_entities = []
        num_rows = len(rows)
        for i in range(num_rows):
            if entity_cells[i, subject] == 1:
                entity = rows[i][subject]["surfaceLinks"][0]["target"]["id"]
                entity_text = rows[i][subject]["text"]
                core_entities.append([entity_text, entity])
        core_entities = [z for z in core_entities if z[1] in all_entity_set]
        if len(core_entities) < 5:
            tables_ignored += 1
            continue
        seed_entities = []
        seed_entities_text = []
        target_entities = set([z[1] for z in core_entities])
        _, target, _, pall, _, pce, ple, _, cand_c = cached_baseline_result[table_id]
        assert target == target_entities
        if len(target_entities) == 0:
            tables_ignored += 1
            continue
        results[table_id] = {}
        cand_c = set([z for z in cand_c if z in all_entity_set])
        entity_cand = list(cand_c)

        pce = {k: v for k, v in pce.items() if k in entity_cand}
        ple = {k: v for k, v in ple.items() if k in entity_cand}
        pall = {k: v for k, v in pall.items() if k in entity_cand}

        (
            input_tok,
            input_tok_type,
            input_tok_pos,
            input_mask,
            input_ent,
            input_ent_text,
            input_ent_text_length,
            input_ent_type,
            candidate_entity_set,
        ) = CER_build_input(
            pgEnt, pgTitle, secTitle, caption, headers[0], seed_entities, seed_entities_text, entity_cand, dataset
        )

        input_tok = input_tok.to(device)
        input_tok_type = input_tok_type.to(device)
        input_tok_pos = input_tok_pos.to(device)
        input_ent = input_ent.to(device)
        input_ent_text = input_ent_text.to(device)
        input_ent_text_length = input_ent_text_length.to(device)
        input_ent_type = input_ent_type.to(device)
        input_mask = input_mask.to(device)
        candidate_entity_set = candidate_entity_set.to(device)

        with torch.no_grad():
            ent_outputs = model(
                input_tok,
                input_tok_type,
                input_tok_pos,
                input_mask,
                input_ent,
                input_ent_text,
                input_ent_text_length,
                input_ent_type,
                input_mask,
                candidate_entity_set,
                None,
                None,
            )
            ent_prediction_scores = ent_outputs[0][0].tolist()

            p_neural = {}

            for i, entity in enumerate(entity_cand):
                p_neural[entity] = ent_prediction_scores[i]
        results[table_id] = {
            "pgTitle": pgTitle,
            "secTitle": secTitle,
            "caption": caption,
            "headers": headers,
            "cand_c": cand_c,
            "seed_e": seed_entities,
            "target_e": target_entities,
            "p_neural": p_neural,
            "pce": pce,
            "ple": ple,
            "pall": pall,
        }

In [None]:
len(cached_baseline_result)

In [None]:
np.mean([len(z["target_e"]) for _, z in results.items()])

In [None]:
print(
    "recall all",
    np.mean([len(set(x["cand_all"]) & x["target_e"]) / len(x["target_e"]) for _, x in results.items()]),
    np.mean([len(set(x["cand_all"])) for _, x in results.items()]),
)
print(
    "recall e",
    np.mean([len(x["cand_e"] & x["target_e"]) / len(x["target_e"]) for _, x in results.items()]),
    np.mean([len(set(x["cand_e"])) for _, x in results.items()]),
)
print(
    "recall c",
    np.mean([len(x["cand_c"] & x["target_e"]) / len(x["target_e"]) for _, x in results.items()]),
    np.mean([len(set(x["cand_c"])) for _, x in results.items()]),
)

In [None]:
def get_ap(scores, target_e):
    ranked = sorted(scores.items(), key=lambda z: z[1], reverse=True)
    ranked_l = [1 if z[0] in target_e else 0 for z in ranked]
    ap = average_precision(ranked_l)
    return ap

In [None]:
print("map neural", np.mean([get_ap(x["p_neural"], x["target_e"]) for _, x in results.items()]))
# print('map neural - only cand_e', np.mean([get_ap({z:score if z in x['cand_e'] else -10000 for z, score in x['p_neural'].items()},x['target_e']) for _,x in results.items()]))
# print('map ee', np.mean([get_ap(x['pee'],x['target_e']) for _,x in results.items()]))
print("map le", np.mean([get_ap(x["ple"], x["target_e"]) for _, x in results.items()]))
print("map ce", np.mean([get_ap(x["pce"], x["target_e"]) for _, x in results.items()]))
print("map all", np.mean([get_ap(x["pall"], x["target_e"]) for _, x in results.items()]))

In [None]:
for w in [0.999, 0.99, 0.9, 0.5, 0.1, 0.05, 0.06, 0.07, 0.08, 0.09, 0.01]:
    print(
        "map neural - ensemble {}".format(w),
        np.mean(
            [
                get_ap({z: w * score + (1 - w) * x["pee"][z] for z, score in x["p_neural"].items()}, x["target_e"])
                for _, x in results.items()
            ]
        ),
    )

In [None]:
inspect_ids = []
for table_id, x in results.items():
    recall = len(set(x["cand_all"]) & x["target_e"]) / len(x["target_e"])
    ap_neural = get_ap(x["p_neural"], x["target_e"])
    ap_ee = get_ap(x["pee"], x["target_e"])
    if recall != 0 and (ap_neural < 0.4 or ap_neural < ap_ee):
        inspect_ids.append(table_id)
print(len(inspect_ids))

In [None]:
def inspect_result(result):
    ap_neural = get_ap(result["p_neural"], result["target_e"])
    ap_ee = get_ap(result["pee"], result["target_e"])
    print("ap_neural: {}\nap_ee: {}".format(ap_neural, ap_ee))
    print("{} - {} - {}".format(result["pgTitle"], result["secTitle"], result["caption"]))
    print(result["headers"])
    print("seed:")
    print("; ".join([entity_vocab[entity_wikid2id[e]]["wiki_title"] for e in result["seed_e"]]))
    target_entities = [entity_vocab[entity_wikid2id[z]] for z in result["target_e"]]
    print("target:\n%s" % ("; ".join([z["wiki_title"] for z in target_entities])))
    ranked_neural = sorted(result["p_neural"].items(), key=lambda z: z[1], reverse=True)
    print("neural:")
    print(
        "; ".join(
            [
                (
                    "[%s:%f]" % (entity_vocab[entity_wikid2id[e]]["wiki_title"], score)
                    if e in result["target_e"]
                    else "%s:%.2f" % (entity_vocab[entity_wikid2id[e]]["wiki_title"], score)
                )
                for e, score in ranked_neural[:10]
            ]
        )
    )
    ranked_e = sorted(result["pee"].items(), key=lambda z: z[1], reverse=True)
    print("ee:")
    print(
        "; ".join(
            [
                (
                    "[%s:%f]" % (entity_vocab[entity_wikid2id[e]]["wiki_title"], score)
                    if e in result["target_e"]
                    else "%s:%.2f" % (entity_vocab[entity_wikid2id[e]]["wiki_title"], score)
                )
                for e, score in ranked_e[:10]
            ]
        )
    )

In [None]:
inspect_result(results[inspect_ids[3]])
print(len([id for id in inspect_ids if results[id]["headers"][0] in ["opponent", "team 1", "home team"]]))

In [None]:
inspect_result(results[inspect_ids[6]])
print(len([id for id in inspect_ids if "miss dominican republic" in results[id]["pgTitle"]]))

In [None]:
inspect_result(results[inspect_ids[32]])
print(len([id for id in inspect_ids if results[id]["headers"][0] == "constituency"]))

# Attribute Recommendation

In [None]:
config_class, model_class, _ = MODEL_CLASSES["HR"]
config = config_class.from_pretrained(config_name)
config.output_attentions = True

In [None]:
train_dataset = WikiHeaderDataset(
    data_dir, max_input_tok=350, src="train", max_length=[50, 10], force_new=False, tokenizer=None
)
eval_dataset = WikiHeaderDataset(
    data_dir, max_input_tok=350, src="dev", max_length=[50, 10], force_new=False, tokenizer=None
)
test_dataset = WikiHeaderDataset(
    data_dir, max_input_tok=350, src="test", max_length=[50, 10], force_new=False, tokenizer=None
)

In [None]:
len(eval_dataset.header_vocab)

In [None]:
config.__dict__["header_vocab_size"] = len(eval_dataset.header_vocab)

In [None]:
checkpoint = "output/HR/v2/1/model_v1_table_0.2_0.6_0.7_10000_1e-4_candnew_0_adam"
# checkpoint = "output/HR/bert_seed_0/"
model = model_class(config, is_simple=True)
checkpoint = torch.load(os.path.join(checkpoint, "pytorch_model.bin"))
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

In [None]:
eval_batch_size = 64
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = WikiHeaderLoader(
    eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size, is_train=False, seed=1
)

In [None]:
results = []
for batch in tqdm(eval_dataloader, desc="Evaluating"):
    _, input_tok, input_tok_type, input_tok_pos, input_mask, seed_header, target_header = batch
    input_tok = input_tok.to(device)
    input_tok_type = input_tok_type.to(device)
    input_tok_pos = input_tok_pos.to(device)
    input_mask = input_mask.to(device)
    seed_header = seed_header.to(device)
    target_header = target_header.to(device)
    # pdb.set_trace()
    with torch.no_grad():
        header_outputs = model(input_tok, input_tok_type, input_tok_pos, input_mask, seed_header, target_header)
        header_loss = header_outputs[0]
        header_prediction_scores = header_outputs[1]
        results.extend(header_prediction_scores.tolist())

In [None]:
def get_ap(scores, target_e):
    ranked = np.argsort(scores)[::-1]
    target_e = set(target_e)
    ranked_l = [1 if z in target_e else 0 for z in ranked]
    ap = average_precision(ranked_l)
    #     if ap<0.7:
    #         display([train_dataset.header_vocab[z] for z in ranked[:10]])
    return ap

In [None]:
maps = []
for i, x in tqdm(enumerate(results)):
    ap = get_ap(x, eval_dataset[i][6][1:])
    #     if ap<0.7:
    #         display([train_dataset.header_vocab[z] for z in eval_dataset[i][6]])
    maps.append(ap)
print(np.mean(maps))

In [None]:
errors = [i for i, ap in enumerate(maps) if ap < 0.5]

In [None]:
display([eval_dataset.header_vocab[x] for x in np.argsort(results[1])[::-1][:10]])
display([eval_dataset.header_vocab[x] for x in eval_dataset[1][5][1:]])

In [None]:
eval_dataset.tokenizer.decode(eval_dataset[0][1])

In [None]:
def inspect(i):
    print(eval_dataset.tokenizer.decode(eval_dataset[i][1]))
    print(maps[i])
    print("; ".join([eval_dataset.header_vocab[x] for x in np.argsort(results[i])[::-1][:10]]))
    print("; ".join([eval_dataset.header_vocab[x] for x in eval_dataset[i][5][1:]]))

In [None]:
inspect(errors[23])

In [None]:
dump_loc = "output/HR/hybrid/model_v1_table_0.2_0.4_0.7_30000_1e-4_with_cand_0_seed_0/"
with open(os.path.join(dump_loc, "header_embedding.tsv"), "w") as f_e, open(
    os.path.join(dump_loc, "header_names.tsv"), "w", encoding="utf8"
) as f_n:
    for i, name in eval_dataset.header_vocab.items():
        f_n.write("{}\n".format(name))
        f_e.write("{}\n".format("\t".join([str(z) for z in model.cls.weight.data[i].tolist()])))

In [None]:
train_dataset[0]

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer

tfidf = TfidfVectorizer(analyzer=lambda x: x, token_pattern=None)
train_tfidf = tfidf.fit_transform([x[1][1 : x[5] + 1] for x in train_dataset])

In [None]:
from sklearn.neighbors import NearestNeighbors

neigh = NearestNeighbors(n_neighbors=1, metric="cosine", n_jobs=None)
neigh.fit(train_tfidf)