In [1]:
import json

from fairseq.models.bart import BARTModel
from utils import read_lines

In [2]:
PATH = json.load(open('../path_config.json'))

In [3]:
# posterior_bart = BARTModel.from_pretrained(PATH['bart.large.xsum'],
#                                            checkpoint_file='model.pt',
#                                            data_name_or_path=PATH['bart.large.xsum'])

# posterior_bart = BARTModel.from_pretrained(PATH['bart.large.cnn'],
#                                            checkpoint_file='model.pt',
#                                            data_name_or_path=PATH['bart.large.cnn'])

# posterior_bart = BARTModel.from_pretrained(PATH['xsum_cmlm_bos'],
#                                            checkpoint_file='checkpoint_best.pt',
#                                            data_name_or_path=PATH['data_name_or_path'])

posterior_bart = BARTModel.from_pretrained(PATH['xsum_cmlm_scratch_cedar_warmup_20000'],
                                           checkpoint_file='checkpoint_best.pt',
                                           data_name_or_path=PATH['data_name_or_path'])

In [4]:
prior_bart = BARTModel.from_pretrained(PATH['bart.large'],
                                       checkpoint_file='model.pt',
                                       data_name_or_path=PATH['bart.large'])

# prior_bart = BARTModel.from_pretrained(PATH['cnndm_cmlm_cedar'],
#                                        checkpoint_file='checkpoint_best.pt',
#                                        data_name_or_path=PATH['data_name_or_path'])

# prior_bart = BARTModel.from_pretrained(PATH['cnndm_cmlm_scratch_cedar_warmup_10000'],
#                                        checkpoint_file='checkpoint_best.pt',
#                                        data_name_or_path=PATH['data_name_or_path'])

#### Read XSum

In [5]:
document_path = PATH['xsum_fariseq'] + '/test.source'
target_path = PATH['xsum_fariseq'] + '/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Test One Example

In [6]:
import spacy

from model import ConditionalSequenceGenerator
from utils import prepare_cmlm_inputs, prepare_mlm_inputs, prepare_clm_inputs, get_cmlm_probability, get_probability

nlp = spacy.load('en_core_web_sm')

In [7]:
def get_prior_probability(generator, src_input, tgt_input, position, entity):
    """Tokenize input with a special <mask> token."""
    assert len(src_input) == len(tgt_input), "source & target length should match."
    decoder_output = generator.mask_filling(src_input, tgt_input)
    init_input, tokens, token_probs = decoder_output
    
    probs = []
    for p, tok, tokp, e in zip(position, tokens, token_probs, entity):
        probs.append(get_probability(p, tok, tokp, e).item())
    return probs

In [8]:
INDEX = 10943

source = xsum_source[INDEX]
target = "A powerful cyclone has killed at least 11 people and injured more than 100 in Vanuatu, the Pacific nation's president has said."

In [9]:
# ent_parts = [{'start': 35, 'end': 39, 'label': 0, 'type': 'ORG', 'ent': 'TTTS'},
#              {'start': 75, 'end': 82, 'label': 2, 'type': 'LOC', 'ent': 'Cardiff'}]

ent_parts = nlp(target).to_json()['ents']

for e in ent_parts:
    print('{} - {}'.format(e, target[e['start']: e['end']]))

{'start': 30, 'end': 41, 'label': 'CARDINAL'} - at least 11
{'start': 61, 'end': 74, 'label': 'CARDINAL'} - more than 100
{'start': 78, 'end': 85, 'label': 'GPE'} - Vanuatu


In [10]:
prior_model = ConditionalSequenceGenerator(prior_bart)
posterior_model = ConditionalSequenceGenerator(posterior_bart)

pri_args = prepare_mlm_inputs(source, target, ent_parts)
pos_args = prepare_cmlm_inputs(source, target, ent_parts)

prior_probs = get_prior_probability(prior_model, pri_args[0], pri_args[1], pri_args[2], pri_args[3])
posterior_probs = get_cmlm_probability(posterior_model, pos_args[0], pos_args[1], pos_args[2], pos_args[3])

assert len(prior_probs) == len(posterior_probs)

In [11]:
print('- prior: {}'.format(prior_probs))
print('- posterior: {}'.format(posterior_probs))

- prior: [0.0213775634765625, 0.05877685546875, 0.0002529621124267578]
- posterior: [0.092041015625, 0.003955841064453125, 0.0030040740966796875]


#### Probability Calculation

In [11]:
from tqdm import tqdm

In [12]:
data = json.load(open('../data/annotated_with_probability_200.json', 'r'))
print(len(data))

200


In [13]:
data[55]

{'id': 10943,
 'pred': "A powerful cyclone has killed at least 11 people and injured more than 100 in Vanuatu, the Pacific nation's president has said.",
 'ents': [{'start': 30,
   'end': 41,
   'label': 2,
   'type': 'CARDINAL',
   'ent': 'at least 11',
   'bart.large': 0.0215301513671875,
   'xsum_cmlm_bos': 0.02984619140625,
   'bart.large.xsum': 0.0200347900390625,
   'cnndm_cmlm_cedar': 0.007183074951171875},
  {'start': 61,
   'end': 74,
   'label': 2,
   'type': 'CARDINAL',
   'ent': 'more than 100',
   'bart.large': 0.05804443359375,
   'xsum_cmlm_bos': 0.0843505859375,
   'bart.large.xsum': 0.06317138671875,
   'cnndm_cmlm_cedar': 0.01030731201171875},
  {'start': 78,
   'end': 85,
   'label': 0,
   'type': 'GPE',
   'ent': 'Vanuatu',
   'bart.large': 0.00024771690368652344,
   'xsum_cmlm_bos': 0.857421875,
   'bart.large.xsum': 0.736328125,
   'cnndm_cmlm_cedar': 0.8759765625},
  {'start': 91,
   'end': 98,
   'label': 1,
   'type': 'LOC',
   'ent': 'Pacific',
   'bart.large'

In [14]:
for INDEX in tqdm(range(len(data))):
    source = xsum_source[data[INDEX]['id']]
    target = data[INDEX]['pred']
    
    pri_args = prepare_cmlm_inputs(source, target, data[INDEX]['ents'])
    pos_args = prepare_cmlm_inputs(source, target, data[INDEX]['ents'])

    prior_probs = get_cmlm_probability(prior_model, pri_args[0], pri_args[1], pri_args[2], pri_args[3])
    posterior_probs = get_cmlm_probability(posterior_model, pos_args[0], pos_args[1], pos_args[2], pos_args[3])
    
    assert len(prior_probs) == len(posterior_probs) == len(data[INDEX]['ents']), "{};\n {};\n {}".format(prior_probs, posterior_probs, data[INDEX]['ents'])
    for i in range(len(prior_probs)):
        data[INDEX]['ents'][i]['cnndm_cmlm_scratch_cedar_warmup_10000'] = prior_probs[i]
        data[INDEX]['ents'][i]['xsum_cmlm_scratch_cedar_warmup_20000'] = posterior_probs[i]

100%|██████████| 200/200 [03:40<00:00,  1.10s/it]


In [15]:
# with open('../data/annotated_with_probability_200.json', 'w') as fout:
#     json.dump(data , fout)