In [1]:
import math
import torch
import torch.nn as nn

from fairseq.models.bart import BARTModel
from utils import read_lines

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [3]:
finetuned_bart = BARTModel.from_pretrained('/home/ml/cadencao/fairseq/checkpoints/xsum_cmlm_bos',
                                           checkpoint_file='checkpoint_best.pt',
                                           data_name_or_path='/home/ml/cadencao/XSum/fairseq_files/xsum-bin')

In [4]:
finetuned_bart.cuda()
finetuned_bart.eval()
finetuned_bart.half()
print('- fine-tuned bart model loaded.')

- fine-tuned bart model loaded.


In [5]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large')

In [6]:
bart.cuda()
bart.eval()
bart.half()
print('- bart model loaded.')

- bart model loaded.


In [7]:
encode_func = bart.encode
decode_func = bart.decode

#### Read XSum

In [8]:
document_path = '/home/ml/cadencao/XSum/fairseq_files/test.source'
target_path = '/home/ml/cadencao/XSum/fairseq_files/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Generate Summary

In [9]:
from fairseq.data.data_utils import collate_tokens

In [10]:
def tokenize(src_input, verbose=False):
    src_inputs = [src_input]  # list of input string
    src_tokens = collate_tokens([encode_func(i) for i in src_inputs], pad_idx=1, left_pad=True)
    src_tokens = src_tokens.cuda()
    src_lengths = torch.sum(src_tokens != 1, dim=1)
    
    if verbose:
        print('- src tokens: {};\n- src lengths: {}'.format(src_tokens.shape, src_lengths.shape))
    return src_tokens, src_lengths

In [11]:
def tokenize_with_mask(input_sentence):
    bpe_code = bart.bpe.encode(input_sentence)  # <mask>: 1279 27932 29
    input_ids = bart.task.source_dictionary.encode_line('<s> ' + bpe_code.replace('1279 27932 29', '<mask>'), 
                                                        append_eos=True).long()
    input_ids = input_ids.unsqueeze(0).cuda()
    src_lengths = torch.sum(input_ids != 1, dim=1)
    return input_ids, src_lengths

In [12]:
def generate_sequence(decoder, encoder_out, batch_size=1, tgt_tokens=None, min_decode_step=1, max_decode_step=100, pad_id=1, eos_id=2, verbose=True):
    init_input = torch.tensor([[2, 0]] * batch_size, dtype=torch.long).cuda()
    softmax = nn.Softmax(dim=1)
    token_probs, tokens = [], []

    for step in range(max_decode_step):
        decoder_outputs = decoder(init_input, encoder_out, features_only=False)
        logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
        
        if step + 1 < min_decode_step:
            logits[:, eos_id] = -math.inf
        logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select pad, start token

        probs = softmax(logits)
        assert logits.shape == probs.shape
        attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]
        assert logits.dim() == 2 and attn.dim() == 3

        if tgt_tokens is not None:
            selected_token = tgt_tokens[step].unsqueeze(0)
        else:
            value, indices = torch.topk(probs, 5, dim=1)
            selected_token = indices[:, 0]

        init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
        token, prob = decode_func(selected_token), probs.squeeze()[selected_token.item()].item()
        
        if selected_token.item() == eos_id:
            break
        elif verbose:
            print("- {:02d}: {} ({:.2f})".format(step, token, prob), end='\n')

        token_probs.append(prob)
        tokens.append(token)

    return init_input, tokens, token_probs

In [13]:
def get_probability(position, tokens, probs, entity):
    """Get probability of the given target.

    Args:
        position: (start, end)
        tokens: ['The', ' Archbishop', ' of', ...]
        probs: [0.50, 0.49, 0.88, ...]
        entity: Rodgers
    """
    assert len(tokens) == len(probs)
    
    end_pointer, end_pos = 0, []
    for t in tokens:
        end_pointer += len(t)
        end_pos.append(end_pointer)
    
    assert position[1] in end_pos
    last_index = end_pos.index(position[1])
    indexes = [last_index]
    total_length = len(tokens[last_index])
    
    while total_length < (position[1] - position[0]):
        last_index -= 1
        assert last_index >= 0
        indexes.append(last_index)
        total_length += len(tokens[last_index])
    
    indexes.reverse()
    
    generated = ''.join([tokens[i] for i in indexes])
    assert entity in generated, 'entity: {}; prob calculated: {}'.format(entity, generated)
    
    prob = 1.0
    for i in indexes:
        prob *= probs[i]
    return prob

In [14]:
def get_cmlm_probability(bart_model, sentence, masked_sentence, position, entity, verbose=False):
    masked_input, masked_lengths = tokenize(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input,
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=bart_model.encode(sentence)[1:].cuda(),
                                       verbose=verbose)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    assert decode_func(masked_output_ids[0]) == sentence, '- generated: {}\n- target: {}'.format(decode_func(masked_output_ids[0]), 
                                                                                                 sentence)
    assert ''.join(masked_tokens) == sentence, '- generated: {}\n- target: {}'.format(''.join(masked_tokens), 
                                                                                      sentence)
    
    return get_probability(position, masked_tokens, masked_token_probs, entity)

In [15]:
def get_prior_probability(bart_model, sentence, masked_sentence, position, entity, verbose=False):
    masked_input, masked_lengths = tokenize_with_mask(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input,
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=bart_model.encode(sentence)[1:].cuda(),
                                       verbose=verbose)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    assert decode_func(masked_output_ids[0]) == sentence, '{}; {}'.format(decode_func(masked_output_ids[0]), sentence)

    return get_probability(position, masked_tokens, masked_token_probs, entity)

In [16]:
def cmlm_generate(bart_model, masked_sentence, verbose=False):
    masked_input, masked_lengths = tokenize(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input, 
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=None,
                                       verbose=verbose)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    
    return decode_func(masked_output_ids[0])

In [17]:
def prior_generate(bart_model, masked_sentence):
    masked_input, masked_lengths = tokenize_with_mask(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input, 
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=None,
                                       verbose=False)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    
    return decode_func(masked_output_ids[0])

#### Get Conditional Probability

In [18]:
import spacy

nlp = spacy.load('en')

In [19]:
INDEX = 9444

In [20]:
source = xsum_source[INDEX]
target = "Twin-to-twin transfusion syndrome (TTTS) is being tracked by a hospital in Cardiff in a bid to save the lives of babies born with the condition."
print(target)

ent_parts = nlp(target).to_json()['ents']
print(ent_parts, end='\n\n')

for e in ent_parts:
    entity = target[e['start']: e['end']]
    
    with torch.no_grad():
        masked_hypothesis = target[0: e['start']] + '<mask>' + target[e['end']:]
        prior = get_prior_probability(bart, target, masked_hypothesis, (e['start'], e['end']), entity)
        print(target[0: e['start']] + '<mask>' + target[e['end']:])
#         print('- prior: {}'.format(prior_generate(bart, masked_hypothesis)))

        masked_hypothesis = target[0: e['start']] + '###' + target[e['end']:]
        masked_hypothesis = '<s> ' + masked_hypothesis + ' <\s> ' + source
        posterior = get_cmlm_probability(finetuned_bart,
                                         '<s> ' + target,
                                         masked_hypothesis,
                                         (e['start'] + 4, e['end'] + 4),
                                         entity, verbose=False)
#         print(target[0: e['start']] + '###' + target[e['end']:])
#         print('- posterior: {}'.format(cmlm_generate(finetuned_bart, masked_hypothesis, verbose=False)))

        print('- entity: {}'.format(entity))
        print('- prior: {}'.format(prior))
        print('- posterior: {}'.format(posterior))
        print('- ratio: {:.3f} / {:.3f} = {:.3f}'.format(posterior, prior, posterior / (prior + 1e-5)))
        print()

Twin-to-twin transfusion syndrome (TTTS) is being tracked by a hospital in Cardiff in a bid to save the lives of babies born with the condition.
[{'start': 35, 'end': 39, 'label': 'ORG'}, {'start': 75, 'end': 82, 'label': 'ORG'}]

Twin-to-twin transfusion syndrome (<mask>) is being tracked by a hospital in Cardiff in a bid to save the lives of babies born with the condition.
- entity: TTTS
- prior: 0.0003767434973269701
- posterior: 0.7491459846496582
- ratio: 0.749 / 0.000 = 1937.062

Twin-to-twin transfusion syndrome (TTTS) is being tracked by a hospital in <mask> in a bid to save the lives of babies born with the condition.
- entity: Cardiff
- prior: 0.0011692047119140625
- posterior: 0.1407470703125
- ratio: 0.141 / 0.001 = 119.358



#### Read Google Data

In [21]:
import json

from tqdm import tqdm

In [22]:
google_data_path = '../Dataset/entity_data.json'

In [23]:
google_data = json.load(open(google_data_path))

In [24]:
print(len(google_data))

500


In [25]:
def process_document(raw_doc):
    TRIVIAL_SENTS = [
        'Share this with',
        'Copy this link',
        'These are external links and will open in a new window',
    ]
    
    raw_doc = raw_doc.strip()
    raw_doc_sents = raw_doc.split('\n')
    
    start_signal = False
    filtered_sentences = []
    for s in raw_doc_sents: 
        if start_signal:
            filtered_sentences.append(s)
        elif len(s.split()) > 1 and s not in TRIVIAL_SENTS:
            start_signal = True
            filtered_sentences.append(s)
            
    return ' '.join(filtered_sentences)

In [26]:
def read_document(bbcid):
    folder = '/home/ml/cadencao/XSum/xsum-preprocessed/document/'
    file_path = folder + '{}.document'.format(bbcid)

    with open(file_path, 'r') as f:
        return process_document(f.read())

In [27]:
read_document(34687720)

'France \'s Dubuisson carded a 67 to tie with overnight leader Van Zyl of South Africa on 16 under par . McIlroy carded a third straight five under - par 67 to move to 15 under par with Thailand \'s Kiradech Aphibarnrat . The world number three \'s round included an eagle on the 12th as he bids to win his first title since May . " The 67s I \'ve shot this week have all been a little different and I feel like I \'ve played within myself for all of them , " said four - time major winner McIlroy of Northern Ireland . " I feel there \'s a low round out there for me and hopefully it \'s tomorrow . " McIlroy was level par for the day after 10 holes , dropping his first shots of the week by three - putting the third and 10th , the latter mistake prompting the 26 - year - old to throw his putter at his bag . But he hit back with a birdie on the par - five 11th and a towering four iron from 229 yards on the 13th set up an eagle from just four feet . The former world number one ruptured a ligame

In [28]:
for bbcid in tqdm(google_data.keys()):
    try:
        source = read_document(bbcid)
    except:
        print('- document {} does not exist!'.format(bbcid))
        continue
    
#     print(bbcid)
    for system in google_data[bbcid]:
        target = google_data[bbcid][system]['summary_upper']
        
        for e in google_data[bbcid][system]['ents']:
            try:
                entity = target[e['start']: e['end']]
                assert entity == e['ent']

                with torch.no_grad():
                    masked_hypothesis = target[0: e['start']] + '<mask>' + target[e['end']:]
                    prior = get_prior_probability(bart, target, masked_hypothesis, (e['start'], e['end']), entity)
    #                 print('- prior int: {}'.format(target[0: e['start']] + '<mask>' + target[e['end']:]))
    #                 print('- prior out: {}'.format(prior_generate(bart, masked_hypothesis)))

                    masked_hypothesis = '<s> ' + target[0: e['start']] + '###' + target[e['end']:]
                    masked_input = masked_hypothesis + ' <\s> ' + source
                    posterior = get_cmlm_probability(finetuned_bart,
                                                     '<s> ' + target,
                                                     masked_input,
                                                     (e['start'] + 4, e['end'] + 4),
                                                     entity, verbose=False)
    #                 print('- posterior int: {}'.format(masked_hypothesis))
    #                 print('- posterior out: {}'.format(cmlm_generate(finetuned_bart, masked_input, verbose=False)))

    #                 print('- entity: {}'.format(entity))
    #                 print('- prior: {}'.format(prior))
    #                 print('- posterior: {}'.format(posterior))
    #                 print('- ratio: {:.3f} / {:.3f} = {:.3f}'.format(posterior, prior, posterior / (prior + 1e-5)))
    #                 print('- label: {}'.format(e['label']))
    #                 print()

                    e['prior'], e['posterior'] = prior, posterior
            except:
                print('- error appears!')
                print(e)

 28%|██▊       | 140/500 [40:10<2:14:19, 22.39s/it]

- document 33928888 does not exist!


 31%|███▏      | 157/500 [45:02<1:38:24, 17.21s/it]


AssertionError: - generated: <s> Mae Gan O Gan Yng Nghymru Wedi Cael EU R AR ��l I 'R du Yng Nghymru .
- target: <s> Mae Gan O Gan Yng Nghymru Wedi Cael EU R AR Ôl I 'R du Yng Nghymru .

In [29]:
def get_label_with_probability(prior, posterior):
    if posterior > 0.1 or prior > 0.2:
        return 0
    else:
        return 1

In [40]:
true_labels, pred_labels = [], []

for bbcid in google_data.keys():
    for system in google_data[bbcid]:
        for e in google_data[bbcid][system]['ents']:
            if 'prior' in e and e['label'] in [0, 1, 2]:
                pred_labels.append(get_label_with_probability(e['prior'], e['posterior']))
                if e['label'] == 0 or e['label'] == 1:
                    true_labels.append(0)
                else:
                    true_labels.append(1)

In [41]:
# for bbcid in google_data.keys():
#     for system in google_data[bbcid]:
#         if system == 'Gold':
#             continue
#         for e in google_data[bbcid][system]['ents']:
#             if e['label'] == -1:
#                 print('{} - {}'.format(bbcid, system))

In [42]:
from sklearn.metrics import classification_report

In [43]:
print(classification_report(true_labels, pred_labels, target_names=['Non-hallucination', 'Hallucination']))

                   precision    recall  f1-score   support

Non-hallucination       0.61      0.65      0.63       928
    Hallucination       0.55      0.51      0.53       783

         accuracy                           0.59      1711
        macro avg       0.58      0.58      0.58      1711
     weighted avg       0.58      0.59      0.58      1711

