In [1]:
import math
import torch
import torch.nn as nn
from fairseq.models.bart import BARTModel

In [2]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large.cnn',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large.cnn')

In [3]:
bart.cuda()
bart.eval()
bart.half()
print('- activate evaluation mode')

- activate evaluation mode


In [4]:
encode_func = bart.encode
decode_func = bart.decode

In [5]:
bart_encoder = bart.model.encoder
bart_decoder = bart.model.decoder
print(type(bart.model))
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.bart.model.BARTModel'>
<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


#### Read XSum

In [6]:
def read_lines(file_path):
    files = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            files.append(line.strip())
    return files

In [8]:
document_path = '/home/ml/cadencao/Two-Steps-Summarization/datasets/cnn_dm/fairseq_files/test.source'
summarization_path = '/home/ml/cadencao/Two-Steps-Summarization/datasets/cnn_dm/fairseq_files/test.target'
cnndm_source = read_lines(document_path)
cnndm_target = read_lines(summarization_path)
print(len(cnndm_source))
assert len(cnndm_source) == len(cnndm_target)

11490


#### Tokenization

In [9]:
import spacy
import rouge

from fairseq.data.data_utils import collate_tokens

In [10]:
def prepare_results(p, r, f):
    return '\t{}:\t{}: {:5.2f}\t{}: {:5.2f}\t{}: {:5.2f}'.format(metric, 'P', 100.0 * p, 'R', 100.0 * r, 'F1', 100.0 * f)

aggregator = 'Avg'
apply_avg = aggregator == 'Avg'
apply_best = aggregator == 'Best'
print('Evaluation with {}'.format(aggregator))

evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l', 'rouge-w'],
                        max_n=4,
                        limit_length=True,
                        length_limit=100,
                        length_limit_type='words',
                        apply_avg=apply_avg,
                        apply_best=apply_best,
                        alpha=0.5, # Default F1_score
                        weight_factor=1.2,
                        stemming=True)

Evaluation with Avg


In [11]:
nlp = spacy.load('en')

In [56]:
INDEX = 3205
test_inputs = [cnndm_source[INDEX]]
src_tokens = collate_tokens([encode_func(i) for i in test_inputs], pad_idx=1, left_pad=True)
src_tokens = src_tokens.cuda()
print(src_tokens.shape)

src_lengths = torch.sum(src_tokens != 1, dim=1)
print(src_lengths)

torch.Size([1, 836])
tensor([836], device='cuda:0')


In [57]:
# encoding

In [58]:
encoder_out = bart_encoder(src_tokens, src_lengths=src_lengths)
print(encoder_out.encoder_out.shape)  # [seq_len, batch_size, hidden_size]

torch.Size([836, 1, 1024])


In [59]:
# decoding

In [114]:
init_input = torch.tensor([[2, 0]] * src_tokens.shape[0], dtype=torch.long).cuda()
min_decode_step, max_decode_step, pad_id, eos_id = 55, 140, 1, 2
softmax = nn.Softmax(dim=1)

for step in range(max_decode_step):
    decoder_outputs = bart_decoder(init_input, encoder_out, features_only=False)
    logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
    if step + 1 < min_decode_step:
        logits[:, eos_id] = -math.inf
    probs = softmax(logits)
    
    assert logits.shape == probs.shape
    attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]
    assert logits.dim() == 2 and attn.dim() == 3
    logits[:, pad_id] = -math.inf  # never select pad
    probs[:, pad_id] = 0.0  # never select pad

    # preds = torch.argmax(logits, dim=1)
    value, indices = torch.topk(probs, 5, dim=1)
    selected_token = indices[:, 0]

    if step == 36:
        selected_token = indices[:, 0]
    elif step == 39:
        selected_token = indices[:, 0]
    elif step == 40:
        selected_token = indices[:, 0]

    init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
    print("- {:02d}: {} ({:.2f})".format(step, decode_func(selected_token), probs.squeeze()[selected_token.item()]), end='\n')

    if selected_token.item() == eos_id:
        break

print(attn[0].shape)

- 00: St (0.30)
- 01: uart (0.92)
- 02:  McC (0.91)
- 03: all (0.88)
- 04:  is (0.28)
- 05:  relaxed (0.39)
- 06:  about (0.71)
- 07:  Rangers (0.32)
- 08: ' (0.38)
- 09:  chances (0.42)
- 10:  of (0.71)
- 11:  promotion (0.45)
- 12: . (0.83)
- 13:  The (0.29)
- 14:  G (0.48)
- 15: ers (0.95)
- 16:  currently (0.29)
- 17:  lead (0.74)
- 18:  the (0.81)
- 19:  way (0.28)
- 20:  in (0.88)
- 21:  the (0.86)
- 22:  Scottish (0.39)
- 23:  Championship (0.94)
- 24: . (0.82)
- 25:  McC (0.23)
- 26: all (0.90)
- 27: 's (0.34)
- 28:  side (0.67)
- 29:  face (0.20)
- 30:  Dumb (0.70)
- 31: arton (0.94)
- 32:  on (0.62)
- 33:  Saturday (0.89)
- 34:  in (0.44)
- 35:  the (0.33)
- 36:  Scottish (0.21)
- 37:  Cup (0.41)
- 38: . (0.63)
- 39:  Rangers (0.16)
- 40: ' (0.14)
- 41:  final (0.68)
- 42: - (0.69)
- 43: day (0.94)
- 44:  tie (0.69)
- 45:  with (0.80)
- 46:  Hearts (0.95)
- 47:  was (0.50)
- 48:  pushed (0.57)
- 49:  back (0.90)
- 50:  to (0.64)
- 51:  Sunday (0.27)
- 52: , (0.54)
- 53:  May 

In [115]:
init_input

tensor([[    2,     0,  5320, 41962,  3409,  1250,    16, 11956,    59,  5706,
           108,  3255,     9,  6174,     4,    20,   272,   268,   855,   483,
             5,   169,    11,     5,  5411,  3261,     4,  3409,  1250,    18,
           526,   652, 37098, 24735,    15,   378,    11,     5,  5411,   968,
             4,  5706,   108,   507,    12,  1208,  3318,    19, 18784,    21,
          3148,   124,     7,   395,     6,   392,   155,     4,     2]],
       device='cuda:0')

In [116]:
hypothesis = decode_func(init_input[0])
print(hypothesis)

Stuart McCall is relaxed about Rangers' chances of promotion. The Gers currently lead the way in the Scottish Championship. McCall's side face Dumbarton on Saturday in the Scottish Cup. Rangers' final-day tie with Hearts was pushed back to Sunday, May 3.


In [117]:
cnndm_target[INDEX]

'Rangers are currently second in the Championship with three games to go . Finishing third would mean playing two extra play-off matches vs fourth . But manager Stuart McCall\xa0is\xa0relaxed about the prospect of finishing third .'

In [113]:
# hypothesis = "Stuart McCall is relaxed about the prospect of Rangers finishing third in the Scottish Championship. The Gers currently lead the way in the race for second spot in the Championship, with a one-point advantage over Hibernian and a six-point cushion separating them from Queen of the South. Finishing immediately behind newly-crowned champions Hearts would spare the Ibrox men the bother of two extra hazardous games in the play-offs."
all_hypothesis = [hypothesis]
all_references = [cnndm_target[INDEX]]

scores = evaluator.get_scores(all_hypothesis, all_references)

for metric, results in sorted(scores.items(), key=lambda x: x[0]):
    print(prepare_results(results['p'], results['r'], results['f']))

	rouge-1:	P: 33.33	R: 66.67	F1: 44.44
	rouge-2:	P: 18.31	R: 37.14	F1: 24.53
	rouge-3:	P: 11.43	R: 23.53	F1: 15.38
	rouge-4:	P:  8.70	R: 18.18	F1: 11.76
	rouge-l:	P: 24.02	R: 42.79	F1: 30.77
	rouge-w:	P: 14.63	R: 14.29	F1: 14.45


In [84]:
# 	rouge-1:	P: 32.56	R: 38.89	F1: 35.44
# 	rouge-2:	P: 11.90	R: 14.29	F1: 12.99
# 	rouge-3:	P:  7.32	R:  8.82	F1:  8.00
# 	rouge-4:	P:  5.00	R:  6.06	F1:  5.48
# 	rouge-l:	P: 22.03	R: 25.55	F1: 23.66
# 	rouge-w:	P: 13.95	R:  8.14	F1: 10.28