In [1]:
import math
import torch
from fairseq.models.bart import BARTModel

In [4]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large', checkpoint_file='model.pt')

1042301B [00:00, 2043724.62B/s]
456318B [00:00, 1245308.90B/s]


In [5]:
bart.cuda()
bart.eval()
print('- activate evaluation mode')

- activate evaluation mode


In [6]:
encode_func = bart.encode
decode_func = bart.decode

In [7]:
bart_encoder = bart.model.encoder
bart_decoder = bart.model.decoder
print(type(bart.model))
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.bart.model.BARTModel'>
<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


In [8]:
test_inputs = ["Before she saw face his, Liz heard Allen's voice and began to cry.",
               "Before she saw saw his face, Liz heard Allen's voice and began to cry."]
test_optuputs = ["Before she saw his face, Liz heard Allen's voice and began to cry.",
                 "Before she saw his face, Liz heard Allen's voice and began to cry."]

#### Tokenization

In [9]:
from fairseq.data.data_utils import collate_tokens

In [10]:
input_tokens = [encode_func(s) for s in test_inputs]
output_tokens = [encode_func(s) for s in test_optuputs]

In [11]:
input_batch = collate_tokens(input_tokens, pad_idx=1)
output_batch = collate_tokens(output_tokens, pad_idx=1)

In [12]:
print(input_batch.shape)
print(output_batch.shape)

torch.Size([2, 19])
torch.Size([2, 18])


In [13]:
src_lengths = torch.tensor([17, 20], dtype=torch.long)
encoder_outputs = bart_encoder(input_batch.cuda(), src_lengths.cuda())

In [14]:
encoder_outputs.encoder_out.shape

torch.Size([19, 2, 1024])

In [15]:
# prev_output_tokens: [batch, tgt_len], encoder_out: all_outputs

In [16]:
init_input = torch.tensor([[2, 0], [2, 0]], dtype=torch.long).cuda()
max_decode_step = 20
pad_id = 1

for step in range(max_decode_step):
    decoder_outputs = bart_decoder(init_input, encoder_outputs)
    logits = decoder_outputs[0]  # [batch_size, seq_len, vocab]
    logits[:, -1, pad_id] = -math.inf  # never select pad
    
    _, sample = torch.topk(logits, 1, dim=-1)
    sample = sample.squeeze()
    if sample.ndim > 1:
        pred_tokens = sample[:, -1:]

    init_input = torch.cat([init_input, pred_tokens], dim=-1)

In [17]:
decode_func(init_input[0])

["Before she saw face his, Liz heard Allen's voice and began to cry.",
 '',
 '',
 '']

In [18]:
test = torch.tensor([    2,   347,  4405, 14725,     6, 18585,  1745,    35,  3742,    16,
             5,   144,  2247,  3944,    11,     5,   232,     6,    53,    24,
            18,    23,   810,   479, 14725,     6,  1745,    35,   318,     5,
         13678,   817,     5,  1593,   568,     6,    24,    74,  3549,  1161,
         18755,   479,   252,   224,  1769,  7993,    74,  8439,     5,   490,
          3742,     6,   712,  6886,     8,   850,   479, 14725,     6,  1745,
            35,   660,   490,  3742, 15885,   776,   434,     6,   592,     8,
           776, 11525,   479,     1], dtype=torch.long)

In [19]:
decode_func(test[:-3])

"Cory Booker, Angus King: Internet is the most powerful tool in the world, but it's at risk . Booker, King: If the FCC makes the wrong decision, it would kill net neutrality . They say fast lanes would destroy the open Internet, increase discrimination and prices . Booker, King: An open Internet promotes economic growth, social and economic"

In [20]:
encode_func("Before she saw face his. </s> Liz heard Allen's voice and began to cry.")

tensor([    0, 17206,    79,   794,   652,    39,     4, 49703,    29, 15698,
        14359,  1317,  3823,    18,  2236,     8,   880,     7,  8930,     4,
            2])

In [21]:
decode_func(torch.tensor([    0, 17206,    79,   794,   652,    39,     4, 14359,  1317,  3823,
           18,  2236,     8,   880,     7,  8930,     4,     2], dtype=torch.long))

"Before she saw face his. Liz heard Allen's voice and began to cry."

In [27]:
decode_func(torch.tensor([487,  5649,  1910], dtype=torch.long))

'Nemanja'

In [26]:
"Nemanja Vidic and Rio Ferdinand have both had poor starts to the season . Vidic joined Inter in a three-year deal this summer from Manchester United . Ferdinand moved to QPR after his Old Trafford career . Pair have conceded 18 league goals in 10 matches since leaving United . Vidic was sent off on his Inter debut against Torino . 32-year-old has given away two penalties in four Serie A matches . Ferdinand was thoroughly outplayed by Stoke's Peter Crouch this month ."

"Nemanja Vidic and Rio Ferdinand have both had poor starts to the season . Vidic joined Inter in a three-year deal this summer from Manchester United . Ferdinand moved to QPR after his Old Trafford career . Pair have conceded 18 league goals in 10 matches since leaving United . Vidic was sent off on his Inter debut against Torino . 32-year-old has given away two penalties in four Serie A matches . Ferdinand was thoroughly outplayed by Stoke's Peter Crouch this month ."

In [None]:
# prev_output_tokens
tensor([[    2,   487,  5649,  1910, 23551,   636,     8,  5716, 28855,    33,
           258,    56,  2129,  2012,     7,     5,   191,   479, 23551,   636,
          1770,  3870,    11,    10,   130,    12,   180,   432,    42,  1035,
            31,  2361,   315,   479, 28855,  1410,     7,  1209,  4454,    71,
            39,  3470, 13592,   756,   479, 34587,    33, 11508,   504,  1267,
          1175,    11,   158,  2856,   187,  1618,   315,   479, 23551,   636,
            21,  1051,   160,    15,    39,  3870,  2453,   136,  6623,  1696,
           479,  2107,    12,   180,    12,   279,    34,   576,   409,    80,
          6736,    11,   237, 14158,    83,  2856,   479, 28855,    21, 12826,
            66, 13089,    30, 15607,    18,  2155,   230, 35378,    42,   353,
           479],
        [    2, 12815,  5892,  2971,   161,    70,  2891,   227, 18752,     8,
          2924, 15136,     8,  2203,   197,    28,  4638,   479,  5095, 33018,
          1478,   197,    33,  2998,   984,  1749,     6,   151,   458, 25176,
            11,   491,  1913,    77,    37,  2867, 18123,   852,     6,  5736,
           384,   108, 21489,   161,   479,   427, 33018,  1478,   956,    55,
          1675,   899,     7,  2339,    15,  4952,    71,  5195,  4840,  6197,
           479,  5653,  7567,   197,    33,    57,   699,    59,    99,    39,
           780,  4988,    21,   608,     6,    26,  5736,   384,   108, 21489,
           479,   427,  5628,    34,  1433,  2641,  3770,  1220,  1235,     7,
           120,   128, 15605,   593,   108,     7, 15136,     8, 32130,   994,
           479]], device='cuda:0')

In [None]:
# target
tensor([[  487,  5649,  1910, 23551,   636,     8,  5716, 28855,    33,   258,
            56,  2129,  2012,     7,     5,   191,   479, 23551,   636,  1770,
          3870,    11,    10,   130,    12,   180,   432,    42,  1035,    31,
          2361,   315,   479, 28855,  1410,     7,  1209,  4454,    71,    39,
          3470, 13592,   756,   479, 34587,    33, 11508,   504,  1267,  1175,
            11,   158,  2856,   187,  1618,   315,   479, 23551,   636,    21,
          1051,   160,    15,    39,  3870,  2453,   136,  6623,  1696,   479,
          2107,    12,   180,    12,   279,    34,   576,   409,    80,  6736,
            11,   237, 14158,    83,  2856,   479, 28855,    21, 12826,    66,
         13089,    30, 15607,    18,  2155,   230, 35378,    42,   353,   479,
             2],
        [12815,  5892,  2971,   161,    70,  2891,   227, 18752,     8,  2924,
         15136,     8,  2203,   197,    28,  4638,   479,  5095, 33018,  1478,
           197,    33,  2998,   984,  1749,     6,   151,   458, 25176,    11,
           491,  1913,    77,    37,  2867, 18123,   852,     6,  5736,   384,
           108, 21489,   161,   479,   427, 33018,  1478,   956,    55,  1675,
           899,     7,  2339,    15,  4952,    71,  5195,  4840,  6197,   479,
          5653,  7567,   197,    33,    57,   699,    59,    99,    39,   780,
          4988,    21,   608,     6,    26,  5736,   384,   108, 21489,   479,
           427,  5628,    34,  1433,  2641,  3770,  1220,  1235,     7,   120,
           128, 15605,   593,   108,     7, 15136,     8, 32130,   994,   479,
             2]], device='cuda:0')

In [None]:
# src_lengths
tensor([1024, 1024], device='cuda:0')

In [None]:
# src_tokens
tensor([[  487,  5649,  1910,  ..., 35444, 20921,     2],
        [ 2765,   479,  3005,  ...,    12,  5532,     2]], device='cuda:0')