In [11]:
import pprint
from keras.utils.data_utils import get_file
from keras.utils import np_utils
from keras.preprocessing.text import Tokenizer
from gensim.models import Word2Vec, KeyedVectors
from keras.preprocessing.sequence import skipgrams
from keras.models import Sequential
from keras.layers import Dense


sentences = [line.strip() for line in open('swahili.txt') if line != '\n']

tokenizer = Tokenizer(filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n')
tokenizer.fit_on_texts(sentences)
corpus = tokenizer.texts_to_sequences(sentences)

V = len(tokenizer.word_index) + 1
dim = 100
window_size = 2


model = Sequential()
model.add(Dense(input_dim=V, output_dim=dim))
model.add(Dense(input_dim=dim, output_dim=V, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
model.summary()


def generate_data(corpus, window_size, V):
    for words in corpus:
        couples, labels = skipgrams(words, V, window_size, negative_samples=0, shuffle=True)
        if couples:
            X, y = zip(*couples)
            X = np_utils.to_categorical(X, V)
            y = np_utils.to_categorical(y, V)
            yield X, y

for epoch in range(500):
    loss = 0.
    for x, y in generate_data(corpus, window_size, V):
        loss += model.train_on_batch(x, y)

    print(epoch, loss)


with open('vectors.txt', 'w') as f:
    f.write(' '.join([str(V-1), str(dim)]))
    f.write('\n')
    vectors = model.get_weights()[0]
    for word, i in tokenizer.word_index.items():
        f.write(word)
        f.write(' ')
        f.write(' '.join(map(str, list(vectors[i, :]))))
        f.write('\n')





_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_9 (Dense)              (None, 100)               82400     
_________________________________________________________________
dense_10 (Dense)             (None, 823)               83123     
Total params: 165,523
Trainable params: 165,523
Non-trainable params: 0
_________________________________________________________________
0 632.0674047470093
1 590.9814305305481
2 551.9261205196381
3 541.5976960659027
4 536.2399816513062
5 531.6352005004883
6 527.142338514328
7 522.4998347759247
8 517.5416448116302
9 512.226155757904
10 506.63676595687866
11 500.93324756622314
12 495.2650628089905
13 489.71039724349976
14 484.294424533844
15 479.0265107154846
16 473.9202483892441
17 468.97234189510345
18 464.17397379875183
19 459.50794184207916
20 454.9520412683487
21 450.48729407787323
22 446.0968806743622
23 441.7691322565079
24 437.50051176548004
25 433.29065

345 260.0315423011782
346 259.9880594015124
347 259.92031532526033
348 259.88089036941545
349 259.81532633304613
350 259.78708672523516
351 259.7170929908754
352 259.68649864196794
353 259.622743368149
354 259.5873366594316
355 259.5285778045656
356 259.5066117048265
357 259.4405205249788
358 259.41292393207567
359 259.3601538538934
360 259.3373239636423
361 259.2788910865785
362 259.24954223632824
363 259.20510029792797
364 259.17308151721966
365 259.1234946846963
366 259.0986908078195
367 259.05837363004696
368 259.03116327524197
369 258.9826759696008
370 258.9679327607156
371 258.92548364400875
372 258.91143876314175
373 258.85414212942135
374 258.8461007475854
375 258.80029302835476
376 258.7926230132581
377 258.73920130729687
378 258.73333948850643
379 258.68355745077145
380 258.67899385094654
381 258.6314989626409
382 258.61696735024464
383 258.5712977945806
384 258.5745568573476
385 258.5159317255021
386 258.5085132122041
387 258.46496880054485
388 258.4532687664033
389 258.4129

In [16]:
w2v = KeyedVectors.load_word2vec_format('./vectors.txt', binary=False)
pprint.pprint(w2v.most_similar(positive=['mikate']))



[('mihogo', 0.8438014388084412),
 ('alitembeza', 0.67585289478302),
 ('walipoteza', 0.561606764793396),
 ('biashara', 0.5053472518920898),
 ('zote', 0.4827842712402344),
 ('kusukuma', 0.4470747411251068),
 ('kuchemshwa', 0.3926873207092285),
 ('ajali', 0.372442364692688),
 ('walibakia', 0.3716486692428589),
 ('nyingine', 0.3694137930870056)]


  if np.issubdtype(vec.dtype, np.int):
