In [43]:
from summary_reader import read_tensors

In [44]:
import numpy as np
import itertools

def take(iterable, k):
    return list(itertools.islice(iterable, k))

def reservoir_sample(iterator, k):
    """ Samples k elements uniformly from iterator. """
    r = take(iterator, k)
    for i, x in enumerate(iterator):
        j = np.random.randint(0, i+k)
        if j < k:
            r[j] = x
    return r


def sample_errors(example_iterator, is_error_fn, N=100, is_shuffled=False):
    """ Returns |N| elements of example_iterator that satisfy |is_error_fn|.
    
    Args:
        example_iterator: iterable of dictionary of numpy arrays.
        is_error_fn: function (dict of ndarray) -> bool. Returns true if
                     example is an error.
        N: number of examples to draw.
        is_shuffled: if true, assumes example iterator is pre-shuffled.
    """
    errors = itertools.ifilter(is_error_fn, example_iterator)
    if is_shuffled:
        return take(errors, N)
    else:
        return reservoir_sample(errors, N)
    

def unbatch(batch_iterator):
    """ Slices elements of |batch_iterator| along the 0th axis, yielding
        one at a time.
    """
    for batch in batch_iterator:
        for idx in xrange(len(batch.values()[0])):
            yield dict((k, batch[k][idx]) for k in batch.keys())    

In [70]:
# Visualization, exporting utils
import csv
from IPython.display import HTML, display

def html_table(errors, keys=None):
    """ Returns HTML string for a table of the errors.
    
    Args:
        errors: List of dictionaries. All elements of dictionaries will be represented
                as strings in the table.
        keys: Keys from the dictionaries that should be shown. If None, will use keys
              from first element of the |errors|.

    Returns:
        Valid HTML string.
    """
    if not errors:
        return ""

    keys = keys or errors[0].keys()
    s = "<table>"
    s += "<tr><th>" + "</th><th>".join(keys) + "</th></tr>"
    row_template = "<tr><td>"+"</td><td>".join("{"+key+"}" for key in keys) + "</td></tr>"
    s += "".join(row_template.format(**error) for error in errors)
    s += "</table>"
    return s

def show_table(errors):
    display(HTML(html_table(errors)))

def export_csv(errors, filename, transform_fn=lambda k, x: str(x), keys=None):
    keys = keys or errors[0].keys()
    with open(filename, 'w') as fp:
        writer = csv.writer(fp)
        writer.writerow(keys)
        for e in errors:
            writer.writerow([transform_fn(key, e[key]) for key in keys])


filename = "../results/dnn-regularized/error_analysis/events.out.tfevents.*"
iterator = read_tensors(filename)
errors = sample_errors(unbatch(iterator), lambda x: not x['is_error'])

def different_party_aye(x):
    return x['SponsorParty'] != x['VoterParty'] and x['label'] == [1]

def same_party_nay(x):
    return x['SponsorParty'] == x['VoterParty'] and x['label'] == [0]

def evaluate(errors, functions):
    results = {}
    for f in functions:
        results[f.__name__] = sum(1.0 for e in errors if f(e)) / len(errors)
        
    return results

print evaluate(errors, [different_party_aye, same_party_nay])

show_table(errors)
with open('../data/vocab.txt') as fp:
    tokens = [line.strip() for line in fp]

def clean(key, value):
    global tokens
    if key == 'BillTitle':
        return " ".join([tokens[x] for x in value])
    else:
        return str(value[0])

export_csv(errors, 'dnn-regularized-errors.csv', transform_fn=clean)

{'different_party_aye': 0.38, 'same_party_nay': 0.23}


BillTitle,is_error,VoterState,SponsorParty,prediction,label,VoterParty,VoterAge
[ 4 78 1 279 443 11 5 69 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0],[False],['MA'],['UNK'],[False],[1],['democrat'],[68]
[12 10 2 1 23 5 39 2 1 8 9 14 15 16 69 3 2 6 7 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0],[False],['CT'],['republican'],[ True],[0],['republican'],[71]
[ 4 13 38 1282 21 19 28 4 18 20 898 1424 4 173 3  472 31 17 1 1419 2226 17 1 750 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0],[False],['NJ'],['democrat'],[False],[1],['republican'],[47]
[ 4 26 10 2 8 9 273 2 34 36 5 1 23 5 39 2 34 59  3 2 39 36 5 1 23 5 42 4 114 34 95 121 2 52 8 9  3 2 6 7 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0],[False],['OK'],['republican'],[ True],[0],['republican'],[82]
[ 4 159 463 17 1 107 77 880 29 167 20 1293 320 2 1  1294 5 1295 193 697 3 6 107 46 850 3 2 6 7 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0],[False],['FL'],['republican'],[ True],[0],['republican'],[74]
[12 10 2 42 3 65 31 2 1 8 9 14 15 16 73 3 2 6 7 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0],[False],['MN'],['republican'],[ True],[0],['democrat'],[75]
[12 92 68 10 2 1 8 9 69 3 2 6 7 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0],[False],['CT'],['republican'],[ True],[0],['democrat'],[69]
[12 92 68 10 2 1 8 9 69 3 2 6 7 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0],[False],['AR'],['republican'],[ True],[0],['democrat'],[82]
[ 4 13 1 161 67 11 5 247 4 18 2 236 35 1 67  484 202 2 81 431 1513 276 5 1514 1549 1303 237 2 22 431  2 22 86 33 5 1 17 1 67 332 444 708 431 996 17  147 613 187 27 35 67 3 2 6 7],[False],['PA'],['republican'],[ True],[0],['republican'],[85]
[ 4 163 545 2 43 1245 223 3 1622 42 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0  0 0 0 0 0 0 0 0 0 0],[False],['FL'],['republican'],[ True],[0],['republican'],[59]


IndexError: list index out of range