In [1]:
import torch
import torch.nn as nn

from fairseq.models.bart import BARTModel

In [2]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large')

In [3]:
bart.cuda()
bart.eval()
print('- Activate evaluation mode')

- Activate evaluation mode


In [4]:
encode_func = bart.encode
decode_func = bart.decode

In [5]:
bart_encoder = bart.model.encoder
bart_decoder = bart.model.decoder

#### Read XSum

In [6]:
from utils import read_lines

In [None]:
document_path = '/home/ml/cadencao/XSum/fairseq_files/train.source'
target_path = '/home/ml/cadencao/XSum/fairseq_files/train.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

#### Read Loss Data

In [None]:
from os import listdir
from os.path import isfile, join

In [None]:
ID = 4
folder = 'cedar_losses/'

files = [f for f in listdir(folder) if isfile(join(folder, f)) and f[:len(str(ID))] == str(ID) and f[len(str(ID))] == '.']
files.sort()

In [None]:
files

#### Extract Losses

In [None]:
import numpy as np

In [None]:
target = None
losses = []

for f in files:
    s = torch.load(join(folder, f), map_location='cpu')
    if target is None:
        target = s['sample']['target'][0]
    else:
        assert torch.all(target.eq(s['sample']['target'][0])).item()
    token_loss = s['token_loss'].view(s['sample']['target'].shape)[0]
    assert token_loss.shape == target.shape
    losses.append(token_loss.numpy())

losses = np.array(losses).T

In [None]:
print(losses.shape)
print(target)

#### Analysis

In [None]:
from utils import tokenize, decode_sequence, get_probability

In [None]:
print(decode_func(target))

In [None]:
for i, t in enumerate(target[:-1]):
    # print('{} ({})'.format(decode_func(t.unsqueeze(dim=0)), i), end=' ')
    print('- {} {}'.format(i, decode_func(t.unsqueeze(dim=0))))

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

In [None]:
display_index = [0, 1, 2, 3, 18, 20, 21, 24, 25, 26]
display_index = [0, 1, 2, 20, 21, 22, 23, 24, 25, 26, 27]
display_index = None

x = np.arange(0, losses.shape[1], 1)

fig, axs = plt.subplots(figsize=(20.0, 10.0))

count = 0
for i, t in enumerate(target):
    if t.item() == 1: break
    word = decode_func(t.unsqueeze(dim=0))
    if display_index is None or i in display_index:
        axs.plot(x, losses[i], label=word)
        
        text_index = np.argmax(losses[i])
        axs.text(text_index, losses[i][text_index] + 0.05, word,
                 horizontalalignment='center')
        count += 1

axs.set_xlabel('Epoch')
axs.set_ylabel('Training loss')
axs.grid(True)

# plt.legend()
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

fig.tight_layout()
plt.savefig('foo.png')
plt.show()

In [None]:
xsum_target[ID]

In [None]:
def tokenize_with_mask(input_sentence):
    bpe_code = bart.bpe.encode(input_sentence)  # <mask>: 1279 27932 29
    input_ids = bart.task.source_dictionary.encode_line('<s> ' + bpe_code.replace('1279 27932 29', '<mask>'), 
                                                        append_eos=True).long()
    input_ids = input_ids.unsqueeze(0).cuda()
    src_lengths = torch.sum(input_ids != 1, dim=1)
    return input_ids, src_lengths

In [None]:
word_pieces = [decode_func(t.unsqueeze(dim=0)) for t in target[:-1]]
print(word_pieces)

In [None]:
import math

with torch.no_grad():
    for wp in word_pieces:
        masked_target = xsum_target[ID].replace(wp, '<mask>', 1)

        masked_input, masked_lengths = tokenize_with_mask(masked_target)
        masked_outputs = decode_sequence(decode_func,
                                         bart_decoder,
                                         bart_encoder(masked_input, src_lengths=masked_lengths),
                                         tgt_tokens=target.cuda(),
                                         verbose=False)
        masked_output_ids, masked_tokens, masked_token_probs, token_logits = masked_outputs
        assert decode_func(masked_output_ids[0]) == xsum_target[ID]
        
        loss = -math.log(masked_token_probs[masked_tokens.index(wp)])
        print('- {}: {}'.format(wp, loss))

In [None]:
decode_func(torch.tensor([41552, 43776, 15698]))

In [None]:
target

In [None]:
# tokenize
src_tokens, src_lengths = tokenize(xsum_source[ID], encode_func)

In [None]:
target_ = torch.cat([torch.tensor([0]), target], dim=0)

In [None]:
torch.cat([target_[:1], torch.tensor([41552, 43776, 15698])], dim=0)

In [None]:
encoder_out = bart_encoder(torch.cat([target_[:1], torch.tensor([41552, 43776, 15698])], dim=0).unsqueeze(0).cuda(),
                           src_lengths=torch.tensor([4]).cuda())

In [None]:
torch.cat([target_[:1], torch.tensor([41552, 43776, 15698])], dim=0).cuda()

In [None]:
decoder_outputs = bart_decoder(torch.tensor([[2, 0]], dtype=torch.long).cuda(), encoder_out, features_only=False)

In [None]:
logits = decoder_outputs[0][:, -1, :]

In [None]:
logits.shape

In [None]:
target

In [None]:
softmax = nn.Softmax(dim=1)

In [None]:
with torch.no_grad():
    for i, t in enumerate(target[:-1]):
        start_token, end_token = torch.tensor([0]), torch.tensor([2])
        mask_token = torch.tensor([41552, 43776, 15698])

        src_tokens = torch.cat([start_token, target[: i], mask_token, end_token], dim=0).unsqueeze(dim=0).cuda()
        src_lengths = torch.tensor([src_tokens.shape[1]]).cuda()
        prev_output_tokens = torch.cat([torch.tensor([2, 0]), target[: i]], dim=0).unsqueeze(dim=0).cuda()

        encoder_out = bart_encoder(src_tokens, src_lengths=src_lengths)
        decoder_outputs = bart_decoder(prev_output_tokens, encoder_out, features_only=False)
        logits = decoder_outputs[0][:, -1, :]  # [batch_size, vocab]
        
        probs = softmax(logits)
        token, prob = decode_func(t.unsqueeze(dim=0)), probs.squeeze()[t.item()].item()
        print('- {}: {}'.format(token, -math.log(prob)))

In [None]:
torch.tensor([    0, 41552, 43776, 15698,     2]).shape[0]