In [1]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt

import seq2seq
from seq2seq import *

import re

datadir = 'data/'

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
def parse_file(filename='geobase.txt'):
    kb = {}
    with open(datadir+filename, 'r') as f:
        for line in f:
            if line.startswith('state'):
                m = re.match('state\((.*)\).\n', line.replace("'", ''))
                data = m.group(1).split(',')
                
                if 'state' not in kb:
                    kb['state'] = []
                
                relations = ['abbreviation', 'capital', 'population', 'area', \
                             'state_number', 'city', 'city', 'city', 'city']
                
                for rel, subj in zip(relations, data[1:]):
                    kb['state'].append((rel, subj, data[0]))

            elif line.startswith('city'):
                pass
            elif line.startswith('river'):
                pass
            elif line.startswith('border'):
                pass
            elif line.startswith('highlow'):
                pass
            elif line.startswith('mountain'):
                pass
            elif line.startswith('road'):
                pass
            elif line.startswith('lake'):
                pass
            else:
                pass
    return kb

In [3]:
kb = parse_file()

In [4]:
def generate_sentence(rel, subj, obj):
    return '%s is %s of %s' % (subj, rel, obj)

In [5]:
def generate_fol(rel, subj, obj):
    return '%s ( %s, %s )' % (rel, subj, obj)

In [6]:
src = []
tar = []
for tup in kb['state']:
    src.append(generate_sentence(*tup))
    tar.append(generate_fol(*tup))
print src[:10]
print tar[:10]

['al is abbreviation of alabama', 'montgomery is capital of alabama', '3894.0e+3 is population of alabama', '51.7e+3 is area of alabama', '22 is state_number of alabama', 'birmingham is city of alabama', 'mobile is city of alabama', 'montgomery is city of alabama', 'huntsville is city of alabama', 'ak is abbreviation of alaska']
['abbreviation ( al, alabama )', 'capital ( montgomery, alabama )', 'population ( 3894.0e+3, alabama )', 'area ( 51.7e+3, alabama )', 'state_number ( 22, alabama )', 'city ( birmingham, alabama )', 'city ( mobile, alabama )', 'city ( montgomery, alabama )', 'city ( huntsville, alabama )', 'abbreviation ( ak, alaska )']


In [8]:
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences

tokenizer = Tokenizer()
tokenizer.fit_on_texts(src)
src_inputs = tokenizer.texts_to_sequences(src)
src_inputs = pad_sequences(src_inputs,
                           maxlen=max(len(seq) for seq in src_inputs))
src_m, src_n = src_inputs.shape
src_inputs = src_inputs.reshape((src_m, src_n, 1))

tokenizer = Tokenizer(filters='!"#$%&*+,-./:;<=>?@[\]^_`{|}~')
tokenizer.fit_on_texts(tar)
tar_inputs = tokenizer.texts_to_sequences(tar)
tar_inputs = pad_sequences(tar_inputs,
                           maxlen=max(len(seq) for seq in tar_inputs))
tar_m, tar_n = tar_inputs.shape
tar_inputs = tar_inputs.reshape((tar_m, tar_n, 1))

In [33]:
def test_seq2seq(src_inputs, tar_inputs, hidden_dim=100):
    _, input_length, input_dim = src_inputs.shape
    _, output_length, output_dim = tar_inputs.shape
    
    models = []
    
    # SIMPLE SEQ2SEQ
    
    # epoch 100: 7873.7063
    # models += [SimpleSeq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=output_length, input_shape=(input_length, input_dim))]
    
    # models += [SimpleSeq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=output_length, input_shape=(input_length, input_dim), depth=2)]
    # models += [SimpleSeq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=output_length, input_shape=(input_length, input_dim), depth=4)]
    
    # epoch 100: 12396.6243
    # models += [SimpleSeq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=output_length, input_shape=(input_length, input_dim), depth=8)]

    # SEQ2SEQ
 
    models += [Seq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=output_length, input_shape=(input_length, input_dim))]
    models += [Seq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=output_length, input_shape=(input_length, input_dim), peek=True)]
    models += [Seq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=output_length, input_shape=(input_length, input_dim), depth=4)]
    # explodes
    models += [Seq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=output_length, input_shape=(input_length, input_dim), peek=True, depth=4)]
    
    # Attention
    
    models += [AttentionSeq2Seq(output_dim=output_dim, hidden_dim=hidden_dim, output_length=output_length, input_shape=(input_length, input_dim))]
    #models += [AttentionSeq2Seq(output_dim=output_dim, hidden_dim=100, output_length=output_length, input_shape=(input_length, input_dim), depth=2)]
    #models += [AttentionSeq2Seq(output_dim=output_dim, hidden_dim=200, output_length=output_length, input_shape=(input_length, input_dim), depth=2)]
    #models += [AttentionSeq2Seq(output_dim=output_dim, hidden_dim=300, output_length=output_length, input_shape=(input_length, input_dim), depth=4)]
    
    histories = []
    for model in models:
        # model.compile(loss='mean_absolute_error', optimizer='adam', metrics=['accuracy'])
        model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])
        history = model.fit(src_inputs, tar_inputs, nb_epoch=100)
        histories.append(history)
    return histories

In [34]:
histories = test_seq2seq(src_inputs, tar_inputs)



Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 6

Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 4

Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100

KeyboardInterrupt: 

In [None]:
def plot_history(history):
    loss_list = [s for s in history.history.keys() if 'loss' in s and 'val' not in s]
    val_loss_list = [s for s in history.history.keys() if 'loss' in s and 'val' in s]
    acc_list = [s for s in history.history.keys() if 'acc' in s and 'val' not in s]
    val_acc_list = [s for s in history.history.keys() if 'acc' in s and 'val' in s]
    
    if len(loss_list) == 0:
        print('Loss is missing in history')
        return 
    
    ## As loss always exists
    epochs = range(1,len(history.history[loss_list[0]]) + 1)
    
    ## Loss
    plt.figure(1)
    for l in loss_list:
        plt.plot(epochs, history.history[l], 'b', label='Training loss (' + str(str(format(history.history[l][-1],'.5f'))+')'))
    for l in val_loss_list:
        plt.plot(epochs, history.history[l], 'g', label='Validation loss (' + str(str(format(history.history[l][-1],'.5f'))+')'))
    
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    ## Accuracy
    plt.figure(2)
    for l in acc_list:
        plt.plot(epochs, history.history[l], 'b', label='Training accuracy (' + str(format(history.history[l][-1],'.5f'))+')')
    for l in val_acc_list:    
        plt.plot(epochs, history.history[l], 'g', label='Validation accuracy (' + str(format(history.history[l][-1],'.5f'))+')')

    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

In [None]:
# TESTING THE SIMPLE SEQ2SEQ
for history in histories:
    plot_history(history)

In [None]:
# SEQ2SEQ
for history in histories:
    plot_history(history)