In [2]:
from tensorflow.contrib.framework import checkpoint_utils
from scipy.spatial.distance import cosine, euclidean
import numpy as np

from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)
import plotly.graph_objs as go
from collections import defaultdict

In [3]:
x_range = np.arange(0, 310000, 1000)

In [4]:
def get_diffs(run):
    diffs = list()
    last = None
    for i in x_range:
        reader = checkpoint_utils.load_checkpoint("%s/save_%d.ckpt" % (run, i))
        names = reader.get_variable_to_shape_map()
        def of_interest(name):
            if "Adam" in name:
                return False

            if "write" in name:
                return False

            if "beta" in name:
                return False

            return True

        to_read = sorted(filter(of_interest, names.keys()))

        t = dict()
        for name in to_read:
            t[name] = reader.get_tensor(name)

        if last is not None:
            d = dict()
            for k, v in t.items():
                d[k] = euclidean(v.reshape(-1), last[k].reshape(-1)) / float(v.size)
            diffs.append(d)

        last = t
    return diffs

In [5]:
human_key = defaultdict(list)
machine_key = defaultdict(list)

for d in get_diffs("combined-run"):
    for k, v in d.items():
        machine_key[k].append(v)
        
for d in get_diffs("attention-run"):
    for k, v in d.items():
        human_key[k].append(v)

In [6]:
traces = list()
for k, vs in human_key.items():
    trace = go.Scatter(
        x = x_range,
        y = vs,
        mode = 'lines',
        name = k
    )
    traces.append(trace)
iplot(traces)

In [13]:
def read_log(folder):
    log = open("%s/log.csv" % folder).read().split("\n")[:301]
    accs = list()
    for l in log:
        accs.append(float(l.split(",")[-1]))
    return accs

machine_accs = read_log("combined-run")
human_accs = read_log("attention-run")

trace1 = go.Scatter(
    x = x_range / 1000,
    y = human_key["decoder/LSTMCell/W_0"],
    mode = 'lines',
    name = "All-Step Weight Delta",
     marker=dict(color="orange")
)
    
trace2 = go.Scatter(
    x = x_range / 1000,
    y = machine_key["decoder/LSTMCell/W_0"],
    mode = 'lines',
    name = "Last-Step Weight Delta",
     marker=dict(color="pink"),
)

machine_trace = go.Scatter(
    x = x_range / 1000,
    y = machine_accs,
    name="Last-Step Accuracy",
    marker=dict(color="green"),
    yaxis="y2"
)

data = [trace2, trace1, machine_trace]
layout = go.Layout(
    title='Accuracy and Weight Updates',
    xaxis=dict(title="Iteration (10^3)"),
    yaxis=dict(
        title='Weight Delta'
    ),
    yaxis2=dict(
        title='Accuracy',
        overlaying='y',
        side='right'
    )
)
fig = go.Figure(data=data, layout=layout)


iplot(fig)