In [1]:
import math
import torch
import torch.nn as nn

from fairseq.models.bart import BARTModel
from utils import read_lines, get_probability

In [2]:
finetuned_bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large.xsum',
                                           checkpoint_file='model.pt',
                                           data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large.xsum')

In [3]:
finetuned_bart.cuda()
finetuned_bart.eval()
finetuned_bart.half()
print('- fine-tuned bart model loaded.')

- fine-tuned bart model loaded.


In [4]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large')

In [5]:
bart.cuda()
bart.eval()
bart.half()
print('- bart model loaded.')

- bart model loaded.


In [6]:
encode_func = bart.encode
decode_func = bart.decode

In [7]:
bart_encoder = finetuned_bart.model.encoder
bart_decoder = finetuned_bart.model.decoder

In [8]:
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


#### Read XSum

In [9]:
document_path = '/home/ml/cadencao/XSum/fairseq_files/test.source'
target_path = '/home/ml/cadencao/XSum/fairseq_files/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Generate Summary

In [10]:
from fairseq.data.data_utils import collate_tokens

In [11]:
def tokenize(src_input, verbose=False):
    src_inputs = [src_input]  # list of input string
    src_tokens = collate_tokens([encode_func(i) for i in src_inputs], pad_idx=1, left_pad=True)
    src_tokens = src_tokens.cuda()
    src_lengths = torch.sum(src_tokens != 1, dim=1)
    
    if verbose:
        print('- src tokens: {};\n- src lengths: {}'.format(src_tokens.shape, src_lengths.shape))
    return src_tokens, src_lengths

In [12]:
def generate_sequence(decoder, encoder_out, tgt_tokens=None, min_decode_step=10, max_decode_step=60, pad_id=1, eos_id=2, verbose=True):
    init_input = torch.tensor([[2, 0]] * src_tokens.shape[0], dtype=torch.long).cuda()
    softmax = nn.Softmax(dim=1)
    token_probs, tokens = [], []

    for step in range(max_decode_step):
        decoder_outputs = decoder(init_input, encoder_out, features_only=False)
        logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
        
        if step + 1 < min_decode_step:
            logits[:, eos_id] = -math.inf
        logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select pad, start token

        probs = softmax(logits)
        assert logits.shape == probs.shape
        attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]
        assert logits.dim() == 2 and attn.dim() == 3

        if tgt_tokens is not None:
            selected_token = tgt_tokens[step].unsqueeze(0)
        else:
            value, indices = torch.topk(probs, 5, dim=1)
            selected_token = indices[:, 0]

#             if step == 1:
#                 selected_token = indices[:, 0]
#             elif step == 10:
#                 selected_token = indices[:, 0]
#             elif step == 12:
#                 selected_token = indices[:, 0]

        init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
        token, prob = decode_func(selected_token), probs.squeeze()[selected_token.item()].item()
        token_probs.append(prob)
        tokens.append(token)

        if selected_token.item() == eos_id:
            break
        elif verbose:
            print("- {:02d}: {} ({:.2f})".format(step, token, prob), end='\n')

    return init_input, tokens, token_probs

In [64]:
def tokenize_with_mask(input_sentence):
    bpe_code = bart.bpe.encode(input_sentence)  # <mask>: 1279 27932 29
    input_ids = bart.task.source_dictionary.encode_line('<s> ' + bpe_code.replace('1279 27932 29', '<mask>'), 
                                                        append_eos=True).long()
    input_ids = input_ids.unsqueeze(0).cuda()
    src_lengths = torch.sum(input_ids != 1, dim=1)
    return input_ids, src_lengths

In [160]:
'For $235, #### will ship 6'.replace('will', '##')

'For $235, #### ## ship 6'

In [153]:
bart.bpe.encode('For $235, #### will ship 6')

'1890 720 22370 11 1303 21017 481 4074 718'

In [159]:
bart.bpe.encode('For $###. he will ship 6')

'1890 720 21017 13 339 481 4074 718'

In [155]:
bart.bpe.encode('#### , he will ship 6')

'4242 837 339 481 4074 718'

In [156]:
bart.bpe.encode('####, he will ship 6')

'4242 11 339 481 4074 718'

In [158]:
bart.bpe.decode('1303')

' #'

In [103]:
input_ids = bart.task.source_dictionary.encode_line('27 27932 29 8276 50026 7372 22559 23780 418 2539 468 4488 257 2775 1279 27932 22330 543 481 766 683 3520 379 10578 6839 10499 1566 262 3931 286 13130 13',
                                                    append_eos=True).long()

In [104]:
decode_func(input_ids)

'<mask> somebody Ulster centre Stuart McCloskey has signed a contract <mask>, which will see him remain at Kingspan Stadium until the summer of 2019.'

#### Get Conditional Probability

In [14]:
import spacy

nlp = spacy.load('en')

In [15]:
INDEX = 4

In [16]:
# tokenize
src_tokens, src_lengths = tokenize(xsum_source[INDEX])

# encode & decode
encoder_out = finetuned_bart.model.encoder(src_tokens, src_lengths=src_lengths)
output_ids, tokens, token_probs = generate_sequence(bart_decoder, encoder_out, verbose=False)

print(output_ids.shape)

hypothesis = decode_func(output_ids[0])
print(hypothesis + '\n')

ent_parts = ' '.join([e.text for e in nlp(hypothesis).ents]).split()
ent_parts = [e.text for e in nlp(hypothesis).ents]
for entity in ent_parts:
    masked_summary = hypothesis.replace(entity, '<mask>')

    masked_input, masked_lengths = tokenize_with_mask(masked_summary)
    masked_outputs = generate_sequence(bart.model.decoder,
                                       bart.model.encoder(masked_input,
                                                          src_lengths=masked_lengths),
                                       tgt_tokens=output_ids[0][2:],
                                       verbose=False)
    masked_output_ids, masked_tokens, masked_token_probs = masked_outputs
    assert decode_func(masked_output_ids[0]) == hypothesis

    posterior = get_probability(entity, tokens, token_probs)
    prior = get_probability(entity, masked_tokens, masked_token_probs)
    
    print('- entity: {}'.format(entity))
    print('- prior: {}'.format(prior))
    print('- posterior: {}'.format(posterior))
    print('- ratio: {:.3f} / {:.3f} = {:.3f}'.format(posterior, prior, posterior / (prior + 1e-5)))
    print()

torch.Size([1, 23])
Ireland and Ulster back Stuart McCloskey has signed a new three-year contract with the province.

- entity: Ireland
- prior: 2.6226043701171875e-06
- posterior: 0.423583984375
- ratio: 0.424 / 0.000 = 33557.574

- entity: Ulster
- prior: 0.1962890625
- posterior: 0.88818359375
- ratio: 0.888 / 0.196 = 4.525

- entity: Stuart McCloskey
- prior: 0.00516355092622689
- posterior: 0.32229159308417366
- ratio: 0.322 / 0.005 = 62.296

- entity: three-year
- prior: 0.1528004475403577
- posterior: 0.17944228858686984
- ratio: 0.179 / 0.153 = 1.174



In [17]:
xsum_target[INDEX]

'Ulster centre Stuart McCloskey has signed a contract extension which will see him remain at Kingspan Stadium until the summer of 2019.'

In [18]:
src_tokens.shape

torch.Size([1, 271])

In [19]:
encoder_out[0].shape

torch.Size([271, 1, 1024])