In [1]:
import os
import json
import math
import torch
import torch.nn as nn

from fairseq.models.bart import BARTModel
from utils import read_lines

from transformers import BartTokenizer

In [2]:
PATH = json.load(open('path_config.json'))

In [3]:
finetuned_bart = BARTModel.from_pretrained(PATH['xsum_cmlm_bos'],
                                           checkpoint_file='checkpoint_best.pt',
                                           data_name_or_path=PATH['data_name_or_path'])

In [4]:
# bart = BARTModel.from_pretrained(PATH['bart.large'],
#                                  checkpoint_file='model.pt',
#                                  data_name_or_path=PATH['bart.large'])

bart = BARTModel.from_pretrained(PATH['cnndm_cmlm_cedar'],
                                       checkpoint_file='checkpoint_best.pt',
                                       data_name_or_path=PATH['data_name_or_path'])

#### Read XSum

In [5]:
document_path = PATH['xsum_fariseq'] + '/test.source'
target_path = PATH['xsum_fariseq'] + '/test.target'
xsum_source = read_lines(document_path)
xsum_target = read_lines(target_path)
print(len(xsum_source))
assert len(xsum_source) == len(xsum_target)

11301


#### Generate Summary

In [6]:
from fairseq.data.data_utils import collate_tokens

In [7]:
class ConditionalSequenceGenerator:
    """Conditional sequence generator for calculating prior and posterior probability."""
    def __init__(self, bart):
        self.bart = bart
        self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
        
        self.encode_func = bart.encode
        self.decode_func = bart.decode
        self.max_positions = bart.max_positions
        self.encode_line = bart.task.source_dictionary.encode_line
        
        self._initialize()
    
    def _initialize(self):
        """Set BART model to evaluation mode."""
        self.bart.cuda()
        self.bart.eval()
        self.bart.half()
        
    def tokenize(self, input_str, append_bos=False, append_eos=True, left_pad=True):
        """BPE-encode a sentence (or multiple sentences).

        Args:
            input_str (str or List[str]): input sentence to be tokenized.
            append_bos (bool): self-explained.
            append_eos (bool): self-explained.

        Return:
            input_ids (torch.Tensor): [batch_size, length]
            src_lengths (torch.Tensor): [batch_size]
        """
        if type(input_str) == type(''):
            input_str = [input_str]

        input_ids = []
        for ins in input_str:
            tokens = self.bart.bpe.encode(ins)  # <mask>: 1279 27932 29
            calibration = sum([append_bos, append_eos])
            if len(tokens.split(" ")) > min(self.max_positions) - calibration:
                tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - calibration])

            tokens = "<s> " + tokens if append_bos else tokens
            tokens = tokens + " </s>" if append_eos else tokens
            ids = self.encode_line(tokens, append_eos=False).long()
            input_ids.append(ids)

        input_ids = collate_tokens(input_ids, pad_idx=1, left_pad=left_pad).cuda()
        input_lengths = torch.sum(input_ids != 1, dim=1).cuda()

        return input_ids, input_lengths
    
    def tokenize_with_mask(self, input_str):
        """Tokenize sentence with a special <mask> token in it.

        Args:
            input_str (str or List[str]): input sentence to be tokenized.

        Return:
            input_ids (torch.Tensor): [batch_size, length]
            src_lengths (torch.Tensor): [batch_size]
        """
        input_ids = self.tokenizer(input_str, return_tensors='pt', padding=True)['input_ids'].cuda()
        input_lengths = torch.sum(input_ids != 1, dim=1).cuda()
        return input_ids, input_lengths
    
    def generate(self, src_input, tgt_input=None):
        """Conditional generation."""
        input_ids, lengths = self.tokenize(src_input, append_bos=False) 
        
        target_ids = None
        if tgt_input is not None:
            assert len(src_input) == len(tgt_input), "source & target length should match."
            target_ids, _ = self.tokenize(tgt_input, append_bos=False, left_pad=False)
        
        with torch.no_grad():
            encoder_output = self.encode_sequence(input_ids, lengths)
            decoder_output = self.decode_sequence(encoder_output, 
                                                  target_ids=target_ids,
                                                  prefix_tokens=[2])
        return decoder_output
    
    def mask_filling(self, src_input, tgt_input=None):
        """
        Filling the mask in sentence(s).
        """
        input_ids, lengths = self.tokenize_with_mask(src_input)
        
        target_ids = None
        if tgt_input is not None:
            assert len(src_input) == len(tgt_input), "source & target length should match."
            target_ids, _ = self.tokenize(tgt_input, left_pad=False)

        with torch.no_grad():
            encoder_output = self.encode_sequence(input_ids, lengths)
            decoder_output = self.decode_sequence(encoder_output, 
                                                  target_ids=target_ids,
                                                  prefix_tokens=[2, 0])
        return decoder_output
    
    def encode_sequence(self, input_ids, lengths):
        return self.bart.model.encoder(input_ids, src_lengths=lengths)
        
    def decode_sequence(
        self,
        encoder_out,
        target_ids=None,
        min_decode_step=3,
        max_decode_step=100,
        pad_id=1,
        eos_id=2,
        prefix_tokens=[2, 0],
    ):
        batch_size = encoder_out['encoder_padding_mask'][0].shape[0]
        init_input = torch.tensor([prefix_tokens] * batch_size, dtype=torch.long).cuda()
        token_probs, tokens = None, [[] for i in range(batch_size)]
        end_mask = torch.tensor([False] * batch_size).cuda()

        softmax = nn.Softmax(dim=1)
        for step in range(max_decode_step):
            decoder_outputs = self.bart.model.decoder(init_input, encoder_out, features_only=False)
            logits = decoder_outputs[0][:, -1, :]  # logits: [batch_size, vocab]
            attn = decoder_outputs[1]['attn'][0]  # [batch_size, prev_token_len, src_token_len]

            if step + 1 < min_decode_step:
                logits[:, eos_id] = -math.inf  # mask <EOS> token when within minimal step
            logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf  # never select <PAD> & <BOS> token
            probs = softmax(logits)  # probs: [batch_size, vocab]

            # select tokens
            if target_ids is not None:
                selected_token = target_ids[:, step]
            else:
                value, indices = torch.topk(probs, 5, dim=1)
                selected_token = indices[:, 0]

            selected_token = selected_token.masked_fill(end_mask, pad_id)
            init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1)
            
            probs = torch.gather(probs, 1, selected_token.unsqueeze(1)).detach()
            probs = probs.masked_fill(end_mask.unsqueeze(1), 1.0)
            
            # str & probability
            token_probs = probs if token_probs is None else torch.cat([token_probs, probs], dim=-1)
            for t, s in zip(tokens, selected_token):
                t.append(self.decode_func(s.unsqueeze(0)) if s.item() != pad_id else '<pad>')
            
            # stop generation when all finished
            end_mask = torch.logical_or(end_mask, selected_token == eos_id) 
            if end_mask.sum().item() == batch_size:
                break

        return init_input, tokens, token_probs

In [8]:
def get_probability(position, tokens, probs, entity):
    """Calculate the probability of a span.

    Args:
        position: (start, end)
        tokens: ['The', ' Archbishop', ' of', ...]
        probs: [0.50, 0.49, 0.88, ...]
        entity: Rodgers
    """
    assert len(tokens) == len(probs), "Tokens and token probabilities does not match."
    
    end_pointer, end_pos = 0, []
    for t in tokens:
        end_pointer += len(t)
        end_pos.append(end_pointer)
    
    assert position[1] in end_pos, "- {}\n- {}\n- {}\n- {}\n- {}\n".format(position, tokens, probs, entity, end_pos)
    last_index = end_pos.index(position[1])
    indexes = [last_index]
    total_length = len(tokens[last_index])
    
    while total_length < (position[1] - position[0]):
        last_index -= 1
        assert last_index >= 0
        indexes.append(last_index)
        total_length += len(tokens[last_index])
    
    indexes.reverse()
    
    generated = ''.join([tokens[i] for i in indexes])
    assert entity in generated, 'entity: {}; span: {}'.format(entity, generated)
    
    prob = 1.0
    for i in indexes:
        prob *= probs[i]
    return prob

In [9]:
def get_cmlm_probability(generator, src_input, tgt_input, position, entity):
    outputs = generator.generate(src_input, tgt_input=tgt_input)
    init_input, tokens, token_probs = outputs
    
    probs = []
    for p, tok, tokp, e in zip(position, tokens, token_probs, entity):
        probs.append(get_probability(p, tok, tokp, e).item())
    
    return probs

In [10]:
def get_prior_probability(generator, src_input, tgt_input, position, entity):
    assert len(src_input) == len(tgt_input), "source & target length should match."
    decoder_output = generator.mask_filling(src_input, tgt_input)
    init_input, tokens, token_probs = decoder_output
    
    probs = []
    for p, tok, tokp, e in zip(position, tokens, token_probs, entity):
        probs.append(get_probability(p, tok, tokp, e).item())
    return probs

#### Read Google Data

In [11]:
from tqdm import tqdm

In [12]:
google_data = json.load(open('../Dataset/entity_data.json'))

In [13]:
print(len(google_data))
google_data['29347895']['BERTS2S']

500


{'summary': 'veteran classical music conductor christopher hogwood has died at the age of 83.',
 'summary_upper': 'Veteran classical music conductor Christopher Hogwood has died at the age of 83 .',
 'ents': [{'start': 34,
   'end': 45,
   'label': 0,
   'type': 'PERSON',
   'ent': 'Christopher'},
  {'start': 46, 'end': 53, 'label': 0, 'type': 'PERSON', 'ent': 'Hogwood'},
  {'start': 66,
   'end': 79,
   'label': 0,
   'type': 'DATE',
   'ent': 'the age of 83'}]}

In [14]:
def process_document(raw_doc):
    TRIVIAL_SENTS = [
        'Share this with',
        'Copy this link',
        'These are external links and will open in a new window',
    ]
    
    raw_doc = raw_doc.strip()
    raw_doc_sents = raw_doc.split('\n')
    
    start_signal = False
    filtered_sentences = []
    for s in raw_doc_sents: 
        if start_signal:
            filtered_sentences.append(s)
        elif len(s.split()) > 1 and s not in TRIVIAL_SENTS:
            start_signal = True
            filtered_sentences.append(s)
            
    return ' '.join(filtered_sentences)

In [15]:
def read_document(bbcid, folder='/home/mcao610/scratch/summarization/XSum/xsum-preprocessed/document/'):
    file_path = folder + '{}.document'.format(bbcid)

    with open(file_path, 'r') as f:
        return process_document(f.read())

In [16]:
def check_document_exist(bbcid, folder='/home/mcao610/scratch/summarization/XSum/xsum-preprocessed/document/'):
    file_path = folder + '{}.document'.format(bbcid)

    return os.path.exists(file_path)

In [17]:
read_document(34687720)

'France \'s Dubuisson carded a 67 to tie with overnight leader Van Zyl of South Africa on 16 under par . McIlroy carded a third straight five under - par 67 to move to 15 under par with Thailand \'s Kiradech Aphibarnrat . The world number three \'s round included an eagle on the 12th as he bids to win his first title since May . " The 67s I \'ve shot this week have all been a little different and I feel like I \'ve played within myself for all of them , " said four - time major winner McIlroy of Northern Ireland . " I feel there \'s a low round out there for me and hopefully it \'s tomorrow . " McIlroy was level par for the day after 10 holes , dropping his first shots of the week by three - putting the third and 10th , the latter mistake prompting the 26 - year - old to throw his putter at his bag . But he hit back with a birdie on the par - five 11th and a towering four iron from 229 yards on the 13th set up an eagle from just four feet . The former world number one ruptured a ligame

#### Inference

In [18]:
def prepare_clm_inputs(source, target, ent_parts=None):
    """For Conditional Language Model."""
    if ent_parts is None:
        ent_parts = nlp(target).to_json()['ents']
    
    entities, positions = [], []
    inputs, targets = [], []

    for e in ent_parts:
        inputs.append(source)
        targets.append(target)
        positions.append((e['start'], e['end']))
        entities.append(target[e['start']: e['end']])

    return inputs, targets, positions, entities

In [19]:
def prepare_mlm_inputs(source, target, ent_parts=None):
    """For Masked Language Model."""
    if ent_parts is None:
        ent_parts = nlp(target).to_json()['ents']
    
    inputs, targets = [], []
    positions, entities = [], []

    for e in ent_parts:
        inputs.append(target[0: e['start']] + '<mask>' + target[e['end']:])
        targets.append(target)
        entities.append(target[e['start']: e['end']])
        positions.append((e['start'], e['end']))
    
    return inputs, targets, positions, entities

In [20]:
def prepare_cmlm_inputs(source, target, ent_parts=None):
    """For Masked Language Model."""
    if ent_parts is None:
        ent_parts = nlp(target).to_json()['ents']
    
    inputs, targets = [], []
    positions, entities = [], []

    for e in ent_parts:
        masked_hypothesis = target[0: e['start']] + '###' + target[e['end']:]
        masked_hypothesis = '<s> ' + masked_hypothesis + ' <\s> ' + source
        inputs.append(masked_hypothesis)
        targets.append('<s> ' + target)
        
        entities.append(target[e['start']: e['end']])
        positions.append((e['start'] + 4, e['end'] + 4))

    return inputs, targets, positions, entities

In [21]:
google_data['29347895']['BERTS2S']

{'summary': 'veteran classical music conductor christopher hogwood has died at the age of 83.',
 'summary_upper': 'Veteran classical music conductor Christopher Hogwood has died at the age of 83 .',
 'ents': [{'start': 34,
   'end': 45,
   'label': 0,
   'type': 'PERSON',
   'ent': 'Christopher'},
  {'start': 46, 'end': 53, 'label': 0, 'type': 'PERSON', 'ent': 'Hogwood'},
  {'start': 66,
   'end': 79,
   'label': 0,
   'type': 'DATE',
   'ent': 'the age of 83'}]}

In [22]:
prior_model = ConditionalSequenceGenerator(bart)
posterior_model = ConditionalSequenceGenerator(finetuned_bart)

In [23]:
unsuccessful_ids = []
for bbcid in tqdm(google_data.keys()):
    if bbcid == '39553812': continue
    if check_document_exist(bbcid):    
        for system in google_data[bbcid]:
            source = read_document(bbcid)
            target = google_data[bbcid][system]['summary_upper']

            if len(google_data[bbcid][system]['ents']) > 0:
                pro = prepare_cmlm_inputs(source, target, google_data[bbcid][system]['ents'])
                pos = prepare_cmlm_inputs(source, target, google_data[bbcid][system]['ents'])
                
                prior_probs = get_cmlm_probability(prior_model, pro[0], pro[1], pro[2], pro[3])
                posterior_probs = get_cmlm_probability(posterior_model, pos[0], pos[1], pos[2], pos[3])

                assert len(prior_probs) == len(posterior_probs) == len(google_data[bbcid][system]['ents'])
                for i in range(len(prior_probs)):
                    google_data[bbcid][system]['ents'][i]['prior'] = prior_probs[i]
                    google_data[bbcid][system]['ents'][i]['posterior'] = posterior_probs[i]
    else:
        for system in google_data[bbcid]:
            for i in range(len(google_data[bbcid][system]['ents'])):
                google_data[bbcid][system]['ents'][i]['prior'] = None
                google_data[bbcid][system]['ents'][i]['posterior'] = None

100%|██████████| 500/500 [34:09<00:00,  4.10s/it]


In [24]:
google_data['34802406']

{'BERTS2S': {'summary': 'young people in scotland are more likely than their peers to commit violent crimes, according to new research.',
  'summary_upper': 'Young people in Scotland are more likely than their peers to commit violent crimes , according to new research .',
  'ents': [{'start': 16,
    'end': 24,
    'label': 0,
    'type': 'GPE',
    'ent': 'Scotland',
    'prior': 0.1915283203125,
    'posterior': 0.8310546875}]},
 'TConvS2S': {'summary': 'young girls in scotland are more likely to be linked to violent crime, according to a new study.',
  'summary_upper': 'Young girls in Scotland are more likely to be linked to violent crime , according to a new study .',
  'ents': [{'start': 15,
    'end': 23,
    'label': 0,
    'type': 'GPE',
    'ent': 'Scotland',
    'prior': 0.1826171875,
    'posterior': 0.364501953125}]},
 'Gold': {'summary': "scotland's criminal justice system punishes poorer people and makes it difficult for them to escape poverty, according to an academic st

In [25]:
len(google_data)

500

#### Save Data

In [26]:
json.dump(google_data, open("google_data_with_proba_2CMLM.json", "w"))

In [27]:
# done