In [1]:
import json

from fairseq.models.bart import BARTModel
from utils import read_lines

In [2]:
PATH = json.load(open('../path_config.json'))

In [3]:
posterior_bart = BARTModel.from_pretrained(PATH['bart.large.xsum'],
                                           checkpoint_file='model.pt',
                                           data_name_or_path=PATH['bart.large.xsum'])

# posterior_bart = BARTModel.from_pretrained(PATH['bart.large.cnn'],
#                                            checkpoint_file='model.pt',
#                                            data_name_or_path=PATH['bart.large.cnn'])

# posterior_bart = BARTModel.from_pretrained(PATH['xsum_cmlm_bos'],
#                                            checkpoint_file='checkpoint_best.pt',
#                                            data_name_or_path=PATH['data_name_or_path'])

# posterior_bart = BARTModel.from_pretrained(PATH['xsum_cmlm_scratch_cedar_warmup_20000'],
#                                            checkpoint_file='checkpoint_best.pt',
#                                            data_name_or_path=PATH['data_name_or_path'])

In [4]:
# prior_bart = BARTModel.from_pretrained(PATH['bart.large'],
#                                        checkpoint_file='model.pt',
#                                        data_name_or_path=PATH['bart.large'])

prior_bart = BARTModel.from_pretrained(PATH['cnndm_cmlm_cedar'],
                                       checkpoint_file='checkpoint_best.pt',
                                       data_name_or_path=PATH['data_name_or_path'])

# prior_bart = BARTModel.from_pretrained(PATH['cnndm_cmlm_scratch_cedar_warmup_10000'],
#                                        checkpoint_file='checkpoint_best.pt',
#                                        data_name_or_path=PATH['data_name_or_path'])

#### Read XSum

In [5]:
document_path = PATH['xsum_fariseq'] + '/test.source'
target_path = PATH['xsum_fariseq'] + '/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Test One Example

In [6]:
import spacy

from model import ConditionalSequenceGenerator
from utils import prepare_cmlm_inputs, prepare_mlm_inputs, prepare_clm_inputs, get_cmlm_probability, get_prior_probability

nlp = spacy.load('en_core_web_sm')

In [7]:
INDEX = 9444

source = xsum_source[INDEX]
target = 'Twin-to-twin transfusion syndrome (TTTS) is being tracked by a hospital in Cardiff in a bid to save the lives of babies born with the condition.'

In [8]:
ent_parts = [{'start': 35, 'end': 39, 'label': 0, 'type': 'ORG', 'ent': 'TTTS'},
             {'start': 75, 'end': 82, 'label': 2, 'type': 'LOC', 'ent': 'Cardiff'}]

for e in ent_parts:
    print('{} - {}'.format(e, target[e['start']: e['end']]))

{'start': 35, 'end': 39, 'label': 0, 'type': 'ORG', 'ent': 'TTTS'} - TTTS
{'start': 75, 'end': 82, 'label': 2, 'type': 'LOC', 'ent': 'Cardiff'} - Cardiff


In [9]:
prior_model = ConditionalSequenceGenerator(prior_bart)
posterior_model = ConditionalSequenceGenerator(posterior_bart)

pri_args = prepare_cmlm_inputs(source, target, ent_parts)
pos_args = prepare_cmlm_inputs(source, target, ent_parts)

prior_probs = get_cmlm_probability(prior_model, pri_args[0], pri_args[1], pri_args[2], pri_args[3])
posterior_probs = get_cmlm_probability(posterior_model, pos_args[0], pos_args[1], pos_args[2], pos_args[3])

assert len(prior_probs) == len(posterior_probs)

In [10]:
print('- prior: {}'.format(prior_probs))
print('- posterior: {}'.format(posterior_probs))

- prior: [0.6904296875, 0.0042724609375]
- posterior: [0.27099609375, 0.147705078125]


#### Probability Calculation

In [11]:
from tqdm import tqdm
from utils import read_document

In [18]:
data = json.load(open('../data/Maynez_entity_data_with_prob.json', 'r'))
print(len(data))

500


In [19]:
data['21267591']['BERTS2S']

{'summary': "the leader of bahrain's main opposition party has said there needs to be dialogue between the kingdom's crown prince and government.",
 'summary_upper': "The leader of Bahrain 's main opposition party has said there needs to be dialogue between the kingdom 's crown prince and government .",
 'ents': [{'start': 14,
   'end': 21,
   'label': 0,
   'type': 'GPE',
   'ent': 'Bahrain',
   'bart.large': 0.11968994140625,
   'xsum_cmlm_bos': 0.9189453125,
   'cnndm_cmlm_cedar': 0.93115234375,
   'bart.large.xsum': 0.7041015625}]}

In [14]:
read_document(21267591, '/home/mcao610/scratch/summarization/XSum/xsum-preprocessed/document/')

'Sheikh Ali Salman told the BBC that for national dialogue to be meaningful , the government had to show its willingness to offer " concrete solutions " . " We want someone who can speak for the royal family , " he said . Crown Prince Salman al - Khalifa is seen as a reformist in a court divided on how to respond to opposition demands . Hardliners - centred around the unelected Prime Minister Sheikh Khalifa bin Salman al - Khalifa , who has been in his post since 1971 - are said to be opposed to a dialogue process which has only just been agreed between the government and six opposition societies . They fear that any concessions will only serve to encourage more demands from opposition leaders they deeply distrust . However , speaking to the BBC during a visit to London , Sheikh Salman insisted that now was the time for dialogue . " We welcome it , we are ready for it , " he said . " We believe that dialogue and negotiations are necessary . " The al - Wefaq leader acknowledged that man

In [15]:
for bbcid in tqdm(data.keys()):
    if bbcid == '39553812': continue  # corrupted sample
    source = read_document(int(bbcid), '/home/mcao610/scratch/summarization/XSum/xsum-preprocessed/document/')
    if source is None:
        print('- cannot read source: {}'.format(bbcid))
        continue
    
    for system in data[bbcid]:
        target = data[bbcid][system]['summary_upper']
        ents = data[bbcid][system]['ents']
        
        if len(ents) == 0: continue

        pri_args = prepare_cmlm_inputs(source, target, ents)
        pos_args = prepare_clm_inputs(source, target, ents)

        prior_probs = get_cmlm_probability(prior_model, pri_args[0], pri_args[1], pri_args[2], pri_args[3])
        posterior_probs = get_cmlm_probability(posterior_model, pos_args[0], pos_args[1], pos_args[2], pos_args[3])

        assert len(prior_probs) == len(posterior_probs) == len(ents), "{};\n {};\n {}".format(prior_probs, posterior_probs, ents)
        for i in range(len(prior_probs)):
            data[bbcid][system]['ents'][i]['cnndm_cmlm_cedar'] = prior_probs[i]
            data[bbcid][system]['ents'][i]['bart.large.xsum'] = posterior_probs[i]

 28%|██▊       | 140/500 [08:45<25:31,  4.25s/it]

- cannot read source: 33928888


100%|██████████| 500/500 [32:15<00:00,  3.87s/it]


In [16]:
# for INDEX in tqdm(range(len(data))):
#     source = xsum_source[data[INDEX]['id']]
#     target = data[INDEX]['pred']
    
#     pri_args = prepare_mlm_inputs(source, target, data[INDEX]['ents'])
#     pos_args = prepare_clm_inputs(source, target, data[INDEX]['ents'])

#     prior_probs = get_prior_probability(prior_model, pri_args[0], pri_args[1], pri_args[2], pri_args[3])
#     posterior_probs = get_cmlm_probability(posterior_model, pos_args[0], pos_args[1], pos_args[2], pos_args[3])
    
#     assert len(prior_probs) == len(posterior_probs) == len(data[INDEX]['ents']), "{};\n {};\n {}".format(prior_probs, posterior_probs, data[INDEX]['ents'])
#     for i in range(len(prior_probs)):
#         data[INDEX]['ents'][i]['bart.large'] = prior_probs[i]
#         data[INDEX]['ents'][i]['bart.large.xsum'] = posterior_probs[i]

In [17]:
with open('../data/Maynez_entity_data_with_prob.json', 'w') as fout:
    json.dump(data , fout)