In [1]:
from __future__ import division
%matplotlib inline
import pandas as pd
import numpy as np
import sys
from collections import defaultdict
import matplotlib.pyplot as plt
from copy import deepcopy
# add the deep_disfluency module to the path
sys.path.append("../../../..")

In [2]:
from deep_disfluency.evaluation.disf_evaluation import incremental_output_disfluency_eval_from_file
from deep_disfluency.evaluation.disf_evaluation import final_output_disfluency_eval_from_file
from deep_disfluency.evaluation.eval_utils import get_tag_data_from_corpus_file
from deep_disfluency.evaluation.eval_utils import rename_all_repairs_in_line_with_index
from deep_disfluency.evaluation.eval_utils import sort_into_dialogue_speakers
from deep_disfluency.evaluation.results_utils import convert_to_latex

In [3]:
# Get the locations of all needed files
# Assume we have the incremental output
experiment_dir = "../../../experiments"

partial_words = False  # No partial words in these experiments, removed
if partial_words:
    partial = '_partial'
else:
    partial = ''
#the evaluation files (as text files)
disf_dir = "../../../data/disfluency_detection/switchboard"
disfluency_files = [
                    disf_dir + "/swbd_disf_heldout{}_data_timings.csv".format(partial),
                    disf_dir + "/swbd_disf_test{}_data_timings.csv".format(partial)
                   ]
    

allsystemsfinal = ["021/epoch_40",
                   "041/epoch_16"]

# Incremental Evaluation

In [4]:
# create final output files for the final output evaluation (and do incremental evaluation first:
DO_INCREMENTAL_EVAL = True
VERBOSE = False
if DO_INCREMENTAL_EVAL:
    all_incremental_results = {}
    all_incremental_error_dicts = {}
    for system in allsystemsfinal:
        print system
        #if 'complex' in system: break
        hyp_dir = experiment_dir + "/" + system
        #hyp_dir = experiment_dir
        for division, disf_file in zip(["heldout","test"], disfluency_files):
            print "*" * 30, division, "*" * 30
            IDs, timings, words, pos_tags, labels = get_tag_data_from_corpus_file(disf_file)
            gold_data = {} #map from the file name to the data
            for dialogue,a,b,c,d in zip(IDs, timings, words, pos_tags, labels):
                # if "asr" in division and not dialogue[:4] in good_asr: continue
                gold_data[dialogue] = (a,b,c,d)
            inc_filename = hyp_dir + "/swbd_disf_{0}{1}_data_output_increco".format(division, partial) + ".text"
            final_output_name = inc_filename.replace("_increco", "_final")
            results, error_analysis = incremental_output_disfluency_eval_from_file(
                                                                             inc_filename,
                                                                             gold_data,
                                                                             utt_eval=True,
                                                                             error_analysis=True,
                                                                             word=True,
                                                                             interval=False,
                                                                             outputfilename=final_output_name)
            if VERBOSE:
                for k,v in results.items():
                    print k,v
            all_incremental_results[division + "_" + system] = deepcopy(results)
            if "heldout" in division:
                # only do the error analyses on the heldout data
                all_incremental_error_dicts[division + "_" + system] = deepcopy(error_analysis)


021/epoch_40
****************************** heldout ******************************
loading data ../../../data/disfluency_detection/switchboard/swbd_disf_heldout_data_timings.csv
loaded 102 sequences
102 speakers
incremental output disfluency evaluation
word= True interval= False utt_eval= True




writing final output to file ../../../experiments/021/epoch_40/swbd_disf_heldout_data_output_final.text
****************************** test ******************************
loading data ../../../data/disfluency_detection/switchboard/swbd_disf_test_data_timings.csv
loaded 100 sequences
100 speakers
incremental output disfluency evaluation
word= True interval= False utt_eval= True
writing final output to file ../../../experiments/021/epoch_40/swbd_disf_test_data_output_final.text
041/epoch_16
****************************** heldout ******************************
loading data ../../../data/disfluency_detection/switchboard/swbd_disf_heldout_data_timings.csv
loaded 102 sequences
102 speakers
incremental output disfluency evaluation
word= True interval= False utt_eval= True
writing final output to file ../../../experiments/041/epoch_16/swbd_disf_heldout_data_output_final.text
****************************** test ******************************
loading data ../../../data/disfluency_detection/switc

In [6]:
final = "No incremental results here"
if DO_INCREMENTAL_EVAL:
    display_results = dict()
    display_results['RNN (window length=2) (+ POS)'] = all_incremental_results['test_021/epoch_40']
    display_results['LSTM (window length=2) (+ POS)'] = all_incremental_results['test_041/epoch_16']
    final = convert_to_latex(display_results, eval_level=['word'], inc=True, utt_seg=False, only_include=
                            ['t_t_detection_<rms_word', 't_t_detection_<rps_word', 'edit_overhead_rel_word'])
    #final = final.drop(final.columns[[-2]], axis=1)
final

Unnamed: 0,System (eval. method),TTD$_{rms}$ (word),TTD$_{rps}$ (word),EO (word)
0,LSTM (window length=2) (+ POS) (transcript),2.369,1.083,3.136
1,RNN (window length=2) (+ POS) (transcript),2.413,1.087,3.259


# Final output evaluation

In [7]:
all_results = {}
all_error_dicts = {}
VERBOSE = False
for system in allsystemsfinal:
    print system
    #if 'complex' in system: break
    hyp_dir = experiment_dir
    for division, disf_file in zip(["heldout", "test"],disfluency_files):
        #if not division == "heldout": continue
        print "*" * 30, division, "*" * 30
        IDs, timings, words, pos_tags, labels = get_tag_data_from_corpus_file(disf_file)
        gold_data = {} #map from the file name to the data
        for dialogue,a,b,c,d in zip(IDs, timings, words, pos_tags, labels):
            # if "asr" in division and not dialogue[:4] in good_asr: continue
            d = rename_all_repairs_in_line_with_index(list(d))
            gold_data[dialogue] = (a,b,c,d)

        #the below does just the final output evaluation, assuming a final output file, faster
        hyp_file = hyp_dir + '/' + system + "/" + "swbd_disf_{0}{1}_data_output_final.text".format(division,
                                                                                                        partial)
        word = True  # world-level analyses
        error = True # get an error analysis
        results,speaker_rate_dict,error_analysis = final_output_disfluency_eval_from_file(
                                                        hyp_file,
                                                        gold_data,
                                                        utt_eval=False,
                                                        error_analysis=error,
                                                        word=word,
                                                        interval=False,
                                                        outputfilename=None
                                                    )
        #the below does incremental and final output in one, also outputting the final outputs
        #derivable from the incremental output, takes quite a while
        if VERBOSE:
            for k,v in results.items():
                print k,v
        all_results[division + "_" + system] = deepcopy(results)
        if "heldout" in division:
            # only do the error analyses on the heldout data
            all_error_dicts[division + "_" + system] = deepcopy(error_analysis)


021/epoch_40
****************************** heldout ******************************
loading data ../../../data/disfluency_detection/switchboard/swbd_disf_heldout_data_timings.csv
loaded 102 sequences
102 speakers
final output disfluency evaluation
word= True interval= False utt_eval= False
word
****************************** test ******************************
loading data ../../../data/disfluency_detection/switchboard/swbd_disf_test_data_timings.csv
loaded 100 sequences
100 speakers
final output disfluency evaluation
word= True interval= False utt_eval= False
word
041/epoch_16
****************************** heldout ******************************
loading data ../../../data/disfluency_detection/switchboard/swbd_disf_heldout_data_timings.csv
loaded 102 sequences
102 speakers
final output disfluency evaluation
word= True interval= False utt_eval= False
word
****************************** test ******************************
loading data ../../../data/disfluency_detection/switchboard/swbd_di

In [8]:
print all_results.keys()

['heldout_041/epoch_16', 'heldout_021/epoch_40', 'test_021/epoch_40', 'test_041/epoch_16']


In [9]:
display_results = dict()
display_results['RNN (window length=2) (+ POS)'] = all_results['test_021/epoch_40']
display_results['LSTM (window length=2) (+ POS)'] = all_results['test_041/epoch_16']
final = convert_to_latex(display_results, eval_level=['word'], inc=False, utt_seg=False, only_include=
                        ['f1_<rm_word', 'f1_<rps_word', 'f1_<e_word'])
#final = final.drop(final.columns[[-2]], axis=1)
final

Unnamed: 0,System (eval. method),$F_{rm}$ (per word),$F_{rps}$ (per word),$F_{e}$ (per word)
0,LSTM (window length=2) (+ POS) (transcript),0.62,0.721,0.887
1,RNN (window length=2) (+ POS) (transcript),0.627,0.721,0.856


# Error Analysis

In [10]:
#Error analyses on exact match ('rms') and getting the right repair start ('rps')
target_tags = ['<rms', '<rps']

for div,all_error in all_error_dicts.items():
    # print div, type(all_error)
   
    if type(all_error) == bool: continue
    if "test" in div: continue
    #if not 'TTO only' in div or "asr" in div: continue
    for tag, errors in all_error.items():
        if tag not in target_tags:
            continue
        print "*" * 30, div, tag, "*" * 30
        # print errors
        # continue
        #if not 'TTO only' in div or "asr" in div: continue
        error = {"TP" : {}, "FP" : {}, "FN": {} }
        for k,v in errors.items():
            #if k == "FP":
            #    continue
            # print k, len(v)
            typedict = defaultdict(int)
            lendict = defaultdict(int)
            for repair in v:

                #print repair.gold_context
                onset = ""
                if tag == "<rps" or tag == "<rms":
                    
                    
                    for i in range(0,len(repair.gold_context)):
                        if repair.gold_context[i] == "+|+":
                            onset = repair.gold_context[i+1]
                            break

                    word = onset.split("|")[0]
                    #if k == "FP":
                    #    onset = gold_onset
                    if "<e" in onset and not tag == "<e":
                        typedict["<e"]+=1
                    else:
                        if word in ["and","or","but","so","because","that","although"]:
                            typedict["CC"]+=1
                        elif word in ["i","we","they","im","ive","he","she","id"]:
                            typedict["subj"]+=1
                        elif word in ["you","the"] or "$" in word:
                            typedict["proper_other"]+=1
                        elif word in ["yeah","no","okay","yes","right","uh-huh"]:
                            typedict["ack"]+=1
                        elif word in ["it","its"]:
                            typedict["it"]+=1
                        else:
                            typedict[word]+=1
                
                if tag == "<rps" or tag == "<rms": # and not k == 'FP':
                    if k == "TP" and len(repair.reparandumWords) > 8:
                        # should not be getting any over 8 words
                        print "** overlength repair!"
                        print repair
                    lendict[len(repair.reparandumWords) + len(repair.interregnumWords)]+=1
                    repair_type = None
                    if repair.type:
                        repair_type = repair.type 
                        typedict[repair_type]+=1

            error[k]['len'] = deepcopy(lendict)
            error[k]['type'] = deepcopy(typedict)

                
        for mode in ['type', 'len']:
            #q1. THE RECALL RATES FOR VARIOUS GOLD REPAIR TYPES
            print mode, "*" * 30
            tps = error['TP'][mode]
            fns = error['FN'][mode]
            fps = error['FP'][mode]

            total_tps = 0
            total_fns = 0
            total_fps = 0
            top_n = 50
            all_items = list(set(tps.keys() + fns.keys()))
            # print all_items
            for k in sorted(all_items,  reverse=False):
                #print k, "*" * 30
                if mode == 'type' and k not in ["rep", "del", "sub"]:
                    continue
                recall_total = tps[k] + fns[k]
                recall = 0 if tps[k] == 0 else tps[k]/recall_total
                precision_total = tps[k] + fps[k]
                precision = 0 if tps[k] == 0 else tps[k]/precision_total
                fscore = 0 if precision == 0 or recall == 0 else (2 * (precision * recall))/(precision + recall)
                # print k, ':', tps[k], "out of", recall_total
                #print k, ':', tps[k], "out of", precision_total
                total_tps += tps[k]
                total_fns += fns[k]
                total_fps += fps[k]
                print " & ".join([str(k), "({0}/{1})".format(tps[k],recall_total), 
                                  '{0:.3f}'.format(fscore)]) + "\\\\"
                top_n-=1
                if top_n <= 0:
                    break
            print total_tps/(total_fns + total_tps)

            if False:
                #q2. ERROR TYPE SUMMARY
                print "*" * 30
                total = sum(fns.values()+tps.values())

                errormass = 0
                errortotal = 0
                top_n = 20
                for k,v in sorted(tps.items(),key= lambda x: x[1],reverse=True):
                    print k,"&",v,"&",'{0:.2f}'.format(v/total)
                    errormass +=(v/total * 100)
                    errortotal+=v
                    top_n-=1
                    if top_n <= 0: break
                print "total &",errortotal,"&",'{0:.2f}'.format(errormass)

****************************** heldout_041/epoch_16 <rps ******************************
type ******************************
del & (17/101) & 0.286\\
rep & (821/1001) & 0.894\\
sub & (368/844) & 0.573\\
0.6197327852
len ******************************
0 & (0/1) & 0.000\\
1 & (850/1096) & 0.856\\
2 & (246/467) & 0.659\\
3 & (68/205) & 0.482\\
4 & (25/88) & 0.424\\
5 & (8/46) & 0.296\\
6 & (7/26) & 0.424\\
7 & (0/8) & 0.000\\
8 & (2/5) & 0.571\\
9 & (0/1) & 0.000\\
10 & (0/1) & 0.000\\
11 & (0/1) & 0.000\\
15 & (0/1) & 0.000\\
0.6197327852
****************************** heldout_041/epoch_16 <rms ******************************
type ******************************
del & (1/93) & 0.021\\
rep & (868/1069) & 0.871\\
sub & (235/784) & 0.411\\
0.567317574512
len ******************************
1 & (828/1095) & 0.827\\
2 & (218/469) & 0.580\\
3 & (44/207) & 0.319\\
4 & (12/86) & 0.222\\
5 & (2/46) & 0.078\\
6 & (0/26) & 0.000\\
7 & (0/8) & 0.000\\
8 & (0/5) & 0.000\\
9 & (0/1) & 0.000\\
10 & (0/1) &