In [1]:
#Computes the accuracies for the outputs from the EACL 2017 experiments on
#joint incremental utterance segmentation and disfluency detection
#this assumes the experiments are in simple_rnn_disf/rnn_disf_detection/experiments/
from __future__ import division
%matplotlib inline
import pandas as pd
import numpy as np
import sys
from collections import defaultdict
import matplotlib.pyplot as plt

from copy import deepcopy
sys.path.append("../../../../")
# from mumodo.mumodoIO import open_intervalframe_from_textgrid

In [2]:
#add the evaluation module functions
from deep_disfluency.evaluation.disf_evaluation import incremental_output_disfluency_eval_from_file
from deep_disfluency.evaluation.disf_evaluation import final_output_disfluency_eval_from_file
from deep_disfluency.evaluation.eval_utils import get_tag_data_from_corpus_file
from deep_disfluency.evaluation.eval_utils import rename_all_repairs_in_line_with_index
from deep_disfluency.evaluation.eval_utils import sort_into_dialogue_speakers
from deep_disfluency.evaluation.results_utils import convert_to_latex

In [3]:
# Get the locations of all needed files
# Assume we have the incremental output
experiment_dir = "../../../experiments"

partial_words = True  # No partial words in these experiments, removed
if partial_words:
    partial = '_partial'
else:
    partial = ''
#the evaluation files (as text files)
disf_dir = "../../../data/disfluency_detection/switchboard"
disfluency_files = [
                    disf_dir + "/swbd_disf_heldout{}_data_timings.csv".format(partial),
                    disf_dir + "/swbd_disf_test{}_data_timings.csv".format(partial)
                   ]
allsystemsfinal = [
                 ("033/epoch_45",'RNN (timing)'),
                 ("035/epoch_6",'LSTM (timing)'),
                 #  ("036/epoch_15",'LSTM (complex tags)'),
                 #  ("037/epoch_6",'LSTM (disf only) (timing)'),
                 #  ("038/epoch_8",'LSTM (TTO only) (timing)'),
                  ]

In [4]:
good_asr_heldout = [line.strip("\n") for line in open(
        "../../../data/disfluency_detection/swda_divisions_disfluency_detection/swbd_disf_heldout_ASR_good_ranges.text")]
good_asr_test = [line.strip("\n") for line in open(
        "../../../data/disfluency_detection/swda_divisions_disfluency_detection/swbd_disf_test_ASR_good_ranges.text")]

# Incremental Evaluation

In [5]:
# create final output files for the final output evaluation (and do incremental evaluation first:
# NB this takes a while!
DO_INCREMENTAL_EVAL = True
if DO_INCREMENTAL_EVAL:
    all_incremental_results = {}
    all_incremental_error_dicts = {}
    for system, system_name in allsystemsfinal:
        print system
        #if 'complex' in system: break
        hyp_dir = experiment_dir + "/" + system
        #hyp_dir = experiment_dir
        for division, disf_file in zip(["heldout","test"], disfluency_files):
            print "*" * 30, division, "*" * 30
            IDs, timings, words, pos_tags, labels = get_tag_data_from_corpus_file(disf_file)
            gold_data = {} #map from the file name to the data
            for dialogue,a,b,c,d in zip(IDs, timings, words, pos_tags, labels):
                # if "asr" in division and not dialogue[:4] in good_asr: continue
                gold_data[dialogue] = (a,b,c,d)
            inc_filename = hyp_dir + "/swbd_disf_{0}{1}_data_output_increco".format(division, partial) + ".text"
            final_output_name = inc_filename.replace("_increco", "_final")
            results, error_analysis = incremental_output_disfluency_eval_from_file(
                                                                             inc_filename,
                                                                             gold_data,
                                                                             utt_eval=True,
                                                                             error_analysis=True,
                                                                             word=True,
                                                                             interval=True,
                                                                             outputfilename=final_output_name)
            for k,v in results.items():
                print k,v
            all_incremental_results[division + "_" + system] = deepcopy(results)
            if "heldout" in division:
                # only do the error analyses on the heldout data
                all_incremental_error_dicts[division + "_" + system] = deepcopy(error_analysis)


033/epoch_45
****************************** heldout ******************************
loading data ../../../data/disfluency_detection/switchboard/swbd_disf_heldout_partial_data_timings.csv
loaded 102 sequences
102 speakers
incremental output disfluency evaluation
word= True interval= True utt_eval= True




writing final output to file ../../../experiments/033/epoch_45/swbd_disf_heldout_partial_data_output_final.text
edit_overhead_rel_<rm None
t_t_detection_<e_interval 0.271064261422
t_t_detection_<rps_interval 0.189058147714
t_t_detection_<rms_interval nan
edit_overhead_rel_interval 2.93701049825
delayed_acc_<rm_4_interval None
t_t_detection_final_t/>_interval None
delayed_acc_<rm_4_word None
delayed_acc_<rm_1_word None
t_t_detection_final_t/>_word None
t_t_detection_t/>_interval 1.01139070944
edit_overhead_rel_tto None
t_t_detection_t/>_word 0.269863013699
delayed_acc_<rm_5_interval None
delayed_acc_<rm_2_interval None
delayed_acc_<rm_1_interval None
delayed_acc_<rm_mean_word None
delayed_acc_<rm_3_word None
delayed_acc_<rm_2_word None
delayed_acc_<rm_mean_interval None
processing_overhead_interval None
t_t_detection_<e_word 0.0169444444444
delayed_acc_<rm_6_interval None
t_t_detection_<rms_word nan
edit_overhead_rel_word 2.93701049825
delayed_acc_<rm_5_word None
delayed_acc_<rm_3_inter

In [6]:
final = "No incremental results here"
if DO_INCREMENTAL_EVAL:
    display_results = dict()
    # display_results['RNN joint task (timing)'] = all_incremental_results['test_033/epoch_45']
    display_results['LSTM joint task (timing)'] = all_incremental_results['test_035/epoch_6']
    final = convert_to_latex(display_results, eval_level=['word'], inc=True, utt_seg=False, only_include=
                            ['t_t_detection_t/>_interval', 't_t_detection_<rps_word', 'edit_overhead_rel_word'])
    #final = final.drop(final.columns[[-2]], axis=1)
final

Unnamed: 0,System (eval. method),TTD$_{tto}$ (time in s),TTD$_{rps}$ (word),EO (word)
0,LSTM joint task (timing) (transcript),1.185,1.002,3.551


# Final output evaluation

In [7]:
VERBOSE = True
all_results = {}
all_error_dicts = {}
for system, system_name in allsystemsfinal:
    print system
    #if 'complex' in system: break
    hyp_dir = experiment_dir
    for division, disf_file in zip(["heldout", "test"],disfluency_files):
        #if not division == "heldout": continue
        print "*" * 30, division, "*" * 30
        IDs, timings, words, pos_tags, labels = get_tag_data_from_corpus_file(disf_file)
        gold_data = {} #map from the file name to the data
        for dialogue,a,b,c,d in zip(IDs, timings, words, pos_tags, labels):
            # if "asr" in division and not dialogue[:4] in good_asr: continue
            d = rename_all_repairs_in_line_with_index(list(d))
            gold_data[dialogue] = (a,b,c,d)

        #the below does just the final output evaluation, assuming a final output file, faster
        hyp_file = hyp_dir + '/' + system + "/" + "swbd_disf_{0}{1}_data_output_final.text".format(division,
                                                                                                        partial)
        word = True  # world-level analyses
        error = True # get an error analysis
        results,speaker_rate_dict,error_analysis = final_output_disfluency_eval_from_file(
                                                        hyp_file,
                                                        gold_data,
                                                        utt_eval="disf only" not in system_name,
                                                        error_analysis=error,
                                                        word=word,
                                                        interval=False,
                                                        outputfilename=None
                                                    )
        #the below does incremental and final output in one, also outputting the final outputs
        #derivable from the incremental output, takes quite a while
        for k,v in results.items():
            print k,v
        all_results[division + "_" + system] = deepcopy(results)
        if "heldout" in division:
            # only do the error analyses on the heldout data
            all_error_dicts[division + "_" + system] = deepcopy(error_analysis)


033/epoch_45
****************************** heldout ******************************
loading data ../../../data/disfluency_detection/switchboard/swbd_disf_heldout_partial_data_timings.csv
loaded 102 sequences
102 speakers
final output disfluency evaluation
word= True interval= False utt_eval= True
word
pearson_r_p_value_rps_number 2.00127665348e-57
p_<rm_word 0
spearman_rank_p_value_rps_rate_per_word 7.61361432954e-39
r_<rpnsub_word 0
pearson_r_p_value_rps_rate_per_utt 2.12135073355e-47
p_t/>_word 0.766506922258
f1_t/>_word 0.612292641429
r_<rm.<i.<rp_word 0.184617528215
DSER_word 69.8123229462
r_<i_word 0
f1_<rpndel_word 0
f1_<rp_word 0.510400616333
pearson_r_correl_rps_number 0.960620179499
p_<rpnrep_word 0
p_<i_word 0
r_<e_word 0.901325854156
p_<rpnsub_word 0
r_t/>_relaxed_word 0
SegER None
r_<rm_word 0
p_<rms_word 0
p_<rps_word 0.827997489014
r_<rpn_word 0
pearson_r_p_value_rps_rate_per_word 1.03789571867e-47
p_<rpndel_word 0
r_<rms_word 0
f1_<e_word 0.913200723327
NIST_SU None
r_<rp

In [8]:
print all_results.keys()

['heldout_033/epoch_45', 'heldout_035/epoch_6', 'test_033/epoch_45', 'test_035/epoch_6']


In [9]:
display_results = dict()
#display_results['RNN (joint task)'] = all_results['test_033/epoch_45']
display_results['LSTM (joint task)'] = all_results['test_035/epoch_6']
final = convert_to_latex(display_results, eval_level=['word'], inc=False, utt_seg=False, only_include=
                        ['f1_t/>_word', 'f1_<rps_word', 'f1_<e_word'])
#final = final.drop(final.columns[[-2]], axis=1)
final

Unnamed: 0,System (eval. method),$F_{TTO}$ (per word),$F_{rps}$ (per word),$F_{e}$ (per word)
0,LSTM (joint task) (transcript),0.684,0.666,0.915


# Repair Error Analysis

In [10]:
# rps and rms errors
#Error analyses on exact match ('rms') and getting the right repair start ('rps')
target_tags = ['<rms', '<rps']

for div,all_error in all_error_dicts.items():
    # print div, type(all_error)
   
    if type(all_error) == bool: continue
    if "test" in div: continue
    #if not 'TTO only' in div or "asr" in div: continue
    for tag, errors in all_error.items():
        if tag not in target_tags:
            continue
        print "*" * 30, div, tag, "*" * 30
        # print errors
        # continue
        #if not 'TTO only' in div or "asr" in div: continue
        error = {"TP" : {}, "FP" : {}, "FN": {} }
        for k,v in errors.items():
            #if k == "FP":
            #    continue
            # print k, len(v)
            typedict = defaultdict(int)
            lendict = defaultdict(int)
            for repair in v:

                #print repair.gold_context
                onset = ""
                if tag == "<rps" or tag == "<rms":
                    
                    
                    for i in range(0,len(repair.gold_context)):
                        if repair.gold_context[i] == "+|+":
                            onset = repair.gold_context[i+1]
                            break

                    word = onset.split("|")[0]
                    #if k == "FP":
                    #    onset = gold_onset
                    if "<e" in onset and not tag == "<e":
                        typedict["<e"]+=1
                    else:
                        if word in ["and","or","but","so","because","that","although"]:
                            typedict["CC"]+=1
                        elif word in ["i","we","they","im","ive","he","she","id"]:
                            typedict["subj"]+=1
                        elif word in ["you","the"] or "$" in word:
                            typedict["proper_other"]+=1
                        elif word in ["yeah","no","okay","yes","right","uh-huh"]:
                            typedict["ack"]+=1
                        elif word in ["it","its"]:
                            typedict["it"]+=1
                        else:
                            typedict[word]+=1
                
                if tag == "<rps" or tag == "<rms": # and not k == 'FP':
                    if k == "TP" and len(repair.reparandumWords) > 8:
                        # should not be getting any over 8 words
                        print "** overlength repair!"
                        print repair
                    lendict[len(repair.reparandumWords) + len(repair.interregnumWords)]+=1
                    repair_type = None
                    if repair.type:
                        repair_type = repair.type 
                        typedict[repair_type]+=1

            error[k]['len'] = deepcopy(lendict)
            error[k]['type'] = deepcopy(typedict)

                
        for mode in ['type', 'len']:
            #q1. THE RECALL RATES FOR VARIOUS GOLD REPAIR TYPES
            print mode, "*" * 30
            tps = error['TP'][mode]
            fns = error['FN'][mode]
            fps = error['FP'][mode]

            total_tps = 0
            total_fns = 0
            total_fps = 0
            top_n = 50
            all_items = list(set(tps.keys() + fns.keys()))
            # print all_items
            for k in sorted(all_items,  reverse=False):
                #print k, "*" * 30
                if mode == 'type' and k not in ["rep", "del", "sub"]:
                    continue
                recall_total = tps[k] + fns[k]
                recall = 0 if tps[k] == 0 else tps[k]/recall_total
                precision_total = tps[k] + fps[k]
                precision = 0 if tps[k] == 0 else tps[k]/precision_total
                fscore = 0 if precision == 0 or recall == 0 else (2 * (precision * recall))/(precision + recall)
                # print k, ':', tps[k], "out of", recall_total
                #print k, ':', tps[k], "out of", precision_total
                total_tps += tps[k]
                total_fns += fns[k]
                total_fps += fps[k]
                print " & ".join([str(k), "({0}/{1})".format(tps[k],recall_total), 
                                  '{0:.3f}'.format(fscore)]) + "\\\\"
                top_n-=1
                if top_n <= 0:
                    break
            print total_tps/(total_fns + total_tps)

            if False:
                #q2. ERROR TYPE SUMMARY
                print "*" * 30
                total = sum(fns.values()+tps.values())

                errormass = 0
                errortotal = 0
                top_n = 20
                for k,v in sorted(tps.items(),key= lambda x: x[1],reverse=True):
                    print k,"&",v,"&",'{0:.2f}'.format(v/total)
                    errormass +=(v/total * 100)
                    errortotal+=v
                    top_n-=1
                    if top_n <= 0: break
                print "total &",errortotal,"&",'{0:.2f}'.format(errormass)

****************************** heldout_033/epoch_45 <rps ******************************
type ******************************
del & (25/132) & 0.318\\
rep & (734/1022) & 0.836\\
sub & (560/1061) & 0.691\\
0.595485327314
len ******************************
0 & (1/1) & 0.007\\
1 & (960/1258) & 0.866\\
2 & (243/531) & 0.628\\
3 & (64/222) & 0.448\\
4 & (31/106) & 0.453\\
5 & (12/50) & 0.387\\
6 & (4/25) & 0.276\\
7 & (3/11) & 0.429\\
8 & (1/6) & 0.286\\
9 & (0/1) & 0.000\\
10 & (0/1) & 0.000\\
11 & (0/2) & 0.000\\
15 & (0/1) & 0.000\\
0.595485327314
****************************** heldout_033/epoch_45 <rms ******************************
type ******************************
del & (0/132) & 0.000\\
rep & (0/1022) & 0.000\\
sub & (0/1061) & 0.000\\
0.0
len ******************************
1 & (0/1258) & 0.000\\
2 & (0/531) & 0.000\\
3 & (0/223) & 0.000\\
4 & (0/106) & 0.000\\
5 & (0/50) & 0.000\\
6 & (0/25) & 0.000\\
7 & (0/11) & 0.000\\
8 & (0/6) & 0.000\\
9 & (0/1) & 0.000\\
10 & (0/1) & 0.000\\


# Utterance Segmentation Analysis

In [11]:
#Error analyses
error = {"TP" : {}, "FP" : {}, "FN": {} }
 
for div,all_error in all_error_dicts.items():
    print div, type(all_error)
    if type(all_error) == bool: continue
    if "test" in div: continue
    #if not 'TTO only' in div or "asr" in div: continue
    for tag,errors in all_error.items():
        print div, tag
        #if not 'TTO only' in div or "asr" in div: continue
        if not tag == "t/>": continue
        for k,v in errors.items():
    
            print ""
            #if k == "FP": continue
            print k, len(v)
            typedict = defaultdict(int)
            lendict = defaultdict(int)
            print v[0]
            for repair in v:
                #if len(repair)==0: continue
                #print "*"
                #print repair
                #print repair.gold_context
                onset = ""
                if tag == "<rps":
                    
                    for i in rcange(0,len(repair.gold_context)):
                        if repair.gold_context[i] == "+|+":
                            onset = repair.gold_context[i+1]
                            break
                else:
                    gold_onset = ""
                    onset = ""
                    word = ""
                    if len(repair.gold_tags_right_context)>1:
                        gold_onset = repair.gold_tags_right_context[1]
                        onset = repair.tags_right_context[1]
                        word = repair.words_right_context[1]
                    #penult = repair.tags_left_context[-1]
                    #print repair
                    if k == "FP":
                        onset = gold_onset
                    if "<rps" in onset:
                        typedict["<rps"]+=1
                    elif "<e" in onset:
                        typedict["<e"]+=1
                    else:
                        if word in ["and","or","but","so","because","that","although"]:
                            typedict["CC"]+=1
                        elif word in ["i","we","they","im","ive","he","she","id"]:
                            typedict["subj"]+=1
                        elif word in ["you","the"] or "$" in word:
                            typedict["proper_other"]+=1
                        elif word in ["yeah","no","okay","yes","right","uh-huh"]:
                            typedict["ack"]+=1
                        elif word in ["it","its"]:
                            typedict["it"]+=1
                        else:
                            typedict[word]+=1
                         
                
                #if "<t" in repair.gold_context
                #for t in ["tt","cc","ct","tc"]:
                #    if "<" + t + ">" in onset:
                #        typedict[t]+=1
                #        if not t[0]=='t':
                #            print repair
                if tag == "<rps" and not k == 'FP':
                    lendict[repair.type]+=1
                error[k]['len'] = deepcopy(lendict)
                error[k]['type'] = deepcopy(typedict)

#tp = deepcopy(lendict)
#q1. THE RECALL RATES FOR VARIOUS GOLD REPAIRS
tp = error['TP']['len']
print tp
print error['FN']['len']
for k,v in sorted(error['FN']['len'].items()):
    print " & ".join([k, "({0})".format(v + tp[k]), 
                      '{0:.1f}'.format(100 * float(tp[k])/float(v+ tp[k]))]) + "\\\\"

tps = error['TP']['type']
fns = error['FN']['type']
fps = error['FP']['type']

total = sum(fns.values()+tps.values())

errormass = 0
errortotal = 0
for k,v in sorted(fns.items(),key= lambda x: x[1],reverse=True):
    print k,"&",v,"&",'{0:.2f}'.format(v/total * 100)
    errormass +=(v/total * 100)
    errortotal+=v
print "total &",errortotal,"&",'{0:.2f}'.format(errormass)

heldout_033/epoch_45 <type 'dict'>
heldout_033/epoch_45 <rps
heldout_033/epoch_45 <rms
heldout_033/epoch_45 <e
heldout_033/epoch_45 t/>

FP 877
left context=point|<f/><cc/> is|<f/><cc/> comfortable|<f/><cc/> upper|<f/><cc/> middle|<f/><cc/>
right context=class|<f/><ct/> i|<f/><tc/> guess|<f/><cc/> you|<f/><cc/> might|<f/><cc/>
gold left context=
gold right context=
type = None

TP 2879
left context=uh|<e/><cc/> families|<f/><cc/> of|<f/><cc/> rather|<f/><cc/> modest|<f/><cc/>
right context=means|<f/><ct/> and|<f/><tc/> uh|<e/><cc/> our|<f/><cc/> family|<f/><cc/>
gold left context=uh|<e/><cc/><diact type="sd"/> families|<f/><cc/><diact type="sd"/> of|<f/><cc/><diact type="sd"/> rather|<f/><cc/><diact type="sd"/> modest|<f/><cc/><diact type="sd"/>
gold right context=means|<f/><ct/><diact type="sd"/> and|<f/><tc/><diact type="sd"/> uh|<e/><cc/><diact type="sd"/> our|<f/><cc/><diact type="sd"/> family|<f/><cc/><diact type="sd"/>
type = None

FN 2769
left context=
right context=turn|<f/><cc

In [12]:
 #TODO for paper/future
# - check WER for ASR results and exclude those with high ones given they might have high overlap :(
# - need to adjust the time to detection scores based on the time it comes in from Increco?? 
#      Also for ttdetection can only use word ends unless we re-do the mapping- just needs explanation
# - delayed accuracy based on time, or not bother? do moving window instead and plot this over time- average moving window accuracy
# - error analysis plots
# 036- full task with LSTM- should improve massively over 034, which also needs re-running
# Reproduce 027 (with full training data, efficiently) and re-run with LSTM- not much time.
#Q2 TODO the extent to which the network is memorizing- need to plug these in with the repair gold standards