In [18]:
import yaml
import os
import random
import numpy as np
import pathos.multiprocessing as mp

from simple_state_recurrent_model.model import SimpleRecurrentModel
from simple_state_recurrent_model.evaluation import Evaluator

# Load Data

In [2]:
DATA_DIR = '/Users/frankwang/projects/data_model_repo'

In [3]:
data = []
for f in os.listdir(os.path.join(DATA_DIR, 'datasets/money_detection/')):
    with open(os.path.join('datasets/money_detection/', f), 'r') as fopen:
        data += yaml.load(fopen)
input_data = [d['input'] for d in data]
label_data = [d['labels'] for d in data]

# Setup Augmentor

In [4]:
class Augmentor(object):
    def __init__(self, input_data, label_data):
        self.input_data = input_data
        self.label_data = label_data
        
        self.labels = []
        for i, s in enumerate(self.input_data):
            for label in self.label_data[i]:
                self.labels.append(s[label[0]:label[1]])
                
        self.aug_input_data = []
        self.aug_label_data = []
                
    def substitute(self, input_string, labels):
        new_entries = random.choices(self.labels, k=len(labels))
        end_idx = 0
        ret_s = ''
        ret_l = []
        
        for i, lab in enumerate(labels):
            ret_s += input_string[end_idx:lab[0]]
            ret_s += new_entries[i]
            ret_l.append([len(ret_s) - len(new_entries[i]), len(ret_s)])
            end_idx = lab[1]
            
        ret_s += input_string[end_idx:]
        return ret_s, ret_l
    
    def generate_aug_data(self, n):
        for inp, lab in zip(self.input_data, self.label_data):
            for _ in range(n):
                new_inp, new_lab = self.substitute(inp, lab)
                self.aug_input_data.append(new_inp)
                self.aug_label_data.append(new_lab)

In [5]:
a = Augmentor(input_data, label_data)
a.generate_aug_data(10)

# Setup Model

In [25]:
# Overwriting the loo_cross_validation to do data augmentation as part of training

def loo_cross_validation(self, batch_size=32, train_rate=0.1, steps=1, epochs_per_step=100000, threads=4):
    def loo_x_validate(loo_cand):
        model = SimpleRecurrentModel(**self.model_args)
        filtered_inputs = [self.input_data[i] for i in range(len(self.input_data)) if i != loo_cand]
        filtered_targets = [self.target_data[i] for i in range(len(self.target_data)) if i != loo_cand]
        
        a = Augmentor(filtered_inputs, filtered_targets)
        a.generate_aug_data(20)
        processed_inputs, processed_targets = model.assemble_data(a.aug_input_data, a.aug_label_data)

        training_results = {}
        for s in range(1, steps + 1):
            model._raw_train(processed_inputs, processed_targets, batch_size, train_rate, epochs_per_step)
            training_results[s * epochs_per_step] = model.compute_inference(self.input_data[loo_cand])
        return training_results

    loo_candidates = range(len(self.input_data))
    p = mp.Pool(threads)
    results = p.map(loo_x_validate, loo_candidates)
    p.close()

    processed_results = {}
    for k in results[0].keys():
        processed_results[k] = [r[k] for r in results]
    self.loo_cross_validation_results = processed_results
    return self.loo_cross_validation_results
    
Evaluator.loo_cross_validation = loo_cross_validation

In [26]:
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ €£'

def preprocess_fn(s):
    s = s.lower()
    ret_s = ''
    for c in s:
        if c in alphabet:
            ret_s += c
        else:
            ret_s += '#'
    return ret_s

In [27]:
model_args = {
    'alphabet': alphabet,
    'window_size': 5,
    'preprocess': preprocess_fn,
    'window_shift': 1    
}

In [28]:
evaluator = Evaluator(model_args, input_data, label_data)

In [33]:
evaluator.loo_cross_validation(steps=15, epochs_per_step=2000)
evaluator.compute_accuracy_curve(10)

{1: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}

# Plots

In [36]:
mean_data = [
    np.array([evaluator.accuracy_curve_results[k][i] for k in evaluator.accuracy_curve_results.keys()]).mean()
    for i in range(len(evaluator.accuracy_curve_results[2000]))
]
stddev_data = [
    np.array([evaluator.accuracy_curve_results[k][i] for k in evaluator.accuracy_curve_results.keys()]).std()
    for i in range(len(evaluator.accuracy_curve_results[2000]))
]
max_data = [
    np.array([evaluator.accuracy_curve_results[k][i] for k in evaluator.accuracy_curve_results.keys()]).max()
    for i in range(len(evaluator.accuracy_curve_results[2000]))
]

In [37]:
import plotly.offline as plotly
import plotly.graph_objs as go

In [39]:
frames = [
    go.Scatter(
        y=800*np.array(stddev_data)**3,
        x=list(range(len(evaluator.accuracy_curve_results[2000]))),
        mode='lines',
        name='spread',
        line={'smoothing': 1000, 'shape': 'spline','color':'rgba(26,150,65,0.1)'},
        fillcolor='rgba(26,150,65,0.05)',
        fill='tonexty',
    ),
    go.Scatter(
        y=mean_data,
        x=list(range(len(evaluator.accuracy_curve_results[2000]))),
        mode='lines',
        line={'smoothing': 500, 'shape': 'spline', 'color':'rgb(35,159,255)'},
        name='mean',
    ),
    go.Scatter(
        y=max_data,
        x=list(range(len(evaluator.accuracy_curve_results[2000]))),
        mode='lines',
        line={'smoothing': 500, 'shape': 'spline', 'color':'rgba(255,89,12, 0.3)'},
        name='max',
    ),
]
layout = go.Layout(
    title=dict(text='Accuracy During Training at Different Thresholds', x=0.1),
    xaxis=dict(
        range=[-1, 101], zeroline=False, title="Threshold",
        titlefont=dict(family='sans serif', size=10, color='#727272')),
    yaxis=dict(
        zeroline=False, title="Accuracy",
        titlefont=dict(family='sans serif', size=10, color='#727272')),
    legend=dict(
        x=0.05, y=0.15,
        font=dict(family='sans-serif', size=10, color='#727272'),
        bgcolor='rgba(0,0,0,0)'
    ),
    titlefont=dict(family='sans serif', size=14, color='#727272')
)
fig = go.Figure(data=frames, layout=layout)
plotly.iplot(fig)

In [40]:
plotly.plot(fig, image_filename='threshold_selection', image='svg')

'file:///Users/frankwang/projects/data_model_repo/temp-plot.html'

In [None]:
frames = [
    go.Scatter(
        y=[evaluator.accuracy_curve_results[k][i * 20] for k in evaluator.accuracy_curve_results.keys()],
        x=np.array(range(len(evaluator.accuracy_curve_results[2000]))) * 2000,
        mode='lines',
        opacity=0.4,
        line={'smoothing': 1000, 'shape': 'spline'},
        name="{} threshhold".format(i / 5)
    )
    for i in range(1, 5)
]
layout = go.Layout(
    title=dict(text='Accuracy During Training at Different Thresholds', x=0.1),
    xaxis=dict(
        zeroline=False, title="Epochs",
        titlefont=dict(family='sans serif', size=10, color='#727272')),
    yaxis=dict(
        zeroline=False, title="Test Accuracy",
        titlefont=dict(family='sans serif', size=10, color='#727272')),
    legend=dict(
        x=0.05, y=0.15,
        font=dict(family='sans-serif', size=10, color='#727272'),
        bgcolor='rgba(0,0,0,0)'),
    titlefont=dict(family='sans serif', size=14, color='#727272')
)
fig = go.Figure(data=frames, layout=layout)
plotly.iplot(fig)

In [None]:
plotly_.plot(fig, image_filename='accuracy_epochs.svg', image='svg')