In [23]:
import json
import os
from transformers import BertModel, BertTokenizer
import torch
from tqdm import tqdm
from collections import OrderedDict
import torch.nn as nn
from time import perf_counter

In [2]:
class DualEncoder(nn.Module):
    def __init__(self, mention_encoder,
                 entity_encoder,
                 type_loss):
        super(DualEncoder, self).__init__()
        self.mention_encoder = mention_encoder
        self.entity_encoder = entity_encoder

    def encode(self, mention_token_ids=None,
               mention_masks=None,
               candidate_token_ids=None,
               candidate_masks=None,
               entity_token_ids=None,
               entity_masks=None):
        candidates_embeds = None
        mention_embeds = None
        entity_embeds = None
        # candidate_token_ids and mention_token_ids not None during training
        # mention_token_ids not None for embedding mentions during inference
        # entity_token_ids not None for embedding entities during inference
        if candidate_token_ids is not None:
            B, C, L = candidate_token_ids.size()
            candidate_token_ids = candidate_token_ids.view(-1, L)
            candidate_masks = candidate_masks.view(-1, L)
            # B X C X L --> BC X L
            candidates_embeds = self.entity_encoder(
                input_ids=candidate_token_ids,
                attention_mask=candidate_masks
            )[0][:, 0, :].view(B, C, -1)
        if mention_token_ids is not None:
            mention_embeds = self.mention_encoder(
                input_ids=mention_token_ids,
                attention_mask=mention_masks
            )[0][:, 0, :]
        if entity_token_ids is not None:
            # for getting all the entity embeddings
            entity_embeds = self.entity_encoder(input_ids=entity_token_ids,
                                                attention_mask=entity_masks)[
                                0][:, 0, :]
        return mention_embeds, candidates_embeds, entity_embeds

    def forward(self,
                mention_token_ids=None,
                mention_masks=None,
                candidate_token_ids=None,
                candidate_masks=None,
                entity_token_ids=None,
                entity_masks=None
                ):
        """

        :param inputs: [
                        mention_token_ids,mention_masks,  size: B X L
                        candidate_token_ids,candidate_masks, size: B X C X L
                        passages_labels, size: B X C
                        ]
        :return: loss, logits

        """
        return self.encode(mention_token_ids, mention_masks,
                            candidate_token_ids, candidate_masks,
                            entity_token_ids, entity_masks)

In [3]:
def load_model(is_init, config_path, model_path, device, type_loss,
               blink=True):
    with open(config_path) as json_file:
        params = json.load(json_file)
    if blink:
        ctxt_bert = BertModel.from_pretrained(params["bert_model"])
        cand_bert = BertModel.from_pretrained(params["bert_model"])
    else:
        ctxt_bert = BertModel.from_pretrained('bert-large-uncased')
        cand_bert = BertModel.from_pretrained('bert-large-uncased')
    state_dict = torch.load(model_path) if device.type == 'cuda' else \
        torch.load(model_path, map_location=torch.device('cpu'))
    if is_init:
        if blink:
            ctxt_dict = OrderedDict()
            cand_dict = OrderedDict()
            for k, v in state_dict.items():
                if k[:26] == 'context_encoder.bert_model':
                    new_k = k[27:]
                    ctxt_dict[new_k] = v
                if k[:23] == 'cand_encoder.bert_model':
                    new_k = k[24:]
                    cand_dict[new_k] = v
            ctxt_bert.load_state_dict(ctxt_dict, strict=False)
            cand_bert.load_state_dict(cand_dict, strict=False)
        model = DualEncoder(ctxt_bert, cand_bert, type_loss)
    else:
        model = DualEncoder(ctxt_bert, cand_bert, type_loss)
        model.load_state_dict(state_dict['sd'])
    return model

In [4]:
model = load_model(True, 'EntQA/models/biencoder_wiki_large.json', 'EntQA/retriever.pt', torch.device('cpu'), None, True)

Downloading config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

In [6]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [15]:
entity = """Ordinal numbers (or ordinals) are numbers that show something's order, for example: 1st, 2nd, 3rd, 4th, 5th.

Suppose a person has four different T-shirts, and then lays them in front of the person, from left to right.

    At the far left, there is the red T-shirt.
    Right of that is the blue one.
    Then there is the yellow one.
    And finally, at the far right is an orange T-shirt.

If the person then starts counting the shirts from the left, he would first see the red shirt. So the red shirt is the first T-shirt. The blue shirt is the second T-shirt. The yellow shirt is the third one, and the orange T-shirt is the fourth one.

The first, second, third, and fourth in this case are ordinal numbers. They result from the fact that the person has many objects, and they give them an order (hence 'ordinal'). The person then simply counts those objects, and gives the ordinal numbers to them.

In set theory, ordinals are also ordinal numbers people use to order infinite sets. An example is the set ω 0 {\displaystyle \omega _{0}} (or ω {\displaystyle \omega } for short), which is the set containing all natural numbers (including 0).[1][2] This is the smallest ordinal number that is infinite, and there are many more (such as ω {\displaystyle \omega } + 1).[3] """
passage = """People use symbols to represent numbers; they call them numerals. Common places where numerals are used are for labeling, as in telephone numbers, for ordering, as in serial numbers, or to put a unique identifier, as in an ISBN, a unique number that can identify a book.
    Cardinal numbers are used to measure how many items are in a set. For example, {A,B,C} has size "3".
    Ordinal numbers are used to specify a certain element in a set or sequence (first, second, third).

Numbers are also used for other things like counting. Numbers are used when things are measured. Numbers are used to study how the world works. Mathematics is a way to use numbers to learn about the world and make things. The study of the rules of the natural world is called science. The work that uses numbers to make things is called engineering. """
passage_false = """A keloid is a type of scar that can form where somebody has an injury.[1] Keloids are tough and get larger over time, not going away. They can become as big as 30 centimeters long. They are shaped irregularly, rising high above the skin."""

In [29]:
entity_tokens = tokenizer([entity], return_tensors='pt', padding=True, max_length=100, truncation=True)
passage_tokens = tokenizer([passage], return_tensors='pt', padding=True, max_length=100, truncation=True)
passage_false_tokens = tokenizer([passage_false], return_tensors='pt', padding=True, max_length=100, truncation=True)

In [30]:
with torch.no_grad():
    passage_embedding, _, entity_embedding = model.forward(mention_token_ids=passage_tokens['input_ids'],
                mention_masks=passage_tokens['attention_mask'],
                entity_token_ids=entity_tokens['input_ids'],
                entity_masks=entity_tokens['attention_mask'])
    passage_embedding_false, _, _ = model.forward(mention_token_ids=passage_false_tokens['input_ids'],
                mention_masks=passage_false_tokens['attention_mask'],)

In [31]:
# print cosine similarity between passage and entity
print(torch.cosine_similarity(passage_embedding, entity_embedding, dim=1))

tensor([0.8839])


In [32]:
print(torch.cosine_similarity(passage_embedding_false, entity_embedding, dim=1))

tensor([0.5848])


In [33]:
start = perf_counter()
with torch.no_grad():
    for i in tqdm(range(100)):
        passage_embedding, _, _ = model.forward(mention_token_ids=passage_tokens['input_ids'],
                    mention_masks=passage_tokens['attention_mask'])
end = perf_counter()
print(end - start)

100%|██████████| 100/100 [00:13<00:00,  7.50it/s]

13.33136409916915





In [35]:
entity_embedding

tensor([[ 0.4833,  0.2775, -0.9148,  ..., -0.6494, -0.1603,  0.6491]])