In [11]:
import os

import torch
import torch.nn as nn

#### Read XSum Dataset

In [12]:
def read_lines(file_path):
    files = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            files.append(line.strip())
    return files

In [13]:
document_path = '/home/ml/cadencao/XSum/fairseq_files/val.source'
target_path = '/home/ml/cadencao/XSum/fairseq_files/val.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11307


In [14]:
document_bpe_path = '/home/ml/cadencao/XSum/fairseq_files/val.bpe.source'
target_bpe_path = '/home/ml/cadencao/XSum/fairseq_files/val.bpe.target'
xsum_bpe_source = read_lines(document_bpe_path)
xsum_bpe_target = read_lines(target_bpe_path)
print(len(xsum_bpe_source))
assert len(xsum_bpe_source) == len(xsum_bpe_target) == len(xsum_target)

11307


#### Load BART

In [15]:
from fairseq.models.bart import BARTModel

In [16]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large.xsum',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large.xsum')

In [17]:
bart.cuda()
bart.eval()
bart.half()
print('- activate evaluation mode')

- activate evaluation mode


In [18]:
encode_func = bart.encode
decode_func = bart.decode

In [19]:
bart.task

<fairseq.tasks.translation.TranslationTask at 0x7effadefa8d0>

In [20]:
bart.task.src_dict

<fairseq.data.dictionary.Dictionary at 0x7effadefe090>

#### Build Input to BART

In [10]:
from fairseq.tasks.translation import TranslationTask

In [11]:
decode_func(torch.LongTensor([ 3750, 12533,  3622,  7551,   871,   263,  4177,   102,    34,    26,
           37,    74,    45,   619,     5,  1164,     9,    10,   239,  2937,
         4029,   114,    37,    58,     7,  1962,  2361,   315,     4,     2]))

'Atletico Madrid goalkeeper David de Gea has said he would not feel the pressure of a high transfer fee if he were to join Manchester United.'

In [12]:
xsum_target[33]

'Striker Robert Vittek is the headline absentee from the provisional 27-man squad Slovakia coach Jan Kozak has named for Euro 2016.'

In [13]:
encode_func('Atletico Madrid goalkeeper David de Gea has said he would not feel the pressure of a high transfer fee if he were to join Manchester United.')

tensor([    0,  3750, 12533,  3622,  7551,   871,   263,  4177,   102,    34,
           26,    37,    74,    45,   619,     5,  1164,     9,    10,   239,
         2937,  4029,   114,    37,    58,     7,  1962,  2361,   315,     4,
            2])

In [14]:
path = '/home/ml/cadencao/XSum/test_files/xsum-bin'

In [15]:
src_dict = TranslationTask.load_dictionary(os.path.join(path, 'dict.{}.txt'.format('source')))
tgt_dict = TranslationTask.load_dictionary(os.path.join(path, 'dict.{}.txt'.format('target')))
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
print('[{}] dictionary: {} types'.format('source', len(src_dict)))
print('[{}] dictionary: {} types'.format('target', len(tgt_dict)))

[source] dictionary: 50264 types
[target] dictionary: 50264 types


In [16]:
src_dict

<fairseq.data.dictionary.Dictionary at 0x7f8b181aad10>

In [17]:
def bpe_to_ids(fairseq_dict, sentence, addl_sentence=None, max_positions=1024, no_bos=True):
    """Convert bpe ids to model input ids.

    Args:
        fairseq_dict (fairseq.data.dictionary.Dictionary): fairseq dictionary.
        sentence (str): bpe encoded sentence.
        addl_sentence (str): bpe encoded sentence.
        max_positions (int): max sentence length.
        no_bos (bool): whether append bos token.

    """
    extra_tokens = 2
    bos, eos = '<s> ', ' </s>'

    if no_bos:
        bos = ''
        extra_tokens = 1

    if addl_sentence:
        tokens = sentence + eos + ' ' + addl_sentence
    else:
        tokens = sentence

    if len(tokens.split(' ')) > max_positions - extra_tokens:
        tokens = ' '.join(tokens.split(' ')[:max_positions - extra_tokens])
    bpe_sentence = bos + tokens + eos

    tokens = fairseq_dict.encode_line(bpe_sentence, append_eos=False)
    return tokens.long()

In [18]:
xsum_bpe_target[0]

'12510 1966 3701 4881 4409 423 587 1043 6717 286 34759 262 20858 422 262 12983 286 734 12353 19105 257 3249 546 1693 6630 13'

In [19]:
ids = bpe_to_ids(src_dict, xsum_bpe_target[0])
print(ids)

tensor([15622,   320,  1754,  1470,  1321,    33,    57,   303,  2181,     9,
        28304,     5, 15331,    31,     5,  7314,     9,    80,  4585,  9886,
           10,   529,    59,   633,  2599,     4,     2])


In [20]:
decode_func(ids)

'Three former Air France employees have been found guilty of ripping the shirts from the backs of two executives fleeing a meeting about job cuts.'

In [21]:
from fairseq.data.language_pair_dataset import collate

In [22]:
samples = []
for i in range(3):
    samples.append({
        'id': i,
        'source': bpe_to_ids(src_dict, xsum_bpe_source[i]),
        'target': bpe_to_ids(src_dict, xsum_bpe_target[i])
    })

In [23]:
pad_idx, eos_idx = 1, 2
batch = collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, input_feeding=True)

In [24]:
batch

{'id': tensor([2, 0, 1]),
 'nsentences': 3,
 'ntokens': 65,
 'net_input': {'src_tokens': tensor([[ 100,   33,   57,  ...,  502,    4,    2],
          [   1,    1,    1,  ..., 1427,    4,    2],
          [   1,    1,    1,  ...,  443,    4,    2]]),
  'src_lengths': tensor([668, 376, 165]),
  'prev_output_tokens': tensor([[    2, 29042,  1252,    11,  5295,    28,   357,   160,    11,    50,
              66,     9,     5,   796,  1332,   116,     1,     1,     1,     1,
               1,     1,     1,     1,     1,     1,     1],
          [    2, 15622,   320,  1754,  1470,  1321,    33,    57,   303,  2181,
               9, 28304,     5, 15331,    31,     5,  7314,     9,    80,  4585,
            9886,    10,   529,    59,   633,  2599,     4],
          [    2,  9497,    33,   703, 14363,  3156,    25,   233,     9,  4941,
              88,    41,  2080,   751,    41,  1586,  7450,  2681,  2003,    94,
              76,     4,     1,     1,     1,     1,     1]])},
 'target': te

In [25]:
import random

In [26]:
class DataLoader(object):
    def __init__(self, fairseq_dict, source_path, target_path, max_positions=1024, no_bos=True, pad_idx=1, eos_idx=2):
        """
        Args:
            fairseq_dict (fairseq.data.dictionary.Dictionary): fairseq dictionary.
            source_path (str): path to bpe encoded source.
            target_path (str): path to bpe encoded target.
            max_positions (int): max sentence length.
            no_bos (bool): whether append bos token.

        """
        self.fairseq_dict = fairseq_dict
        self.source_path = source_path
        self.target_path = target_path

        self.max_positions = max_positions
        self.no_bos = no_bos
        self.pad_idx = pad_idx
        self.eos_idx = eos_idx

        source = DataLoader.read_lines(source_path)
        target = DataLoader.read_lines(target_path)
        assert len(source) == len(target), "Source and target size do NOT match!"

        self.data = self.build_sample(source, target)
        del source, target

    def build_sample(self, source, target):
        data = []
        for i, (s, t) in enumerate(zip(source, target)):
            data.append({
                'id': i,
                'source': self.bpe_to_ids(s),
                'target': self.bpe_to_ids(t)
            })
        return data

    @staticmethod
    def read_lines(file_path):
        files = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                files.append(line.strip())
        return files

    def batch_iter(self, batch_size):
        """Create a batch of data.
        """
        samples = []
        for d in self.data:
            if len(samples) == batch_size:
                yield collate(samples, self.pad_idx, self.eos_idx, 
                              left_pad_source=True,
                              left_pad_target=False,
                              input_feeding=True)
                samples = []

            samples += [d]

        if len(samples) != 0:
            yield collate(samples, self.pad_idx, self.eos_idx,
                          left_pad_source=True,
                          left_pad_target=False,
                          input_feeding=True)

    def bpe_to_ids(self, sentence, addl_sentence=None):
        """Convert bpe ids to model input ids.

        Args:
            sentence (str): bpe encoded sentence.
            addl_sentence (str): bpe encoded sentence.

        """
        extra_tokens = 2
        bos, eos = '<s> ', ' </s>'

        if self.no_bos:
            bos = ''
            extra_tokens = 1

        if addl_sentence:
            tokens = sentence + eos + ' ' + addl_sentence
        else:
            tokens = sentence

        if len(tokens.split(' ')) > self.max_positions - extra_tokens:
            tokens = ' '.join(tokens.split(' ')[:self.max_positions - extra_tokens])
        bpe_sentence = bos + tokens + eos

        tokens = self.fairseq_dict.encode_line(bpe_sentence, append_eos=False)
        return tokens.long()

    def shuffle(self):
        random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return iter(self.data)

In [27]:
val = DataLoader(src_dict, document_bpe_path, target_bpe_path)

In [28]:
val.shuffle()

In [29]:
for b in val.batch_iter(24):
    pass

#### Loss

In [30]:
test_batch = next(val.batch_iter(4))

In [31]:
test_batch

{'id': tensor([11131,  2688,   793,  2162]),
 'nsentences': 4,
 'ntokens': 114,
 'net_input': {'src_tokens': tensor([[ 1708,   172,  2206,  ...,    98,  3571,     2],
          [18801, 20083,    16,  ...,   274,  1630,     2],
          [    1,     1,     1,  ...,   599,     4,     2],
          [    1,     1,     1,  ...,  5751,     4,     2]]),
  'src_lengths': tensor([1024, 1024,  468,  191]),
  'prev_output_tokens': tensor([[    2,   771,  4575,    16,    45,     5,    78,   631,    14,   606,
               7,  1508,    77,  1686,    59,     5,    70,    12,  4310,     6,
            3228,    12,  4416,  9556,  2422,  2690,     9,     5,  2762,     4],
          [    2, 31230,   241,  1554,   257,  9688,   718, 18879,  2886,     7,
               5,  1156,   526,    25,  6375,   471,   704, 10125, 15038,   817,
             237,  1022,    13,     5,  5310,  3076,   177,   136,  5295,     4],
          [    2, 29111,   415,  1728, 10394,   922,  4324,  1008,    70,     9,
         

In [32]:
from fairseq.utils import move_to_cuda

In [33]:
test_batch = move_to_cuda(test_batch)

In [34]:
net_output = bart.model(**test_batch['net_input'])

In [35]:
print(net_output[0].shape) # [batch_size, target_len, vocab_size]

torch.Size([4, 30, 50265])


In [36]:
net_output[1].keys()

dict_keys(['attn', 'inner_states'])

In [37]:
lprobs = bart.model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))

In [38]:
lprobs.shape

torch.Size([120, 50265])

In [39]:
target = test_batch["target"].view(-1, 1)

In [40]:
target.shape

torch.Size([120, 1])

In [41]:
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.)
        smooth_loss.masked_fill_(pad_mask, 0.)
    else:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)
    if reduce:
        nll_loss = nll_loss.sum()
        smooth_loss = smooth_loss.sum()
    eps_i = epsilon / lprobs.size(-1)
    loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss
    return loss, nll_loss

In [42]:
padding_idx = 1
eps = 0.1
loss, nll_loss = label_smoothed_nll_loss(
    lprobs, target, eps, ignore_index=padding_idx, reduce=True,
)

In [43]:
loss

tensor(929.9940, device='cuda:0', grad_fn=<AddBackward0>)

#### Arguments

In [44]:
class Args:
    def __init__(self):
        # scheduler
        self.warmup_updates = 500
        self.end_learning_rate = 0.00
        self.total_num_update = 20000
        self.power = 1.0
        self.force_anneal = None
        
        # optimizer
        self.lr = [3e-5]
        self.adam_betas = '(0.9, 0.999)'
        self.adam_eps = 1e-8
        self.weight_decay = 0.01

In [45]:
args = Args()

#### Optimizer

In [46]:
from fairseq.optim.adam import Adam, FairseqAdam

In [47]:
params = list(
            filter(
                lambda p: p.requires_grad,
                bart.model.parameters(),
            )
        )

In [48]:
adam = FairseqAdam(args, params)

#### Scheduler

In [49]:
from fairseq.optim.lr_scheduler.polynomial_decay_schedule import PolynomialDecaySchedule

In [50]:
scheduler = PolynomialDecaySchedule(args, adam)

In [51]:
scheduler.step_update(0)

0.0

#### Calculate Gradient

In [52]:
adam.backward(loss)

#### Clip

In [53]:
grad_norm = adam.clip_grad_norm(0.1, aggregate_norm_fn=None)

#### Try BERT encoder

In [54]:
src_dict.encode_line('464 1294 3845 468 8606 3352 284 31833 <mask> 2485 6973 11 1390 262 17504 286 3777 4200 284 661 319 8649 2342 8341 13', append_eos=False)

tensor([  133,   382,  1112,    34,  3946,   708,     7, 16888, 50264,  1751,
         5656,     6,   217,     5, 20627,     9,  2398,   647,     7,    82,
           15,  4952,  1183,  8204,     4], dtype=torch.int32)

In [55]:
decode_func(torch.tensor([  133,   382,  1112,    34,  3946,   708,     7, 16888,     3,  1751,
          5656,     6,   217,     5, 20627,     9,  2398,   647,     7,    82,
            15,  4952,  1183,  8204,     4,     2]))

'The US Senate has rejected plans to tighten<unk> gun controls, including the restriction of weapons sales to people on terrorism watch lists.'

In [56]:
decode_func(torch.tensor([ 133,   382,  1112,    34,  3946,   708,     7, 16888, 50264,  1751,
         5656,     6,   217,     5, 20627,     9,  2398,   647,     7,    82,
           15,  4952,  1183,  8204,     4]))

'The US Senate has rejected plans to tighten<mask> gun controls, including the restriction of weapons sales to people on terrorism watch lists.'

In [57]:
type(bart.model)

fairseq.models.bart.model.BARTModel

In [58]:
def test_func(target, src_tokens, src_lengths, prev_output_tokens):
    print(src_tokens.shape)
    print(src_lengths.shape)
    print(prev_output_tokens.shape)
    print(target.shape)

In [59]:
test_func(test_batch['target'], **test_batch['net_input'])

torch.Size([4, 1024])
torch.Size([4])
torch.Size([4, 30])
torch.Size([4, 30])


In [61]:
from fairseq.tasks.translation import TranslationTask