In [1]:
import math
import torch
import string
import torch.nn as nn

from fairseq.models.bart import BARTModel
from utils import read_lines, get_probability

In [2]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large')

In [3]:
bart.cuda()
bart.eval()
bart.half()
print('- activate evaluation mode')

- activate evaluation mode


In [4]:
encode_func = bart.encode
decode_func = bart.decode

In [5]:
bart_encoder = bart.model.encoder
bart_decoder = bart.model.decoder
print(type(bart.model))
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.bart.model.BARTModel'>
<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


#### Read XSum

In [6]:
document_path = '/home/ml/cadencao/XSum/fairseq_files/test.source'
target_path = '/home/ml/cadencao/XSum/fairseq_files/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Tokenization

In [7]:
import spacy

from fairseq.data.data_utils import collate_tokens

In [8]:
nlp = spacy.load('en')

#### Build a Masked Input

In [9]:
input_sentence = "A biscuit maker has gone into administration, with the loss of <mask> jobs, after the UK voted to leave the European Union."
pred_sentence  = "A biscuit maker has gone into administration, with the loss of more than 100 jobs, after the UK voted to leave the European Union."

In [10]:
bpe_code = bart.bpe.encode(input_sentence)  # 1279 27932 29
input_ids = bart.task.source_dictionary.encode_line('<s> ' + bpe_code.replace('1279 27932 29', '<mask>'), append_eos=True).long()
print(input_ids)

tensor([    0,   250, 39315,  6439,  4403,    34,  1613,    88,   942,     6,
           19,     5,   872,     9, 50264,  1315,     6,    71,     5,   987,
         2763,     7,   989,     5,   796,  1332,     4,     2])


In [11]:
decode_func(input_ids.long())

'A biscuit maker has gone into administration, with the loss of<mask> jobs, after the UK voted to leave the European Union.'

In [12]:
INDEX = 6824
# test_inputs = [xsum_source[INDEX]]
# src_tokens = collate_tokens([encode_func(i) for i in test_inputs], pad_idx=1, left_pad=True)
tgt_tokens = encode_func(pred_sentence).cuda()
src_tokens = input_ids.unsqueeze(0).cuda()
print(tgt_tokens.shape)
print(src_tokens.shape)

src_lengths = torch.sum(src_tokens != 1, dim=1)
print(src_lengths)

torch.Size([30])
torch.Size([1, 28])
tensor([28], device='cuda:0')


In [13]:
# encoding

In [14]:
encoder_out = bart_encoder(src_tokens, src_lengths=src_lengths)
print(encoder_out.encoder_out.shape)  # [seq_len, batch_size, hidden_size]

torch.Size([28, 1, 1024])


In [15]:
# decoding

In [16]:
tgt_tokens

tensor([    0,   250, 39315,  6439,  4403,    34,  1613,    88,   942,     6,
           19,     5,   872,     9,    55,    87,   727,  1315,     6,    71,
            5,   987,  2763,     7,   989,     5,   796,  1332,     4,     2],
       device='cuda:0')

In [17]:
init_input = torch.tensor([[2, 0]] * src_tokens.shape[0], dtype=torch.long).cuda()
min_decode_step, max_decode_step, pad_id, eos_id = 10, 60, 1, 2
softmax = nn.Softmax(dim=1)
token_probs, tokens = [], []

for step in range(max_decode_step):
    decoder_outputs = bart_decoder(init_input, encoder_out, features_only=False)
    logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
    if step + 1 < min_decode_step:
        logits[:, eos_id] = -math.inf

    logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select pad, start token
    probs = softmax(logits)
    assert logits.shape == probs.shape

    value, indices = torch.topk(probs, 5, dim=1)
    selected_token = indices[:, 0]
    # selected_token = tgt_tokens[1:][step].unsqueeze(0)

    init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
    token, prob = decode_func(selected_token), probs.squeeze()[selected_token.item()].item()
    token_probs.append(prob)
    tokens.append(token)

    if selected_token.item() == eos_id:
        break
    else:
        print("- {:02d}: {} ({:.5f})".format(step, token, prob), end='\n')

assert len(tokens) == len(token_probs)
token_probs_sorted = [i for i in token_probs]
token_probs_sorted.sort()
print('- Lowest Probabilities: [{}, {}, {}]'.format(token_probs.index(token_probs_sorted[0]),
                                                    token_probs.index(token_probs_sorted[1]),
                                                    token_probs.index(token_probs_sorted[2])))

- 00: A (0.96094)
- 01:  bisc (1.00000)
- 02: uit (1.00000)
- 03:  maker (1.00000)
- 04:  has (0.99902)
- 05:  gone (1.00000)
- 06:  into (1.00000)
- 07:  administration (1.00000)
- 08: , (1.00000)
- 09:  with (1.00000)
- 10:  the (1.00000)
- 11:  loss (1.00000)
- 12:  of (1.00000)
- 13:  more (0.15210)
- 14:  than (0.99902)
- 15:  100 (0.12085)
- 16:  jobs (0.79443)
- 17: , (0.96582)
- 18:  after (0.99609)
- 19:  the (1.00000)
- 20:  UK (1.00000)
- 21:  voted (1.00000)
- 22:  to (1.00000)
- 23:  leave (1.00000)
- 24:  the (1.00000)
- 25:  European (1.00000)
- 26:  Union (1.00000)
- 27: . (1.00000)
- Lowest Probabilities: [15, 13, 16]


In [18]:
get_probability('Ryan Chase', tokens, token_probs)

Target (Ryan Chase) not found!!!


-1.0

In [19]:
hypothesis = decode_func(init_input[0])
print(hypothesis)

A biscuit maker has gone into administration, with the loss of more than 100 jobs, after the UK voted to leave the European Union.


In [20]:
input_sentence

'A biscuit maker has gone into administration, with the loss of <mask> jobs, after the UK voted to leave the European Union.'