In [1]:
import math
import torch
import torch.nn as nn
from fairseq.models.bart import BARTModel

In [2]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/cnndm_rate07',
                                 checkpoint_file='checkpoint_best.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large.cnn')

In [3]:
bart.cuda()
bart.eval()
bart.half()
print('- activate evaluation mode')

- activate evaluation mode


In [4]:
encode_func = bart.encode
decode_func = bart.decode

In [5]:
bart_encoder = bart.model.encoder
bart_decoder = bart.model.decoder
print(type(bart.model))
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.bart.model.BARTModel'>
<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


#### Read CNN/DailyMail

In [6]:
def read_lines(file_path):
    files = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            files.append(line.strip())
    return files

In [7]:
document_path = '/home/ml/cadencao/Two-Steps-Summarization/datasets/cnn_dm/corrupted_nohup_rate07/test.source'
summarization_path = '/home/ml/cadencao/Two-Steps-Summarization/datasets/cnn_dm/corrupted_nohup_rate07/test.target'
# prediction_path = 'preds/cnndm_train_best_beam.hypo'
cnndm_source = read_lines(document_path)
cnndm_target = read_lines(summarization_path)
# cnndm_preds = read_lines(prediction_path)
print(len(cnndm_source))
assert len(cnndm_source) == len(cnndm_target)

11490


#### Tokenization

In [8]:
import spacy
import rouge

from fairseq.data.data_utils import collate_tokens

In [9]:
def prepare_results(p, r, f):
    return '\t{}:\t{}: {:5.2f}\t{}: {:5.2f}\t{}: {:5.2f}'.format(metric, 'P', 100.0 * p, 'R', 100.0 * r, 'F1', 100.0 * f)

aggregator = 'Avg'
apply_avg = aggregator == 'Avg'
apply_best = aggregator == 'Best'
print('Evaluation with {}'.format(aggregator))

evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l', 'rouge-w'],
                        max_n=4,
                        limit_length=True,
                        length_limit=100,
                        length_limit_type='words',
                        apply_avg=apply_avg,
                        apply_best=apply_best,
                        alpha=0.5, # Default F1_score
                        weight_factor=1.2,
                        stemming=True)

Evaluation with Avg


In [10]:
nlp = spacy.load('en')

In [11]:
INDEX = 1234
test_inputs = [cnndm_source[INDEX]]
src_tokens = collate_tokens([encode_func(i) for i in test_inputs], pad_idx=1, left_pad=True)
src_tokens = src_tokens.cuda()
print(src_tokens.shape)

src_lengths = torch.sum(src_tokens != 1, dim=1)
print(src_lengths)

torch.Size([1, 413])
tensor([413], device='cuda:0')


In [12]:
# encoding

In [13]:
encoder_out = bart_encoder(src_tokens, src_lengths=src_lengths)
print(encoder_out.encoder_out.shape)  # [seq_len, batch_size, hidden_size]

torch.Size([413, 1, 1024])


In [14]:
# decoding

In [15]:
init_input = torch.tensor([[2, 0]] * src_tokens.shape[0], dtype=torch.long).cuda()
min_decode_step, max_decode_step, pad_id, eos_id = 55, 140, 1, 2
softmax = nn.Softmax(dim=1)
token_probs = []

for step in range(max_decode_step):
    decoder_outputs = bart_decoder(init_input, encoder_out, features_only=False)
    logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
    if step + 1 < min_decode_step:
        logits[:, eos_id] = -math.inf

    logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select pad, start token
    probs = softmax(logits)
    assert logits.shape == probs.shape
    attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]
    assert logits.dim() == 2 and attn.dim() == 3

    # preds = torch.argmax(logits, dim=1)
    value, indices = torch.topk(probs, 5, dim=1)
    selected_token = indices[:, 0]

    if step == 44:
        selected_token = indices[:, 0]
    elif step == 70:
        selected_token = indices[:, 0]
    elif step == 20:
        selected_token = indices[:, 0]

    init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
    print("- {:02d}: {} ({:.2f})".format(step, decode_func(selected_token), probs.squeeze()[selected_token.item()].item()), end='\n')
    token_probs.append(probs.squeeze()[selected_token.item()].item())
    
    if selected_token.item() == eos_id:
        break

token_probs_sorted = [i for i in token_probs]
token_probs_sorted.sort()
print('{}; {}; {}.'.format(token_probs.index(token_probs_sorted[0]),
                           token_probs.index(token_probs_sorted[1]),
                           token_probs.index(token_probs_sorted[2])))

- 00: Three (0.17)
- 01:  adults (0.92)
- 02:  taken (0.92)
- 03:  to (0.90)
- 04:  hospital (0.92)
- 05:  after (0.91)
- 06:  head (0.92)
- 07:  on (0.91)
- 08:  collision (0.93)
- 09:  at (0.91)
- 10:  French (0.92)
- 11: s (0.91)
- 12:  Forest (0.92)
- 13:  . (0.90)
- 14:  The (0.91)
- 15:  crash (0.92)
- 16:  was (0.91)
- 17:  between (0.91)
- 18:  a (0.91)
- 19:  Toyota (0.93)
- 20:  Kl (0.94)
- 21: ug (0.94)
- 22: ger (0.95)
- 23:  and (0.90)
- 24:  E (0.93)
- 25: - (0.90)
- 26: Type (0.92)
- 27:  Jaguar (0.93)
- 28:  . (0.90)
- 29:  Nobody (0.91)
- 30:  was (0.90)
- 31:  injured (0.94)
- 32:  in (0.91)
- 33:  the (0.90)
- 34:  crash (0.91)
- 35:  and (0.90)
- 36:  it (0.91)
- 37:  is (0.91)
- 38:  being (0.91)
- 39:  investigated (0.94)
- 40:  . (0.90)
- 41:  It (0.91)
- 42:  comes (0.92)
- 43:  after (0.91)
- 44:  police (0.92)
- 45:  warned (0.91)
- 46:  motorists (0.93)
- 47:  to (0.90)
- 48:  obey (0.91)
- 49:  road (0.93)
- 50:  rules (0.92)
- 51:  over (0.92)
- 52:  Easter

In [16]:
init_input

tensor([[    2,     0, 15622,  3362,   551,     7,  1098,    71,   471,    15,
          7329,    23,  1515,    29,  5761,   479,    20,  2058,    21,   227,
            10,  7261,  7507,  3252,  2403,     8,   381,    12, 40118, 22264,
           479, 13308,    21,  1710,    11,     5,  2058,     8,    24,    16,
           145,  6807,   479,    85,   606,    71,   249,  2449, 12100,     7,
         28616,   921,  1492,    81,  9274,   479,     2]], device='cuda:0')

In [17]:
hypothesis = decode_func(init_input[0])
print(hypothesis)

Three adults taken to hospital after head on collision at Frenchs Forest . The crash was between a Toyota Klugger and E-Type Jaguar . Nobody was injured in the crash and it is being investigated . It comes after police warned motorists to obey road rules over Easter .


In [18]:
print(cnndm_target[INDEX])

Three adults taken to hospital after head on collision at Frenchs Forest . The crash was between a Toyota Klugger and E-Type Jaguar . Nobody was injured in the crash and it is being investigated . It comes after police warned motorists to obey road rules over Easter .


In [19]:
# hypothesis = 'A French prosecutor says he is not aware of any video footage from on board the plane. German daily Bild and Paris Match claim to have found a cell phone video of the crash. A French Gendarmerie spokesman calls the reports "completely wrong" and "unwarranted" German airline Lufthansa says co-pilot Andreas Lubitz battled depression years before he took the controls.'
all_hypothesis = [hypothesis]
all_predictions = [cnndm_preds[INDEX]]
all_references = [cnndm_target[INDEX]]

# print('- original prediction:')
# scores = evaluator.get_scores(all_predictions, all_references)
# for metric, results in sorted(scores.items(), key=lambda x: x[0]):
#     print(prepare_results(results['p'], results['r'], results['f']))

print('- new prediction:')
scores = evaluator.get_scores(all_hypothesis, all_references)
for metric, results in sorted(scores.items(), key=lambda x: x[0]):
    print(prepare_results(results['p'], results['r'], results['f']))

NameError: name 'cnndm_preds' is not defined