In [None]:
import json
import pandas as pd
from tqdm import tqdm
import re
import statistics
import nltk

In [None]:
TRAIN_FILE=''
GOLD_PASSAGE_FILE=''
RESULT_FILE=''
OUTPUT_CF_TRAIN_PATH=''

### Answer Span Detection and Removal

In [None]:
import math
import re
import random
import nltk
random.seed(0)

make_random_mask = False
exact_answers = []
not_include_answer = []
span_len = 16
WINDOWING=False

def find_and_mask(answers, txt):
    masked_txt = txt
    masked=False
    for ans in answers:
        ans = re.sub(r'([^a-zA-Z0-9,\.!\? -])',r'\\\1',ans)
        ans_len_char = len(ans)

        ans_split = ans.split(' ')
        ans_len_token = len(ans_split)

        re_ans = '( )?'.join(ans_split)
        re_ans = re.compile(re_ans,re.I)
        start_idxs = list(re.finditer(re_ans, masked_txt))
        ans_spans = [(-1,0)] + [(start_idx.start(), start_idx.start() + ans_len_char) for start_idx in start_idxs] + [(len(masked_txt),-1)]

        if len(start_idxs) > 0:
            new_masked_txt = ''
            for i, _ in enumerate(ans_spans[:-1]):
                s_idx = ans_spans[i][1]
                e_idx = ans_spans[i+1][0]
                new_masked_txt += masked_txt[s_idx:e_idx]+' '
            masked_txt = new_masked_txt
            masked=True

    if masked:
        return masked_txt
    return None

### Synthesizing Counterfactual Passages for AAR Evaluation

In [None]:
import json
with open(GOLD_PASSAGE_FILE,'r') as f:
    raw_gold_passages = json.load(f)
raw_gold_passages = raw_gold_passages['data']

In [None]:
# Excluding any entries in gold_passage_info without corresponding context information
gold_passages = list()
for gold in raw_gold_passages:
    if gold['context']:
        gold_passages.append(gold)

In [None]:
import spacy
import copy
import nltk
import random
from tqdm import tqdm
nlp = spacy.load("en_core_web_sm")

removed_percentage = []
masked_gold_infos = []
count = 0

for iter,row in tqdm(enumerate(gold_passages)):
    original_passage = row['context']
    passage=''
    masked_passage = ''
    short_answer =  ''
    answer_sentence = ''
    answer_masked_passage = ''
    answer_percent=-1.0
    masked = False

    answer_mask = find_and_mask(row['short_answers'], original_passage)
    if answer_mask:
        answer_masked_passage = answer_mask
    
    passage = original_passage
    original_sentences = nltk.tokenize.sent_tokenize(original_passage)
    for sent_idx, sentence in enumerate(original_sentences):
        ans=row['short_answers']
        if find_and_mask(ans, sentence):
            cf_sentences = copy.deepcopy(original_sentences)
            cf_sentences.remove(sentence)
            
            if len(cf_sentences) <= 0:
                count += 1
                break
            masked_passage = ' '.join(cf_sentences)
            
            len_passage = len(nltk.word_tokenize(passage))
            len_masked_passage = len(nltk.word_tokenize(masked_passage))
            tok_diff = len_passage - len_masked_passage
            
            rest_sents = ' '.join(nltk.word_tokenize(masked_passage)[:-tok_diff])
            rest_sents = nltk.sent_tokenize(rest_sents)
            original_sents = rest_sents[:sent_idx] + [sentence] +rest_sents[sent_idx:]
            trunc_passage = ' '.join(original_sents)
            
            short_answer = copy.deepcopy(ans)
            answer_sentence = copy.deepcopy(sentence)
            answer_percent = (len(nltk.word_tokenize(answer_sentence))/len(nltk.word_tokenize(passage))*100)
            masked=True
            break

    if masked:
        masked_row = copy.deepcopy(row)
        masked_row['qid'] = iter
        masked_row['short_answers'] = short_answer
        masked_row['answer_sentence'] = answer_sentence
        masked_row['answer_mask_passage'] = answer_masked_passage
        masked_row['sentence_mask_passage'] = masked_passage
        masked_row['answer_passage_trunc'] = trunc_passage
        masked_row['answer_percent'] = answer_percent
        masked_row['token_difference'] = tok_diff
        masked_gold_infos.append(masked_row)

In [None]:
import csv
with open('que_psg_ans_counterfactual_gold_info.csv','w') as f:
    writer = csv.writer(f, delimiter='\t')
    for line in masked_gold_infos:
        writer.writerow([
                line['qid'], 
                line['question'], 
                '', 
                line['context'], 
                line['title'], 
                line['answer_mask_passage'], 
                line['sentence_mask_passage'], 
                line['answer_passage_trunc'], 
                line['short_answers'], 
                line['answer_sentence'], 
                line['answer_percent'],
                line['token_difference']
            ])

### Synthesizing Counterfactual Passages for Training

In [None]:
import json
with open(TRAIN_FILE,'r') as f:
    train_data = json.load(f)

In [None]:
import spacy
import copy
import re
from tqdm import tqdm
import nltk
nlp = spacy.load("en_core_web_sm")

cf_data = list()

for iter,row in tqdm(enumerate(train_data)):
    answers = row['answers']
    entry = copy.deepcopy(row)
    entry['positive_ctxs'] = list()
    entry['cf_negative_ctxs'] = list()

    for j, txt in enumerate(row['positive_ctxs']):
        original_passage = txt['text']
        original_title = txt['title']
        original_sentences = nltk.tokenize.sent_tokenize(original_passage)
        
        masked_passage = ''
        nonanswer_sentences = list()
        nonanswer_title=original_title+''
        
        for sentence in original_sentences:
            if not find_and_mask(answers, sentence):
                nonanswer_sentences.append(sentence)

        if original_title:
            for ans in answers:
                try :
                    nonanswer_title=re.sub(ans,'',nonanswer_title)
                except:
                    continue

        if nonanswer_sentences:
            entry['cf_negative_ctxs'].append({
                'title':nonanswer_title,
                'text':' '.join(nonanswer_sentences)
            })
            entry['positive_ctxs'].append(txt)
    cf_data.append(entry)

In [None]:
with open(OUTPUT_CF_TRAIN_PATH,'w') as f:
    json.dump(cf_data, f, indent=4)