# Prior prob

In [188]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModel
)

bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
causal_prior_model = AutoModelForCausalLM.from_pretrained("facebook/bart-large")
bart_model = AutoModel.from_pretrained('facebook/bart-large')

Some weights of the model checkpoint at facebook/bart-large were not used when initializing BartForCausalLM: ['encoder.layers.0.fc2.weight', 'encoder.layers.1.self_attn.out_proj.bias', 'encoder.layers.4.self_attn.out_proj.bias', 'encoder.layers.8.self_attn.q_proj.weight', 'encoder.layers.10.self_attn_layer_norm.weight', 'encoder.layers.0.final_layer_norm.bias', 'encoder.layers.4.fc2.bias', 'encoder.layers.4.final_layer_norm.bias', 'encoder.layers.7.fc1.bias', 'encoder.layers.4.self_attn.q_proj.weight', 'encoder.layers.9.fc1.weight', 'encoder.layers.5.self_attn.out_proj.bias', 'encoder.layers.6.self_attn.out_proj.weight', 'encoder.layers.1.fc1.bias', 'encoder.layers.5.self_attn.q_proj.bias', 'encoder.layers.4.self_attn.v_proj.weight', 'encoder.layers.10.fc2.weight', 'encoder.layers.10.self_attn_layer_norm.bias', 'encoder.layers.6.self_attn.k_proj.bias', 'encoder.layers.3.self_attn.out_proj.bias', 'encoder.layers.4.self_attn.q_proj.bias', 'encoder.layers.5.self_attn.q_proj.weight', 'enco

## Load ENTFA data

In [178]:
from src.data_utils import load_EntFA

data_entfa = load_EntFA("test")

In [184]:
from datasets import load_dataset
xsum_test = load_dataset("xsum")["test"]
xsum_test_by_id = {
    doc["id"]: doc for doc in xsum_test
}

# Verify that doc in dataset is equal to XSUM doc
print(
    "Is Xsum doc equal to enfa source?", xsum_test_by_id["37989821"]["document"] == data_entfa[0]["source"]
)
print(
    "...with replaced newlines?", xsum_test_by_id["37989821"]["document"].replace("\n", " ") == data_entfa[0]["source"]
)

Using custom data configuration default
Reusing dataset xsum (/Users/anton164/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934)


  0%|          | 0/3 [00:00<?, ?it/s]

Is Xsum doc equal to enfa source? False
...with replaced newlines? True


It's only equal if we replace newlines..

In [185]:
from src.data_utils import lookup_entfa_by_prediction

entfa_sydney = lookup_entfa_by_prediction(
    data_entfa,
    "Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the Australian city."
)
entfa_sydney["prediction"], entfa_sydney["entities"]

('Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the Australian city.',
 [{'start': 0,
   'end': 6,
   'label': 'Non-hallucinated',
   'type': 'ORG',
   'ent': 'Sydney'},
  {'start': 22,
   'end': 27,
   'label': 'Non-hallucinated',
   'type': 'ORDINAL',
   'ent': 'first'},
  {'start': 60,
   'end': 68,
   'label': 'Non-factual Hallucination',
   'type': 'ORG',
   'ent': 'Waverley'},
  {'start': 83,
   'end': 86,
   'label': 'Non-hallucinated',
   'type': 'CARDINAL',
   'ent': 'two'},
  {'start': 124,
   'end': 134,
   'label': 'Non-hallucinated',
   'type': 'NORP',
   'ent': 'Australian'}])

## Prior Prob

In [199]:
from src.masked_probability import prior_causal_probability
prior_causal_probability(
    causal_prior_model,
    bart_tokenizer,
    entfa_sydney["prediction"]
)

[('<s>', 0, 2.1549179541474617e-12),
 ('S', 104, 1.780824277375359e-06),
 ('yd', 9611, 2.7244109332968947e-06),
 ('ney', 2596, 3.707136642105979e-08),
 (' has', 34, 0.027413560077548027),
 (' marked', 4760, 1.1275883480266202e-06),
 (' the', 5, 0.00018941506277769804),
 (' first', 78, 0.0023676673881709576),
 (' anniversary', 4038, 5.814571068185614e-06),
 (' of', 9, 0.0017079596873372793),
 (' the', 5, 0.003059038193896413),
 (' siege', 19951, 3.6858385787930104e-10),
 (' at', 23, 0.0022077469620853662),
 (' the', 5, 0.0014361185021698475),
 (' W', 305, 0.002821172820404172),
 ('aver', 9903, 3.773652963445784e-07),
 ('ley', 607, 7.092088338822577e-08),
 (' cafe', 16381, 1.8298884185696807e-07),
 (' in', 11, 0.03579285740852356),
 (' which', 61, 2.48848186856776e-06),
 (' two', 80, 4.862476998823695e-05),
 (' women', 390, 0.00013294632663019001),
 (' were', 58, 0.00033202560734935105),
 (' killed', 848, 4.035213407860283e-08),
 (' by', 30, 4.279447966837324e-05),
 (' a', 10, 0.00526131

# Posterior Prob

In [65]:
from transformers import (
    AutoModelForSeq2SeqLM
)

bart_xsum = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-xsum")

## What would bart-xsum have generated?

With and without replacing newline.

In [70]:
from src.generation_utils import generate_summaries
doc = xsum_test_by_id["35099282"]["document"]
docs_to_summarize = [
    doc, 
    doc.replace("\n", " ")
]

sums = generate_summaries(
    bart_xsum,
    bart_tokenizer,
    docs_to_summarize,
    None,
    num_beams=4
)
sums

['Sydney has marked the first anniversary of the siege at the Waverley cafe in which two people were killed.',
 'Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed.']

In [186]:
entfa_sydney["prediction"]

'Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the Australian city.'

In [189]:
from src.masked_probability import forceful_conditional_generation
preds, sums = forceful_conditional_generation(
    bart_xsum,
    bart_tokenizer,
    entfa_sydney["prediction"],
    docs_to_summarize
)
preds

[('S', 104, tensor([0.1869, 0.2390])),
 ('yd', 9611, tensor([0.9439, 0.9353])),
 ('ney', 2596, tensor([0.9553, 0.9534])),
 (' has', 34, tensor([0.5095, 0.5292])),
 (' marked', 4760, tensor([0.3043, 0.2741])),
 (' the', 5, tensor([0.6830, 0.7088])),
 (' first', 78, tensor([0.6297, 0.6220])),
 (' anniversary', 4038, tensor([0.9380, 0.9416])),
 (' of', 9, tensor([0.8882, 0.8881])),
 (' the', 5, tensor([0.8174, 0.8127])),
 (' siege', 19951, tensor([0.3841, 0.3650])),
 (' at', 23, tensor([0.7233, 0.7404])),
 (' the', 5, tensor([0.5626, 0.5693])),
 (' W', 305, tensor([0.0992, 0.1038])),
 ('aver', 9903, tensor([0.7166, 0.6745])),
 ('ley', 607, tensor([0.9400, 0.9403])),
 (' cafe', 16381, tensor([0.5410, 0.5290])),
 (' in', 11, tensor([0.3207, 0.3224])),
 (' which', 61, tensor([0.7555, 0.8088])),
 (' two', 80, tensor([0.6819, 0.7448])),
 (' women', 390, tensor([0.3642, 0.3930])),
 (' were', 58, tensor([0.6122, 0.6101])),
 (' killed', 848, tensor([0.8426, 0.8378])),
 (' by', 30, tensor([0.3655,