In [1]:
import os
import math
import torch
import torch.nn as nn

from tqdm import tqdm
from fairseq.models.bart import BARTModel
from utils import read_lines

In [2]:
bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large.xsum',
                                 checkpoint_file='model.pt',
                                 data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large.xsum')

In [3]:
bart.cuda()
bart.eval()
bart.half()
print('- activate evaluation mode')

- activate evaluation mode


In [4]:
encode_func = lambda x: bart.task.source_dictionary.encode_line(bart.bpe.encode(x) + ' </s>', append_eos=False).long()
decode_func = bart.decode

In [5]:
bart_encoder = bart.model.encoder
bart_decoder = bart.model.decoder
print(type(bart.model))
print(type(bart_encoder))
print(type(bart_decoder))

<class 'fairseq.models.bart.model.BARTModel'>
<class 'fairseq.models.transformer.TransformerEncoder'>
<class 'fairseq.models.transformer.TransformerDecoder'>


#### Mask Span

In [6]:
# input: (sentence, span)
# output: [1, 1, 1, 0, 0, 0, 1, 1, 1]

In [7]:
SENTENCE = "Mohammad Javad Zarif has spent more time with John Kerry than any other foreign minister . He once participated in a takeover of the Iranian Consulate in San Francisco . The Iranian foreign minister tweets in English ."
print(SENTENCE)

Mohammad Javad Zarif has spent more time with John Kerry than any other foreign minister . He once participated in a takeover of the Iranian Consulate in San Francisco . The Iranian foreign minister tweets in English .


In [8]:
encode_func(SENTENCE)

tensor([29880, 41007, 24942,   625, 17122,  1594,    34,  1240,    55,    86,
           19,   610,  9153,    87,   143,    97,  1093,  1269,   479,    91,
          683,  7849,    11,    10, 10260,     9,     5,  5051,  9051, 10246,
           11,   764,  2659,   479,    20,  5051,  1093,  1269,  6245,    11,
         2370,   479,     2])

In [9]:
print([decode_func(torch.tensor([i])) for i in encode_func(SENTENCE)])

['Moh', 'ammad', ' Jav', 'ad', ' Zar', 'if', ' has', ' spent', ' more', ' time', ' with', ' John', ' Kerry', ' than', ' any', ' other', ' foreign', ' minister', ' .', ' He', ' once', ' participated', ' in', ' a', ' takeover', ' of', ' the', ' Iranian', ' Cons', 'ulate', ' in', ' San', ' Francisco', ' .', ' The', ' Iranian', ' foreign', ' minister', ' tweets', ' in', ' English', ' .', '']


In [10]:
def get_indices(target, tokens):
    """
    Args:
        target: 'Justin Martin'
        tokens: ['The', ' Archbishop', ' of', ...]
    """
    all_indices = []
    for i, t in enumerate(tokens):
        t = t.strip()
        indices = []
        if t in target:
            indices.append(i)
            if t == target:
                all_indices.extend(indices)
                continue
            elif i + 1 < len(tokens):
                for ni, rt in enumerate(tokens[i + 1:]):
                    t += rt
                    indices.append(i + ni + 1)
                    if t == target:
                        all_indices.extend(indices)
                        break
                    elif t not in target:
                        break
    return all_indices

def build_mask(target, tokens):
    """
    Args:
        target: 'Justin Martin'
        tokens: ['The', ' Archbishop', ' of', ...]
    """
    indices = get_indices(target, tokens)
    mask = torch.ones(len(tokens), dtype=torch.long)
    for i in indices:
        mask[i] = 0
    return mask

In [11]:
SENTENCE = 'The League One match between Oldham and Blackpool has been postponed because of a waterlogged pitch.'

In [12]:
print([decode_func(torch.tensor([i])) for i in encode_func(SENTENCE)])

['The', ' League', ' One', ' match', ' between', ' Old', 'ham', ' and', ' Black', 'pool', ' has', ' been', ' postponed', ' because', ' of', ' a', ' water', 'log', 'ged', ' pitch', '.', '']


In [13]:
build_mask("Oldham", [decode_func(torch.tensor([i])) for i in encode_func(SENTENCE)])

tensor([1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

#### Test

In [14]:
from random import randint

In [15]:
target_path = '/home/ml/cadencao/XSum/fairseq_files/test.target'
xsum_target = read_lines(target_path)

In [16]:
for tgt in tqdm(xsum_target):
    # sample a span
    tokens = tgt.split()
    bos_index = randint(0, len(tokens) - 1)
    eos_index = randint(bos_index, len(tokens) - 1)
    span = ' '.join(tokens[bos_index: eos_index + 1])
    
    ids = encode_func(tgt)
    word_piece = [decode_func(torch.tensor([i])) for i in ids]
    indices = get_indices(span, word_piece)

    if len(indices) == 0:
#         print("- target: {}".format(tgt))
#         print("- span: {}".format(span))
        continue

    continued = [indices[0]]
    if len(indices) > 1:
        for i in indices[1:]:
            if i - continued[-1] == 1:
                continued.append(i)
            else:
                break

    extracted = ''.join(word_piece[continued[0]: continued[-1] + 1]).strip()
    assert extracted == span, "- tgt: {}; span: {}; extracted: {}".format(tgt, span, extracted)

100%|██████████| 11301/11301 [00:32<00:00, 351.97it/s]


#### Generate

In [36]:
import spacy
from tqdm import tqdm

nlp = spacy.load('en')

In [42]:
train_target_path = '/home/ml/cadencao/XSum/fairseq_files/train.target'
train_source_path = '/home/ml/cadencao/XSum/fairseq_files/train.source'
xsum_train_target = read_lines(train_target_path)
xsum_train_source = read_lines(train_source_path)
print(len(xsum_train_target))

203575


In [43]:
target_masks = []

for s, t in tqdm(zip(xsum_train_source, xsum_train_target)):
    tokens = [decode_func(torch.tensor([i])) for i in encode_func(t)]
    mask = torch.ones(len(tokens), dtype=torch.long)
    
    # NER
    t_ents = [e.text for e in nlp(t).ents]
    for e in t_ents:
        for ep in e.split():
            if ep not in s:
                tmp_mask = build_mask(ep, tokens)
                mask.masked_fill_((1 - tmp_mask).bool(), 0)
   
    # add processed
    target_masks.append(mask)
#     print(tokens)
#     print(mask)
    
#     if len(target_mask) == 10:
#         break

203575it [51:22, 66.04it/s]


In [51]:
with open('train.mask', 'w') as file:
    for t in target_masks:
        for i in t.tolist():
            file.write('{} '.format(i))
        file.write('\n')

In [52]:
masks = []
with open('train.mask', 'r') as file:
    for line in file:
        line = line.strip()
        value = [int(i) for i in line.split()]
        masks.append(value)

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1])