In [34]:
import math
import torch
import torch.nn as nn

from fairseq.models.bart import BARTModel
from utils import read_lines

In [35]:
finetuned_bart = BARTModel.from_pretrained('/home/mcao610/scratch/BART_models/xsum_cmlm_ent',
                                           checkpoint_file='checkpoint_best.pt',
                                           data_name_or_path='/home/mcao610/scratch/summarization/XSum/fairseq_files/xsum-bin')

In [36]:
finetuned_bart.cuda()
finetuned_bart.eval()
finetuned_bart.half()
print('- fine-tuned bart model loaded.')

- fine-tuned bart model loaded.


In [37]:
bart = BARTModel.from_pretrained('/home/mcao610/scratch/BART_models/bart.large',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/mcao610/scratch/BART_models/bart.large')

In [38]:
bart.cuda()
bart.eval()
bart.half()
print('- bart model loaded.')

- bart model loaded.


In [39]:
encode_func = bart.encode
decode_func = bart.decode

#### Read XSum

In [40]:
document_path = '/home/mcao610/scratch/summarization/XSum/fairseq_files/test.source'
target_path = '/home/mcao610/scratch/summarization/XSum/fairseq_files/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Generate Summary

In [41]:
from fairseq.data.data_utils import collate_tokens
from utils import get_probability

In [42]:
def tokenize(src_input, verbose=False):
    src_inputs = [src_input]  # list of input string
    src_tokens = collate_tokens([encode_func(i) for i in src_inputs], pad_idx=1, left_pad=True)
    src_tokens = src_tokens.cuda()
    src_lengths = torch.sum(src_tokens != 1, dim=1)
    
    if verbose:
        print('- src tokens: {};\n- src lengths: {}'.format(src_tokens.shape, src_lengths.shape))
    return src_tokens, src_lengths

In [43]:
def tokenize_with_mask(input_sentence):
    bpe_code = bart.bpe.encode(input_sentence)  # <mask>: 1279 27932 29
    input_ids = bart.task.source_dictionary.encode_line('<s> ' + bpe_code.replace('1279 27932 29', '<mask>'), 
                                                        append_eos=True).long()
    input_ids = input_ids.unsqueeze(0).cuda()
    src_lengths = torch.sum(input_ids != 1, dim=1)
    return input_ids, src_lengths

In [44]:
def generate_sequence(decoder, encoder_out, batch_size=1, tgt_tokens=None, min_decode_step=1, max_decode_step=100, pad_id=1, eos_id=2, verbose=True):
    init_input = torch.tensor([[2, 0]] * batch_size, dtype=torch.long).cuda()
    softmax = nn.Softmax(dim=1)
    token_probs, tokens = [], []

    for step in range(max_decode_step):
        decoder_outputs = decoder(init_input, encoder_out, features_only=False)
        logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
        
        if step + 1 < min_decode_step:
            logits[:, eos_id] = -math.inf
        logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select pad, start token

        probs = softmax(logits)
        assert logits.shape == probs.shape
        attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]
        assert logits.dim() == 2 and attn.dim() == 3

        if tgt_tokens is not None:
            selected_token = tgt_tokens[step].unsqueeze(0)
        else:
            value, indices = torch.topk(probs, 5, dim=1)
            selected_token = indices[:, 0]

        init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
        token, prob = decode_func(selected_token), probs.squeeze()[selected_token.item()].item()
        
        if selected_token.item() == eos_id:
            break
        elif verbose:
            print("- {:02d}: {} ({:.2f})".format(step, token, prob), end='\n')

        token_probs.append(prob)
        tokens.append(token)

    return init_input, tokens, token_probs

In [45]:
def get_cmlm_probability(bart_model, masked_sentence, entity, verbose=False):
    masked_input, masked_lengths = tokenize(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input,
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=bart_model.encode(entity)[1:].cuda(),
                                       verbose=verbose)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    assert decode_func(masked_output_ids[0]) == entity
    assert ''.join(masked_tokens) == entity
    
    prob = 1.0
    for i in range(3, len(masked_token_probs)):
        prob *= masked_token_probs[i]
    return prob

In [46]:
def get_prior_probability(bart_model, sentence, masked_sentence, position, entity, verbose=False):
    masked_input, masked_lengths = tokenize_with_mask(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input,
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=bart_model.encode(sentence)[1:].cuda(),
                                       verbose=verbose)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    assert decode_func(masked_output_ids[0]) == sentence, '{}; {}'.format(decode_func(masked_output_ids[0]), sentence)

    return get_probability(position, masked_tokens, masked_token_probs, entity)

In [47]:
def cmlm_generate(bart_model, masked_sentence, verbose=False):
    masked_input, masked_lengths = tokenize(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input, 
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=None,
                                       verbose=verbose)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    
    return decode_func(masked_output_ids[0])

In [48]:
def prior_generate(bart_model, masked_sentence):
    masked_input, masked_lengths = tokenize_with_mask(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input, 
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=None,
                                       verbose=False)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    
    return decode_func(masked_output_ids[0])

#### Get Conditional Probability

In [49]:
import spacy

nlp = spacy.load('en_core_web_sm')

In [50]:
INDEX = 9444

In [51]:
source = xsum_source[INDEX]
target = "Twin-to-twin transfusion syndrome (TTTS) is being tracked by a hospital in Cardiff in a bid to save the lives of babies born with the condition."
print(target)

ent_parts = nlp(target).to_json()['ents']
print(ent_parts, end='\n\n')

for e in ent_parts:
    entity = target[e['start']: e['end']]
    
    with torch.no_grad():
        masked_hypothesis = target[0: e['start']] + '<mask>' + target[e['end']:]
        prior = get_prior_probability(bart, target, masked_hypothesis, (e['start'], e['end']), entity)
        print(target[0: e['start']] + '<mask>' + target[e['end']:])
        print('- prior: {}'.format(prior_generate(bart, masked_hypothesis)))

        masked_hypothesis = target[0: e['start']] + '###' + target[e['end']:]
        masked_hypothesis = '<s> ' + masked_hypothesis + ' <\s> ' + source
        posterior = get_cmlm_probability(finetuned_bart,
                                         masked_hypothesis,
                                         '<s> ' + entity, 
                                         verbose=False)
        print(target[0: e['start']] + '###' + target[e['end']:])
        print('- posterior: {}'.format(cmlm_generate(finetuned_bart, masked_hypothesis, verbose=False)))

        print('- entity: {}'.format(entity))
        print('- prior: {}'.format(prior))
        print('- posterior: {}'.format(posterior))
        print('- ratio: {:.3f} / {:.3f} = {:.3f}'.format(posterior, prior, posterior / (prior + 1e-5)))
        print()

Twin-to-twin transfusion syndrome (TTTS) is being tracked by a hospital in Cardiff in a bid to save the lives of babies born with the condition.
[{'start': 75, 'end': 82, 'label': 'DATE'}]

Twin-to-twin transfusion syndrome (TTTS) is being tracked by a hospital in <mask> in a bid to save the lives of babies born with the condition.
- prior: Twin-to-twin transfusion syndrome (TTTS) is being tracked by a hospital in the UK and a charity in a bid to save the lives of babies born with the condition.
Twin-to-twin transfusion syndrome (TTTS) is being tracked by a hospital in ### in a bid to save the lives of babies born with the condition.
- posterior: <s> London
- entity: Cardiff
- prior: 0.0011692047119140625
- posterior: 0.061309814453125
- ratio: 0.061 / 0.001 = 51.993



#### Read Annotated Data

In [52]:
import json

from tqdm import tqdm

In [53]:
data = json.load(open('annotated.json', 'r'))
print(len(data))

180


In [54]:
data[55]

{'id': 10943,
 'pred': "A powerful cyclone has killed at least 11 people and injured more than 100 in Vanuatu, the Pacific nation's president has said.",
 'ents': [{'start': 30,
   'end': 41,
   'label': 2,
   'type': 'CARDINAL',
   'ent': 'at least 11'},
  {'start': 61,
   'end': 74,
   'label': 2,
   'type': 'CARDINAL',
   'ent': 'more than 100'},
  {'start': 78, 'end': 85, 'label': 0, 'type': 'GPE', 'ent': 'Vanuatu'},
  {'start': 91, 'end': 98, 'label': 1, 'type': 'LOC', 'ent': 'Pacific'}],
 'hallucinations': ['killed at least 11 people and injured more than 100',
  "the Pacific nation's president has said."]}

In [55]:
prior_posterior = []

for INDEX in tqdm(range(len(data))):
    source = xsum_source[data[INDEX]['id']]
    
    for i, e in enumerate(data[INDEX]['ents']):
        target = data[INDEX]['pred']
        entity = target[e['start']: e['end']]
        
        with torch.no_grad():
            masked_hypothesis = target[0: e['start']] + '<mask>' + target[e['end']:]
            prior = get_prior_probability(bart, target, masked_hypothesis, (e['start'], e['end']), entity)

            masked_hypothesis = target[0: e['start']] + '###' + target[e['end']:]
            masked_hypothesis = '<s> ' + masked_hypothesis + ' <\s> ' + source
            posterior = get_cmlm_probability(finetuned_bart,
                                             masked_hypothesis,
                                             '<s> ' + entity, 
                                             verbose=False)
#             print(target[0: e['start']] + '###' + target[e['end']:])
#             print(cmlm_generate(finetuned_bart, masked_hypothesis, verbose=False))

            assert len(data[INDEX]['hallucination ents']) == len(data[INDEX]['correctness']), 'INDEX: {}'.format(INDEX)
            if i not in data[INDEX]['hallucination ents']:
                label = 0
            else:
                if data[INDEX]['correctness'][data[INDEX]['hallucination ents'].index(i)]:
                    label = 1
                else:
                    label = 2
                    assert not data[INDEX]['correctness'][data[INDEX]['hallucination ents'].index(i)]

            prior_posterior.append({'id': data[INDEX]['id'], 
                                    'prior': prior, 
                                    'posterior': posterior, 
                                    'entity': entity, 
                                    'entity pos': e, 
                                    'label': label})

  0%|          | 0/180 [00:00<?, ?it/s]


KeyError: 'hallucination ents'

In [None]:
print(len(prior_posterior))
print(prior_posterior[0])

In [None]:
# import json

In [None]:
# with open('prior_posterior.json', 'w') as fout:
#     json.dump(prior_posterior , fout)

#### Draw Diagram

In [None]:
# prior_posterior = json.load(open('prior_posterior.json', 'r'))
# print(len(prior_posterior))

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

In [None]:
fig, ax = plt.subplots(figsize=(20.0, 10.0))
colors = ['tab:blue', 'tab:orange', 'tab:green']

no_hallucinated = [(p['prior'], p['posterior']) for p in prior_posterior if p['label'] == 0]
hallucinated_true = [(p['prior'], p['posterior']) for p in prior_posterior if p['label'] == 1]
hallucinated_false = [(p['prior'], p['posterior']) for p in prior_posterior if p['label'] == 2]

# ax.scatter([i[0] for i in no_hallucinated], 
#            [i[1] for i in no_hallucinated], c='tab:blue', s=[i[1]*100 + 40 for i in no_hallucinated], label='Non-hallucination', alpha=0.7)

ax.scatter([i[0] for i in hallucinated_true], 
           [i[1] for i in hallucinated_true], c='tab:green', s=[i[1]*100 + 40 for i in hallucinated_true], label='Hallucination True', alpha=0.65)
ax.scatter([i[0] for i in hallucinated_false], 
           [i[1] for i in hallucinated_false], c='tab:orange', s=[i[1]*100 + 40 for i in hallucinated_false], label='Hallucination False', alpha=0.6)

ax.scatter([1.0], [1.0], c='tab:gray', s=10)

ax.set_xlabel('Prior Probability')
ax.set_ylabel('Posterior Probability')
ax.legend()
ax.grid(True)

plt.savefig('foo.png')
plt.show()

In [None]:
for p in prior_posterior:
    if p['label'] == 2 and p['posterior'] > 0.5:
        print(p)

#### Classification

In [None]:
prior_posterior_dict = {}

for p in prior_posterior:
    if p['id'] not in prior_posterior_dict:
        prior_posterior_dict[p['id']] = []
    prior_posterior_dict[p['id']].append(p)

In [None]:
prior_posterior[0]

In [None]:
true_label = []
ent_pred_label = []
prob_pred_label = []

for p in prior_posterior:
    source = xsum_source[p['id']]
    
    if p['entity'].lower() in source.lower():
        ent_pred_label.append(0)
    else:
        ent_pred_label.append(1)
        
    if p['label'] == 0 or p['label'] == 1:
        true_label.append(0)
    else:
        true_label.append(1)

    if p['posterior'] > 0.1 or p['prior'] > 0.2:
        prob_pred_label.append(0)
    else:
        prob_pred_label.append(1)

In [None]:
from sklearn.metrics import classification_report

In [None]:
print(classification_report(true_label, ent_pred_label, target_names=['Non-hallucination', 'Hallucination']))

In [None]:
print(classification_report(true_label, prob_pred_label, target_names=['Non-hallucination', 'Hallucination']))

In [None]:
len(data)

In [None]:
data[0]

In [None]:
total_ents = 0

for d in data:
    total_ents += len(d['hallucination ents'])

In [None]:
total_ents

In [None]:
89 / 326

In [None]:
import numpy as np

# Fixing random state for reproducibility
np.random.seed(19680801)

mu, sigma = 100, 15
x = mu + sigma * np.random.randn(10000)

# the histogram of the data
n, bins, patches = plt.hist(x, 100, density=True, facecolor='g', alpha=0.75)


plt.xlabel('Smarts')
plt.ylabel('Probability')
plt.title('Histogram of IQ')
plt.text(60, .025, r'$\mu=100,\ \sigma=15$')
plt.xlim(40, 160)
plt.ylim(0, 0.03)
plt.grid(True)
plt.show()

In [None]:
source = xsum_source[8770]

In [None]:
source