In [1]:
import json
import torch

from fairseq.models.bart import BARTModel
from utils import read_lines

In [2]:
PATH = json.load(open('../path_config.json'))

In [3]:
bart = BARTModel.from_pretrained(PATH['xsum_cmlm_bos'],
                                 checkpoint_file='checkpoint_best.pt',
                                 data_name_or_path=PATH['data_name_or_path'])

bart_ent = BARTModel.from_pretrained(PATH['xsum_cmlm_ent_bos'],
                                     checkpoint_file='checkpoint_best.pt',
                                     data_name_or_path=PATH['data_name_or_path'])

#### Read XSum

In [4]:
document_path = PATH['xsum_fariseq'] + '/test.source'
target_path = PATH['xsum_fariseq'] + '/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Test One Example

In [5]:
import spacy

from model import ConditionalSequenceGenerator
from utils import prepare_cmlm_ent_inputs, get_cmlm_probability, get_cmlm_ent_probability, prepare_cmlm_inputs, get_probability

nlp = spacy.load('en_core_web_sm')

In [6]:
INDEX = 1110

source = xsum_source[INDEX]
target = xsum_target[INDEX]

In [7]:
model_ent = ConditionalSequenceGenerator(bart_ent)

In [8]:
model = ConditionalSequenceGenerator(bart)

#### Test

In [9]:
data = json.load(open('../data/annotated_with_probability_200.json', 'r'))
print(len(data))

200


In [10]:
data_index = 87
data[data_index]

{'id': 1110,
 'pred': 'Chelsea manager Jose Mourinho said the Premier League is "not the best league in the world" after his side\'s 1-0 defeat by West Ham was like "football from the 19th Century".',
 'ents': [{'start': 0,
   'end': 7,
   'label': 0,
   'type': 'ORG',
   'ent': 'Chelsea',
   'bart.large': 0.079345703125,
   'xsum_cmlm_bos': 0.9365234375,
   'bart.large.xsum': 0.76708984375,
   'cnndm_cmlm_cedar': 0.92529296875,
   'cnndm_cmlm_scratch_cedar_warmup_10000': 0.197265625,
   'xsum_cmlm_scratch_cedar_warmup_20000': 0.0281982421875},
  {'start': 16,
   'end': 20,
   'label': 0,
   'type': 'PERSON',
   'ent': 'Jose',
   'bart.large': 0.56982421875,
   'xsum_cmlm_bos': 0.939453125,
   'bart.large.xsum': 0.93017578125,
   'cnndm_cmlm_cedar': 0.94091796875,
   'cnndm_cmlm_scratch_cedar_warmup_10000': 0.79833984375,
   'xsum_cmlm_scratch_cedar_warmup_20000': 0.96875},
  {'start': 21,
   'end': 29,
   'label': 0,
   'type': 'PERSON',
   'ent': 'Mourinho',
   'bart.large': 0.997558

In [11]:
inputs = prepare_cmlm_inputs(source, data[data_index]['pred'], ent_parts=data[data_index]['ents'])

In [12]:
def get_cmlm_probability_parallel(generator, src_input, tgt_input, position, entity):
    token_probs = model.encode_decode(src_input, tgt_input=tgt_input)
    _, target, _ = model.tokenize_target(tgt_input, left_pad=False)
    
    probs = []
    for p, tok, tokp, e in zip(position, target, token_probs, entity):
        tok = [model.decode_func(i.unsqueeze(0)) for i in tok]
        probs.append(get_probability(p, tok, tokp, e).item())
    
    return probs

In [13]:
get_cmlm_probability_parallel(model, inputs[0], inputs[1], inputs[2], inputs[3])

[0.93701171875,
 0.9404296875,
 0.91650390625,
 0.1610107421875,
 0.8046875,
 0.64892578125]

In [14]:
inputs = prepare_cmlm_ent_inputs(source, target, nlp(target).to_json()['ents'])
target_probs = model.encode_decode(inputs[0], inputs[1])
print(get_cmlm_ent_probability(model_ent, inputs[0], inputs[1]))

[0.7605633735656738, 0.7426369786262512, 0.02257954515516758, 0.7699408531188965, 0.8309576511383057]


In [15]:
inputs[3]

IndexError: tuple index out of range

In [None]:
for e in data[data_index]['ents']:
    print(e['ent'])
    print(e['xsum_cmlm_bos'])
    print()