In [1]:
import json

from fairseq.models.bart import BARTModel
from utils import read_lines

In [2]:
PATH = json.load(open('../path_config.json'))

In [3]:
# posterior_bart = BARTModel.from_pretrained(PATH['bart.large.xsum'],
#                                            checkpoint_file='model.pt',
#                                            data_name_or_path=PATH['bart.large.xsum'])

# posterior_bart = BARTModel.from_pretrained(PATH['bart.large.cnn'],
#                                            checkpoint_file='model.pt',
#                                            data_name_or_path=PATH['bart.large.cnn'])

posterior_bart = BARTModel.from_pretrained(PATH['xsum_cmlm_bos'],
                                           checkpoint_file='checkpoint_best.pt',
                                           data_name_or_path=PATH['data_name_or_path'])

In [4]:
prior_bart = BARTModel.from_pretrained(PATH['bart.large'],
                                       checkpoint_file='model.pt',
                                       data_name_or_path=PATH['bart.large'])

# prior_bart = BARTModel.from_pretrained(PATH['cnndm_cmlm_cedar'],
#                                        checkpoint_file='checkpoint_best.pt',
#                                        data_name_or_path=PATH['data_name_or_path'])

# prior_bart = BARTModel.from_pretrained(PATH['cnndm_cmlm_scratch_cedar_warmup_10000'],
#                                        checkpoint_file='checkpoint_best.pt',
#                                        data_name_or_path=PATH['data_name_or_path'])

#### Read XSum

In [5]:
document_path = PATH['xsum_fariseq'] + '/test.source'
target_path = PATH['xsum_fariseq'] + '/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Test One Example

In [6]:
import spacy
import torch

from model import ConditionalSequenceGenerator
from utils import prepare_cmlm_inputs, prepare_mlm_inputs, prepare_clm_inputs, get_cmlm_probability, get_probability, get_probability_parallel

nlp = spacy.load('en_core_web_sm')

In [7]:
prior_model = ConditionalSequenceGenerator(prior_bart)
posterior_model = ConditionalSequenceGenerator(posterior_bart)

In [8]:
def mask_filling(model, src_input, tgt_input=None):
    """
    Filling the mask in sentence(s).
    """
    input_ids, lengths = model.tokenize_with_mask(src_input)

    target_ids = None
    if tgt_input is not None:
        assert len(src_input) == len(tgt_input), "source & target length should match."
        target_ids, _ = model.tokenize(tgt_input, left_pad=False)
    
    with torch.no_grad():
        encoder_output = model.encode_sequence(input_ids, lengths)
        decoder_output = model.decode_sequence(encoder_output, 
                                               target_ids=target_ids,
                                               prefix_tokens=[2, 0])
    return decoder_output

In [9]:
def get_prior_probability(generator, src_input, tgt_input, position, entity):
    """Tokenize input with a special <mask> token."""
    assert len(src_input) == len(tgt_input), "source & target length should match."
    decoder_output = mask_filling(generator, src_input, tgt_input)
    init_input, tokens, token_probs = decoder_output
    
    probs = []
    for p, tok, tokp, e in zip(position, tokens, token_probs, entity):
        probs.append(get_probability(p, tok, tokp, e).item())
    return probs

In [31]:
INDEX = 9185

source = xsum_source[INDEX]
target = 'A baby pine marten has been captured on camera for the first time in Wales as part of a campaign to reintroduce the animal to Ceredigion.'

In [32]:
# ent_parts = [{'start': 35, 'end': 39, 'label': 0, 'type': 'ORG', 'ent': 'TTTS'},
#              {'start': 75, 'end': 82, 'label': 2, 'type': 'LOC', 'ent': 'Cardiff'}]

ent_parts = nlp(target).to_json()['ents']

for e in ent_parts:
    print('{} - {}'.format(e, target[e['start']: e['end']]))

{'start': 55, 'end': 60, 'label': 'ORDINAL'} - first
{'start': 69, 'end': 74, 'label': 'GPE'} - Wales


In [33]:
pri_args = prepare_mlm_inputs(source, target, ent_parts)
pos_args = prepare_cmlm_inputs(source, target, ent_parts)

# prior_probs = get_prior_probability(prior_model, pri_args[0], pri_args[1], pri_args[2], pri_args[3])
prior_probs = get_probability_parallel(prior_model, pri_args[0], pri_args[1], pri_args[2], pri_args[3], mask_filling=True)
posterior_probs = get_probability_parallel(posterior_model, pos_args[0], pos_args[1], pos_args[2], pos_args[3])

assert len(prior_probs) == len(posterior_probs)

In [34]:
print('- prior: {}'.format(prior_probs))
print('- posterior: {}'.format(posterior_probs))

- prior: [0.95849609375, 0.1553955078125]
- posterior: [0.91455078125, 0.85498046875]


In [35]:
pri_args

(['A baby pine marten has been captured on camera for the <mask> time in Wales as part of a campaign to reintroduce the animal to Ceredigion.',
  'A baby pine marten has been captured on camera for the first time in <mask> as part of a campaign to reintroduce the animal to Ceredigion.'],
 ['A baby pine marten has been captured on camera for the first time in Wales as part of a campaign to reintroduce the animal to Ceredigion.',
  'A baby pine marten has been captured on camera for the first time in Wales as part of a campaign to reintroduce the animal to Ceredigion.'],
 [(55, 60), (69, 74)],
 ['first', 'Wales'])

#### Probability Calculation

In [36]:
from tqdm import tqdm

In [37]:
data = json.load(open('../data/annotated_with_probability_200.json', 'r'))
print(len(data))

200


In [40]:
data[55]

{'id': 10943,
 'pred': "A powerful cyclone has killed at least 11 people and injured more than 100 in Vanuatu, the Pacific nation's president has said.",
 'ents': [{'start': 30,
   'end': 41,
   'label': 2,
   'type': 'CARDINAL',
   'ent': 'at least 11',
   'bart.large': 0.0215301513671875,
   'xsum_cmlm_bos': 0.02984619140625,
   'bart.large.xsum': 0.0200347900390625,
   'cnndm_cmlm_cedar': 0.007183074951171875,
   'cnndm_cmlm_scratch_cedar_warmup_10000': 1.9073486328125e-06,
   'xsum_cmlm_scratch_cedar_warmup_20000': 0.092529296875},
  {'start': 61,
   'end': 74,
   'label': 2,
   'type': 'CARDINAL',
   'ent': 'more than 100',
   'bart.large': 0.05804443359375,
   'xsum_cmlm_bos': 0.0843505859375,
   'bart.large.xsum': 0.06317138671875,
   'cnndm_cmlm_cedar': 0.01030731201171875,
   'cnndm_cmlm_scratch_cedar_warmup_10000': 4.7087669372558594e-05,
   'xsum_cmlm_scratch_cedar_warmup_20000': 0.003948211669921875},
  {'start': 78,
   'end': 85,
   'label': 0,
   'type': 'GPE',
   'ent': 

In [14]:
for INDEX in tqdm(range(len(data))):
    source = xsum_source[data[INDEX]['id']]
    target = data[INDEX]['pred']
    
    pri_args = prepare_cmlm_inputs(source, target, data[INDEX]['ents'])
    pos_args = prepare_cmlm_inputs(source, target, data[INDEX]['ents'])

    prior_probs = get_cmlm_probability(prior_model, pri_args[0], pri_args[1], pri_args[2], pri_args[3])
    posterior_probs = get_cmlm_probability(posterior_model, pos_args[0], pos_args[1], pos_args[2], pos_args[3])
    
    assert len(prior_probs) == len(posterior_probs) == len(data[INDEX]['ents']), "{};\n {};\n {}".format(prior_probs, posterior_probs, data[INDEX]['ents'])
    for i in range(len(prior_probs)):
        data[INDEX]['ents'][i]['cnndm_cmlm_scratch_cedar_warmup_10000'] = prior_probs[i]
        data[INDEX]['ents'][i]['xsum_cmlm_scratch_cedar_warmup_20000'] = posterior_probs[i]

100%|██████████| 200/200 [03:40<00:00,  1.10s/it]


In [15]:
# with open('../data/annotated_with_probability_200.json', 'w') as fout:
#     json.dump(data , fout)