# Entity Detection
The task of entity detection is a basic task for the knowledge graph handling. 

![title](knowledge-graph.jpg)

In [None]:
#import libraries for data preprocessing
import re
from fuzzywuzzy import fuzz
import json

from utils import process_original_entity, repalce_punc, processed_text, process_entity
from nltk.tokenize.treebank import TreebankWordTokenizer

tokenizer = TreebankWordTokenizer()

## Preprocessing
Let us prepare the data for the futher task.

In [None]:
def get_indices(src_list, pattern_list):
    indices = None
    for i in range(len(src_list)):
        match = 1
        for j in range(len(pattern_list)):
            if src_list[i + j] != pattern_list[j]:
                match = 0
                break
        if match:
            indices = range(i, i + len(pattern_list))
            break
    return indices


def get_ngram(tokens):
    ngram = []
    for i in range(1, len(tokens) + 1):
        for s in range(len(tokens) - i + 1):
            ngram.append((" ".join(tokens[s: s + i]), s, i + s))
    return ngram

In [None]:
def reverse_linking(sent, dbpedia_text, original):
    tokens = sent.split()
    label = ["O"] * len(tokens)
    exact_match = False

    pattern = r'(^|\s)(%s)($|\s)' % (re.escape(dbpedia_text))
    entity_span = None
    if re.search(pattern, sent):
        entity_span = get_indices(tokens, dbpedia_text.split())
    pattern = r'(^|\s)(%s)($|\s)' % (re.escape(original))
    if re.search(pattern, sent):
        entity_span = get_indices(tokens, original.split())
    if entity_span != None:
        exact_match = True
        for i in entity_span:
            label[i] = "I"
    else:
        n_gram_candidate = get_ngram(tokens)
        n_gram_candidate = sorted(n_gram_candidate, key=lambda x: fuzz.token_sort_ratio(x[0], dbpedia_text),
                                  reverse=True)
        top = n_gram_candidate[0]
        for i in range(top[1], top[2]):
            label[i] = 'I'
    entity_text = []
    for l, t in zip(label, tokens):
        if l == 'I':
            entity_text.append(t)
    entity_text = " ".join(entity_text)
    label = " ".join(label)
    return entity_text, label, exact_match

### Get our datasets for train, validation and test

In [None]:
# question = "what film did peter menzies jr. do cinematography for" entity = "Peter_Menzies_Jr." processed_query
# = processed_text(repalce_punc(question)) processed_candidate = process_entity(repalce_punc(entity))
# processed_candidate_original = process_original_entity(repalce_punc(entity)) entity_text, label, exact_match =
# reverse_linking(processed_query, processed_candidate, processed_candidate_original) print("{}\t{}\t{}\t{
# }\n".format(question, label, entity_text, str(exact_match))) exit()
folds = ["train", "valid", "test"]
for fold in folds:
    exact_match_counter = 0
    total = 0
    fin = open("data/{}.json".format(fold), "rt", encoding="utf-8")
    json_data = json.load(fin)
    fout = open("data/{}.txt".format(fold), "wt", encoding="utf-8")
    for instance in json_data["Questions"]:
        total += 1
        idx = instance["ID"]
        sub = instance["Subject"]
        pre = instance["PredicateList"][0]["Predicate"]
        direction = instance["PredicateList"][0]["Direction"]
        constraint = instance["PredicateList"][0]["Constraint"]
        free_pre = instance["FreebasePredicate"]
        question = instance["Query"]
        entity = sub.replace("http://dbpedia.org/resource/", "")
        processed_query = processed_text(repalce_punc(question))
        processed_candidate = process_entity(repalce_punc(entity))
        processed_candidate_original = process_original_entity(repalce_punc(entity))
        entity_text, label, exact_match = reverse_linking(processed_query, processed_candidate,
                                                          processed_candidate_original)
        fout.write("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format(idx, processed_query, sub, pre, direction,
                                                             pre + "@" + direction + "@" + str(constraint),
                                                             free_pre, label))  # entity_text, str(exact_match)
        if exact_match:
            exact_match_counter += 1
    print("{}\t{} / {} : {}".format(fold, exact_match_counter, total, exact_match_counter / total))

## Main part - training model 

In [None]:
#import libraries
import torch
import torch.nn as nn
import time
import os
import numpy as np
from torchtext import data
import random

In [None]:
# add some methods for evaluation and creating SQdataset
from entity_detection.nn.args import get_args
from entity_detection.nn.evaluation import evaluation
from entity_detection.nn.sq_entity_dataset import SQdataset
from entity_detection.nn.entity_detection import EntityDetection

In [None]:
# Set default configuration in : args.py
args = get_args(["--entity_detection_mode", "LSTM"])

In [None]:
# Set random seed for reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
torch.backends.cudnn.deterministic = True

In [None]:
# GPU will be very usefull for this task
device = torch.device("cpu")
if torch.cuda.is_available():
    print("Note: You are using GPU for training")
    device = torch.device("gpu")
    torch.cuda.set_device(args.gpu)
    torch.cuda.manual_seed(args.seed)

### Let us handle the data:

In [None]:
# Set up the data for training
TEXT = data.Field(lower=True)
ED = data.Field()

### Create SQdataset with fields: id, sub, entity, relation, obj, text, ed

In [None]:
train, dev, test = SQdataset.splits(TEXT, ED, args.data_dir)
TEXT.build_vocab(train, dev, test)
ED.build_vocab(train, dev, test)

In [None]:
TEXT.vocab.vectors = torch.Tensor(len(TEXT.vocab), args.words_dim)
for i, token in enumerate(TEXT.vocab.itos):
    TEXT.vocab.vectors[i] = torch.FloatTensor(args.words_dim).uniform_(-0.25, 0.25)

In [None]:
# Defines an iterator that loads batches of data from a Dataset.
train_iter = data.Iterator(train, batch_size=args.batch_size, device=device, train=True, repeat=False,
                           sort=False, shuffle=True, sort_within_batch=False)
dev_iter = data.Iterator(dev, batch_size=args.batch_size, device=device, train=False, repeat=False,
                         sort=False, shuffle=False, sort_within_batch=False)
test_iter = data.Iterator(test, batch_size=args.batch_size, device=device, train=False, repeat=False,
                          sort=False, shuffle=False, sort_within_batch=False)

config = args
config.words_num = len(TEXT.vocab)

# Our model
if args.dataset == 'EntityDetection':
    config.label = len(ED.vocab)
    model = EntityDetection(config)
else:
    raise("Error Dataset")

In [None]:
model.embed.weight.data.copy_(TEXT.vocab.vectors)

model = model.to(device)

print(config)

In [None]:
print("VOCAB num", len(TEXT.vocab))
print("Train instance", len(train))
print("Dev instance", len(dev))
print("Test instance", len(test))
print("Entity Type", len(ED.vocab))

### Let see how model looks

In [None]:
print(model)

## Training
Now it is time to set up training cycle.

In [None]:
parameter = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay)
criterion = nn.NLLLoss()

early_stop = False
best_dev_F = 0
best_dev_P = 0
best_dev_R = 0
iterations = 0
iters_not_improved = 0
num_dev_in_epoch = (len(train) // args.batch_size // args.dev_every) + 1
patience = args.patience * num_dev_in_epoch # for early stopping
epoch = 0
start = time.time()
header = '  Time Epoch Iteration Progress    (%Epoch)   Loss   Dev/Loss     Accuracy  Dev/Accuracy'
dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{},{}'.split(','))
save_path = os.path.join(args.save_path, args.entity_detection_mode.lower())
os.makedirs(save_path, exist_ok=True)
print(header)

if args.dataset == 'EntityDetection':
    index2tag = np.array(ED.vocab.itos)
else:
    print("Wrong Dataset")

### It will be long...

In [None]:
for _ in range(100):
    if early_stop:
        print("Early Stopping. Epoch: {}, Best Dev F1: {}".format(epoch, best_dev_F))
        break
    epoch += 1
    train_iter.init_epoch()
    n_correct, n_total = 0, 0
    n_correct_ed, n_correct_ner , n_correct_rel = 0, 0, 0

    for batch_idx, batch in enumerate(train_iter):
        # Batch size : (Sentence Length, Batch_size)
        iterations += 1
        model.train(); optimizer.zero_grad()
        scores = model(batch)
        # Entity Detection
        n_correct += torch.sum((torch.sum((torch.max(scores, 1)[1].view(batch.ed.size()).data == batch.ed.data), 
                                          dim=0) == batch.ed.size()[0])).item()
        loss = criterion(scores, batch.ed.view(-1, 1)[:, 0])

        n_total += batch.batch_size
        loss.backward()
        # clip the gradient
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradient)
        optimizer.step()

        # evaluate performance on validation set periodically
        if iterations % args.dev_every == 0:
            model.eval()
            dev_iter.init_epoch()
            n_dev_correct = 0
            n_dev_correct_rel = 0

            gold_list = []
            pred_list = []

            for dev_batch_idx, dev_batch in enumerate(dev_iter):
                answer = model(dev_batch)

                n_dev_correct += ((torch.max(answer, 1)[1].view(dev_batch.ed.size()).data 
                                   == dev_batch.ed.data).sum(dim=0) == dev_batch.ed.size()[0]).sum()
                index_tag = np.transpose(torch.max(answer, 1)[1].view(dev_batch.ed.size()).cpu().data.numpy())
                gold_list.append(np.transpose(dev_batch.ed.cpu().data.numpy()))
                pred_list.append(index_tag)

            P, R, F = evaluation(gold_list, pred_list, index2tag, type=False)
            print("{} Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format(
                "Dev", 100. * P, 100. * R, 100. * F))

            # update model
            if F > best_dev_F:
                best_dev_F = F
                best_dev_P = P
                best_dev_R = R
                iters_not_improved = 0
                snapshot_path = os.path.join(save_path, args.specify_prefix + '_best_model.pt')
                # save model, delete previous 'best_snapshot' files
                torch.save(model, snapshot_path)
            else:
                iters_not_improved += 1
                if iters_not_improved > patience:
                    early_stop = True
                    break

        if iterations % args.log_every == 1:
            # print progress message
            print(log_template.format(time.time() - start,
                                          epoch, iterations, 1 + batch_idx, len(train_iter),
                                          100. * (1 + batch_idx) / len(train_iter), loss.item(), ' ' * 8,
                                          100. * n_correct / n_total, ' ' * 12))

### Evaluation
The evaluation is simple sequence matching in this task.

In [None]:
index2tag = np.array(ED.vocab.itos)
index2word = np.array(TEXT.vocab.itos)

results_path = os.path.join(args.results_path, args.entity_detection_mode.lower())
if not os.path.exists(results_path):
    os.makedirs(results_path, exist_ok=True)

In [None]:
def convert(file_name, id_file, output_file):
    fin = open(file_name)
    fid = open(id_file)
    fout = open(output_file, "w")

    for line, line_id in tqdm(zip(fin.readlines(), fid.readlines())):
        query_list = []
        query_text = []
        line = line.strip().split('\t')
        sent = line[0].strip().split()
        pred = line[1].strip().split()
        for token, label in zip(sent, pred):
            if label == 'I':
                query_text.append(token)
            if label == 'O':
                query_text = list(filter(lambda x: x != '<pad>', query_text))
                if len(query_text) != 0:
                    query_list.append(" ".join(list(filter(lambda x:x!='<pad>', query_text))))
                    query_text = []
        query_text = list(filter(lambda x: x != '<pad>', query_text))
        if len(query_text) != 0:
            query_list.append(" ".join(list(filter(lambda x:x!='<pad>', query_text))))
            query_text = []
        if len(query_list) == 0:
            query_list.append(" ".join(list(filter(lambda x:x!='<pad>',sent))))
        fout.write(" %%%% ".join([line_id.strip()]+query_list)+"\n")

In [None]:
def predict(dataset_iter=test_iter, dataset=test, data_name="test"):
    print("Dataset: {}".format(data_name))
    model.eval()
    dataset_iter.init_epoch()

    n_correct = 0
    fname = "{}.txt".format(data_name)
    temp_file = 'tmp'+fname
    results_file = open(temp_file, 'w')

    gold_list = []
    pred_list = []

    for data_batch_idx, data_batch in enumerate(dataset_iter):
        scores = model(data_batch)
        n_correct += torch.sum(torch.sum(torch.max(scores, 1)[1].view(data_batch.ed.size()).data 
                                         == data_batch.ed.data, dim=1) == data_batch.ed.size()[0]).item()
        index_tag = np.transpose(torch.max(scores, 1)[1].view(data_batch.ed.size()).cpu().data.numpy())
        tag_array = index2tag[index_tag]
        index_question = np.transpose(data_batch.text.cpu().data.numpy())
        question_array = index2word[index_question]
        gold_list.append(np.transpose(data_batch.ed.cpu().data.numpy()))
        gold_array = index2tag[np.transpose(data_batch.ed.cpu().data.numpy())]
        pred_list.append(index_tag)
        for question, label, gold in zip(question_array, tag_array, gold_array):
            results_file.write("{}\t{}\t{}\n".format(" ".join(question), " ".join(label), " ".join(gold)))


    P, R, F = evaluation(gold_list, pred_list, index2tag, type=False)
    print("{} Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format("Dev", 100. * P, 100. * R,
                                                                                 100. * F))

    results_file.flush()
    results_file.close()
    convert(temp_file, os.path.join(args.data_dir, "lineids_{}.txt".format(data_name)), 
            os.path.join(results_path, "query.{}".format(data_name)))
    os.remove(temp_file)

In [None]:
# run the model on the dev set and write the output to a file
predict(dataset_iter=dev_iter, dataset=dev, data_name="valid")

In [None]:
# run the model on the test set and write the output to a file
predict(dataset_iter=test_iter, dataset=test, data_name="test")

## There is more!
Check out the repository with full baselines on this data: https://github.com/castorini/BuboQA 
This seminar is also based on that repository.