In [1]:
import math
import torch
import torch.nn as nn 
from fairseq.models.bart import BARTModel

In [2]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large.cnn',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large.cnn')

In [3]:
bart.cuda()
bart.eval()
bart.half()
print('- activate evaluation mode')

- activate evaluation mode


In [4]:
bart_encoder, bart_decoder = bart.model.encoder, bart.model.decoder
print(type(bart.model))
print(type(bart_encoder))

<class 'fairseq.models.bart.model.BARTModel'>
<class 'fairseq.models.transformer.TransformerEncoder'>


#### Read Data

In [5]:
from utils import read_lines

In [6]:
source_path = '/home/ml/cadencao/Two-Steps-Summarization/datasets/cnn_dm/fairseq_files/test.source'
target_path = '/home/ml/cadencao/Two-Steps-Summarization/datasets/cnn_dm/fairseq_files/test.target'
source = read_lines(source_path)
target = read_lines(target_path)
print(len(source))
assert len(source) == len(target)

11490


#### Tokenization

In [7]:
sentences = [source[0], source[1], source[2]]
input = [bart.encode(s) for s in sentences]
sample = bart._build_sample(input)

In [8]:
print(sample.keys())
print(sample['net_input'])

dict_keys(['id', 'nsentences', 'ntokens', 'net_input', 'target'])
{'src_tokens': tensor([[    0, 10169,  1090,  ...,  2951,   212,     2],
        [    1,     1,     1,  ...,   777,     4,     2],
        [    1,     1,     1,  ...,   266,     4,     2]], device='cuda:0'), 'src_lengths': tensor([1024,  952,  694], device='cuda:0')}


#### Encoding

In [9]:
src_tokens, src_lengths = sample['net_input']['src_tokens'], sample['net_input']['src_lengths']
encoder_out = bart_encoder(src_tokens, src_lengths=src_lengths)
print(encoder_out.encoder_out.shape)  # [seq_len, batch_size, hidden_size]

torch.Size([1024, 3, 1024])


In [10]:
sample['net_input']['src_tokens'].new_zeros((len(input), 1)).fill_(bart.task.source_dictionary.bos())

tensor([[0],
        [0],
        [0]], device='cuda:0')

#### Decoding

In [11]:
src_tokens.shape  # [batch_size, max_seq_length]

torch.Size([3, 1024])

In [12]:
bsz = src_tokens.shape[0]
min_decode_step, max_decode_step, pad_id, eos_id, bos_id = 55, 140, 1, 2, 0
softmax = nn.Softmax(dim=1)

In [13]:
tokens = src_tokens.new(bsz, max_decode_step + 2).long().fill_(pad_id)
eos_flags = torch.tensor([False] * bsz).cuda()
tokens[:, 0] = eos_id
tokens[:, 1] = bos_id

In [14]:
print(eos_flags.shape)
print(tokens.shape)

torch.Size([3])
torch.Size([3, 142])


In [16]:
token_probs = []
for step in range(max_decode_step):
    decoder_outputs = bart_decoder(tokens[:, :step + 2], encoder_out, features_only=False)
    logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]

    # mask certain SOS, EOS, PAD token
    if step < min_decode_step:
        logits[:, eos_id] = -math.inf
    logits[:, pad_id], logits[:, bos_id] = -math.inf, -math.inf  # never select pad, start token

#     probs = softmax(logits)
#     assert logits.shape == probs.shape

    # preds = torch.argmax(logits, dim=1)
    value, indices = torch.topk(logits, 5, dim=1)  # value, indices: [batch_size, top_k]
    selected_token = indices[:, 0]

    if step == 0:
        selected_token = indices[:, 0]

    # keep the eos_id token
    selected_token.masked_fill_(eos_flags, pad_id)
    eos_flags = eos_flags + (selected_token == eos_id)

    tokens[:, step + 2] = selected_token
#     token_probs.append(torch.gather(probs, 1, selected_token.unsqueeze(1)).squeeze().tolist())

In [17]:
print(tokens.shape)
print(tokens)

torch.Size([3, 142])
tensor([[    2,     0, 28586,  5644,   161,    37,    18,    45,  2542,     9,
           143,   569,  4338,    31,    15,   792,     5,  3286,     4,  1859,
           433,   690,    14,    51,    33,   303,  3551,  1028,   569,     9,
             5,  2058,     4,  1859,  5195,   226,  2951,   212,  1253,   102,
           161,  1029,    12,   642, 20366, 19494, 25500,  4494,    56, 12248,
          6943,     4, 25500,  4494,   174,    39,  2524,  1058,   334,    11,
          2338,    14,    37,    56,    10,    22,  5234, 24963,  3238,     9,
          3814,  6943,    60,   226,  2951,   212,  1253,   102,   161,     4,
             2,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,  

In [18]:
_, index = (tokens[:, 1:] == eos_id).nonzero(as_tuple=True)
print(index)

tensor([79, 57, 63], device='cuda:0')


In [19]:
print(tokens[0][1:index[0] + 2])

tensor([    0, 28586,  5644,   161,    37,    18,    45,  2542,     9,   143,
          569,  4338,    31,    15,   792,     5,  3286,     4,  1859,   433,
          690,    14,    51,    33,   303,  3551,  1028,   569,     9,     5,
         2058,     4,  1859,  5195,   226,  2951,   212,  1253,   102,   161,
         1029,    12,   642, 20366, 19494, 25500,  4494,    56, 12248,  6943,
            4, 25500,  4494,   174,    39,  2524,  1058,   334,    11,  2338,
           14,    37,    56,    10,    22,  5234, 24963,  3238,     9,  3814,
         6943,    60,   226,  2951,   212,  1253,   102,   161,     4,     2],
       device='cuda:0')


In [20]:
bart.decode(tokens[0][1: index[0] + 2])

'French prosecutor says he\'s not aware of any video footage from on board the plane. German media reports that they have found cell phone video of the crash. German airline Lufthansa says co-pilot Andreas Lubitz had battled depression. Lubitz told his flight training school in 2009 that he had a "previous episode of severe depression," Lufthansa says.'

In [21]:
'A French prosecutor says he is not aware of any video footage from on board the plane. German daily Bild and Paris Match claim to have found a cell phone video of the crash. A French Gendarmerie spokesman calls the reports "completely wrong" and "unwarranted" German airline Lufthansa says co-pilot Andreas Lubitz battled depression years before he took the controls.'

'A French prosecutor says he is not aware of any video footage from on board the plane. German daily Bild and Paris Match claim to have found a cell phone video of the crash. A French Gendarmerie spokesman calls the reports "completely wrong" and "unwarranted" German airline Lufthansa says co-pilot Andreas Lubitz battled depression years before he took the controls.'