In [1]:
import math
import torch
import torch.nn as nn
from fairseq.models.bart import BARTModel

In [2]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large.cnn',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large.cnn')

In [3]:
bart.cuda()
bart.eval()
bart.half()
print('- activate evaluation mode')

- activate evaluation mode


In [4]:
encode_func = bart.encode
decode_func = bart.decode

In [5]:
bart_encoder = bart.model.encoder
bart_decoder = bart.model.decoder
print(type(bart.model))
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.bart.model.BARTModel'>
<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


#### Read CNN/DailyMail

In [6]:
def read_lines(file_path):
    files = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            files.append(line.strip())
    return files

In [7]:
document_path = '/home/ml/cadencao/cnn-dailymail/cnn_dm/train.source'
summarization_path = '/home/ml/cadencao/cnn-dailymail/cnn_dm/train.target'
prediction_path = 'preds/cnndm_train_best_beam.hypo'
cnndm_source = read_lines(document_path)
cnndm_target = read_lines(summarization_path)
cnndm_preds = read_lines(prediction_path)
print(len(cnndm_source))
assert len(cnndm_source) == len(cnndm_target) == len(cnndm_preds)

287227


#### Tokenization

In [8]:
import spacy
import rouge

from fairseq.data.data_utils import collate_tokens

In [9]:
def prepare_results(p, r, f):
    return '\t{}:\t{}: {:5.2f}\t{}: {:5.2f}\t{}: {:5.2f}'.format(metric, 'P', 100.0 * p, 'R', 100.0 * r, 'F1', 100.0 * f)

aggregator = 'Avg'
apply_avg = aggregator == 'Avg'
apply_best = aggregator == 'Best'
print('Evaluation with {}'.format(aggregator))

evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l', 'rouge-w'],
                        max_n=4,
                        limit_length=True,
                        length_limit=100,
                        length_limit_type='words',
                        apply_avg=apply_avg,
                        apply_best=apply_best,
                        alpha=0.5, # Default F1_score
                        weight_factor=1.2,
                        stemming=True)

Evaluation with Avg


In [10]:
nlp = spacy.load('en')

In [11]:
INDEX = 2210
test_inputs = [cnndm_source[INDEX]]
src_tokens = collate_tokens([encode_func(i) for i in test_inputs], pad_idx=1, left_pad=True)
src_tokens = src_tokens.cuda()
print(src_tokens.shape)

src_lengths = torch.sum(src_tokens != 1, dim=1)
print(src_lengths)

torch.Size([1, 450])
tensor([450], device='cuda:0')


In [12]:
# encoding

In [13]:
encoder_out = bart_encoder(src_tokens, src_lengths=src_lengths)
print(encoder_out.encoder_out.shape)  # [seq_len, batch_size, hidden_size]

torch.Size([450, 1, 1024])


In [14]:
# decoding

In [15]:
init_input = torch.tensor([[2, 0]] * src_tokens.shape[0], dtype=torch.long).cuda()
min_decode_step, max_decode_step, pad_id, eos_id = 55, 140, 1, 2
softmax = nn.Softmax(dim=1)
token_probs = []

for step in range(max_decode_step):
    decoder_outputs = bart_decoder(init_input, encoder_out, features_only=False)
    logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
    if step + 1 < min_decode_step:
        logits[:, eos_id] = -math.inf

    logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select pad, start token
    probs = softmax(logits)
    assert logits.shape == probs.shape
    attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]
    assert logits.dim() == 2 and attn.dim() == 3

    # preds = torch.argmax(logits, dim=1)
    value, indices = torch.topk(probs, 5, dim=1)
    selected_token = indices[:, 0]

    if step == 44:
        selected_token = indices[:, 0]
    elif step == 70:
        selected_token = indices[:, 0]
    elif step == 20:
        selected_token = indices[:, 0]

    init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
    print("- {:02d}: {} ({:.2f})".format(step, decode_func(selected_token), probs.squeeze()[selected_token.item()].item()), end='\n')
    token_probs.append(probs.squeeze()[selected_token.item()].item())
    
    if selected_token.item() == eos_id:
        break

token_probs_sorted = [i for i in token_probs]
token_probs_sorted.sort()
print('{}; {}; {}.'.format(token_probs.index(token_probs_sorted[0]),
                           token_probs.index(token_probs_sorted[1]),
                           token_probs.index(token_probs_sorted[2])))

- 00: U (0.23)
- 01: . (0.90)
- 02: N (0.81)
- 03: . (0.86)
- 04:  Security (0.61)
- 05:  Council (0.92)
- 06:  resolution (0.20)
- 07:  expresses (0.28)
- 08:  intent (0.73)
- 09:  to (0.90)
- 10:  send (0.77)
- 11:  U (0.42)
- 12: . (0.91)
- 13: N (0.90)
- 14: . (0.90)
- 15:  peace (0.67)
- 16: keeping (0.79)
- 17:  forces (0.75)
- 18:  back (0.72)
- 19: . (0.73)
- 20:  Resolution (0.43)
- 21:  sponsored (0.20)
- 22:  by (0.90)
- 23:  U (0.52)
- 24: . (0.90)
- 25: S (0.92)
- 26: ., (0.63)
- 27:  in (0.38)
- 28:  one (0.81)
- 29:  of (0.90)
- 30:  final (0.64)
- 31:  Bush (0.94)
- 32:  Administration (0.64)
- 33:  initiatives (0.92)
- 34:  at (0.45)
- 35:  U (0.74)
- 36: . (0.90)
- 37: N (0.85)
- 38: . (0.90)
- 39:  Follow (0.21)
- 40: s (0.90)
- 41:  exit (0.88)
- 42:  of (0.90)
- 43:  U (0.47)
- 44: . (0.90)
- 45: N (0.91)
- 46: .- (0.90)
- 47: backed (0.94)
- 48: , (0.79)
- 49:  Ethiopian (0.93)
- 50:  peace (0.95)
- 51: keeping (0.91)
- 52:  force (0.95)
- 53:  that (0.59)
- 54:  

In [16]:
init_input

tensor([[    2,     0,   791,     4,   487,     4,  2010,  1080,  3547, 15994,
          5927,     7,  2142,   121,     4,   487,     4,  1987, 12609,  1572,
           124,     4, 22565,  7966,    30,   121,     4,   104,   482,    11,
            65,     9,   507,  3516,  4237,  5287,    23,   121,     4,   487,
             4,  2184,    29,  4205,     9,   121,     4,   487,  3358,  6996,
             6, 24128,  1987, 12609,  1370,    14,  2121,    80,    12,   180,
          9737,     4,     2]], device='cuda:0')

In [17]:
hypothesis = decode_func(init_input[0])
print(hypothesis)

U.N. Security Council resolution expresses intent to send U.N. peacekeeping forces back. Resolution sponsored by U.S., in one of final Bush Administration initiatives at U.N. Follows exit of U.N.-backed, Ethiopian peacekeeping force that completed two-year deployment.


In [18]:
cnndm_preds[INDEX]

'U.N. Security Council resolution expresses intent to send peacekeeping forces back to Somalia. Resolution was sponsored by the U.S., in one of final Bush Administration initiatives. Resolution follows exit of U.N.-backed, Ethiopian peacekeeping force that completed two-year deployment in Somalia.'

In [19]:
cnndm_target[INDEX]

'Resolution expresses intent to send U.N. peacekeeping forces to Somalia . The resolution in war-torn country was sponsored by the United States . Ethiopian peacekeeping force completed two-year deployment in Somalia . Regional leaders fear vacuum will be filled by Islamic extremist groups .'

In [20]:
# hypothesis = 'A French prosecutor says he is not aware of any video footage from on board the plane. German daily Bild and Paris Match claim to have found a cell phone video of the crash. A French Gendarmerie spokesman calls the reports "completely wrong" and "unwarranted" German airline Lufthansa says co-pilot Andreas Lubitz battled depression years before he took the controls.'
all_hypothesis = [hypothesis]
all_predictions = [cnndm_preds[INDEX]]
all_references = [cnndm_target[INDEX]]

print('- original prediction:')
scores = evaluator.get_scores(all_predictions, all_references)
for metric, results in sorted(scores.items(), key=lambda x: x[0]):
    print(prepare_results(results['p'], results['r'], results['f']))

print('- new prediction:')
scores = evaluator.get_scores(all_hypothesis, all_references)
for metric, results in sorted(scores.items(), key=lambda x: x[0]):
    print(prepare_results(results['p'], results['r'], results['f']))

- original prediction:
	rouge-1:	P: 57.78	R: 60.47	F1: 59.09
	rouge-2:	P: 38.64	R: 40.48	F1: 39.53
	rouge-3:	P: 23.26	R: 24.39	F1: 23.81
	rouge-4:	P: 14.29	R: 15.00	F1: 14.63
	rouge-l:	P: 57.16	R: 59.37	F1: 58.24
	rouge-w:	P: 40.35	R: 19.90	F1: 26.65
- new prediction:
	rouge-1:	P: 46.51	R: 46.51	F1: 46.51
	rouge-2:	P: 33.33	R: 33.33	F1: 33.33
	rouge-3:	P: 24.39	R: 24.39	F1: 24.39
	rouge-4:	P: 17.50	R: 17.50	F1: 17.50
	rouge-l:	P: 50.63	R: 50.63	F1: 50.63
	rouge-w:	P: 36.92	R: 17.40	F1: 23.65
