In [1]:
import math
from tqdm import tqdm
import numpy as np
from typing import Dict
from collections import defaultdict
from torch.utils.data import DataLoader
import torch
from pymongo import MongoClient
from torch.utils.data import Dataset
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, T5ForConditionalGeneration, T5Config, PretrainedConfig
import os
from transformers import (
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    BeamSearchScorer,
)
from typing import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
entities = []
with open("wd5m_aliases_entities_v3.txt", 'r') as f:
    for line in f:
        entities.append(line.split('\t')[1][:-1])

In [3]:
entities[:10], len(entities)

(['human',
  'United States of America',
  'taxon',
  'species',
  'United Kingdom',
  'English',
  'association football',
  'politician',
  'association football player',
  'UTC+01:00'],
 4818679)

In [4]:
entities = entities[:10]

In [5]:
class KGLMDataset(Dataset):
    def __init__(self, port, db, collection):
        self.client = MongoClient('localhost', port)
        self.db_name = db
        self.collection_name = collection
        self.collection = self.client[db][collection]
        self.tokenizer = AutoTokenizer.from_pretrained("t5-base")
        self.tokenizer.add_tokens(['[SEP]'], special_tokens=True)
        self.length = self.client[self.db_name].command("collstats", self.collection_name)['count']

    def  __getitem__(self, idx):
        item = {}
        doc = self.collection.find_one({'_id': str(idx)})
        item["input"] = doc['verbalization']
        item["outputs"] = doc['target']
        return item
        
    def __len__(self):
        return self.length
    
    def _collate_eval(self, batch):
        encode_plus_kwargs = {'truncation': True, 'padding': 'longest', 'pad_to_multiple_of': 1}

        
        inputs = [b['input'] for b in batch]
        inputs_tokenized = self.tokenizer.batch_encode_plus(list(inputs), max_length=512, return_tensors='pt',
                                                   **encode_plus_kwargs)
        
        target_text = [b["outputs"] for b in batch]

        return inputs_tokenized.input_ids, inputs_tokenized.attention_mask, target_text

In [6]:
dataset = KGLMDataset(27017, 'KGLM', 'test')
data_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=1,
    collate_fn=dataset._collate_eval
)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [7]:
path = 'lr5e-05_constant_with_warmup_adamw_wd1e-03_512-512_bs64_iters4000000/run_1/'

model_cpt = os.path.join(path, 'model_best.pth')
config_path = os.path.join(path, 'config.json')

model_cfg = AutoConfig.from_pretrained('t5-small')
model = T5ForConditionalGeneration(config=model_cfg)

cpt = torch.load(model_cpt, map_location='cpu')
model.load_state_dict(cpt['model_state_dict'])

<All keys matched successfully>

In [8]:
class Args:
    def __init__(self, batch_size=1, chunk_size=50, num_workers=2, device='cuda'):
        self.batch_size = batch_size
        self.beam_size = 11
        self.num_predictions = 10
        self.length_penalty = 0.3
        self.num_workers = num_workers
        self.device = device

args = Args()

In [9]:
def eval_multi_old(model, dataset, args):
    num_workers = 1
    batch_size = args.batch_size
    model.cuda()
    model.eval()
    
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
                            collate_fn=dataset._collate_eval)
    loader = tqdm(data_loader, total=len(data_loader), unit="batches")
    i = 0
    beam_size = args.beam_size
    num_predictions = args.num_predictions
    length_penalty = args.length_penalty
    correct = 0
    print('Beams: %d, Predictions: %d, Length Penalty: %f' % (beam_size, num_predictions, length_penalty))
    
    for steps, batch in enumerate(loader):
        
        encoder_input_ids, attention_mask, target_text = batch
        encoder_input_ids = encoder_input_ids.cuda()
        attention_mask = attention_mask.cuda()
        input_ids = torch.ones((len(encoder_input_ids) * beam_size, 1), device=model.device, dtype=torch.long)
        
        input_ids = input_ids * model.config.decoder_start_token_id

        model_kwargs = {
            "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(beam_size, dim=0), return_dict=True)
        }

        beam_scorer = BeamSearchScorer(
            batch_size=len(encoder_input_ids),
            num_beams=beam_size,
            device=model.device,
            num_beam_hyps_to_keep=num_predictions,
            length_penalty = length_penalty
        )
        logits_processor = LogitsProcessorList([])
        outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs, max_length=64)
        # outputs = model.generate(input_ids = encoder_input_ids)
        # target_text = dataset.tokenizer.batch_decode(labels, skip_special_tokens=True)
        predicted_text = dataset.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        input_text = dataset.tokenizer.batch_decode(encoder_input_ids, skip_special_tokens=True)

        current_batch_size = len(encoder_input_ids)
        predicted_grouped = []
        for i in range(current_batch_size):
            predicted_grouped.append(predicted_text[i*num_predictions: (i+1)*num_predictions])

        for i in range(current_batch_size):
            target = target_text[i]
            predicted = set(predicted_grouped[i])
#             print(target, predxicted)
            if target in predicted:
                correct += 1
            
#         if steps % 100 == 0 and steps != 0:
#             print(correct/steps)
    accuracy = correct/len(dataset)
    return accuracy    

In [10]:
accuracy = eval_multi_old(model, dataset, args)

print(accuracy)

  0%|                                             | 0/5133 [00:00<?, ?batches/s]

Beams: 11, Predictions: 10, Length Penalty: 0.300000


 50%|█████████████████▏                | 2592/5133 [07:56<07:46,  5.44batches/s]


RuntimeError: CUDA out of memory. Tried to allocate 12.00 MiB (GPU 0; 10.92 GiB total capacity; 9.42 GiB already allocated; 9.56 MiB free; 10.27 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
data_loader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=2,
                        collate_fn=dataset._collate_eval)
loader = tqdm(data_loader, total=len(data_loader), unit="batches")

for steps, batch in enumerate(loader):
    encoder_input_ids, attention_mask, target_text = batch
    print(encoder_input_ids)
    input_ = dataset.tokenizer.batch_decode(encoder_input_ids, skip_special_tokens=True)
    output = dataset.tokenizer.batch_decode(model.generate(encoder_input_ids.to('cuda')), skip_special_tokens=True)
    print(input_, output, target_text)
    print()