In [2]:
import math
import torch
import torch.nn as nn

from fairseq.models.bart import BARTModel
from utils import read_lines, get_probability

In [6]:
finetuned_bart = BARTModel.from_pretrained('/home/ml/users/cadencao/Downloads/BART_models/bart.large.xsum',
                                           checkpoint_file='model.pt',
                                           data_name_or_path='/home/ml/users/cadencao/Downloads/BART_models/bart.large.xsum')

In [7]:
finetuned_bart.cuda()
finetuned_bart.eval()
finetuned_bart.half()
print('- fine-tuned bart model loaded.')

- fine-tuned bart model loaded.


In [8]:
bart = BARTModel.from_pretrained('/home/ml/users/cadencao/Downloads/BART_models/bart.large',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/users/cadencao/Downloads/BART_models/bart.large')

In [9]:
bart.cuda()
bart.eval()
bart.half()
print('- bart model loaded.')

- bart model loaded.


In [10]:
encode_func = bart.encode
decode_func = bart.decode

In [11]:
bart_encoder = finetuned_bart.model.encoder
bart_decoder = finetuned_bart.model.decoder

In [12]:
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


#### Read XSum

In [14]:
document_path = '/home/ml/users/cadencao/XSum/fairseq_files/test.source'
target_path = '/home/ml/users/cadencao/XSum/fairseq_files/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Generate Summary

In [15]:
from fairseq.data.data_utils import collate_tokens

In [20]:
def tokenize(input_str, append_bos=False, append_eos=True, left_pad=True):
    """BPE-encode a sentence (or multiple sentences).

    Args:
        input_str (str or List[str]): input sentence to be tokenized.
        append_bos (bool): self-explained.
        append_eos (bool): self-explained.

    Return:
        input_ids (torch.Tensor): [batch_size, length]
        src_lengths (torch.Tensor): [batch_size]
    """
    if type(input_str) == type(''):
        input_str = [input_str]

    input_ids = []
    for ins in input_str:
        tokens = bart.bpe.encode(ins)  # <mask>: 1279 27932 29
        calibration = sum([append_bos, append_eos])
        if len(tokens.split(" ")) > min(bart.max_positions) - calibration:
            tokens = " ".join(tokens.split(" ")[: min(bart.max_positions) - calibration])

        tokens = "<s> " + tokens if append_bos else tokens
        tokens = tokens + " </s>" if append_eos else tokens
        ids = bart.task.source_dictionary.encode_line(tokens, append_eos=False).long()
        input_ids.append(ids)

    input_ids = collate_tokens(input_ids, pad_idx=1, left_pad=left_pad).cuda()
    input_lengths = torch.sum(input_ids != 1, dim=1).cuda()

    return input_ids, input_lengths

In [21]:
tokenize("New Celtic manager Brendan Rodgers says he has been discussing his future plans with some of his senior players.")

(tensor([[ 4030, 11955,  1044, 13015,  9122,   161,    37,    34,    57,  7345,
             39,   499,   708,    19,   103,     9,    39,   949,   472,     4,
              2]], device='cuda:0'), tensor([21], device='cuda:0'))

In [None]:
# "New Celtic manager Brendan Rodgers says he has been discussing his future plans with some of his senior players."
# [0,  4030, 11955,  1044, 13015,  9122,   161,    37,    34,    57,
#  7345,    39,   499,   708,    19,   103,     9,    39,   949,   472, 4,     2]

In [12]:
def generate_sequence(decoder, encoder_out, tgt_tokens=None, min_decode_step=10, max_decode_step=60, pad_id=1, eos_id=2, verbose=True):
    init_input = torch.tensor([[2, 0]] * src_tokens.shape[0], dtype=torch.long).cuda()
    softmax = nn.Softmax(dim=1)
    token_probs, tokens = [], []

    for step in range(max_decode_step):
        decoder_outputs = decoder(init_input, encoder_out, features_only=False)
        logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
        
        if step + 1 < min_decode_step:
            logits[:, eos_id] = -math.inf
        logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select pad, start token

        probs = softmax(logits)
        assert logits.shape == probs.shape
        attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]
        assert logits.dim() == 2 and attn.dim() == 3

        if tgt_tokens is not None:
            selected_token = tgt_tokens[step].unsqueeze(0)
        else:
            value, indices = torch.topk(probs, 5, dim=1)
            selected_token = indices[:, 0]

#             if step == 1:
#                 selected_token = indices[:, 0]
#             elif step == 10:
#                 selected_token = indices[:, 0]
#             elif step == 12:
#                 selected_token = indices[:, 0]

        init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
        token, prob = decode_func(selected_token), probs.squeeze()[selected_token.item()].item()
        token_probs.append(prob)
        tokens.append(token)

        if selected_token.item() == eos_id:
            break
        elif verbose:
            print("- {:02d}: {} ({:.2f})".format(step, token, prob), end='\n')

    return init_input, tokens, token_probs

In [None]:
def tokenize_with_mask(input_sentence):
    bpe_code = bart.bpe.encode(input_sentence)  # <mask>: 1279 27932 29
    input_ids = bart.task.source_dictionary.encode_line('<s> ' + bpe_code.replace('1279 27932 29', '<mask>'), 
                                                        append_eos=True).long()
    input_ids = input_ids.unsqueeze(0).cuda()
    src_lengths = torch.sum(input_ids != 1, dim=1)
    return input_ids, src_lengths

#### Get Conditional Probability

In [None]:
import spacy

nlp = spacy.load('en')

In [None]:
INDEX = 4

In [None]:
# tokenize
src_tokens, src_lengths = tokenize(xsum_source[INDEX])

# encode & decode
encoder_out = finetuned_bart.model.encoder(src_tokens, src_lengths=src_lengths)
output_ids, tokens, token_probs = generate_sequence(bart_decoder, encoder_out, verbose=False)

print(output_ids.shape)

hypothesis = decode_func(output_ids[0])
print(hypothesis + '\n')

ent_parts = ' '.join([e.text for e in nlp(hypothesis).ents]).split()
ent_parts = [e.text for e in nlp(hypothesis).ents]
for entity in ent_parts:
    masked_summary = hypothesis.replace(entity, '<mask>')

    masked_input, masked_lengths = tokenize_with_mask(masked_summary)
    masked_outputs = generate_sequence(bart.model.decoder,
                                       bart.model.encoder(masked_input,
                                                          src_lengths=masked_lengths),
                                       tgt_tokens=output_ids[0][2:],
                                       verbose=False)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    assert decode_func(masked_output_ids[0]) == hypothesis

    posterior = get_probability(entity, tokens, token_probs)
    prior = get_probability(entity, masked_tokens, masked_token_probs)
    
    print('- entity: {}'.format(entity))
    print('- prior: {}'.format(prior))
    print('- posterior: {}'.format(posterior))
    print('- ratio: {:.3f} / {:.3f} = {:.3f}'.format(posterior, prior, posterior / (prior + 1e-5)))
    print()

In [17]:
xsum_target[INDEX]

'Ulster centre Stuart McCloskey has signed a contract extension which will see him remain at Kingspan Stadium until the summer of 2019.'

In [18]:
src_tokens.shape

torch.Size([1, 271])

In [19]:
encoder_out[0].shape

torch.Size([271, 1, 1024])