In [1]:
import json
import math
import torch
import torch.nn as nn

from fairseq.models.bart import BARTModel
from utils import read_lines

In [2]:
finetuned_bart = BARTModel.from_pretrained('/home/ml/users/cadencao/fairseq/checkpoints/xsum_cmlm_bos',
                                           checkpoint_file='checkpoint_best.pt',
                                           data_name_or_path='/home/ml/users/cadencao/XSum/fairseq_files/xsum-bin')

In [3]:
finetuned_bart.cuda()
finetuned_bart.eval()
finetuned_bart.half()
print('- fine-tuned bart model loaded.')

- fine-tuned bart model loaded.


In [4]:
encode_func = finetuned_bart.encode
decode_func = finetuned_bart.decode

#### Read XSum

In [5]:
document_path = '/home/ml/users/cadencao/XSum/fairseq_files/train.source'
target_path = '/home/ml/users/cadencao/XSum/fairseq_files/train.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

203575


In [6]:
xsum_ents = json.load(open('../Build_dataset/xsum_train_ents.json', 'r'))

In [7]:
assert len(xsum_ents) == len(xsum_source)
print(xsum_ents[0].keys())
print(xsum_ents[0]['ents'])

dict_keys(['id', 'src ents', 'ents'])
[{'start': 0, 'end': 13, 'label': 'ORG'}]


#### Test Summarization Generation

In [8]:
import json
import spacy

from tqdm import tqdm
from fairseq.data.data_utils import collate_tokens
from utils import tokenize, tokenize_with_mask, generate_sequence, get_cmlm_probability, get_prior_probability, cmlm_generate, prior_generate

nlp = spacy.load('en_core_web_sm')

In [9]:
def get_weight(source, target, ents):
    """Get the weight of target sample.
    """
    if len(ents) == 0:
        return [1.0]
    
    posteriors = []
    for e in ents:
        entity = target[e['start']: e['end']]
        assert entity in target
        masked_hypothesis = '<s> ' + target[0: e['start']] + '###' + target[e['end']:]
        
        masked_input = masked_hypothesis + ' <\s> ' + source
        with torch.no_grad():
            posterior = get_cmlm_probability(finetuned_bart,
                                             '<s> ' + target,
                                             masked_input,
                                             (e['start'] + 4, e['end'] + 4),
                                             entity, verbose=False)
        posteriors.append(posterior)
    
    assert len(posteriors) == len(ents)
    return posteriors

In [10]:
get_weight(xsum_source[0], xsum_target[0], xsum_ents[0]['ents'])

[0.73991930403281]

In [11]:
counter = 0
for source, target, ents in tqdm(zip(xsum_source, xsum_target, xsum_ents)):
    print(get_weight(source, target, ents['ents']))
    counter += 1
    if counter == 10:
        break

1it [00:00,  1.09it/s]

[0.73991930403281]


2it [00:03,  1.55s/it]

[0.916015625, 0.90234375, 0.92431640625]


3it [00:06,  1.92s/it]

[0.00072479248046875, 0.1496827631490305, 0.04168701171875]


4it [00:12,  3.13s/it]

[0.57421875, 0.6819409048184752, 0.13623046875, 6.873070624029316e-09, 0.93896484375, 0.0008949282569603589, 0.0092620849609375]


5it [00:15,  3.00s/it]

[0.85205078125, 0.5870727326118719, 0.617599068980856]


6it [00:19,  3.20s/it]

[0.88232421875, 0.8101534843444824, 0.89208984375, 0.9208984375]


7it [00:21,  3.04s/it]

[0.84228515625, 0.955078125, 0.96728515625]


8it [00:24,  2.82s/it]

[0.05348589984350838, 0.057025096903089434, 0.003514811396598816]


9it [00:30,  4.01s/it]

[0.00018777873229680608, 0.9033203125, 0.96533203125, 0.94384765625, 0.5778204803355038, 0.8798134326934814, 0.05303359031677246]


9it [00:35,  3.92s/it]

[0.5727882519116267, 0.0006327629089355469, 0.81903076171875, 0.904296875, 0.17190804299525553]



