In [1]:
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AdamW
from transformers import BertConfig
from transformers import BertModel
from transformers import BertTokenizer
from transformers import get_linear_schedule_with_warmup
from transformers import RobertaConfig
from transformers import RobertaModel
from transformers import RobertaTokenizer

In [2]:
from KBQA.appB.transformer_architectures.kb.knowbert import KnowBert
import logging
import sys


logger = logging.getLogger('knowbert-logger')
streamHandler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
streamHandler.setFormatter(formatter)
logger.addHandler(streamHandler)
logger.error("This is the first error")
logger.setLevel(logging.DEBUG)


2022-06-10 09:15:30,338 - knowbert-logger - ERROR - This is the first error


In [3]:
encoder = KnowBert.load_pretrained_model()

2022-06-10 09:15:34,006 - knowbert-logger.main - INFO - Loaded Vocabulary
2022-06-10 09:15:34,007 - knowbert-logger.main - DEBUG - Vocabulary with namespaces:
 	Non Padded Namespaces: {'*labels', '*tags'}
 	Namespace: entity_wiki, Size: 470116 
 	Namespace: entity_wordnet, Size: 117663 



0it [00:00, ?it/s]

2022-06-10 09:16:19,879 - knowbert-logger.main - INFO - Loaded wiki embedding
2022-06-10 09:16:19,885 - knowbert-logger.main - INFO - Init Bert Encoder
2022-06-10 09:16:19,913 - knowbert-logger.main - INFO - Loaded wiki soldered KG
2022-06-10 09:16:26,082 - knowbert-logger.main - INFO - Loaded wordnet embedding
2022-06-10 09:16:26,086 - knowbert-logger.main - INFO - Init Bert Encoder
2022-06-10 09:16:26,106 - knowbert-logger.main - INFO - Loaded wordnet soldered KG


In [None]:
triple_encoder_config = BertConfig.from_pretrained("razent/spbert-mlm-wso-base")
triple_encoder = BertModel.from_pretrained(
    "razent/spbert-mlm-wso-base", config=triple_encoder_config
)

In [None]:
decoder_config = BertConfig.from_pretrained(
    "razent/spbert-mlm-wso-base"
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
decoder = BertModel.from_pretrained(
    "razent/spbert-mlm-wso-base", config=decoder_config
)

In [5]:
from KBQA.appB.transformer_architectures.kb.model import BertSeq2Seq

config = BertConfig.from_pretrained("bert-base-uncased")
device = torch.device("cuda:0")
tokenizer_uncased = BertTokenizer.from_pretrained("bert-base-uncased")

model = BertSeq2Seq(encoder=encoder,
                    triple_encoder=triple_encoder,
                    decoder=decoder,
                    config=config,
                    beam_size=2,
                    max_length=4,
                    sos_id=tokenizer_uncased.cls_token_id,
                    eos_id=tokenizer_uncased.sep_token_id,
                    device=device
                    )


In [4]:
from KBQA.appB.transformer_architectures.kb.knowbert_utils import KnowBertBatchifier

archive_file = 'https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz'
batcher = KnowBertBatchifier(archive_file)

2022-06-10 09:17:09,264 - knowbert-logger.batchifier - INFO - Building Generators




2022-06-10 09:18:10,234 - knowbert-logger.wiki - INFO - duplicate_mentions_cnt: 6777
2022-06-10 09:18:10,235 - knowbert-logger.wiki - INFO - end of p_e_m reading. wall time: 0.9682023803393046 minutes
2022-06-10 09:18:10,235 - knowbert-logger.wiki - INFO - p_e_m_errors: 0
2022-06-10 09:18:10,236 - knowbert-logger.wiki - INFO - incompatible_ent_ids: 0
2022-06-10 09:18:39,660 - knowbert-logger.batchifier - INFO - Build Generators
2022-06-10 09:18:50,385 - knowbert-logger.batchifier - INFO - Done building candidate generators


In [7]:
import torch
sentences = ["Paris is located in France.", "KnowBert is a knowledge enhanced BERT"]
encoder.eval()
logger.setLevel(logging.DEBUG)
# batcher takes raw untokenized sentences
# and yields batches of tensors needed to run KnowBert
for batch in batcher.iter_batches(sentences, verbose=True):
    # model_output['contextual_embeddings'] is (batch_size, seq_len, embed_dim) tensor of top layer activations
    model_output = encoder(**batch)
    logger.info(model_output)

2022-06-10 09:31:50,216 - knowbert-logger.candidate-generator - DEBUG - offsets: [1, 2, 3, 4, 5, 6]
2022-06-10 09:31:50,216 - knowbert-logger.candidate-generator - DEBUG - word_piece_tokens: [['paris'], ['is'], ['located'], ['in'], ['france'], ['.']]
2022-06-10 09:31:50,217 - knowbert-logger.candidate-generator - DEBUG - tokens: ['Paris', 'is', 'located', 'in', 'France', '.']
2022-06-10 09:31:50,222 - knowbert-logger.batchifier - DEBUG - token_candidates: {'tokens': ['[CLS]', 'paris', 'is', 'located', 'in', 'france', '.', '[SEP]'], 'segment_ids': [0, 0, 0, 0, 0, 0, 0, 0], 'candidates': {'wiki': {'tokenized_text': ['Paris', 'is', 'located', 'in', 'France', '.'], 'candidate_spans': [[1, 1], [3, 3], [5, 5], [6, 6]], 'candidate_entities': [['Paris', 'Paris_(mythology)', 'Paris_Hilton', 'Paris,_Texas', 'Paris,_Kentucky', 'Paris_Las_Vegas', 'Paris_Masters', 'Paris_(Paris_Hilton_album)', 'Paris,_Missouri', 'Paris_Metropolitan_Area', 'Paris_(The_Cure_album)', 'Count_Paris', 'University_of_Pari

In [3]:
import pickle

with open("Batchifier.pkl", "wb") as pickle_file:
    pickle.dump(batcher, pickle_file)

In [2]:
import pickle

with open("Batchifier.pkl", "rb") as pickle_file:
    batcher = pickle.load(pickle_file)

In [7]:
tokenizer_cased = BertTokenizer.from_pretrained('bert-base-cased')

In [8]:
import logging

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

In [9]:
class Example:
    """A single training/test example."""

    def __init__(self, idx, source, triples, target):
        self.idx = idx
        self.source = source
        self.triples = triples
        self.target = target


def read_examples(source_file, triples_file, target_file):
    """Read examples from filename."""
    examples = []
    with open(source_file, encoding="utf-8") as source_f:
        with open(triples_file, encoding="utf-8") as triples_f:
            with open(target_file, encoding="utf-8") as target_f:
                for idx, (source, triples, target) in enumerate(
                    zip(source_f, triples_f, target_f)
                ):
                    examples.append(
                        Example(
                            idx=idx,
                            source=source.strip(),
                            triples=triples.strip(),
                            target=target.strip(),
                        )
                    )
    return examples

class InputFeatures:
    """A single training/test features for a example."""

    def __init__(
        self,
        example_id,
        triples_ids,
        target_ids,
        triples_mask,
        target_mask,
    ):
        self.example_id = example_id
        self.triples_ids = triples_ids
        self.target_ids = target_ids
        self.triples_mask = triples_mask
        self.target_mask = target_mask

def replace_mask(text):
    return text.replace('[MASK]', ' [MASK] ')
def convert_examples_to_features(examples, 
                                 tokenizer, 
                                 max_triple_length, 
                                 max_target_length, 
                                 stage=None):
    features = []
    for example_index, example in enumerate(examples):
        # source handled elsewhere
        

        # triples
        triples_tokens = tokenizer.tokenize(example.triples)[: max_triple_length]
        triples_ids = tokenizer.convert_tokens_to_ids(triples_tokens)
        triples_mask = [1] * (len(triples_tokens))
        padding_length = max_triple_length - len(triples_ids)
        triples_ids += [tokenizer.pad_token_id] * padding_length
        triples_mask += [0] * padding_length

        # target
        if stage == "test" or stage == "predict":
            target_tokens = tokenizer.tokenize("None")
        else:
            target_tokens = tokenizer.tokenize(example.target)[
                : max_target_length - 2
            ]
        target_tokens = [tokenizer.cls_token] + target_tokens + [tokenizer.sep_token]
        target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
        target_mask = [1] * len(target_ids)
        padding_length = max_target_length - len(target_ids)
        target_ids += [tokenizer.pad_token_id] * padding_length
        target_mask += [0] * padding_length

        if example_index < 5:
            if stage == "train":
                logger.info("*** Example ***")
                logger.info("idx: {}".format(example.idx))

                logger.info(
                    "triples_tokens: {}".format(
                        [x.replace("\u0120", "_") for x in triples_tokens]
                    )
                )
                logger.info("triples_ids: {}".format(" ".join(map(str, triples_ids))))
                logger.info("triples_mask: {}".format(" ".join(map(str, triples_mask))))

                logger.info(
                    "target_tokens: {}".format(
                        [x.replace("\u0120", "_") for x in target_tokens]
                    )
                )
                logger.info("target_ids: {}".format(" ".join(map(str, target_ids))))
                logger.info("target_mask: {}".format(" ".join(map(str, target_mask))))

        features.append(
            InputFeatures(
                example_index,
                triples_ids,
                target_ids,
                triples_mask,
                target_mask,
            )
        )
    return features

In [10]:
train_filename = "../bert_spbert_spbert_base/data/qald-9-small/preprocessed_data_files/qtq-qald-9-train-small"

train_examples = read_examples(
            train_filename + "." + "en",
            train_filename + ".triple",
            train_filename + "." + "sparql",
        )



In [16]:
print(train_examples[0].source)

List all boardgames by GMT .


In [11]:
train_features = convert_examples_to_features(train_examples, 
                                            tokenizer_cased, 
                                            max_triple_length=32, 
                                            max_target_length=32, 
                                            stage="train")

06/07/2022 07:47:59 - INFO - __main__ -   *** Example ***
06/07/2022 07:47:59 - INFO - __main__ -   idx: 0
06/07/2022 07:47:59 - INFO - __main__ -   triples_tokens: ['d', '##b', '##r', ':', 'Greenwich', '_', 'Mean', '_', 'Time', 'a', 'ya', '##go', ':', 'Time', '##P', '##eri', '##od', '##11', '##51', '##13', '##22', '##9', '.', 'd', '##b', '##r', ':', 'Greenwich', '_', 'Mean', '_', 'Time']
06/07/2022 07:47:59 - INFO - __main__ -   triples_ids: 173 1830 1197 131 14323 168 25030 168 2614 170 11078 2758 131 2614 2101 9866 5412 14541 24050 17668 20581 1580 119 173 1830 1197 131 14323 168 25030 168 2614
06/07/2022 07:47:59 - INFO - __main__ -   triples_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
06/07/2022 07:47:59 - INFO - __main__ -   target_tokens: ['[CLS]', 'select', 'variable', ':', 'u', '##ri', 'where', 'bracket', 'open', 'variable', ':', 'u', '##ri', 'd', '##bo', ':', 'publisher', 'd', '##b', '##r', ':', 'GM', '##T', '_', 'Games', 'bracket', 'close', '[SEP]']

In [12]:
batch = batcher.iter_batches([train_example.source for train_example in train_examples])
source_fields = next(batch)

source_fields['tokens']['tokens'] = source_fields['tokens']['tokens']['tokens']
source_fields['candidates']['wiki']['candidate_entities']['ids'] = \
    source_fields['candidates']['wiki']['candidate_entities']['ids']['token_characters']

candidate_mask = (source_fields['candidates']['wiki']['candidate_entities']['ids'] > 0).type(torch.uint8)
source_fields['candidates']['wiki']['candidate_entities']['ids'] -= candidate_mask
print(candidate_mask)

source_fields['candidates']['wordnet']['candidate_entities']['ids'] = \
    source_fields['candidates']['wordnet']['candidate_entities']['ids']['token_characters']

candidate_mask = (source_fields['candidates']['wordnet']['candidate_entities']['ids'] > 0).type(torch.uint8)
source_fields['candidates']['wordnet']['candidate_entities']['ids'] -= candidate_mask

print(source_fields)
all_source_ids = source_fields['tokens']['tokens']
all_source_segment_ids = source_fields['segment_ids']
print(all_source_ids)
print(all_source_segment_ids)
max_source_length = 32
padding_length = max(max_source_length - len(all_source_ids[0]), 0)
num_examples = len(all_source_ids)
if padding_length > 0:
    source_ids = torch.cat((all_source_ids, torch.full((num_examples,padding_length), fill_value=tokenizer_uncased.pad_token_id)), dim=1)
    source_segment_ids = torch.cat((all_source_segment_ids, torch.full((num_examples,padding_length), fill_value=0)), dim=1)
source_candidates = source_fields['candidates']
all_source_mask = all_source_ids > 0

print(all_source_ids)
print(all_source_mask)
print(all_source_segment_ids)

all_source_wiki_candidate_priors = source_fields['candidates']['wiki']['candidate_entity_priors']
all_source_wiki_candidate_ids = source_fields['candidates']['wiki']['candidate_entities']['ids']
all_source_wiki_candidate_spans = source_fields['candidates']['wiki']['candidate_spans']
all_source_wiki_candidate_segment_ids = source_fields['candidates']['wiki']['candidate_segment_ids']

all_source_wordnet_candidate_priors = source_fields['candidates']['wordnet']['candidate_entity_priors']
all_source_wordnet_candidate_ids = source_fields['candidates']['wordnet']['candidate_entities']['ids']
all_source_wordnet_candidate_spans = source_fields['candidates']['wordnet']['candidate_spans']
all_source_wordnet_candidate_segment_ids = source_fields['candidates']['wordnet']['candidate_segment_ids']



offsets: [1, 2, 4, 5, 7, 8]
word_piece_tokens: [['list'], ['all'], ['board', '##games'], ['by'], ['gm', '##t'], ['.']]
tokens: ['List', 'all', 'boardgames', 'by', 'GMT', '.']
name wiki
mention_generator <KBQA.appB.transformer_architectures.kb.wiki_linking_util.WikiCandidateMentionGenerator object at 0x7fa0448b64f0>
name wordnet
mention_generator <KBQA.appB.transformer_architectures.kb.wordnet.WordNetCandidateMentionGenerator object at 0x7fa03a5db280>
token_candidates: {'tokens': ['[CLS]', 'list', 'all', 'board', '##games', 'by', 'gm', '##t', '.', '[SEP]'], 'segment_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'candidates': {'wiki': {'tokenized_text': ['List', 'all', 'boardgames', 'by', 'GMT', '.'], 'candidate_spans': [[1, 1], [3, 4], [6, 7], [8, 8]], 'candidate_entities': [['List_(abstract_data_type)', 'List,_Schleswig-Holstein', 'President_of_Iran', 'Sniper_rifle', 'Angle_of_list', 'Prime_Minister_of_Poland', 'Robert_List', 'Prime_Minister_of_Iraq', 'Friedrich_List', 'Party-list_proportional

In [13]:
all_triples_ids = torch.tensor(
    [f.triples_ids for f in train_features], dtype=torch.long
)
all_triples_mask = torch.tensor(
    [f.triples_mask for f in train_features], dtype=torch.long
)
all_target_ids = torch.tensor(
    [f.target_ids for f in train_features], dtype=torch.long
)
all_target_mask = torch.tensor(
    [f.target_mask for f in train_features], dtype=torch.long
)

In [14]:
from torch.utils.data import TensorDataset
train_data = TensorDataset(
    all_source_ids,
    all_source_mask,
    all_source_wiki_candidate_priors,
    all_source_wiki_candidate_ids,
    all_source_wiki_candidate_spans,
    all_source_wiki_candidate_segment_ids,
    all_source_wordnet_candidate_priors,
    all_source_wordnet_candidate_ids,
    all_source_wordnet_candidate_spans,
    all_source_wordnet_candidate_segment_ids,
    all_triples_ids,
    all_triples_mask,
    all_target_ids,
    all_target_mask,
)

In [17]:
from torch.utils.data import RandomSampler
from torch.utils.data import DataLoader
train_sampler = RandomSampler(train_data)

train_dataloader = DataLoader(
    train_data,
    sampler=train_sampler,
    batch_size=2 // 1
)

num_train_optimization_steps = -1

# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [
            p
            for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.01,
    },
    {
        "params": [
            p
            for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.0,
    },
]
t_total = (
    len(train_dataloader)
    // 1
    * 2
)
optimizer = AdamW(
    optimizer_grouped_parameters, lr=5e-5, eps=1e-8
)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=int(t_total * 0.1), num_training_steps=t_total
)

# Start training
logger.info("***** Running training *****")
logger.info("  Num examples = %d", len(train_examples))
logger.info("  Batch size = %d", 2)
logger.info("  Num epoch = %d", 2)

model.train()
dev_dataset = {}
nb_tr_examples, nb_tr_steps, tr_loss, global_step, best_bleu, best_loss = (
    0,
    0,
    0,
    0,
    -1,
    1e6,
)

06/07/2022 07:55:51 - INFO - __main__ -   ***** Running training *****
06/07/2022 07:55:51 - INFO - __main__ -     Num examples = 8
06/07/2022 07:55:51 - INFO - __main__ -     Batch size = 2
06/07/2022 07:55:51 - INFO - __main__ -     Num epoch = 2


In [19]:
for epoch in range(2):
    bar = tqdm(train_dataloader, total=len(train_dataloader))
    for batch in bar:
        batch = tuple(t.to(device) for t in batch)
        (
            source_ids,
            source_mask,
            source_wiki_candidate_priors,
            source_wiki_candidate_ids,
            source_wiki_candidate_spans,
            source_wiki_candidate_segment_ids,
            source_wordnet_candidate_priors,
            source_wordnet_candidate_ids,
            source_wordnet_candidate_spans,
            source_wordnet_candidate_segment_ids,
            triples_ids,
            triples_mask,
            target_ids,
            target_mask
        ) = batch

        source_candidates = {
            'wiki' : {
                'candidate_entity_priors' : source_wiki_candidate_priors,
                'candidate_entities' : {'ids' : source_wiki_candidate_ids},
                'candidate_spans' : source_wiki_candidate_spans,
                'candidate_segment_ids' : source_wiki_candidate_segment_ids
            },
            'wordnet' : {
                'candidate_entity_priors' : source_wordnet_candidate_priors,
                'candidate_entities' : {'ids' : source_wordnet_candidate_ids},
                'candidate_spans' : source_wordnet_candidate_spans,
                'candidate_segment_ids' : source_wordnet_candidate_segment_ids
            }
        }


        loss, _, _ = model(
            source_ids=source_ids,
            source_mask=source_mask,
            source_candidates=source_candidates,
            triples_ids=triples_ids,
            triples_mask=triples_mask,
            target_ids=target_ids,
            target_mask=target_mask,
        )

        tr_loss += loss.item()
        train_loss = round(
            tr_loss * -1 / (nb_tr_steps + 1), 4
        )
        bar.set_description("epoch {} loss {}".format(epoch, train_loss))
        nb_tr_examples += source_ids.size(0)
        nb_tr_steps += 1
        loss.backward()

        if (nb_tr_steps + 1) % -1 == 0:
            # Update parameters
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            global_step += 1

    if False and (epoch + 1) % 1 == 0:
        # Eval model with dev dataset
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        eval_flag = False

        # Calculate bleu
        if "dev_bleu" in dev_dataset:
            eval_examples, eval_data = dev_dataset["dev_bleu"]
        else:
            eval_examples = read_examples(
                args.dev_filename + "." + args.source,
                args.dev_filename + ".triple",
                args.dev_filename + "." + args.target,
            )
            eval_examples = random.sample(
                eval_examples, min(1000, len(eval_examples))
            )
            eval_features = convert_examples_to_features(
                eval_examples, tokenizer, args, stage="test"
            )
            all_source_ids = torch.tensor(
                [f.source_ids for f in eval_features], dtype=torch.long
            )
            all_source_mask = torch.tensor(
                [f.source_mask for f in eval_features], dtype=torch.long
            )
            all_triples_ids = torch.tensor(
                [f.triples_ids for f in eval_features], dtype=torch.long
            )
            all_triples_mask = torch.tensor(
                [f.triples_mask for f in eval_features], dtype=torch.long
            )
            eval_data = TensorDataset(
                all_source_ids,
                all_source_mask,
                all_triples_ids,
                all_triples_mask,
            )
            dev_dataset["dev_bleu"] = eval_examples, eval_data

        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(
            eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size
        )

        model.eval()
        p = []
        for batch in eval_dataloader:
            batch = tuple(t.to(device) for t in batch)
            source_ids, source_mask, triples_ids, triples_mask = batch
            with torch.no_grad():
                preds = model(
                    source_ids=source_ids,
                    source_mask=source_mask,
                    triples_ids=triples_ids,
                    triples_mask=triples_mask,
                )
                for pred in preds:
                    t = pred[0].cpu().numpy()
                    t = list(t)
                    if 0 in t:
                        t = t[: t.index(0)]
                    text = tokenizer.decode(
                        t, clean_up_tokenization_spaces=False
                    )
                    p.append(text)
        model.train()
        predictions = []
        pred_str = []
        label_str = []
        with open(os.path.join(args.output_dir, "dev.output"), "w") as f, open(
            os.path.join(args.output_dir, "dev.gold"), "w"
        ) as f1:
            for ref, gold in zip(p, eval_examples):
                ref = ref.strip().replace("< ", "<").replace(" >", ">")
                ref = re.sub(
                    r' ?([!"#$%&\'(’)*+,-./:;=?@\\^_`{|}~]) ?', r"\1", ref
                )
                ref = ref.replace("attr_close>", "attr_close >").replace(
                    "_attr_open", "_ attr_open"
                )
                ref = ref.replace(" [ ", " [").replace(" ] ", "] ")
                ref = ref.replace("_obd_", " _obd_ ").replace(
                    "_oba_", " _oba_ "
                )

                pred_str.append(ref.split())
                label_str.append([gold.target.strip().split()])
                predictions.append(str(gold.idx) + "\t" + ref)
                f.write(str(gold.idx) + "\t" + ref + "\n")
                f1.write(str(gold.idx) + "\t" + gold.target + "\n")

        bl_score = corpus_bleu(label_str, pred_str) * 100

        logger.info("  {} = {} ".format("BLEU", str(round(bl_score, 4))))
        logger.info("  " + "*" * 20)
        if bl_score > best_bleu:
            logger.info("  Best bleu:%s", bl_score)
            logger.info("  " + "*" * 20)
            best_bleu = bl_score
            # Save best checkpoint for best bleu
            output_dir = os.path.join(args.output_dir, "checkpoint-best-bleu")
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Only save the model it-self
            output_model_file = os.path.join(output_dir, "pytorch_model.bin")
            torch.save(model_to_save.state_dict(), output_model_file)

  0%|          | 0/4 [00:03<?, ?it/s]

In forward
tokens tensor([[  101,  2029, 13586,  2515,  2250,  2859,  3710,  1029,   102,     0,
             0,     0,     0,     0],
        [  101,  2040,  2003,  1996,  3664,  1997,  2047,  2259,  2103,  1029,
           102,     0,     0,     0]], device='cuda:0')
candidates {'wiki': {'candidate_entity_priors': tensor([[[8.4444e-01, 1.5556e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [9.4826e-01, 1.4411e-02, 8.9771e-03, 8.9771e-03, 6.1422e-03,
          5.9060e-03, 2.1262e-03, 1.6537e-03, 1.1812e-03, 1.1812e-03,
          2.3624e-04, 2.3624e-04, 2.3624e-04, 2.3624e-04, 2.3624e-04,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.000




IndexError: too many indices for tensor of dimension 2