In [None]:
from processing_arithmetics.arithmetics.treebanks import treebank
from processing_arithmetics.arithmetics.MathExpression import MathExpression
from processing_arithmetics.arithmetics.MathTreebank import MathTreebank
from processing_arithmetics.sequential.architectures import DiagnosticClassifier
import numpy as np
import re
import pandas as pd
from plotnine import *
import plotnine
import matplotlib
from ipywidgets import widgets
from IPython.display import clear_output


In [None]:
basename = "2018-10-11T11+02:00"
digits = np.arange(-10,11)
operators = ['+', '-']
classifiers = "subtracting intermediate_locally intermediate_recursively grammatical depth minus1depth minus2depth minus3depth minus4depth minus1depth_count switch_mode".split()
#classifiers = "intermediate_locally intermediate_recursively".split()
arch1 = DiagnosticClassifier(digits=digits,
                             operators=operators,
                             classifiers=classifiers)
arch1.add_pretrained_model(model="./diagnoses/ScalarPrediction_GRU_infix_{}_10_dc8.h5".format(basename))
arch1.model.compile(loss=arch1.loss_functions, 
                    optimizer='adam',
                    metrics=arch1.metrics,
                    loss_weights=arch1.loss_weights)

arch2 = DiagnosticClassifier(digits=digits,
                         operators=operators,
                         classifiers=classifiers)
arch2.add_pretrained_model(model="./diagnoses/ScalarPrediction_GRU_infix_{}_16_dc8.h5".format(basename))
arch2.model.compile(loss=arch2.loss_functions,
                    optimizer='adam',
                    metrics=arch2.metrics,
                    loss_weights=arch2.loss_weights)


In [None]:
def plot_tb(tb, classifier_list):
    for example in tb.examples:
        display(example[0])

    data = arch1.generate_test_data(tb, digits=digits)
    results = {}
    for name, X, Y in data:
        results[name] = {}
        for n, arch in [("10", arch1), ("16", arch2)]:
            result = np.array(arch.model.predict(X))
            result = result.reshape(*result.shape[:-1])
            results[name][n] = result

    results = results['test treebank']

    plot_data = []
    for model in results:
        model_results = results[model]
        for n, example in enumerate(tb.examples):
            seq_len = len(list(example[0].iterate('infix')))
            for i, classifier in enumerate(classifiers):      
                plot_data.append(pd.DataFrame({
                    'model': model,
                    'example': n,
                    'classifier': classifier,
                    'prediction': model_results[i, n, -seq_len:],
                    'expected': data[0][2][classifier][n,-seq_len:,0],
                }).reset_index())
    plot_df = pd.concat(plot_data).reset_index(drop=True)

    plot_df = plot_df.melt(id_vars=['index', 'classifier', 'example', 'model'], value_vars=['expected', 'prediction'])
    plot_df = plot_df[plot_df['classifier'].isin(classifier_list)]
    plotnine.options.figure_size = (12,3*len(classifier_list))
    theplot = ggplot(plot_df, aes(x="index", y="value")) + \
        geom_step(aes(linetype="variable")) + \
        scale_x_continuous(breaks=range(len(list(tb.examples[0][0].iterate('infix')))), labels=list(tb.examples[0][0].iterate('infix'))) + \
        facet_grid("classifier~model", scales="free")
    display(theplot)

## Plotting trajectories 

We can not only use diagnostic classifiers to evaluate the overall match with a specific hypotheses, we can also track the fit of our predictions over time, by comparing the trajectories of predicted variables with the trajectories of observed variables while the networks process different sentences. In the cell below, the predictions of the diagnostic classifiers on a sentences you input are depicted, along with their target trajectories as defined by the hypotheses. These trajectories confirm that the curve representing the cumulative strategy is much better predicted than the recursive one.

In [None]:


clas = widgets.SelectMultiple(
    options=classifiers,
    value=['intermediate_locally', 'intermediate_recursively'],
    rows=len(classifiers),
    description='Classifiers',
    disabled=False
)
text = widgets.Text(description="Expression", placeholder="( 1 + (3 - 8 ) )", disabled=False)
vbox = widgets.VBox([clas, text])
display(vbox)

def handle_submit(foo):
    clear_output()
    display(vbox)
    tb = MathTreebank({}, digits=digits)
    raw = text.value
    tokens = re.findall(r"(-?\d+|\(|\)|\+|\-)", raw)
    try:
        tb.add_example_from_string(" ".join(tokens))
        plot_tb(tb, clas.value)
    except ValueError:
        print("Sorry, that doesn't seem to be a valid expression")
    
text.on_submit(handle_submit)