In [217]:
import gensim
import os
import numpy as np
import nltk
from keras.models import Sequential
from keras.layers.recurrent import LSTM, SimpleRNN
from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, recurrent, Embedding
from keras.preprocessing.sequence import pad_sequences
import gensim
import json

In [332]:
def read_lines(file_path, separator, name_idx, content_idx, start_line = 0, limit = 10000):
    ln = 1
    prev_name = None
    sentences = []
    for line in open(file_path, 'r', encoding="utf-8"):
        if ln >= limit:
            break
        elif ln < start_line:
            ln += 1
            continue
        ln += 1
        items = line.split(separator)
        name = items[name_idx].lower()
        content = items[content_idx].lower().replace('?', '').replace('!', '').replace( '.', '')
        words = content.split()
        if prev_name != name:
            sentences.append(words)
        else:
            #print(prev_name, "spoke again")
            sentences[-1].extend([" "] + words )
        prev_name = name
    return sentences



sentences = read_lines("movie_lines.txt", " +++$+++ ", 3, 4, 0, 3000)


In [326]:
SENTENCE_LENGTH = 20
VOCAB_SIZE = 5000

In [327]:
#flatten_sentences = [word for sentence in sentences for word in sentence ]



In [328]:
def build_vocab_dict(sentences, padding_char, unknown_char):
    word_freq = nltk.FreqDist(np.hstack(sentences))
    word_freq.pop(padding_char, None)
    word_freq.pop(unknown_char, None)
    vocab = word_freq.most_common(VOCAB_SIZE - 2)
    vocab.insert(0, (padding_char,1))
    vocab.append( (unknown_char,1))
    
    vocab_dict = {pair[0]: id for id, pair in enumerate(vocab)}
    idx_dict = {idx:word for word, idx in vocab_dict.items()}
    return vocab_dict, idx_dict

vocab_dict, idx_dict = build_vocab_dict(vocab, ' ', 'UNK')


In [320]:
#print(vocab_dict, idx_dict)



In [329]:
f = open("new_vocab_dict.json", 'w') 
f.write(json.dumps(idx_dict)) 
f.close() 

In [322]:
def sentence_to_vec(sentences, vocab_dict, unknown_char, sentence_length):
    l = len(sentences)
    vec = []
    unk_idx = vocab_dict[unknown_char]

    for sen in sentences:
        vec.append( [vocab_dict[x] if x in vocab_dict else unk_idx for x in sen ][:20])

    padded = pad_sequences(vec, maxlen=sentence_length, dtype='int32')
    return padded

In [334]:
def to_one_hot(vec, sentence_length, vocab_length):
    print((len(vec), sentence_length, vocab_length))
    res = np.zeros((len(vec), sentence_length, vocab_length))
    for i, sen in enumerate(vec):
        for j, num in enumerate(sen):
            res[i, j, num] = 1
    return res

x_sentences = [sentences[i] for i in range(len(sentences)) if i % 2 == 0]
y_sentences = [sentences[i] for i in range(len(sentences)) if i % 2 == 1]

x_vec = sentence_to_vec(x_sentences, vocab_dict, 'UNK', SENTENCE_LENGTH)
y_vec = sentence_to_vec(y_sentences, vocab_dict, 'UNK', SENTENCE_LENGTH)
y_vec = to_one_hot(y_vec, SENTENCE_LENGTH, VOCAB_SIZE)

print(x_vec.shape, y_vec.shape)

(1316, 20, 5000)
(1316, 20) (1316, 20, 5000)


In [335]:
def create_model(x_vocab_len, x_max_len, y_vocab_len, y_max_len, hidden_size, num_layers):
    model = Sequential()

    # Creating encoder network
    model.add(Embedding(x_vocab_len, 1024, input_length=x_max_len, mask_zero=True))
    model.add(LSTM(hidden_size))
    model.add(RepeatVector(y_max_len))

    # Creating decoder network
    for _ in range(num_layers):
        model.add(LSTM(hidden_size, return_sequences=True))
    model.add(TimeDistributed(Dense(y_vocab_len)))
    model.add(Activation('softmax'))
    model.compile(loss='categorical_crossentropy',
            optimizer='rmsprop',
            metrics=['accuracy'])
    return model
model = create_model(VOCAB_SIZE, SENTENCE_LENGTH, VOCAB_SIZE, SENTENCE_LENGTH, 1024, 3)
#model.summary()

In [None]:
model.fit(x_vec, y_vec, batch_size=100, nb_epoch=10)
model.save_weights("new_chatbot_model_1.h5")



Epoch 1/10
Epoch 2/10
Epoch 3/10

array([[   0,    0,    0, ...,  344,  290,  310],
       [   0,    0,    0, ...,  259,  904,  342],
       [   0,    0,    0, ...,    0,  592,  366],
       ...,
       [   0,    0,    0, ...,    0,    0, 2329],
       [   0,    0,    0, ...,  286,  304, 1040],
       [   0,    0,    0, ...,  304,  388, 1597]])

In [239]:
model.save_weights("new_chatbot_model.h5")

In [350]:
sen = "cat"
sen = sen.lower().replace('?', '').replace('!', '').replace( '.', '')
vec = sentence_to_vec([sen], vocab_dict, 'UNK', SENTENCE_LENGTH)
#print(model.predict(vec).shape)
res = model.predict(vec)

print(res)
#print(res.shape)
vec_y = np.argmax(res, axis=2)
" ".join([idx_dict[x] for x in vec_y[0]])

[[[4.8426989e-01 1.2735189e-06 1.4036017e-06 ... 1.2599357e-06
   1.3708780e-06 3.9419573e-02]
  [8.0347425e-01 1.1827537e-08 1.1309058e-08 ... 1.2676928e-08
   1.2604789e-08 2.5429886e-02]
  [8.2979429e-01 7.6190352e-09 7.2584956e-09 ... 9.7925001e-09
   7.8122815e-09 1.5682936e-02]
  ...
  [5.2858281e-01 2.9685561e-08 2.7241001e-08 ... 2.6880072e-08
   2.9743541e-08 4.1117113e-02]
  [4.8911610e-01 3.4107543e-08 3.0700665e-08 ... 2.9767632e-08
   3.4237967e-08 4.5728393e-02]
  [4.3478626e-01 4.1325073e-08 3.7692200e-08 ... 3.5850061e-08
   4.2872465e-08 5.1098447e-02]]]


'                                       '

In [380]:
sen = "hi"
sen = sen.lower().replace('?', '').replace('!', '').replace( '.', '')
vec = sentence_to_vec([sen], vocab_dict, 'UNK', SENTENCE_LENGTH)
#print(model.predict(vec).shape)
res = model.predict(vec)

print(res)
#print(res.shape)
vec_y = np.argmax(res, axis=2)
" ".join([idx_dict[x] for x in vec_y[0]])

[[[9.86587226e-01 1.87109384e-10 2.15507043e-10 ... 1.47583265e-10
   1.44955561e-10 4.42900870e-04]
  [9.82708693e-01 6.77484513e-10 4.87038354e-10 ... 5.88700810e-10
   5.54170432e-10 1.48002140e-03]
  [9.90999222e-01 1.01016639e-09 9.88095494e-10 ... 9.42031342e-10
   8.18319856e-10 9.42815212e-04]
  ...
  [1.81067497e-01 1.19653748e-07 1.05419694e-07 ... 1.07494287e-07
   1.22014626e-07 5.74353300e-02]
  [1.73538074e-01 1.13325335e-07 9.85676394e-08 ... 9.96040654e-08
   1.14987287e-07 5.81073985e-02]
  [1.75065488e-01 1.18403037e-07 1.02444368e-07 ... 1.02040005e-07
   1.18716791e-07 5.90325072e-02]]]


'                                       '