In [1]:
from transformers import AutoTokenizer, AutoModel
import functorch
import regex as re 
import torch 
import torch.nn.functional as F
import numpy as np
from functools import lru_cache
from nltk.util import ngrams 
from collections import Counter

In [209]:
def group_duplicates(embeddings, lst, mean=True):
    embeddings = torch.cat([e.reshape(1, -1) for i, e in enumerate(embeddings) if lst[i]!=None], dim=0)
    lst = [i for i in lst if i!=None]
    output = [None for _ in range(len(set(lst)))]
    i = 0
    for idx, i in enumerate(lst):
        if(i!=None):
            if(output[i] == None):
                output[i] = embeddings[idx, :].reshape(1, -1)
            else:
                output[i] = torch.cat((output[i], embeddings[idx, :].reshape(1, -1)), dim=0)
    if(mean):
        for idx, val in enumerate(output):
            output[idx] = torch.mean(output[idx], dim=0).reshape(1, -1)
    return output

def get_facts(source, token_split=True):
    words = source.split(":")[1].split(" ")
    facts = re.split(r'<[^>]*>', " ".join(words))
    facts = [re.sub(r'_', ' ', facts[i].strip()) for i in range(len(facts))]
    if(token_split):
        return tuple([re.sub(r'_', ' ', f) for i in range(len(facts)) for f in facts[i].split()]), [i-1 for i, word in enumerate(facts) for _ in range(len(word.split()))]
    return tuple(facts), []

@lru_cache(maxsize=10000)
def get_embedding(tokens, model, tokenizer, split_into_words=False, parent_device = 'cuda:4'):
    with torch.no_grad():
        tokenized_facts = tokenizer(tokens, padding=True, truncation=True, max_length=512, is_split_into_words=split_into_words, return_tensors="pt").to(parent_device)
        #print(tokenizer.decode(tokenized_facts[1].ids))
        batch_states = model(**tokenized_facts).hidden_states
        batch_output = torch.stack([batch_states[i] for i in range(len(batch_states))])
        batch_output = batch_output.squeeze()
        batch_final_hidden_state = torch.mean(batch_output[:, :, ...], dim=0)
        return batch_final_hidden_state[:, 1:-1, :], list(map(lambda i: tokenized_facts.word_ids(i)[1:-1], range(len(tokens))))

def entailment_prob(fact_embeddings, generated_ngram, threshold=None): 
    if(generated_ngram.dim() == 1):
        generated_ngram = generated_ngram.reshape(1, -1)
    similarities = functorch.vmap(lambda row_a: F.cosine_similarity(row_a, fact_embeddings))(generated_ngram)
    if(threshold):
        #print(torch.max(similarities, dim=1).values.cpu())
        return np.mean((torch.max(similarities, dim=1).values > threshold).int().cpu().numpy())
    return np.mean(torch.max(similarities, dim=1).values.cpu().numpy())

def get_coverage_reward(generated, source,  model, parentTokenizer, parent_device):
    generated = list(map(lambda gen: re.sub(r"[',.।()]", '', gen), generated))
    generated_tokens = tuple(map(lambda gen: tuple(re.sub(r'[ ]{2,}', ' ', gen.strip()).split()),  generated))
    generated_embeddings, g_idx = get_embedding(generated_tokens, model, parentTokenizer, split_into_words=True, parent_device=parent_device)
    gen_emb = list(map(lambda ge, gi: group_duplicates(ge, gi), generated_embeddings, g_idx))
    # gen_emb = [] 
    # for ge, gi in zip(generated_embeddings, g_idx):
    #     gen_emb.append(group_duplicates(ge, gi))

    fact_pos_pairs = tuple(map(lambda src: get_facts(src, token_split=True), source))
    facts = tuple(fact_pos_pairs[i][0] for i in range(len(fact_pos_pairs)))
    #print(facts)
    fact_pos = tuple(fact_pos_pairs[i][1] for i in range(len(fact_pos_pairs)))
    fact_embeddings, f_idx = get_embedding(facts, model, parentTokenizer, split_into_words=True, parent_device=parent_device)
    #print(fact_embeddings[0].shape)
    #print(f_idx[0])
    fact_embeddings = list(map(lambda fe, fi: group_duplicates(fe, fi), fact_embeddings, f_idx))
    fact_emb = [torch.cat(fe).squeeze() for fe in fact_embeddings]
    batch_eps = []
    for generated_tokens_, gen_emb_, fact_emb_ in zip(generated_tokens, gen_emb, fact_emb):
        g_dict = {t: emb for t, emb in zip(generated_tokens_, gen_emb_)}
        generated_ngrams = Counter(list(ngrams(generated_tokens_, 1)))
        eps = [] 
        for ngram, count in generated_ngrams.items():
            ngram_embedding = torch.cat([g_dict[i] for i in ngram if i in g_dict]).squeeze()
            ep = entailment_prob(fact_emb_, ngram_embedding, 0.4)
            eps.append(ep)
        batch_eps.append(np.mean(eps))
    return batch_eps

# def get_coverage_reward(generated, source, model, parentTokenizer, parent_device):


In [50]:
get_embedding(tuple(('this is the first sentence', 'this is the second sentence')), parentModel, parentTok, False, 'cuda')

(tensor([[[-0.4804,  0.0539,  0.0544,  ..., -0.3140,  0.1942,  0.1003],
          [ 0.2996, -0.0531,  0.0894,  ...,  0.0893,  0.2084, -0.0393],
          [-0.2180,  0.1041, -0.0069,  ...,  0.0867,  0.0623, -0.1641],
          [-0.6242, -0.0263, -0.1341,  ...,  0.2337, -0.1200,  0.0210],
          [-0.3984,  0.0523, -0.0879,  ...,  0.0384, -0.0524, -0.0337]],
 
         [[-0.4739,  0.0522,  0.0419,  ..., -0.3060,  0.1482,  0.1512],
          [ 0.3545, -0.0529,  0.1093,  ...,  0.0431,  0.1868, -0.0357],
          [-0.1713,  0.1216,  0.0378,  ...,  0.0866,  0.0663, -0.1286],
          [-0.7763,  0.0850, -0.0959,  ...,  0.1394, -0.1650, -0.0215],
          [-0.4028,  0.0396, -0.0659,  ...,  0.0565, -0.0691, -0.0366]]],
        device='cuda:0'),
 [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]])

In [3]:
parentModel =  AutoModel.from_pretrained("google/muril-base-cased", output_hidden_states=True).to('cuda')
parentTok =  AutoTokenizer.from_pretrained("google/muril-base-cased", padding='max_length', truncation='max_length', max_length=512)
parentModel.eval()

Some weights of the model checkpoint at google/muril-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(197285, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
         

In [211]:
get_coverage_reward(gen, src, parentModel, parentTok, 'cuda')

[0.6666666666666666, 0.47058823529411764]

In [5]:
src = ['generate english high : <H> charlie townsend <R> date_of_birth <T> 07 november 1876 <R> date_of_death <T> 17 october 1958 <R> occupation <T> cricketer', 'generate english high : <H> matthew kleinveldt <R> country_of_citizenship <T> united kingdom <R> date_of_birth <T> 10 august 1989 <R> occupation <T> cricketer']
gen = ['Charles Henry Dunsend ( 7 November 1876 - 17 October 1958 ) was an English first-class cricketer.', 'Matthew William Kleinveldt ( born 10 August 1989 ) is an English cricketer who plays for Lancashire and England.' ]

('Charles Henry Dunsend 7 November 1876 - 17 October 1958 was an English first-class cricketer', 'Matthew William Kleinveldt born 10 August 1989 is an English cricketer who plays for Lancashire and England')
tensor([[[ 0.0835,  0.0800, -0.0204,  ...,  0.1786,  0.2951, -0.0416],
         [ 0.3666, -0.1116,  0.1555,  ...,  0.0649,  0.2151, -0.0020],
         [-0.4001,  0.0111, -0.1273,  ..., -0.0014,  0.0995,  0.0403],
         ...,
         [-0.0201, -0.0308,  0.0537,  ..., -0.1912, -0.2950,  0.0696],
         [-0.2440,  0.1058, -0.0462,  ...,  0.1693,  0.0591, -0.0303],
         [-0.3973,  0.0672,  0.0983,  ..., -0.2901, -0.1923, -0.0645]],

        [[ 0.1944,  0.0336, -0.1154,  ...,  0.2096,  0.1232,  0.0467],
         [ 0.0797, -0.0440,  0.1365,  ...,  0.2682,  0.1908,  0.0426],
         [ 0.0481,  0.0775,  0.0491,  ...,  0.0195,  0.0543,  0.0102],
         ...,
         [-0.6054, -0.0656,  0.0046,  ..., -0.1141,  0.0216,  0.0881],
         [ 0.0234,  0.0373,  0.1339,  ...,  0.0184, 

TypeError: <lambda>() missing 1 required positional argument: 'gi'

In [47]:
get_coverage_reward(gen[0], src[0], parentModel, parentTok, 'cuda')

0.6666666666666666