In [1]:
import math
import torch
import torch.nn as nn

from fairseq.models.bart import BARTModel
from utils import read_lines

In [2]:
finetuned_bart = BARTModel.from_pretrained('/home/ml/users/cadencao/fairseq/checkpoints/xsum_cmlm_bos',
                                           checkpoint_file='checkpoint_best.pt',
                                           data_name_or_path='/home/ml/users/cadencao/XSum/fairseq_files/xsum-bin')

In [3]:
finetuned_bart.cuda()
finetuned_bart.eval()
finetuned_bart.half()
print('- fine-tuned bart model loaded.')

- fine-tuned bart model loaded.


In [4]:
bart = BARTModel.from_pretrained('/home/ml/users/cadencao/Downloads/BART_models/bart.large',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/users/cadencao/Downloads/BART_models/bart.large')

In [5]:
bart.cuda()
bart.eval()
bart.half()
print('- bart model loaded.')

- bart model loaded.


In [6]:
encode_func = bart.encode
decode_func = bart.decode

#### Read XSum

In [7]:
document_path = '/home/ml/users/cadencao/XSum/fairseq_files/test.source'
target_path = '/home/ml/users/cadencao/XSum/fairseq_files/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Test Summarization Generation

In [8]:
import spacy

from fairseq.data.data_utils import collate_tokens
from utils import tokenize, tokenize_with_mask, generate_sequence, get_probability, get_cmlm_probability, get_prior_probability, cmlm_generate, prior_generate

nlp = spacy.load('en_core_web_sm')

In [9]:
INDEX = 9444

In [10]:
source = xsum_source[INDEX]
target = "Twin-to-twin transfusion syndrome (TTTS) is being tracked by a hospital in Cardiff in a bid to save the lives of babies born with the condition."
print(target)

ent_parts = nlp(target).to_json()['ents']
print(ent_parts, end='\n\n')

for e in ent_parts:
    entity = target[e['start']: e['end']]
    
    with torch.no_grad():
        masked_hypothesis = target[0: e['start']] + '<mask>' + target[e['end']:]
        prior = get_prior_probability(bart, target, masked_hypothesis, (e['start'], e['end']), entity)
#         print(target[0: e['start']] + '<mask>' + target[e['end']:])
#         print('- prior: {}'.format(prior_generate(bart, masked_hypothesis)))

        masked_hypothesis = target[0: e['start']] + '###' + target[e['end']:]
        masked_hypothesis = '<s> ' + masked_hypothesis + ' <\s> ' + source
        posterior = get_cmlm_probability(finetuned_bart,
                                         '<s> ' + target,
                                         masked_hypothesis,
                                         (e['start'] + 4, e['end'] + 4),
                                         entity, verbose=False)
#         print(target[0: e['start']] + '###' + target[e['end']:])
#         print('- posterior: {}'.format(cmlm_generate(finetuned_bart, masked_hypothesis, verbose=False)))

        print('- entity: {}'.format(entity))
        print('- prior: {}'.format(prior))
        print('- posterior: {}'.format(posterior))
        print('- ratio: {:.3f} / {:.3f} = {:.3f}'.format(posterior, prior, posterior / (prior + 1e-5)))
        print()

Twin-to-twin transfusion syndrome (TTTS) is being tracked by a hospital in Cardiff in a bid to save the lives of babies born with the condition.
[{'start': 35, 'end': 39, 'label': 'ORG'}, {'start': 75, 'end': 82, 'label': 'ORG'}]

- entity: TTTS
- prior: 0.00037639751099050045
- posterior: 0.7503890991210938
- ratio: 0.750 / 0.000 = 1942.013

- entity: Cardiff
- prior: 0.0011739730834960938
- posterior: 0.1402587890625
- ratio: 0.140 / 0.001 = 118.465



#### Read Prediction

In [11]:
import json

from tqdm import tqdm

In [12]:
def get_posterior(data):
    for INDEX in tqdm(range(len(data))):
        source = xsum_source[data[INDEX]['id']]
        avg_posterior = 0.

        for i, e in enumerate(data[INDEX]['ents']):
            target = data[INDEX]['pred']
            entity = e['ent']

            with torch.no_grad():
                try:
                    masked_hypothesis = target[0: e['start']] + '<mask>' + target[e['end']:]
                    prior = get_prior_probability(bart, target, masked_hypothesis, (e['start'], e['end']), entity)

                    masked_hypothesis = '<s> ' + target[0: e['start']] + '###' + target[e['end']:]
                    masked_input = masked_hypothesis + ' <\s> ' + source
                    posterior = get_cmlm_probability(finetuned_bart,
                                                     '<s> ' + target,
                                                     masked_input,
                                                     (e['start'] + 4, e['end'] + 4),
                                                     entity, verbose=False)

                    e['prior'], e['posterior'] = prior, posterior
                    avg_posterior += posterior
                except:
                    print('- Got an error!')

        if len(data[INDEX]['ents']) > 0:
            data[INDEX]['avg_posterior'] = avg_posterior / len(data[INDEX]['ents'])

In [13]:
target_path = 'preds/xsum_binary_smoothly_weighted.hypo'
xsum_preds = read_lines(target_path)
data = []

for i, t in tqdm(enumerate(xsum_preds[:500])):
    item = {}
    item['id'] = i
    item['pred'] = t
    item['ents'] = nlp(t).to_json()['ents']

    for e in item['ents']:
        e['ent'] = item['pred'][e['start']: e['end']]

    data.append(item)

get_posterior(data)

avg_posterior, counter = 0., 0.

for d in data:
    if 'avg_posterior' in d:
        avg_posterior += d['avg_posterior']
        counter += 1

print('checkpoint: {}'.format(INDEX))
print(avg_posterior / counter)

300it [00:02, 142.67it/s]
 77%|███████▋  | 232/300 [13:02<04:55,  4.35s/it]

- Got an error!


100%|██████████| 300/300 [17:09<00:00,  3.43s/it]

checkpoint: 9444
0.6241776715763897





In [14]:
# reference: 0.5103260373198956
# data-dropout: 0.6402941753973592
# xsum_official: 0.6002646216892829
# xsum self-trained on Cedar: 0.5992162492884309
# xsum_binary_smoothly_weighted (200): 0.6287983280432797
# xsum_binary_weighted (200): 0.6580150275871669

# xsum cedar checkpoint1 (100): 0.6132371826385169
# xsum cedar checkpoint2 (100): 0.61654346501808
# xsum cedar checkpoint3 (100): 0.6229014194340766

# xsum cedar checkpoint1 (500): 0.5925646644582644
# xsum cedar checkpoint2 (500): 0.60001381397845
# xsum cedar checkpoint3 (500): 0.603983573738802

# xsum cedar checkpoint1 (1000): 0.6070493213913469
# xsum cedar checkpoint2 (1000): 0.6027837024769428
# xsum cedar checkpoint3 (1000): 0.6016024746051729