## Exploring

In [3]:
import spacy

nlp = spacy.load('en_core_web_sm')

source = 'The city was brought to a standstill on 15 December last year when a gunman held 18 hostages for 17 hours. Family members of victims Tori Johnson and Katrina Dawson were in attendance. Images of the floral tributes that filled the city centre in the wake of the siege were projected on to the cafe and surrounding buildings in an emotional twilight ceremony. Prime Minister Malcolm Turnbull gave an address saying a "whole nation resolved to answer hatred with love". "Testament to the spirit of Australians is that with such unnecessary, thoughtless tragedy, an amazing birth of mateship, unity and love occurs. Proud to be Australian," he said. How the Sydney siege unfolded. New South Wales Premier Mike Baird has also announced plans for a permanent memorial to be built into the pavement in Martin Place. Clear cubes containing flowers will be embedded into the concrete and will shine with specialised lighting. It is a project inspired by the massive floral tributes that were left in the days after the siege. "Something remarkable happened here. As a city we were drawn to Martin Place. We came in shock and in sorrow but every step we took was with purpose," he said on Tuesday.'
prediction = 'Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the Australian city.'

entities = nlp(prediction).to_json()['ents']
ent_text = [prediction[e['start']: e['end']] for e in entities]

print('entitites:')
print(type(entities))

print('ent text:')
print(type(ent_text))
print(ent_text)

entitites:
<class 'list'>
ent text:
<class 'list'>
['Sydney', 'first', 'Waverley', 'two', 'Australian']


In [17]:
def prepare_clm_inputs(source, target, ent_parts=None):
    """For Masked Language Model. For BART only."""
    if ent_parts is None:
        ent_parts = nlp(target).to_json()['ents']
    
    inputs, targets = [], []
    positions, entities = [], []

    for e in ent_parts:
        inputs.append(target[0: e['start']] + '<mask>')
        targets.append(target[:e['end']])
        entities.append(target[e['start']: e['end']])
        positions.append((e['start'], e['end']))
    
    return inputs, targets, positions, entities

In [34]:
prepare_mlm_inputs(source, prediction, ent_parts=entities)

(['<mask> has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the Australian city.',
  'Sydney has marked the <mask> anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the Australian city.',
  'Sydney has marked the first anniversary of the siege at the <mask> cafe in which two women were killed by a gunman in the Australian city.',
  'Sydney has marked the first anniversary of the siege at the Waverley cafe in which <mask> women were killed by a gunman in the Australian city.',
  'Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the <mask> city.'],
 ['Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the Australian city.',
  'Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in th

## Original

In [1]:
import json
import torch

from fairseq.models.bart import BARTModel

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
CMLM_MODEL_PATH = 'BART_models/xsum_cedar_cmlm'
MLM_MODEL_PATH = 'BART_models/bart.large'
DATA_NAME_OR_PATH = 'summarization/XSum/fairseq_files/xsum-bin'

In [5]:
bart = BARTModel.from_pretrained(CMLM_MODEL_PATH,
                                 checkpoint_file='checkpoint_best.pt',
                                 data_name_or_path=DATA_NAME_OR_PATH)

2022-03-15 14:57:32 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/xsum_cedar_cmlm
2022-03-15 14:57:32 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/summarization/XSum/fairseq_files/xsum-bin
2022-03-15 14:57:41 | INFO | fairseq.tasks.translation | [source] dictionary: 50264 types
2022-03-15 14:57:41 | INFO | fairseq.tasks.translation | [target] dictionary: 50264 types


In [6]:
prior_bart = BARTModel.from_pretrained(MLM_MODEL_PATH,
                                       checkpoint_file='model.pt',
                                       data_name_or_path=MLM_MODEL_PATH)

2022-03-15 14:57:52 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/bart.large
2022-03-15 14:57:52 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/bart.large
2022-03-15 14:57:59 | INFO | fairseq.tasks.denoising | dictionary: 50264 types


#### Build Prior & Posterior Model

In [7]:
from EntFA.model import ConditionalSequenceGenerator
from EntFA.utils import prepare_cmlm_inputs, prepare_mlm_inputs, get_probability_parallel

In [8]:
model = ConditionalSequenceGenerator(bart)
prior_model = ConditionalSequenceGenerator(prior_bart)

#### Test on One Sample

In [9]:
import spacy

nlp = spacy.load('en_core_web_sm')

In [10]:
source = 'The city was brought to a standstill on 15 December last year when a gunman held 18 hostages for 17 hours. Family members of victims Tori Johnson and Katrina Dawson were in attendance. Images of the floral tributes that filled the city centre in the wake of the siege were projected on to the cafe and surrounding buildings in an emotional twilight ceremony. Prime Minister Malcolm Turnbull gave an address saying a "whole nation resolved to answer hatred with love". "Testament to the spirit of Australians is that with such unnecessary, thoughtless tragedy, an amazing birth of mateship, unity and love occurs. Proud to be Australian," he said. How the Sydney siege unfolded. New South Wales Premier Mike Baird has also announced plans for a permanent memorial to be built into the pavement in Martin Place. Clear cubes containing flowers will be embedded into the concrete and will shine with specialised lighting. It is a project inspired by the massive floral tributes that were left in the days after the siege. "Something remarkable happened here. As a city we were drawn to Martin Place. We came in shock and in sorrow but every step we took was with purpose," he said on Tuesday.'
prediction = 'Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the Australian city.'

In [11]:
entities = nlp(prediction).to_json()['ents']
ent_text = [prediction[e['start']: e['end']] for e in entities]
print(ent_text)

['Sydney', 'first', 'Waverley', 'two', 'Australian']


In [12]:
inputs = prepare_cmlm_inputs(source, prediction, ent_parts=entities)
posteriors = get_probability_parallel(model, inputs[0], inputs[1], inputs[2], inputs[3])

In [13]:
inputs = prepare_mlm_inputs(source, prediction, ent_parts=entities)
priors = get_probability_parallel(prior_model, inputs[0], inputs[1], inputs[2], inputs[3], mask_filling=True)

In [14]:
print('{:<8}\t{:}\t\t{:}'.format('', 'Prior', 'Posterior'))
for e, pri, pos in zip(ent_text, priors, posteriors):
    print('{:<8}\t{:.6}\t{:.6}'.format(e, pri, pos))

        	Prior		Posterior
Sydney  	0.00366783	0.946777
first   	0.116516	0.325928
Waverley	0.0179596	0.00888062
two     	0.0629272	0.858887
Australian	0.00283623	0.911133
