In [1]:
import math
import torch
from fairseq.models.bart import BARTModel

In [2]:
bart = BARTModel.from_pretrained('./bart.large', checkpoint_file='model.pt')

In [3]:
bart.cuda()
bart.eval()
print('- activate evaluation mode')

- activate evaluation mode


In [4]:
encode_func = bart.encode
decode_func = bart.decode

In [5]:
bart_encoder = bart.model.encoder
bart_decoder = bart.model.decoder
print(type(bart.model))
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.bart.model.BARTModel'>
<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


In [6]:
test_inputs = ["Before she saw face his, Liz heard Allen's voice and began to cry.",
               "Before she saw saw his face, Liz heard Allen's voice and began to cry."]
test_optuputs = ["Before she saw his face, Liz heard Allen's voice and began to cry.",
                 "Before she saw his face, Liz heard Allen's voice and began to cry."]

#### Tokenization

In [7]:
from fairseq.data.data_utils import collate_tokens

In [8]:
input_tokens = [encode_func(s) for s in test_inputs]
output_tokens = [encode_func(s) for s in test_optuputs]

In [9]:
input_batch = collate_tokens(input_tokens, pad_idx=1)
output_batch = collate_tokens(output_tokens, pad_idx=1)

In [10]:
print(input_batch.shape)
print(output_batch.shape)

torch.Size([2, 19])
torch.Size([2, 18])


In [11]:
src_lengths = torch.tensor([17, 20], dtype=torch.long)
encoder_outputs = bart_encoder(input_batch.cuda(), src_lengths.cuda())

In [12]:
encoder_outputs.encoder_out.shape

torch.Size([19, 2, 1024])

In [13]:
# prev_output_tokens: [batch, tgt_len], encoder_out: all_outputs

In [14]:
init_input = torch.tensor([[2, 0], [2, 0]], dtype=torch.long).cuda()
max_decode_step = 20
pad_id = 1

for step in range(max_decode_step):
    decoder_outputs = bart_decoder(init_input, encoder_outputs)
    logits = decoder_outputs[0]  # [batch_size, seq_len, vocab]
    logits[:, -1, pad_id] = -math.inf  # never select pad
    
    _, sample = torch.topk(logits, 1, dim=-1)
    sample = sample.squeeze()
    if sample.ndim > 1:
        pred_tokens = sample[:, -1:]

    init_input = torch.cat([init_input, pred_tokens], dim=-1)

In [15]:
decode_func(init_input[0])

["Before she saw face his, Liz heard Allen's voice and began to cry.",
 '',
 '',
 '']

In [16]:
test = torch.tensor([    2,   347,  4405, 14725,     6, 18585,  1745,    35,  3742,    16,
             5,   144,  2247,  3944,    11,     5,   232,     6,    53,    24,
            18,    23,   810,   479, 14725,     6,  1745,    35,   318,     5,
         13678,   817,     5,  1593,   568,     6,    24,    74,  3549,  1161,
         18755,   479,   252,   224,  1769,  7993,    74,  8439,     5,   490,
          3742,     6,   712,  6886,     8,   850,   479, 14725,     6,  1745,
            35,   660,   490,  3742, 15885,   776,   434,     6,   592,     8,
           776, 11525,   479,     1], dtype=torch.long)

In [17]:
decode_func(test[:-3])

"Cory Booker, Angus King: Internet is the most powerful tool in the world, but it's at risk . Booker, King: If the FCC makes the wrong decision, it would kill net neutrality . They say fast lanes would destroy the open Internet, increase discrimination and prices . Booker, King: An open Internet promotes economic growth, social and economic"

In [47]:
encode_func("Before she saw face his. </s> Liz heard Allen's voice and began to cry.")

tensor([    0, 17206,    79,   794,   652,    39,     4, 14359,  1317,  3823,
           18,  2236,     8,   880,     7,  8930,     4,     2])

In [48]:
decode_func(torch.tensor([    0, 17206,    79,   794,   652,    39,     4, 14359,  1317,  3823,
           18,  2236,     8,   880,     7,  8930,     4,     2], dtype=torch.long))

"Before she saw face his. Liz heard Allen's voice and began to cry."

In [49]:
decode_func(torch.tensor([    0, 17206,    79,   794,   652,    39,     4, 49703,    29, 15698,
        14359,  1317,  3823,    18,  2236,     8,   880,     7,  8930,     4,
            2], dtype=torch.long))

"Before she saw face his. </s> Liz heard Allen's voice and began to cry."

In [51]:
import os
os.path.splitext("data/fairseq/text.jsonl")

('data/fairseq/text', '.jsonl')