In [1]:
import math
import torch
import string
import torch.nn as nn

from fairseq.models.bart import BARTModel
from utils import read_lines, get_probability

In [2]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large.xsum',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large.xsum')

In [3]:
bart.cuda()
bart.eval()
bart.half()
print('- activate evaluation mode')

- activate evaluation mode


In [4]:
encode_func = bart.encode
decode_func = bart.decode

In [5]:
bart_encoder = bart.model.encoder
bart_decoder = bart.model.decoder
print(type(bart.model))
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.bart.model.BARTModel'>
<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


#### Read XSum

In [6]:
document_path = '/home/ml/cadencao/XSum/fairseq_files/test.source'
target_path = '/home/ml/cadencao/XSum/fairseq_files/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Tokenization

In [7]:
import spacy
import rouge

from fairseq.data.data_utils import collate_tokens

In [8]:
def prepare_results(p, r, f):
    return '\t{}:\t{}: {:5.2f}\t{}: {:5.2f}\t{}: {:5.2f}'.format(metric, 'P', 100.0 * p, 'R', 100.0 * r, 'F1', 100.0 * f)

aggregator = 'Avg'
apply_avg = aggregator == 'Avg'
apply_best = aggregator == 'Best'
print('Evaluation with {}'.format(aggregator))

evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l', 'rouge-w'],
                        max_n=4,
                        limit_length=True,
                        length_limit=100,
                        length_limit_type='words',
                        apply_avg=apply_avg,
                        apply_best=apply_best,
                        alpha=0.5, # Default F1_score
                        weight_factor=1.2,
                        stemming=True)

Evaluation with Avg


In [9]:
nlp = spacy.load('en')

In [10]:
INDEX = 6220
test_inputs = [xsum_source[INDEX]]
src_tokens = collate_tokens([encode_func(i) for i in test_inputs], pad_idx=1, left_pad=True)
src_tokens = src_tokens.cuda()
print(src_tokens.shape)

src_lengths = torch.sum(src_tokens != 1, dim=1)
print(src_lengths)

torch.Size([1, 220])
tensor([220], device='cuda:0')


In [11]:
# encoding

In [12]:
encoder_out = bart_encoder(src_tokens, src_lengths=src_lengths)
print(encoder_out.encoder_out.shape)  # [seq_len, batch_size, hidden_size]

torch.Size([220, 1, 1024])


In [13]:
# decoding

In [14]:
init_input = torch.tensor([[2, 0]] * src_tokens.shape[0], dtype=torch.long).cuda()
min_decode_step, max_decode_step, pad_id, eos_id = 10, 60, 1, 2
softmax = nn.Softmax(dim=1)
token_probs, tokens = [], []

for step in range(max_decode_step):
    decoder_outputs = bart_decoder(init_input, encoder_out, features_only=False)
    logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
    if step + 1 < min_decode_step:
        logits[:, eos_id] = -math.inf

    logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select pad, start token
    probs = softmax(logits)
    assert logits.shape == probs.shape
    attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]
    assert logits.dim() == 2 and attn.dim() == 3

    # preds = torch.argmax(logits, dim=1)
    value, indices = torch.topk(probs, 5, dim=1)
    selected_token = indices[:, 0]

    if step == 1:
        selected_token = indices[:, 0]
    elif step == 10:
        selected_token = indices[:, 0]
    elif step == 12:
        selected_token = indices[:, 0]

    init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
    token, prob = decode_func(selected_token), probs.squeeze()[selected_token.item()].item()
    token_probs.append(prob)
    tokens.append(token)

    if selected_token.item() == eos_id:
        break
    else:
        print("- {:02d}: {} ({:.2f})".format(step, token, prob), end='\n')

assert len(tokens) == len(token_probs)
token_probs_sorted = [i for i in token_probs]
token_probs_sorted.sort()
print('- Lowest Probabilities: [{}, {}, {}]'.format(token_probs.index(token_probs_sorted[0]),
                           token_probs.index(token_probs_sorted[1]),
                           token_probs.index(token_probs_sorted[2])))

- 00: The (0.60)
- 01:  body (0.74)
- 02:  of (0.90)
- 03:  a (0.59)
- 04:  man (0.60)
- 05:  who (0.40)
- 06:  died (0.53)
- 07:  when (0.32)
- 08:  the (0.48)
- 09:  collapsed (0.05)
- 10:  Swansea (0.08)
- 11:  Power (0.10)
- 12:  Station (0.93)
- 13:  building (0.58)
- 14:  collapsed (0.67)
- 15:  has (0.82)
- 16:  been (0.89)
- 17:  removed (0.26)
- 18:  from (0.87)
- 19:  the (0.90)
- 20:  site (0.85)
- 21: . (0.58)
- Lowest Probabilities: [9, 10, 11]


In [15]:
get_probability('Swansea', tokens, token_probs)

0.07952880859375

In [16]:
init_input

tensor([[    2,     0,   133,   809,     9,    10,   313,    54,   962,    77,
             5,  7793, 15338,  3029,  5088,   745,  7793,    34,    57,  2928,
            31,     5,  1082,     4,     2]], device='cuda:0')

In [17]:
hypothesis = decode_func(init_input[0])
print(hypothesis)

The body of a man who died when the collapsed Swansea Power Station building collapsed has been removed from the site.


In [18]:
nlp(hypothesis).ents

(Swansea Power Station,)

In [19]:
xsum_target[INDEX]

'A body found in the ruins of a collapsed building at Didcot Power Station has been identified.'

In [20]:
all_hypothesis = [hypothesis]
all_references = [xsum_target[INDEX]]
scores = evaluator.get_scores(all_hypothesis, all_references)

for metric, results in sorted(scores.items(), key=lambda x: x[0]):
    print(prepare_results(results['p'], results['r'], results['f']))

	rouge-1:	P: 47.62	R: 58.82	F1: 52.63
	rouge-2:	P: 15.00	R: 18.75	F1: 16.67
	rouge-3:	P:  0.00	R:  0.00	F1:  0.00
	rouge-4:	P:  0.00	R:  0.00	F1:  0.00
	rouge-l:	P: 44.74	R: 53.36	F1: 48.67
	rouge-w:	P: 32.48	R: 22.77	F1: 26.77


In [21]:
SENTENCE = 'The US Senate has rejected plans to tighten gun controls, including the restriction of weapons sales to people on terrorism watch lists.'
encode_func(SENTENCE)

tensor([    0,   133,   382,  1112,    34,  3946,   708,     7, 16888,  1751,
         5656,     6,   217,     5, 20627,     9,  2398,   647,     7,    82,
           15,  4952,  1183,  8204,     4,     2])

In [22]:
encode_func(' Didcot')

tensor([    0,  6553, 22921,     2])

In [23]:
bart.bpe.encode(' Didcot')

'7731 25557'

In [24]:
tokens = bart.bpe.encode(SENTENCE)

In [25]:
tokens

'464 1294 3845 468 8606 3352 284 31833 2485 6973 11 1390 262 17504 286 3777 4200 284 661 319 8649 2342 8341 13'

In [26]:
print(len(SENTENCE.split()))
print(len(tokens.split()))

22
24


In [39]:
bpe_sentence = '<s> ' + tokens + ' </s>'
print(bpe_sentence)

<s> 464 1294 3845 468 8606 3352 284 31833 2485 6973 11 1390 262 17504 286 3777 4200 284 661 319 8649 2342 8341 13 </s>


In [40]:
torch_tokens = bart.task.source_dictionary.encode_line(bpe_sentence, append_eos=False)

In [41]:
torch_tokens.long()

tensor([    0,   133,   382,  1112,    34,  3946,   708,     7, 16888,  1751,
         5656,     6,   217,     5, 20627,     9,  2398,   647,     7,    82,
           15,  4952,  1183,  8204,     4,     2])

In [42]:
decode_func(torch_tokens.long())

'The US Senate has rejected plans to tighten gun controls, including the restriction of weapons sales to people on terrorism watch lists.'

In [43]:
SENTENCE

'The US Senate has rejected plans to tighten gun controls, including the restriction of weapons sales to people on terrorism watch lists.'