In [1]:
import math
from tqdm import tqdm
import numpy as np
from typing import Dict
from collections import defaultdict
from torch.utils.data import DataLoader
import torch
from pymongo import MongoClient
from torch.utils.data import Dataset
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, T5ForConditionalGeneration, T5Config, PretrainedConfig
import os
from transformers import (
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    BeamSearchScorer,
)
from typing import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
entities = []
with open("wd5m_aliases_entities_v3.txt", 'r') as f:
    for line in f:
        entities.append(str(line.split('\t')[1][:-1]))

In [3]:
entities[:10], len(entities)

(['human',
  'United States of America',
  'taxon',
  'species',
  'United Kingdom',
  'English',
  'association football',
  'politician',
  'association football player',
  'UTC+01:00'],
 4818679)

In [4]:
class KGLMDataset(Dataset):
    def __init__(self, port, db, collection):
        self.client = MongoClient('localhost', port)
        self.db_name = db
        self.collection_name = collection
        self.collection = self.client[db][collection]
        self.tokenizer = AutoTokenizer.from_pretrained("t5-base")
        self.tokenizer.add_tokens(['[SEP]'], special_tokens=True)
        self.length = self.client[self.db_name].command("collstats", self.collection_name)['count']

    def  __getitem__(self, idx):
        item = {}
        doc = self.collection.find_one({'_id': str(idx)})
        item["input"] = doc['verbalization']
        item["outputs"] = doc['target']
        return item
        
    def __len__(self):
        return self.length
    
    def _collate_eval(self, batch):
        encode_plus_kwargs = {'truncation': True, 'padding': 'longest', 'pad_to_multiple_of': 1}

        
        inputs = [b['input'] for b in batch]
        inputs_tokenized = self.tokenizer.batch_encode_plus(list(inputs), max_length=512, return_tensors='pt',
                                                   **encode_plus_kwargs)
        
        target_text = [b["outputs"] for b in batch]

        return inputs_tokenized.input_ids, inputs_tokenized.attention_mask, target_text, inputs




In [5]:
dataset = KGLMDataset(27017, 'KGLM', 'test')
data_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=1,
    collate_fn=dataset._collate_eval
)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [6]:
path = 'lr5e-05_constant_with_warmup_adamw_wd1e-03_512-512_bs64_iters4000000/run_1/'

model_cpt = os.path.join(path, 'model_best.pth')
config_path = os.path.join(path, 'config.json')

model_cfg = AutoConfig.from_pretrained('t5-small')
model = T5ForConditionalGeneration(config=model_cfg)

cpt = torch.load(model_cpt, map_location='cpu')
model.load_state_dict(cpt['model_state_dict'])

<All keys matched successfully>

In [7]:
class Args:
    def __init__(self, entity_strings=entities, batch_size=2, chunk_size=100, num_workers=2, device='cuda'):
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.chunk_size = chunk_size
        self.device = device
        self.entity_strings = entity_strings

args = Args()

In [8]:
class Evaluator:
    def __init__(self, dataset: KGLMDataset, model, args):
        self.device = args.device
        self.dataset = dataset
        self.model = model.to(self.device)
        self.num_workers = args.num_workers
        self.batch_size = args.batch_size
        self.chunk_size = args.chunk_size
        self.entity_strings = args.entity_strings
        self.ent2id = {ent: i for i, ent in enumerate(self.entity_strings)}
        self.data_loader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=dataset._collate_eval,
        )
#         self.filter_dicts = dict()
#         self.filter_dicts["train"] = self.create_filter_dict("train")
#         self.filter_dicts["valid"] = self.create_filter_dict("valid")
#         self.filter_dicts["test"] = self.create_filter_dict("test")

#     def create_filter_dict(self, split: str) -> Dict[str, int]:
#         data = self.dataset.split(split)
#         filter_dict = defaultdict(list)
#         for input, output in zip(data["inputs"], data["outputs"]):
#             filter_dict[input].append(self.dataset.entity_string_to_id[output])
#         return filter_dict

    @torch.no_grad()
    def eval(self):
        self.model.eval()
        loader = tqdm(self.data_loader, total=len(self.data_loader), unit="batch")
        ranks = {
            "unfiltered": list(),
            "filtered": list(),
        }
        for steps, batch in enumerate(loader):
            ranks_in_batch = {
                "unfiltered": list(),
                "filtered": list()
            }
            
            input_ids, attention_mask, label_strings, input_strings = batch
            input_ids = input_ids.to(self.device)
            attention_mask = attention_mask.to(self.device)
            
            # labels = labels.to(self.device)
#             input_ids_repeated = torch.repeat_interleave(
#                 input_ids, len(self.dataset.entity_strings), dim=0
#             )
#             attention_mask_repeated = torch.repeat_interleave(
#                 attention_mask, len(self.dataset.entity_strings), dim=0
#             )
#             tokenized_entities = self.dataset.tokenized_entities.input_ids.to(
#                 self.device
#             )
            # todo: for filtering we need to use only the filtered entities per triple here
#             all_entities_repeated = tokenized_entities.repeat([self.batch_size, 1])
            summed_logit_chunks = []
            # process chunk by chunk
            for chunk_number in tqdm(range(
                math.ceil(len(self.entity_strings) / self.chunk_size)
            )):
                chunk_start = self.chunk_size * chunk_number
                chunk_end = min(
                    self.chunk_size * (chunk_number + 1), len(self.entity_strings)
                )
                current_chunk_size = chunk_end - chunk_start
                
                input_ids_repeated = torch.repeat_interleave(
                    input_ids, current_chunk_size, dim=0
                )
                
                attention_mask_repeated = torch.repeat_interleave(
                    attention_mask, current_chunk_size, dim=0
                )
                
                entities_repeated_chunk = \
                    self.dataset.tokenizer.batch_encode_plus(self.entity_strings[chunk_start:chunk_end], padding='max_length', 
                    truncation=True, max_length=32, return_tensors="pt").input_ids.to(self.device)
                
                entities_repeated_chunk = entities_repeated_chunk.repeat([self.batch_size, 1])
                
                outputs_chunk = self.model(
                    input_ids=input_ids_repeated,
                    attention_mask=attention_mask_repeated,
                    labels=entities_repeated_chunk,
                )

                
                logits_chunk = outputs_chunk.logits
                soft_logits_chunk = torch.log_softmax(logits_chunk, dim=2)

                coordinates = entities_repeated_chunk.view(current_chunk_size*self.batch_size, -1, 1)

                # set padded logits to zero
                padded_mask = (coordinates == 0).squeeze()
                soft_logits_chunk[padded_mask] = 0
                needed_soft_logits_chunk = torch.gather(
                    soft_logits_chunk,
                    2,
                    coordinates
#                 )
                ).view(self.batch_size, current_chunk_size, -1)
                
                summed_logits = torch.sum(needed_soft_logits_chunk, dim=2)
                summed_logit_chunks.append(summed_logits)
                
            summed_logits = torch.cat(summed_logit_chunks, dim=1)
            
            for summed_logits_per_triple, input_string, label in zip(
                summed_logits, input_strings, label_strings
            ):
                # todo: currently we are calculating best rank on equality
                #  change to mean
                arg_sorted = torch.argsort(summed_logits_per_triple, descending=True)
                print(arg_sorted)
                
                entity_id = self.ent2id[label]
                print(entity_id)
                rank = (
                    (arg_sorted == entity_id)
                    .nonzero(as_tuple=True)[0]
                    .item()
                )
                print(rank)
                ranks_in_batch["unfiltered"].append(rank)

                # now filter
#                 true_score = summed_logits_per_triple[entity_id].clone()
#                 for filter_dict in self.filter_dicts.values():
#                     summed_logits_per_triple[filter_dict[input_string]] = -float("inf")
#                 summed_logits_per_triple[entity_id] = true_score
#                 arg_sorted = torch.argsort(summed_logits_per_triple, descending=True)
#                 rank = (
#                     (arg_sorted == entity_id)
#                         .nonzero(as_tuple=True)[0]
#                         .item()
#                 )
#                 print(rank)
#                 ranks_in_batch["filtered"].append(rank)
#             ranks["filtered"].extend(ranks_in_batch["filtered"])
            ranks["unfiltered"].extend(ranks_in_batch["unfiltered"])
        for setting, list_of_ranks in ranks.items():
            ranks[setting] = np.array(list_of_ranks, dtype=np.float32) + 1
        # ranks = np.array(ranks, dtype=np.float32)
        # # add 1 to have best rank 1 not 0
        # ranks += 1
        print("MR", ranks["unfiltered"].mean())
#         print("MR-filtered", ranks["filtered"].mean())
        print("MRR", np.power(ranks["unfiltered"], -1).mean())
        print("MRR-filtered", np.power(ranks["filt /ered"], -1).mean())
        print("Hits@1", (ranks["unfiltered"] == 1).sum() / len(self.dataset))
#         print("Hits@1-filtered", (ranks["filtered"] == 1).sum() / len(self.dataset))
        print("Hits@10", (ranks["unfiltered"] <= 10).sum() / len(self.dataset))
#         print("Hits@10-filtered", (ranks["filtered"] <= 10).sum() / len(self.dataset))


In [9]:
evaluator = Evaluator(dataset, model, args)

In [10]:
evaluator.eval()


  0%|                                                 | 0/48187 [00:00<?, ?it/s][A
  0%|                                       | 1/48187 [00:00<3:33:28,  3.76it/s][A
  0%|                                       | 2/48187 [00:00<2:44:54,  4.87it/s][A
  0%|                                       | 3/48187 [00:00<2:40:05,  5.02it/s][A
  0%|                                       | 4/48187 [00:00<2:40:23,  5.01it/s][A
  0%|                                       | 5/48187 [00:01<2:40:44,  5.00it/s][A
  0%|                                       | 6/48187 [00:01<2:41:26,  4.97it/s][A
  0%|                                       | 7/48187 [00:01<2:41:45,  4.96it/s][A
  0%|                                       | 8/48187 [00:01<2:43:36,  4.91it/s][A
  0%|                                       | 9/48187 [00:01<2:41:28,  4.97it/s][A
  0%|                                      | 10/48187 [00:02<2:41:47,  4.96it/s][A
  0%|                                      | 11/48187 [00:02<2:41:27,  4.97

  0%|                                      | 89/48187 [00:17<2:41:35,  4.96it/s][A
  0%|                                      | 90/48187 [00:18<2:42:34,  4.93it/s][A
  0%|                                      | 91/48187 [00:18<2:42:37,  4.93it/s][A
  0%|                                      | 92/48187 [00:18<2:42:46,  4.92it/s][A
  0%|                                      | 93/48187 [00:18<2:42:58,  4.92it/s][A
  0%|                                      | 94/48187 [00:19<2:41:54,  4.95it/s][A
  0%|                                      | 95/48187 [00:19<2:42:09,  4.94it/s][A
  0%|                                      | 96/48187 [00:19<2:41:58,  4.95it/s][A
  0%|                                      | 97/48187 [00:19<2:42:22,  4.94it/s][A
  0%|                                      | 98/48187 [00:19<2:42:54,  4.92it/s][A
  0%|                                      | 99/48187 [00:20<2:41:41,  4.96it/s][A
  0%|                                     | 100/48187 [00:20<2:41:42,  4.96i

  0%|▏                                    | 186/48187 [00:37<2:43:45,  4.89it/s][A
  0%|▏                                    | 187/48187 [00:38<2:44:38,  4.86it/s][A
  0%|▏                                    | 188/48187 [00:38<2:43:53,  4.88it/s][A
  0%|▏                                    | 189/48187 [00:38<2:43:36,  4.89it/s][A
  0%|▏                                    | 190/48187 [00:38<2:43:38,  4.89it/s][A
  0%|▏                                    | 191/48187 [00:38<2:43:27,  4.89it/s][A
  0%|▏                                    | 192/48187 [00:39<2:43:33,  4.89it/s][A
  0%|▏                                    | 193/48187 [00:39<2:45:20,  4.84it/s][A
  0%|▏                                    | 194/48187 [00:39<2:43:47,  4.88it/s][A
  0%|▏                                    | 195/48187 [00:39<2:44:14,  4.87it/s][A
  0%|▏                                    | 196/48187 [00:39<2:44:39,  4.86it/s][A
  0%|▏                                    | 197/48187 [00:40<2:44:24,  4.86i

  1%|▏                                    | 283/48187 [00:57<2:46:58,  4.78it/s][A
  1%|▏                                    | 284/48187 [00:58<2:46:29,  4.80it/s][A
  1%|▏                                    | 285/48187 [00:58<2:46:01,  4.81it/s][A
  1%|▏                                    | 286/48187 [00:58<2:46:16,  4.80it/s][A
  1%|▏                                    | 287/48187 [00:58<2:45:32,  4.82it/s][A
  1%|▏                                    | 288/48187 [00:58<2:45:26,  4.83it/s][A
  1%|▏                                    | 289/48187 [00:59<2:47:42,  4.76it/s][A
  1%|▏                                    | 290/48187 [00:59<2:45:56,  4.81it/s][A
  1%|▏                                    | 291/48187 [00:59<2:45:15,  4.83it/s][A
  1%|▏                                    | 292/48187 [00:59<2:45:03,  4.84it/s][A
  1%|▏                                    | 293/48187 [00:59<2:45:24,  4.83it/s][A
  1%|▏                                    | 294/48187 [01:00<2:45:42,  4.82i

KeyboardInterrupt: 