In [2]:
from collections import defaultdict
import itertools
import json
import sys
import os
from datasets import load_dataset

from matplotlib import pyplot as plt

_PATH = '/home/sireesh/current_semester/11-797_QA/11-797-multidoc2dial/multidoc2dial/'
sys.path.append(os.path.join(_PATH))  # noqa: E402 # isort:skip

In [3]:
val_data = json.load(open(_PATH + 'data/multidoc2dial/multidoc2dial_dial_validation.json','r'))
docs = json.load(open(_PATH + 'data/multidoc2dial/multidoc2dial_doc.json','r'))

# dialogue act type. Are these all `query_condition?`
DAs      = [line.strip() for line in open(_PATH + 'data/mdd_all/dd-generation-structure/val.{}'.format('da'), "r").readlines()]
# domain - dmv, ssa, studentaid, va
domains  = [line.strip() for line in open(_PATH + 'data/mdd_all/dd-generation-structure/val.{}'.format('domain'), "r").readlines()]
# gold passage ID to answer this question
pids     = [line.strip() for line in open(_PATH + 'data/mdd_all/dd-generation-structure/val.{}'.format('pids'), "r").readlines()]
# query IDs, in the format {ID}_{turn}
qids     = [line.strip() for line in open(_PATH + 'data/mdd_all/dd-generation-structure/val.{}'.format('qids'), "r").readlines()]
# the actual text of the query
sources  = [line.strip() for line in open(_PATH + 'data/mdd_all/dd-generation-structure/val.{}'.format('source'), "r").readlines()]
# Gold responses
targets  = [line.strip() for line in open(_PATH + 'data/mdd_all/dd-generation-structure/val.{}'.format('target'), "r").readlines()]
# Gold article titles
titles   = [line.strip() for line in open(_PATH + 'data/mdd_all/dd-generation-structure/val.{}'.format('titles'), "r").readlines()]

In [4]:
# making sure we have the same number of document turns here.
len(DAs), len(domains), len(pids), len(qids), len(sources), len(targets), len(titles)

(4201, 4201, 4201, 4201, 4201, 4201, 4201)

In [5]:
retr_results = [line.replace('####', '\t').replace('\n', '').split('\t') for line in open('Results_retrieval.txt', "r").readlines()]
grounding_pids = []
grounding_titles = []
for x in retr_results:
    len_ = len(x)
    grounding_titles.append(x[:len_//2])
    grounding_pids.append(x[len_//2:])


hypos = [line.strip() for line in open('Results_generation.txt', "r").readlines()]
assert len(DAs) == len(domains) == len(pids) == len(qids) == len(sources) == len(targets) == len(titles) == len(retr_results) == len(hypos)

In [6]:
len(hypos)

4201

In [7]:
from scripts.rag.utils_rag import exact_match_score, f1_score

def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    return max(metric_fn(prediction, gt) for gt in ground_truths)

from datasets import load_metric
metric = load_metric("sacrebleu")

generation_metrics = []
metrics_by_query = defaultdict(lambda: [])
metrics_by_turn_no = defaultdict(lambda: [])

for qid, prediction, ground_truths in zip(qids, hypos, targets):
    query_group, turn_id = qid.split("_")
    em_i = metric_max_over_ground_truths(exact_match_score, prediction, [ground_truths]) 
    f1_i = metric_max_over_ground_truths(f1_score, prediction, [ground_truths])
    metric.add_batch(predictions=[prediction], references=[[ground_truths]])
    sbleu_i = metric.compute()["score"]
    metrics_dict = {'EM' : em_i, 'F1' : f1_i, 'sacrebleu' : sbleu_i}
    generation_metrics.append(metrics_dict)
    metrics_by_query[query_group].append(metrics_dict)
    metrics_by_turn_no[turn_id].append(metrics_dict)


In [8]:
def text2line(text):
    return text.replace("\n", " ").replace("\r", " ").replace("\t", " ").strip()

def split_text_section(spans, title):
    def get_text(buff, title, span):
        text = " ".join(buff).replace("\n", " ")
        parent_titles = [title.replace("/", "-").rsplit("#")[0]]
        if len(span["parent_titles"]) > 1:
            parent_titles = [ele['text'].replace("/", "-").rsplit("#")[0] for ele in span["parent_titles"]]
        text = " / ".join(parent_titles) + " // " + text
        return text2line(text)

    buff = []
    pre_sec, pre_title, pre_span = None, None, None
    passages = []
    subtitles = []
        
    for span_id in spans:
        span = spans[span_id]
        parent_titles = title
        if len(span["parent_titles"]) > 1:                        
            parent_titles = [ele['text'].replace("/", "-").rsplit("#")[0] for ele in span["parent_titles"]]
            parent_titles = " / ".join(parent_titles)
        if pre_sec == span["id_sec"] or pre_title == span["title"].strip():
            buff.append(span["text_sp"])
        elif buff:
            text = get_text(buff, title, pre_span)
            passages.append(text)
            subtitles.append(parent_titles)
            buff = [span["text_sp"]]
        else:
            buff.append(span["text_sp"])
        pre_sec = span["id_sec"]
        pre_span = span
        pre_title = span["title"].strip()
    if buff:
        text = get_text(buff, title, span)
        passages.append(text)
        subtitles.append(parent_titles)
    return passages, subtitles        

In [9]:
doc_passages = {}
all_passages = []
start_idx = 0
for domain in docs['doc_data']:
    for doc_id in docs['doc_data'][domain].keys():
        ex = docs['doc_data'][domain][doc_id]
        #passages = split_text(ex["doc_text"]) # Token-based segmentation
        passages, subtitles = split_text_section(ex["spans"], ex["title"])
        all_passages.extend(passages)
        doc_passages[ex["doc_id"]] = (start_idx, len(passages))
        start_idx += len(passages)
        
passage_map = {}
for title in doc_passages:
    psg_start_ix = doc_passages[title][0]
    n_psgs = doc_passages[title][1]
    for i in range(n_psgs):
        passage_map[psg_start_ix + i] = {"text": all_passages[psg_start_ix + i], "title": title}

In [10]:
passage_map[1806]

{'text': 'Top 5 DMV Mistakes and How to Avoid Them //   3. Letting Insurance Lapse   Because we all pay indirectly for crashes involving uninsured motorists ,  New York State requires every motorist to maintain auto insurance every single day a vehicle is registered.  DMV works with insurance companies to electronically monitor your insurance coverage ,  and we know when coverage is dropped for any reason.  When that happens ,  we mail you an insurance inquiry letter to allow you to clear up the problem.  We send 500,000 inquiry letters a year.  If the inquiry letter does not resolve the problem ,  we must suspend the vehicle registration and ,  if it persists, your driver license!We suspend 300,000 registrations a year for failure to maintain insurance.  If you fail to maintain an updated address with us ,  you won t learn that you have an insurance problem ,  and we will suspend your registration and license.  Make sure you turn in your vehicle s license plates at DMV before you canc

In [None]:
dialogue_map = {}

for i, qid in qids:
    dialogue_id, turn_id = qid.split("_")


In [13]:
overall_perf = {'EM' : 0.0, 'F1' : 0.0, 'sacrebleu' : 0.0}
total = 0


#do_print = False
#for k in range(len(qids)):
do_print = True
for k in range(5):

    if True :  # Filtering Condition  
        overall_perf['EM'] += int(generation_metrics[k]['EM'])
        overall_perf['F1'] += generation_metrics[k]['F1']
        overall_perf['sacrebleu'] += generation_metrics[k]['sacrebleu']
        total += 1
        
        if do_print :         
            print ('Query ID       : ', qids[k])
            print ('Domain         : ', domains[k])
            print ('Query          : ', sources[k])
            print ('Dialogue Act   : ', DAs[k])
            print ('Passage (Gold) : ', pids[k])
            print ('Title   (Gold) : ', titles[k])
            print ('Response(Gold) : ', targets[k])
            print ('PID    (Retr)  : ', grounding_pids[k])
            print ('Titles (Retr)  : ', grounding_titles[k])
            print ('Pred Response  : ', hypos[k])
            print ('Perf (gen)     : ', generation_metrics[k])
            print ()
    
em = 100.0 * overall_perf['EM'] / total
f1 = 100.0 * overall_perf['F1'] / total
sb = overall_perf['sacrebleu'] / total

print(f"F1: {f1: .2f}")
print(f"EM: {em: .2f}")
print(f"sacrebleu: {sb: .2f}")
print(f"all: {f1: .2f} & {em: .2f} & {sb: .2f} ")    

Query ID       :  1409501a35697e0ce68561e29577b90a_1
Domain         :  dmv
Query          :  My insurance ended so what should i do[SEP]
Dialogue Act   :  query_condition
Passage (Gold) :  1806
Title   (Gold) :  Top 5 DMV Mistakes and How to Avoid Them#3_0
Response(Gold) :  You will need to get insurance or we will suspend your registration and license
PID    (Retr)  :  ['1806', '2829', '2825', '2441', '2340', '2456', '1421', '2346', '2004', '2455']
Titles (Retr)  :  ['Top 5 DMV Mistakes and How to Avoid Them#3_0', 'Pay insurance lapse civil penalty#1_0', 'Pay insurance lapse civil penalty#1_0', 'Provide proof of insurance#1_0', 'Insurance lapses#3_0', 'Respond to DMV insurance letters and orders#3_0', 'Access Your VA Life Insurance Policy Online | Veterans Affairs#1_0', 'Insurance lapses#3_0', 'How insurance premium reduction works#3_0', 'Respond to DMV insurance letters and orders#3_0']
Pred Response  :  you want to know what to do if your insurance lapsed?
Perf (gen)     :  {'EM': F

In [15]:
for domain in val_data['dial_data']:
    for flow in val_data['dial_data'][domain]:
        print (flow['dial_id'])
        print ("---------------------")
        for turn in flow['turns']:
            print ("turn_id      : ", turn['turn_id'])
            print ("role         : ", turn['role'])
            print ("Dialogue Act : ", turn['da'])
            print ("Utterance    : ", turn['utterance'])
            print ("Reference    : ", turn['references'])                        
            print()
        break


1409501a35697e0ce68561e29577b90a
---------------------
turn_id      :  1
role         :  user
Dialogue Act :  query_condition
Utterance    :  My insurance ended so what should i do
Reference    :  [{'label': 'precondition', 'id_sp': '23', 'doc_id': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}]

turn_id      :  2
role         :  agent
Dialogue Act :  respond_solution
Utterance    :  You will need to get insurance or we will suspend your registration and license
Reference    :  [{'label': 'solution', 'id_sp': '24', 'doc_id': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}, {'label': 'solution', 'id_sp': '25', 'doc_id': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}, {'label': 'solution', 'id_sp': '26', 'doc_id': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}]

turn_id      :  3
role         :  user
Dialogue Act :  query_condition
Utterance    :  Don't do that I'll get insurance
Reference    :  [{'label': 'precondition', 'id_sp': '28', 'doc_id': 'Top 5 DMV Mistakes and How to Avoid T