In [1]:
from summary_reader import read_tensors
from sampling import sample_errors

In [2]:
# Visualization, exporting utils
import csv
from IPython.display import HTML, display

def html_table(errors, keys=None):
    """ Returns HTML string for a table of the errors.
    
    Args:
        errors: List of dictionaries. All elements of dictionaries will be represented
                as strings in the table.
        keys: Keys from the dictionaries that should be shown. If None, will use keys
              from first element of the |errors|.

    Returns:
        Valid HTML string.
    """
    if not errors:
        return ""

    keys = keys or errors[0].keys()
    s = "<table>"
    s += "<tr><th>" + "</th><th>".join(keys) + "</th></tr>"
    row_template = "<tr><td>"+"</td><td>".join("{"+key+"}" for key in keys) + "</td></tr>"
    s += "".join(row_template.format(**error) for error in errors)
    s += "</table>"
    return s

def show_table(errors):
    display(HTML(html_table(errors)))

def export_csv(errors, filename, transform_fn=lambda k, x: str(x), keys=None):
    keys = keys or errors[0].keys()
    with open(filename, 'w') as fp:
        writer = csv.writer(fp)
        writer.writerow(keys)
        for e in errors:
            writer.writerow([transform_fn(key, e[key]) for key in keys])


filename = "../results/dnn-regularized/error_analysis/events.out.tfevents.*"
iterator = read_tensors(filename)
errors = sample_errors(iterator, lambda x: not x['is_error'])

def different_party_aye(x):
    return x['SponsorParty'] != x['VoterParty'] and x['label'] == [1]

def same_party_nay(x):
    return x['SponsorParty'] == x['VoterParty'] and x['label'] == [0]

def evaluate(errors, functions):
    results = {}
    for f in functions:
        results[f.__name__] = sum(1.0 for e in errors if f(e)) / len(errors)
        
    return results

print evaluate(errors, [different_party_aye, same_party_nay])

show_table(errors)
with open('../data/vocab.txt') as fp:
    tokens = [line.strip() for line in fp]

def clean(key, value):
    global tokens
    if key == 'BillTitle':
        return " ".join([tokens[x] for x in value])
    else:
        return str(value[0])

export_csv(errors, 'dnn-regularized-errors.csv', transform_fn=clean)

{'different_party_aye': 0.44, 'same_party_nay': 0.16}


BillTitle,is_error,VoterState,SponsorParty,prediction,label,VoterParty,VoterAge
[ 4 13 38 1379 21 19 28 4 18 2 1 826 41 5 186  5 1 23 5 93 96 473 35 240 41 1898 3 2 6 7  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0],[False],['CT'],['UNK'],[ True],[0],['democrat'],[69]
[ 4 13 38 374 21 19 28 4 198 864 1422 361 3 97 160  1001 3 54 864 1325 3 2 6 7 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0],[False],['CA'],['republican'],[False],[1],['democrat'],[72]
[12 10 2 42 3 65 31 2 1 8 9 14 15 16 73 3 2 6 7 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0],[False],['IL'],['republican'],[ True],[0],['democrat'],[72]
[ 323 27 285 4 1 287 4 18 2 20 493 67 2 1 21  19 40 3 2 263 237 17 1 2109 5 82 620 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0],[False],['IL'],['democrat'],[False],[1],['republican'],[82]
[ 20 45 4 179 1 21 19 1368 306 98 264 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0],[False],['CO'],['republican'],[False],[1],['democrat'],[62]
[ 4 163 58 117 203 850 5 1146 817 152 46 2385 210 4 1146  578 789 590 43 294 147 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0],[False],['CA'],['UNK'],[False],[1],['democrat'],[57]
[12 92 68 10 2 1 8 9 79 3 2 6 7 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0],[False],['PA'],['democrat'],[ True],[0],['republican'],[89]
[ 4 63 1 82 346 1542 43 113 4 77 17 20 353 4 77  17 1 22 322 5 1033 3 32 6 607 3 2 6 7 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0],[False],['SD'],['republican'],[False],[1],['democrat'],[70]
[ 4 1920 1 1921 31 3 695 5 1113 1114 1922 3 2 6 7  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0],[False],['IL'],['UNK'],[ True],[0],['democrat'],[73]
[ 4 292 1 107 2012 5 1 37 993 86 55 3 4 18 2  52 55 4 108 1542 205 2 2297 1400 54 3 3 2 6 7  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0],[False],['IL'],['democrat'],[ True],[0],['republican'],[71]
