In [1]:
from data_gen import *
from models import *
from keras import models

Using TensorFlow backend.


In [2]:

# data generator for gofa word, root word, word feature
dg2 = DataGen(data="data/goffa.txt")

# length of a word
n_input_length = len(char2int)
n_steps_in = dg2.max_root_len
n_steps_out = dg2.max_output_len

In [3]:
model = models.load_model("model.h5")
root_input = model.get_layer("root_word_input").output
feature_input = model.get_layer("word_feature_input").output
target_input = model.get_layer("target_word_input").output

state_h = model.get_layer("concatnate").output
encoder_model = Model([root_input, feature_input], state_h)


decoder_state_input_h = Input(shape=(256,))
decoder_gru = model.get_layer("decoder_gru")
gru_outputs, state_h  = decoder_gru(target_input, initial_state=decoder_state_input_h)

decoder_dense = model.get_layer("train_output")
decoder_outputs = decoder_dense(gru_outputs)

decoder_model = Model([target_input, decoder_state_input_h], [decoder_outputs, state_h])

In [4]:
# infenc - inference encoder model
# infdec - inference decoder model
# train - training model that combines both
# n_input_length - the length of the input and the output
# word_feat_len - the length of the word feature vector
# n_units - size of the hidden memory in the RNN
# model, encoder_model, decoder_model = conv_model(n_input_length, n_input_length, dg2.word_feat_len, 256)
model.compile(optimizer='adam', loss='categorical_crossentropy')


In [5]:
# number of batches for the gofa data generator
batch_size = 128
n_batches = int(len(dg2.words) * .05 / batch_size) 
gen2 = dg2.cnn_gen_data(batch_size=batch_size, n_batches=n_batches)
print("Train size: ", (n_batches * batch_size))

Train size:  8320


In [6]:
# model train given the data generator, how many batches and number of epochs
history = model.fit_generator(gen2, steps_per_epoch=n_batches, epochs = 20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [7]:
# test data generator for gofa words

g_test_batches = int(len(dg2.words) * .02 / batch_size) 
gen3 = dg2.cnn_gen_data(batch_size=batch_size, n_batches=g_test_batches, trainset=False)
print("Train size: ", (g_test_batches * batch_size))

Train size:  3328


In [8]:

# shows sample examples and calculates accuracy

total, correct = 0, 0
in_word = 0
sims = []
for b in range(g_test_batches):
    # get data from test data generator
    [X1, X2, X3], y = next(gen3)
    for j in range(batch_size):
        word_features = X3[j].reshape((1, X3.shape[1])) 
        root_word_matrix = X1[j].reshape((1, X1.shape[1], X1.shape[2], 1))
        
        # predicts the target word given root word and features
        target = predict(encoder_model, decoder_model, root_word_matrix, word_features, n_steps_out, n_input_length)
        root = ''.join(dg2.one_hot_decode(X1[j]))#.replace('&', ' ')
        word = ''.join(dg2.one_hot_decode(y[j]))#.replace('&', ' ')
        targetS = ''.join(dg2.one_hot_decode(target))#.replace('&', ' ')
#         sims.append(dg.word_sim(word, targetS))
        
        # checks if the predicted and the real words are equal'
#         print(len(dg.one_hot_decode(y[j])), len(dg.one_hot_decode(target)))
#         print(len(dg.one_hot_decode(target)[:27]))
#         print(len(y[j]))
        if dg2.one_hot_decode(y[j]) == dg2.one_hot_decode(target)[:27]:
            correct += 1
        else:
            print(root, word.split('&')[0], '\t\t', targetS.split('&')[0])
        if root.strip() in targetS.strip():
            in_word += 1
#     print(b, root, word, targetS)
    total += batch_size
    

print('Exact Accuracy: %.2f%%' % (float(correct)/float(total)*100.0))

piig            piigokkonii 		 piigettennee
tooc            toocettennee 		 toocokkonii
munaqq          munaqqadinaa 		 munaqqadii
zawr            zawradinaa 		 zawradii
shanggoc        shanggocokkonii 		 shanggocettennee
happurssaa      happurssaaoppite 		 happurrsamoppite
higishshiball   higishshiballabiikke 		 hiishishablabaakke
kar             karadinaa 		 karadii
meesh           meeshettennee 		 meeshokkonii
yaayyam         yaayyamettennee 		 yaayyamokkonii
z               zideta 		 zidee
nashsh          nashshanee 		 nashshonee
paah            paahettennee 		 paahokkonii
keten''         keten''okkonii 		 keten''ettennee
shoppatt        shoppattadinaa 		 shoppattadii
wonzz           wonzzadinaa 		 wonzzadii
dummat          dummatokkonii 		 dummatettennee
piikkatt        piikkattii 		 piikkatteetii
waaqiwaaq       waaqiwaaqadinaa 		 waaqiwaaqadii
yaayyam         yaayyamadinaa 		 yaayyamadii
sig             sigettennee 		 sigekkenii
koyroy          koyroyadinaa 		 koyroyadii
hiillat

In [9]:
[97.57, 97.69, 97.54, 97.48, 97.72, 97.36, 97.51]

[97.57, 97.69, 97.54, 97.48, 97.72, 97.36, 97.51]

In [10]:
history.history['loss']

[0.0974451454786154,
 0.057033732246894106,
 0.05549450046741045,
 0.054814370119800934,
 0.05444869072391437,
 0.053862272919370575,
 0.05356800304009364,
 0.053443708557348986,
 0.05340436138212681,
 0.053390522071948415,
 0.04202403320142856,
 0.03813541886898188,
 0.026991238874884752,
 0.019109935481817678,
 0.018226230074651538,
 0.01814174318480162,
 0.018007075392121735,
 0.026805303765174288,
 0.018683376427865227,
 0.01816844432346093]

In [11]:
[0.08649586516504104,
 0.00981702023687271,
 0.0034727409643192705,
 0.002293258377064306,
 0.0017179189742399523,
 0.0013772513368166984,
 0.0011395107838325202,
 0.000961381251601359,
 0.0008240110254309212,
 0.0007225332513021735]

[1.6368019525821393,
 1.109701026402987,
 0.8467150734021114,
 0.6373297223678002,
 0.4980540312253512,
 0.3955386955004472,
 0.30929824755741997,
 0.24970667981184447,
 0.2069551529792639,
 0.17557245813883268,
 0.18030325586979207,
 0.14884408918710856,
 0.12150984463783411,
 0.10766618079864061,
 0.09909856835236916,
 0.08846385880158497,
 0.07828886463091923,
 0.07907314289074678,
 0.06622060502950962,
 0.05161330436284726]
46.63


46.63

In [12]:
"similarity" = 0.7878232042064027

SyntaxError: can't assign to literal (<ipython-input-12-fcfaffb2fed0>, line 1)