In [1]:
import math
import torch
import torch.nn as nn

from fairseq.models.bart import BARTModel
from utils import read_lines, get_probability

In [2]:
finetuned_bart = BARTModel.from_pretrained('/home/ml/cadencao/fairseq/checkpoints/xsum_cmlm_bos_ent',
                                           checkpoint_file='checkpoint_best.pt',
                                           data_name_or_path='/home/ml/cadencao/XSum/fairseq_files/xsum-bin')

In [3]:
finetuned_bart.cuda()
finetuned_bart.eval()
finetuned_bart.half()
print('- fine-tuned bart model loaded.')

- fine-tuned bart model loaded.


In [4]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large')

In [5]:
bart.cuda()
bart.eval()
bart.half()
print('- bart model loaded.')

- bart model loaded.


In [6]:
encode_func = bart.encode
decode_func = bart.decode

#### Read XSum

In [7]:
document_path = '/home/ml/cadencao/XSum/fairseq_files/test.source'
target_path = '/home/ml/cadencao/XSum/fairseq_files/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Generate Summary

In [8]:
from fairseq.data.data_utils import collate_tokens

In [9]:
def tokenize(src_input, verbose=False):
    src_inputs = [src_input]  # list of input string
    src_tokens = collate_tokens([encode_func(i) for i in src_inputs], pad_idx=1, left_pad=True)
    src_tokens = src_tokens.cuda()
    src_lengths = torch.sum(src_tokens != 1, dim=1)
    
    if verbose:
        print('- src tokens: {};\n- src lengths: {}'.format(src_tokens.shape, src_lengths.shape))
    return src_tokens, src_lengths

In [10]:
def tokenize_with_mask(input_sentence):
    bpe_code = bart.bpe.encode(input_sentence)  # <mask>: 1279 27932 29
    input_ids = bart.task.source_dictionary.encode_line('<s> ' + bpe_code.replace('1279 27932 29', '<mask>'), 
                                                        append_eos=True).long()
    input_ids = input_ids.unsqueeze(0).cuda()
    src_lengths = torch.sum(input_ids != 1, dim=1)
    return input_ids, src_lengths

In [11]:
def generate_sequence(decoder, encoder_out, batch_size=1, tgt_tokens=None, min_decode_step=1, max_decode_step=100, pad_id=1, eos_id=2, verbose=True):
    init_input = torch.tensor([[2, 0]] * batch_size, dtype=torch.long).cuda()
    softmax = nn.Softmax(dim=1)
    token_probs, tokens = [], []

    for step in range(max_decode_step):
        decoder_outputs = decoder(init_input, encoder_out, features_only=False)
        logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
        
        if step + 1 < min_decode_step:
            logits[:, eos_id] = -math.inf
        logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select pad, start token

        probs = softmax(logits)
        assert logits.shape == probs.shape
        attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]
        assert logits.dim() == 2 and attn.dim() == 3

        if tgt_tokens is not None:
            selected_token = tgt_tokens[step].unsqueeze(0)
        else:
            value, indices = torch.topk(probs, 5, dim=1)
            selected_token = indices[:, 0]

        init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
        token, prob = decode_func(selected_token), probs.squeeze()[selected_token.item()].item()
        
        if selected_token.item() == eos_id:
            break
        elif verbose:
            print("- {:02d}: {} ({:.2f})".format(step, token, prob), end='\n')

        token_probs.append(prob)
        tokens.append(token)

    return init_input, tokens, token_probs

In [12]:
def get_probability(position, tokens, probs, entity):
    """Get probability of the given target.

    Args:
        position: (start, end)
        tokens: ['The', ' Archbishop', ' of', ...]
        probs: [0.50, 0.49, 0.88, ...]
        entity: Rodgers
    """
    assert len(tokens) == len(probs)
    
    end_pointer, end_pos = 0, []
    for t in tokens:
        end_pointer += len(t)
        end_pos.append(end_pointer)
    
    assert position[1] in end_pos
    last_index = end_pos.index(position[1])
    indexes = [last_index]
    total_length = len(tokens[last_index])
    
    while total_length < (position[1] - position[0]):
        last_index -= 1
        assert last_index >= 0
        indexes.append(last_index)
        total_length += len(tokens[last_index])
    
    indexes.reverse()
    
    generated = ''.join([tokens[i] for i in indexes])
    assert entity in generated, 'entity: {}; prob calculated: {}'.format(entity, generated)
    
    prob = 1.0
    for i in indexes:
        prob *= probs[i]
    return prob

In [13]:
def get_cmlm_probability(bart_model, masked_sentence, entity, verbose=False):
    masked_input, masked_lengths = tokenize(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input,
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=bart_model.encode(entity)[1:].cuda(),
                                       verbose=verbose)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    assert decode_func(masked_output_ids[0]) == entity
    assert ''.join(masked_tokens) == entity
    
    prob = 1.0
    for i in range(3, len(masked_token_probs)):
        prob *= masked_token_probs[i]
    return prob

In [14]:
def get_prior_probability(bart_model, sentence, masked_sentence, position, entity, verbose=False):
    masked_input, masked_lengths = tokenize_with_mask(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input,
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=bart_model.encode(sentence)[1:].cuda(),
                                       verbose=verbose)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    assert decode_func(masked_output_ids[0]) == sentence, '{}; {}'.format(decode_func(masked_output_ids[0]), sentence)

    return get_probability(position, masked_tokens, masked_token_probs, entity)

In [15]:
def cmlm_generate(bart_model, masked_sentence, verbose=False):
    masked_input, masked_lengths = tokenize(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input, 
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=None,
                                       verbose=verbose)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    
    return decode_func(masked_output_ids[0])

In [16]:
def prior_generate(bart_model, masked_sentence):
    masked_input, masked_lengths = tokenize_with_mask(masked_sentence)
    masked_outputs = generate_sequence(bart_model.model.decoder,
                                       bart_model.model.encoder(masked_input, 
                                                                src_lengths=masked_lengths),
                                       tgt_tokens=None,
                                       verbose=False)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    
    return decode_func(masked_output_ids[0])

#### Get Conditional Probability

In [17]:
import spacy

nlp = spacy.load('en')

In [18]:
INDEX = 7079

In [19]:
source = xsum_source[INDEX]
target = "New Celtic manager Brendan Rodgers has met the club's captain for the first time as he prepares for his first game in charge of the club."
print(target)

ent_parts = nlp(target).to_json()['ents']
print(ent_parts, end='\n\n')

for e in ent_parts:
    entity = target[e['start']: e['end']]
    
    masked_hypothesis = target[0: e['start']] + '<mask>' + target[e['end']:]
    prior = get_prior_probability(bart, target, masked_hypothesis, (e['start'], e['end']), entity)
#     print(target[0: e['start']] + '<mask>' + target[e['end']:])
#     print('- generated: {}'.format(prior_generate(bart, masked_hypothesis)))
    
    masked_hypothesis = target[0: e['start']] + '###' + target[e['end']:]
    masked_hypothesis = '<s> ' + masked_hypothesis + ' <\s> ' + source
    posterior = get_cmlm_probability(finetuned_bart,
                                     masked_hypothesis,
                                     '<s> ' + entity, 
                                     verbose=False)
#     print(target[0: e['start']] + '###' + target[e['end']:])
#     print(cmlm_generate(finetuned_bart, masked_hypothesis))

    print('- entity: {}'.format(entity))
    print('- prior: {}'.format(prior))
    print('- posterior: {}'.format(posterior))
    print('- ratio: {:.3f} / {:.3f} = {:.3f}'.format(posterior, prior, posterior / (prior + 1e-5)))
    print()

New Celtic manager Brendan Rodgers has met the club's captain for the first time as he prepares for his first game in charge of the club.
[{'start': 0, 'end': 10, 'label': 'NORP'}, {'start': 19, 'end': 34, 'label': 'PERSON'}, {'start': 70, 'end': 75, 'label': 'ORDINAL'}, {'start': 104, 'end': 109, 'label': 'ORDINAL'}]

- entity: New Celtic
- prior: 0.00017683696933090687
- posterior: 0.2803802490234375
- ratio: 0.280 / 0.000 = 1500.668

- entity: Brendan Rodgers
- prior: 0.05584716796875
- posterior: 0.8109626770019531
- ratio: 0.811 / 0.056 = 14.519

- entity: first
- prior: 0.99072265625
- posterior: 0.9189453125
- ratio: 0.919 / 0.991 = 0.928

- entity: first
- prior: 0.91552734375
- posterior: 0.892578125
- ratio: 0.893 / 0.916 = 0.975

