In [1]:
import json
import math
import torch
import torch.nn as nn

from tqdm import tqdm
from fairseq.models.bart import BARTModel
from utils import read_lines, get_cmlm_probability

In [2]:
finetuned_bart = BARTModel.from_pretrained('/home/ml/users/cadencao/fairseq/checkpoints/xsum_cmlm_bos',
                                           checkpoint_file='checkpoint_best.pt',
                                           data_name_or_path='/home/ml/users/cadencao/XSum/fairseq_files/xsum-bin')

In [3]:
finetuned_bart.cuda()
finetuned_bart.eval()
finetuned_bart.half()
print('- fine-tuned bart model loaded.')

- fine-tuned bart model loaded.


In [4]:
bart = BARTModel.from_pretrained('/home/ml/users/cadencao/Downloads/BART_models/bart.large',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/users/cadencao/Downloads/BART_models/bart.large')

In [5]:
bart.cuda()
bart.eval()
bart.half()
print('- bart model loaded.')

- bart model loaded.


In [6]:
encode_func = bart.encode
decode_func = bart.decode

#### Read XSum

In [7]:
document_path = '/home/ml/users/cadencao/XSum/fairseq_files/test.source'
target_path = '/home/ml/users/cadencao/XSum/fairseq_files/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)

assert len(xsum_source) == len(xsum_target)
print('- load {} samples.'.format(len(xsum_source)))

- load 11301 samples.


#### Get Posterior

In [8]:
def get_posterior(source, target, ents):
    """Get the weight of target sample.
    
    Args:
        source: str
        target: str
        ents: [{start: 13, end: 23, label: "ORG"}, ...]
    
    Return:
        posteriors: list of posterior probilities
    
    """
    if len(ents) == 0:
        return [1.0]

    posteriors = []
    for e in ents:
        entity = target[e['start']: e['end']]
        assert entity in target
        masked_hypothesis = '<s> ' + target[0: e['start']] + '###' + target[e['end']:]

        masked_input = masked_hypothesis + ' <\s> ' + source
        with torch.no_grad():
            posterior = get_cmlm_probability(finetuned_bart,
                                             '<s> ' + target,
                                             masked_input,
                                             (e['start'] + 4, e['end'] + 4),
                                             entity, verbose=False)
        posteriors.append(posterior)

    assert len(posteriors) == len(ents)
    return posteriors

#### Inference

In [9]:
import spacy

nlp = spacy.load('en_core_web_sm')

In [10]:
prediction_file_path = '/home/ml/users/cadencao/fairseq/preds/xsum_regularized_original_cpb.hypo'

xsum_preds = read_lines(prediction_file_path)
assert len(xsum_source) == len(xsum_preds)

In [11]:
posteriors, extracted_ents = [], []

for i, (s, p) in tqdm(enumerate(zip(xsum_source[:300], xsum_preds[:300]))):
    p = ' '.join(p.split())
    ents = nlp(p).to_json()['ents']

    extracted_ents.append(ents)
    posteriors.append(get_posterior(s, p, ents))

300it [14:38,  2.93s/it]


In [12]:
ent_labels, summary_labels = [], []
posterior_sum, counter = 0, 0.

for ps in posteriors:
    if min(ps) > 0.1:
        summary_labels.append(0)
    else:
        summary_labels.append(1) # non-consistent
        
    for p in ps:
        posterior_sum += p
        if p > 0.1:
            ent_labels.append(0)
        else:
            ent_labels.append(1) # hallucinated
            
        counter += 1

print('- average posterior: {}'.format(posterior_sum / counter))
print('- percentage of non-factual summary: {}'.format(sum(summary_labels) / len(summary_labels)))
print('- percentage of hallucinated entities: {}'.format(sum(ent_labels) / len(ent_labels)))

- average posterior: 0.6350243526848589
- percentage of non-factual summary: 0.2866666666666667
- percentage of hallucinated entities: 0.13087674714104194


In [13]:
# - xsum_official (300 samples)
# - average posterior: 0.6042669351904417
# - percentage of non-factual summary: 0.4
# - percentage of hallucinated entities: 0.16062683643486778

# - xsum_cedar_min_weighted_last (300 samples)
# - average posterior: 0.6248949521352261
# - percentage of non-factual summary: 0.32
# - percentage of hallucinated entities: 0.14774114774114774

# - xsum_cedar_avg_weighted_best (300 samples)
# - average posterior: 0.6288346728699337
# - percentage of non-factual summary: 0.2966666666666667
# - percentage of hallucinated entities: 0.1323529411764706

# - cedar_data_drop (300 samples)
# - average posterior: 0.6622435274226233
# - percentage of non-factual summary: 0.2733333333333333
# - percentage of hallucinated entities: 0.13150684931506848

# - xsum_regularized_cpb_elr2_cpb.hypo (300 samples)
# - average posterior: 0.6239574796195876
# - percentage of non-factual summary: 0.28
# - percentage of hallucinated entities: 0.13586291309669524

# - xsum_regularized_cp1_elr2_cpb (300 samples)
# - average posterior: 0.6314182049151594
# - percentage of non-factual summary: 0.29333333333333333
# - percentage of hallucinated entities: 0.13506815365551425

# - xsum_regularized_cpb_elr10_cp5 (300 samples)
# - average posterior: 0.6293614379030468
# - percentage of non-factual summary: 0.27
# - percentage of hallucinated entities: 0.13451776649746192

# - xsum_regularized_cpb_elr10_cp11 (300 samples)
# - average posterior: 0.6293614379030468
# - percentage of non-factual summary: 0.27
# - percentage of hallucinated entities: 0.13451776649746192

# - xsum_regularized_cpb_elr30_cp9
# - average posterior: 0.6387096765756954
# - percentage of non-factual summary: 0.2733333333333333
# - percentage of hallucinated entities: 0.13127413127413126

# -xsum_regularized_original_cpb
# - average posterior: 0.6350243526848589
# - percentage of non-factual summary: 0.2866666666666667
# - percentage of hallucinated entities: 0.13087674714104194

# - xsum_clean_cp1 (300 samples)
# - average posterior: 0.6434800207703358
# - percentage of non-factual summary: 0.32666666666666666
# - percentage of hallucinated entities: 0.14478527607361963

# - xsum target (300 samples)
# - average posterior: 0.5156889161579035
# - percentage of non-factual summary: 0.55
# - percentage of hallucinated entities: 0.2813953488372093

In [14]:
# Average posterior prob:

# reference: 0.5103260373198956
# data-dropout: 0.6402941753973592
# xsum_official: 0.6002646216892829
# xsum self-trained on Cedar: 0.5992162492884309

# xsum_binary_smoothly_weighted (200): 0.6287983280432797
# xsum_cedar_min_weighted_best (200): 0.6249476049270435
# xsum_cedar_min_weighted_last (200): 0.6207262391602929
# xsum_binary_weighted (200): 0.6580150275871669

# xsum cedar checkpoint1 (100): 0.6132371826385169
# xsum cedar checkpoint2 (100): 0.61654346501808
# xsum cedar checkpoint3 (100): 0.6229014194340766

# xsum cedar checkpoint1 (500): 0.5925646644582644
# xsum cedar checkpoint2 (500): 0.60001381397845
# xsum cedar checkpoint3 (500): 0.603983573738802

# xsum cedar checkpoint1 (1000): 0.6070493213913469
# xsum cedar checkpoint2 (1000): 0.6027837024769428
# xsum cedar checkpoint3 (1000): 0.6016024746051729