In [1]:
import math
import torch
import string
import torch.nn as nn

from fairseq.models.bart import BARTModel
from utils import read_lines, get_probability

In [2]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/checkpoints_xsum_clean/',
                                 checkpoint_file='checkpoint5.pt',
                                 data_name_or_path='/home/ml/cadencao/Two-Steps-Summarization/datasets/XSum/cleaned_files/cnn_dm-bin')

In [3]:
bart.cuda()
bart.eval()
bart.half()
print('- activate evaluation mode')

- activate evaluation mode


In [4]:
encode_func = bart.encode
decode_func = bart.decode

In [5]:
bart_encoder = bart.model.encoder
bart_decoder = bart.model.decoder
print(type(bart.model))
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.bart.model.BARTModel'>
<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


#### Read XSum

In [6]:
document_path = '/home/ml/cadencao/XSum/fairseq_files/test.source'
target_path = '/home/ml/cadencao/XSum/fairseq_files/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Tokenization

In [7]:
import spacy
import rouge

from fairseq.data.data_utils import collate_tokens

In [8]:
def prepare_results(p, r, f):
    return '\t{}:\t{}: {:5.2f}\t{}: {:5.2f}\t{}: {:5.2f}'.format(metric, 'P', 100.0 * p, 'R', 100.0 * r, 'F1', 100.0 * f)

aggregator = 'Avg'
apply_avg = aggregator == 'Avg'
apply_best = aggregator == 'Best'
print('Evaluation with {}'.format(aggregator))

evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l', 'rouge-w'],
                        max_n=4,
                        limit_length=True,
                        length_limit=100,
                        length_limit_type='words',
                        apply_avg=apply_avg,
                        apply_best=apply_best,
                        alpha=0.5, # Default F1_score
                        weight_factor=1.2,
                        stemming=True)

Evaluation with Avg


In [9]:
nlp = spacy.load('en')

In [10]:
INDEX = 11207
test_inputs = [xsum_source[INDEX]]
src_tokens = collate_tokens([encode_func(i) for i in test_inputs], pad_idx=1, left_pad=True)
src_tokens = src_tokens.cuda()
print(src_tokens.shape)

src_lengths = torch.sum(src_tokens != 1, dim=1)
print(src_lengths)

torch.Size([1, 325])
tensor([325], device='cuda:0')


In [11]:
# encoding

In [12]:
encoder_out = bart_encoder(src_tokens, src_lengths=src_lengths)
print(encoder_out.encoder_out.shape)  # [seq_len, batch_size, hidden_size]

torch.Size([325, 1, 1024])


In [13]:
# decoding

In [14]:
init_input = torch.tensor([[2, 0]] * src_tokens.shape[0], dtype=torch.long).cuda()
min_decode_step, max_decode_step, pad_id, eos_id = 10, 60, 1, 2
softmax = nn.Softmax(dim=1)
token_probs, tokens = [], []

for step in range(max_decode_step):
    decoder_outputs = bart_decoder(init_input, encoder_out, features_only=False)
    logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
    if step + 1 < min_decode_step:
        logits[:, eos_id] = -math.inf

    logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select pad, start token
    probs = softmax(logits)
    assert logits.shape == probs.shape
    attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]
    assert logits.dim() == 2 and attn.dim() == 3

    # preds = torch.argmax(logits, dim=1)
    value, indices = torch.topk(probs, 5, dim=1)
    selected_token = indices[:, 0]

    if step == 1:
        selected_token = indices[:, 0]
    elif step == 10:
        selected_token = indices[:, 0]
    elif step == 12:
        selected_token = indices[:, 0]

    init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
    token, prob = decode_func(selected_token), probs.squeeze()[selected_token.item()].item()
    token_probs.append(prob)
    tokens.append(token)

    if selected_token.item() == eos_id:
        break
    else:
        print("- {:02d}: {} ({:.2f})".format(step, token, prob), end='\n')

assert len(tokens) == len(token_probs)
token_probs_sorted = [i for i in token_probs]
token_probs_sorted.sort()
print('- Lowest Probabilities: [{}, {}, {}]'.format(token_probs.index(token_probs_sorted[0]),
                           token_probs.index(token_probs_sorted[1]),
                           token_probs.index(token_probs_sorted[2])))

- 00: ists (0.05)
- 01:  have (0.25)
- 02:  been (0.27)
- 03:  charged (0.22)
- 04:  with (0.78)
- 05:  the (0.22)
- 06:  murder (0.43)
- 07:  of (0.92)
- 08:  a (0.49)
- 09:  US (0.17)
- 10:  man (0.10)
- 11:  in (0.24)
- 12:  the (0.23)
- 13:  US (0.35)
- 14:  state (0.18)
- 15:  of (0.95)
- 16: , (0.09)
- 17:  in (0.12)
- 18:  a (0.34)
- 19:  case (0.13)
- 20:  that (0.42)
- 21:  killed (0.18)
- 22:  people (0.26)
- 23: . (0.35)
- Lowest Probabilities: [0, 16, 10]


In [15]:
get_probability('Swansea', tokens, token_probs)

Target (Swansea) not found!!!


-1.0

In [16]:
init_input

tensor([[   2,    0, 1952,   33,   57, 1340,   19,    5, 1900,    9,   10,  382,
          313,   11,    5,  382,  194,    9,    6,   11,   10,  403,   14,  848,
           82,    4,    2]], device='cuda:0')

In [17]:
hypothesis = decode_func(init_input[0])
print(hypothesis)

ists have been charged with the murder of a US man in the US state of, in a case that killed people.


In [18]:
nlp(hypothesis).ents

(US, US)

In [19]:
xsum_target[INDEX]

'The prosecution has rested its case in the trial of a man accused of carrying out the 2013 Boston Marathon bombings.'

In [20]:
all_hypothesis = [hypothesis]
all_references = [xsum_target[INDEX]]
scores = evaluator.get_scores(all_hypothesis, all_references)

for metric, results in sorted(scores.items(), key=lambda x: x[0]):
    print(prepare_results(results['p'], results['r'], results['f']))

	rouge-1:	P: 36.36	R: 38.10	F1: 37.21
	rouge-2:	P:  9.52	R: 10.00	F1:  9.76
	rouge-3:	P:  0.00	R:  0.00	F1:  0.00
	rouge-4:	P:  0.00	R:  0.00	F1:  0.00
	rouge-l:	P: 29.09	R: 30.24	F1: 29.66
	rouge-w:	P: 19.09	R: 10.88	F1: 13.86


In [42]:
SENTENCE = 'A body found in the ruins of a collapsed building at Didcot Power Station has been identified .'
encode_func(SENTENCE)

tensor([    0,   250,   809,   303,    11,     5, 24757,     9,    10,  7793,
          745,    23,  6553, 22921,  3029,  5088,    34,    57,  2006,   479,
            2])

In [58]:
for t in torch_tokens[1:-1]:
    print(decode_func(torch.tensor([t])))

A
 body
 found
 in
 the
 ruins
 of
 a
 collapsed
 building
 at
 Did
cot
 Power
 Station
 has
 been
 identified
.


In [53]:
encode_func(' Didcot')

tensor([    0,  6553, 22921,     2])

In [56]:
bart.bpe.encode(' Didcot')

'7731 25557'

In [43]:
tokens = bart.bpe.encode(SENTENCE)

In [44]:
tokens

'32 1767 1043 287 262 20073 286 257 14707 2615 379 7731 25557 4333 9327 468 587 5174 764'

In [45]:
print(len(SENTENCE.split()))
print(len(tokens.split()))

18
19


In [24]:
bpe_sentence = '<s> ' + tokens + ' </s>'

In [25]:
bpe_sentence

'<s> 32 1767 1043 287 262 20073 286 257 14707 2615 379 7731 25557 4333 9327 468 587 5174 13 </s>'

In [36]:
torch_tokens = bart.task.source_dictionary.encode_line(bpe_sentence, append_eos=False)

In [37]:
torch_tokens.long()

tensor([    0,   250,   809,   303,    11,     5, 24757,     9,    10,  7793,
          745,    23,  6553, 22921,  3029,  5088,    34,    57,  2006,     4,
            2])