In [1]:
import json
import torch

from fairseq.models.bart import BARTModel
from utils import read_lines

In [2]:
PATH = json.load(open('../path_config.json'))

In [3]:
bart = BARTModel.from_pretrained(PATH['xsum_cmlm_bos'],
                                 checkpoint_file='checkpoint_best.pt',
                                 data_name_or_path=PATH['data_name_or_path'])

bart_ent = BARTModel.from_pretrained(PATH['xsum_cmlm_ent_bos'],
                                     checkpoint_file='checkpoint_best.pt',
                                     data_name_or_path=PATH['data_name_or_path'])

#### Read XSum

In [4]:
document_path = PATH['xsum_fariseq'] + '/test.source'
target_path = PATH['xsum_fariseq'] + '/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Test One Example

In [5]:
from model import ConditionalSequenceGenerator
from utils import prepare_cmlm_ent_inputs, prepare_cmlm_inputs, get_probability

In [6]:
model_ent = ConditionalSequenceGenerator(bart_ent)

In [7]:
model = ConditionalSequenceGenerator(bart)

#### Test

In [8]:
import spacy

nlp = spacy.load('en_core_web_sm')

In [9]:
data = json.load(open('../data/annotated_with_probability_200.json', 'r'))
print(len(data))

200


In [17]:
data_index = 89
data[data_index]['pred']

"The first World Cup was held in India in 1950, and the country's first football team, the Indian Football Federation (IFF), won a place in the final of the inaugural tournament."

In [11]:
INDEX = data[data_index]['id']

source = xsum_source[INDEX]
target = xsum_target[INDEX]

In [12]:
def get_cmlm_probability_parallel(generator, src_input, tgt_input, position, entity):
    token_probs = model.encode_decode(src_input, tgt_input=tgt_input)
    _, target, _ = model.tokenize_target(tgt_input, left_pad=False)

    probs = []
    for p, tok, tokp, e in zip(position, target, token_probs, entity):
        tok = [model.decode_func(i.unsqueeze(0)) for i in tok]
        probs.append(get_probability(p, tok, tokp, e).item())
    
    return probs

In [13]:
def get_cmlm_ent_probability(generator, src_input, tgt_input):
    tgt_probs = generator.encode_decode(src_input, tgt_input)
    tgt_probs = tgt_probs[:, 3:]
    
    probs = torch.tensor([1.0] * tgt_probs.shape[0]).cuda()
    for t in tgt_probs.T:
        probs = probs * t
    
    return probs.tolist()

In [14]:
inputs = prepare_cmlm_inputs(source, data[data_index]['pred'], ent_parts=data[data_index]['ents'])
print(get_cmlm_probability_parallel(model, inputs[0], inputs[1], inputs[2], inputs[3]))

[0.83642578125, 0.8125, 0.72802734375, 0.16259765625, 0.93408203125, 0.054901123046875]


In [15]:
inputs = prepare_cmlm_ent_inputs(source, data[data_index]['pred'], data[data_index]['ents'])
print(get_cmlm_ent_probability(model_ent, inputs[0], inputs[1]))

[0.822265625, 0.7896687984466553, 0.013885498046875, 0.211669921875, 0.10906982421875, 0.16192013025283813]


In [16]:
for e in data[data_index]['ents']:
    print(e['ent'], end=', ')

print()
for e in data[data_index]['ents']:
    print(e['xsum_cmlm_bos'], end=', ')

first, World Cup, India, 1950, first, the Indian Football Federation, 
0.83642578125, 0.81201171875, 0.72705078125, 0.1617431640625, 0.93408203125, 0.054931640625, 