In [4]:
import math
import torch
import torch.nn as nn
from fairseq.models.bart import BARTModel

In [5]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large.xsum',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large.xsum')

In [6]:
bart.cuda()
bart.eval()
bart.half()
print('- activate evaluation mode')

- activate evaluation mode


In [7]:
encode_func = bart.encode
decode_func = bart.decode

In [8]:
bart_encoder = bart.model.encoder
bart_decoder = bart.model.decoder
print(type(bart.model))
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.bart.model.BARTModel'>
<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


#### Read XSum

In [9]:
def read_lines(file_path):
    files = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            files.append(line.strip())
    return files

In [10]:
document_path = '/home/ml/cadencao/XSum/fairseq_files/test.source'
target_path = '/home/ml/cadencao/XSum/fairseq_files/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Tokenization

In [11]:
import spacy
import rouge

from fairseq.data.data_utils import collate_tokens

In [12]:
def prepare_results(p, r, f):
    return '\t{}:\t{}: {:5.2f}\t{}: {:5.2f}\t{}: {:5.2f}'.format(metric, 'P', 100.0 * p, 'R', 100.0 * r, 'F1', 100.0 * f)

aggregator = 'Avg'
apply_avg = aggregator == 'Avg'
apply_best = aggregator == 'Best'
print('Evaluation with {}'.format(aggregator))

evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l', 'rouge-w'],
                        max_n=4,
                        limit_length=True,
                        length_limit=100,
                        length_limit_type='words',
                        apply_avg=apply_avg,
                        apply_best=apply_best,
                        alpha=0.5, # Default F1_score
                        weight_factor=1.2,
                        stemming=True)

Evaluation with Avg


In [13]:
nlp = spacy.load('en')

In [14]:
INDEX = 922
test_inputs = [xsum_source[INDEX]]
src_tokens = collate_tokens([encode_func(i) for i in test_inputs], pad_idx=1, left_pad=True)
src_tokens = src_tokens.cuda()
print(src_tokens.shape)

src_lengths = torch.sum(src_tokens != 1, dim=1)
print(src_lengths)

torch.Size([1, 900])
tensor([900], device='cuda:0')


In [15]:
# encoding

In [16]:
encoder_out = bart_encoder(src_tokens, src_lengths=src_lengths)
print(encoder_out.encoder_out.shape)  # [seq_len, batch_size, hidden_size]

torch.Size([900, 1, 1024])


In [17]:
# decoding

In [18]:
init_input = torch.tensor([[2, 0]] * src_tokens.shape[0], dtype=torch.long).cuda()
min_decode_step, max_decode_step, pad_id, eos_id = 10, 60, 1, 2
softmax = nn.Softmax(dim=1)
token_probs = []

for step in range(max_decode_step):
    decoder_outputs = bart_decoder(init_input, encoder_out, features_only=False)
    logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
    if step + 1 < min_decode_step:
        logits[:, eos_id] = -math.inf

    logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select pad, start token
    probs = softmax(logits)
    assert logits.shape == probs.shape
    attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]
    assert logits.dim() == 2 and attn.dim() == 3

    # preds = torch.argmax(logits, dim=1)
    value, indices = torch.topk(probs, 5, dim=1)
    selected_token = indices[:, 0]

    if step == 1:
        selected_token = indices[:, 0]
    elif step == 70:
        selected_token = indices[:, 0]
    elif step == 20:
        selected_token = indices[:, 0]

    init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
    print("- {:02d}: {} ({:.2f})".format(step, decode_func(selected_token), probs.squeeze()[selected_token.item()].item()), end='\n')
    token_probs.append(probs.squeeze()[selected_token.item()].item())
    
    if selected_token.item() == eos_id:
        break

token_probs_sorted = [i for i in token_probs]
token_probs_sorted.sort()
print('{}; {}; {}.'.format(token_probs.index(token_probs_sorted[0]),
                           token_probs.index(token_probs_sorted[1]),
                           token_probs.index(token_probs_sorted[2])))

- 00: The (0.23)
- 01:  exams (0.23)
- 02:  regulator (0.45)
- 03:  Of (0.67)
- 04: qual (0.93)
- 05:  has (0.48)
- 06:  been (0.64)
- 07:  accused (0.88)
- 08:  of (0.80)
- 09:  failing (0.43)
- 10:  to (0.86)
- 11:  properly (0.10)
- 12:  apply (0.12)
- 13:  its (0.43)
- 14:  own (0.88)
- 15:  rules (0.56)
- 16:  on (0.25)
- 17:  the (0.40)
- 18:  way (0.92)
- 19:  GC (0.49)
- 20: SE (0.90)
- 21: s (0.43)
- 22:  are (0.82)
- 23:  graded (0.61)
- 24: , (0.40)
- 25:  according (0.25)
- 26:  to (0.90)
- 27:  a (0.75)
- 28:  report (0.52)
- 29: . (0.79)
- 30:  (0.90)
11; 12; 1.


In [19]:
init_input

tensor([[    2,     0,   133, 15734,  8199,  1525, 27702,    34,    57,  1238,
             9,  4551,     7,  5083,  3253,    63,   308,  1492,    15,     5,
           169, 18397,  3388,    29,    32, 35175,     6,   309,     7,    10,
           266,     4,     2]], device='cuda:0')

In [20]:
hypothesis = decode_func(init_input[0])
print(hypothesis)

The exams regulator Ofqual has been accused of failing to properly apply its own rules on the way GCSEs are graded, according to a report.


In [21]:
xsum_target[INDEX]

'England exams regulator Ofqual breached its own rules in allowing controversial changes to the way English GCSEs were graded this summer, it is claimed.'

In [22]:
all_hypothesis = [hypothesis]
all_references = [xsum_target[INDEX]]
scores = evaluator.get_scores(all_hypothesis, all_references)

for metric, results in sorted(scores.items(), key=lambda x: x[0]):
    print(prepare_results(results['p'], results['r'], results['f']))

	rouge-1:	P: 48.00	R: 50.00	F1: 48.98
	rouge-2:	P: 20.83	R: 21.74	F1: 21.28
	rouge-3:	P:  8.70	R:  9.09	F1:  8.89
	rouge-4:	P:  0.00	R:  0.00	F1:  0.00
	rouge-l:	P: 46.60	R: 48.21	F1: 47.39
	rouge-w:	P: 31.22	R: 17.22	F1: 22.20
